/***************************************************************************
 *                                                                         *
 *                  (begin: Feb 20 2003)                                   *
 *                                                                         *
 *   Parallel IQPNNI - Important Quartet Puzzle with NNI                   *
 *                                                                         *
 *   Copyright (C) 2005 by Le Sy Vinh, Bui Quang Minh, Arndt von Haeseler  *
 *   Copyright (C) 2003-2004 by Le Sy Vinh, Arndt von Haeseler             *
 *   {vinh,minh}@cs.uni-duesseldorf.de                                     *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

// innd.cpp: implementation of the inNd class.
//
//////////////////////////////////////////////////////////////////////

#include <iostream>
#include <string.h>

#include "outstream.h"
#include "lind.h"
#include "opturtree.h"
#include "model.h"
#include "ptnls.h"
#include "rate.h"
#include "ptnratecube.h"
#include "interface.h"

#ifdef _OPENMP
#include <omp.h>
#endif

const double SCALING_THRESHOLD = 1e-150;

//#define TEST_NUMBER_DIFF

#ifdef TEST_NUMBER_DIFF
double min_num_diff = 1.0;
#endif

// scaling of each pattern
double *ptn_logl = NULL;

void createPtnLiInfo() {
	int nPtn_ = ptnlist.getNPtn();
	if (!ptn_logl) 
		ptn_logl = new double[nPtn_]; 
	memset(ptn_logl, 0, sizeof(double) * nPtn_);
}

void deletePtnLiInfo() {
	if (ptn_logl) delete ptn_logl;
	ptn_logl = NULL;
}


//--------------------------------------------------------------------
//the constructor function
LiNd::LiNd () {
	isCmped_ = 0;
	isEx_ = -1;

	int nPtn_ = ptnlist.getNPtn ();

	int nRate_ = myrate.getNRate ();
	int nState_ = mymodel.getNState ();
	Utl::getMem (liNdCube_, nRate_ * nState_ * nPtn_);
//	Utl::getMem(liscale, nRate_, nPtn_);
}


//--------------------------------------------------------------------
//set the id, ndType, parent node, and branch no for this node
void LiNd::set (int id, EX_IN ndType, int parNdNo, int parBrNo) {
	id_ = id;
	if (ndType == EX)
		isEx_ = 1;
	else
		isEx_ = 0;

	parNdNo_ = parNdNo;
	parBrNo_ = parBrNo;

	isExDesNdCreated_ = 0;
	isCmped_ = 0;
}


//--------------------------------------------------------------------
void LiNd::turnOffIsCmped () {
	isCmped_ = 0;
}

//--------------------------------------------------------------------
int LiNd::getId () {
	return id_;
}

Vec<int> &LiNd::getExDesNd () {
	return exDesNdNoLs_;
}

//--------------------------------------------------------------------
//the the parent node, and branch no for this node
void LiNd::getPar (int &parNdNo, int &parBrNo) {
	parNdNo = parNdNo_;
	parBrNo = parBrNo_;
}

void LiNd::setParNdBr (int newParNdNo, int newParBrNo) {
	parNdNo_ = newParNdNo;
	parBrNo_ = newParBrNo;
}

void LiNd::writeInf () {
	std::cout << this << " / " << id_ << " / " << parNdNo_ << " / " << parBrNo_ << " / " << endl;
}


//-------------------------------------------------------------------
//reCompute the likelihood for all ptnNo, rateNo, stateNo at this node
void LiNd::reCmpLi () {
	isCmped_ = 0;
	cmpLi ();
}

DMat20 proChangeMat1_, proChangeMat2_;
DMat20 *proChangeMat1_OMP = NULL;
DMat20 *proChangeMat2_OMP = NULL;


void treatAmbiguousDNA(int ndStateNo_, PVec &lipart) {
	switch (ndStateNo_) {
	case BASE_AC:
		lipart[BASE_G] = 0.0;
		lipart[BASE_T] = 0.0;
		break;
	case BASE_AG:
		lipart[BASE_C] = 0.0;
		lipart[BASE_T] = 0.0;
		break;
	case BASE_AT:
		lipart[BASE_C] = 0.0;
		lipart[BASE_G] = 0.0;
		break;
	case BASE_CG:
		lipart[BASE_A] = 0.0;
		lipart[BASE_T] = 0.0;
		break;
	case BASE_CT:
		lipart[BASE_A] = 0.0;
		lipart[BASE_G] = 0.0;
		break;
	case BASE_GT:
		lipart[BASE_A] = 0.0;
		lipart[BASE_C] = 0.0;
		break;
	case BASE_CGT:
		lipart[BASE_A] = 0.0;
		break;
	case BASE_AGT:
		lipart[BASE_C] = 0.0;
		break;
	case BASE_ACT:
		lipart[BASE_G] = 0.0;
		break;
	case BASE_ACG:
		lipart[BASE_T] = 0.0;
		break;
	} // end of switch
}

//-------------------------------------------------------------------
//compute the likelihood for all ptnNo, rateNo, stateNo at this node
void LiNd::cmpLiUniformRate () {
	if (isCmped_ == 1)
		return;

	int nState_ = mymodel.getNState ();
	int nPtn_ = ptnlist.getNPtn ();


	bool site_spec = isSiteSpec();

#ifdef _OPENMP
	if (omp_threads > 1 && !proChangeMat1_OMP && site_spec) {
		proChangeMat1_OMP = new DMat20[omp_threads];
		proChangeMat2_OMP = new DMat20[omp_threads];
	}
#endif


	if (isEx_ == 1) { //the case of external case

		liscale = 0.0;

#ifdef _OPENMP
		#pragma omp parallel for  
#endif
		for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
			if (site_spec && atBound(MIN_PTN_RATE, myrate.getPtnRate(ptnNo_), MAX_PTN_RATE))
				continue;

			// get the starting address of the li-cube
			PVec lipart = liNdCube_ + ptnNo_ * nState_;
			int ndStateNo_ = ptnlist.getBase(ptnNo_, id_);

			if (ndStateNo_ < nState_) {
				memset(lipart, 0, sizeof(LDOUBLE) * nState_);
				lipart[ndStateNo_] = 1.0;
			} else {
				for (int stateNo_ = 0; stateNo_ < nState_; stateNo_ ++)
					lipart[stateNo_] = 1.0;
				if (ndStateNo_ != BS_UNKNOWN && mymodel.getDataType () == NUCLEOTIDE)
					treatAmbiguousDNA(ndStateNo_, lipart);
			} //end of else
		} //end of for

	} else { //the case of internal node

		//first compute the likelihood for all childrent
		int chiNdNo1_, chiBrNo1_, chiNdNo2_, chiBrNo2_;
		opt_urtree.inNdArr_[id_].get2RemNeiNdBr (parNdNo_, chiNdNo1_, chiBrNo1_, chiNdNo2_, chiBrNo2_);
		double chiLen1_ = opt_urtree.getBrLen (chiBrNo1_);
		double chiLen2_ = opt_urtree.getBrLen (chiBrNo2_);

		int nState_ = mymodel.getNState ();

		LiNd *chiLi1_;
		chiLi1_ = &opt_urtree.liBrArr_[chiBrNo1_].getLiNd (chiNdNo1_);
		if (!chiLi1_->isCmped_) chiLi1_->cmpLiUniformRate ();

		LiNd *chiLi2_;
		chiLi2_ = &opt_urtree.liBrArr_[chiBrNo2_].getLiNd (chiNdNo2_);
		if (!chiLi2_->isCmped_) chiLi2_->cmpLiUniformRate ();

		if (!site_spec) {
			mymodel.cmpProbChange (chiLen1_, proChangeMat1_);
			mymodel.cmpProbChange (chiLen2_, proChangeMat2_);
		} 


		// initialize liscale = sum of childrens liscale
		double scaling = 0.0;

#ifdef _OPENMP
		#pragma omp parallel for  reduction(+: scaling) 
#endif
		for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
			if (site_spec && atBound(MIN_PTN_RATE, myrate.getPtnRate(ptnNo_), MAX_PTN_RATE))
				continue;

			DMat20 *proChangeMat1Ref = &proChangeMat1_;
			DMat20 *proChangeMat2Ref = &proChangeMat2_;

#ifdef _OPENMP
			if (site_spec && omp_threads > 1) {
				int thread_num = omp_get_thread_num();
				proChangeMat1Ref = &proChangeMat1_OMP[thread_num];
				proChangeMat2Ref = &proChangeMat2_OMP[thread_num];
			} 
#endif

			if (site_spec) {
				double ptnRate_ = myrate.getPtnRate(ptnNo_);
				mymodel.cmpProbChange (chiLen1_ * ptnRate_, *proChangeMat1Ref);
				mymodel.cmpProbChange (chiLen2_ * ptnRate_, *proChangeMat2Ref);
			}

			int stateNo_;
			bool should_scale = false;
			int start_addr = ptnNo_ * nState_;
			// get the starting address of the li-cube
			PVec lipart = liNdCube_ + start_addr;
			PVec chi1_lipart = chiLi1_->liNdCube_ + start_addr;
			PVec chi2_lipart = chiLi2_->liNdCube_ + start_addr;

			int exState1 = BS_UNKNOWN + 1;
			int exState2 = BS_UNKNOWN + 1;
			if (chiLi1_->isEx_) exState1 = ptnlist.getBase(ptnNo_, chiLi1_->id_);
			if (chiLi2_->isEx_) exState2 = ptnlist.getBase(ptnNo_, chiLi2_->id_);


			for (stateNo_ = 0; stateNo_ < nState_; stateNo_ ++) {

				int chiStateNo_;
				LDOUBLE chiLiBr1_ = 0.0;
				LDOUBLE chiLiBr2_ = 0.0;

				if (exState1 < nState_) 
					chiLiBr1_ = (*proChangeMat1Ref)[stateNo_][exState1];
				else if (exState1 == BS_UNKNOWN)
					chiLiBr1_ = 1.0;
				else 
				for (chiStateNo_ = 0; chiStateNo_ < nState_; chiStateNo_ ++) {
					chiLiBr1_ += chi1_lipart[chiStateNo_] *
						(*proChangeMat1Ref)[stateNo_][chiStateNo_];
				}

				if (exState2 < nState_) 
					chiLiBr2_ = (*proChangeMat2Ref)[stateNo_][exState2];
				else if (exState2 == BS_UNKNOWN)
					chiLiBr2_ = 1.0;
				else 
				for (chiStateNo_ = 0; chiStateNo_ < nState_; chiStateNo_ ++) {
					chiLiBr2_ += chi2_lipart[chiStateNo_] *
						(*proChangeMat2Ref)[stateNo_][chiStateNo_];
				}
				lipart[stateNo_] = chiLiBr1_ * chiLiBr2_;

				if (lipart[stateNo_] < SCALING_THRESHOLD && lipart[stateNo_] > 0.0)
					should_scale = true;

			}//end of for stateNo_

			if (!should_scale) continue;
			/* now do the scaling */

			LDOUBLE limax = lipart[0];
			for (stateNo_ = 1; stateNo_ < nState_; stateNo_ ++) 
				if (limax < lipart[stateNo_]) 
					limax = lipart[stateNo_];
			
			for (stateNo_ = 0; stateNo_ < nState_; stateNo_ ++) 
				lipart[stateNo_] /= limax;

			// sum over pattern counts
			double ptn_scaling = log(limax);
			if (ptn_logl) ptn_logl[ptnNo_] += ptn_scaling;
			scaling +=  ptn_scaling * ptnlist.weightArr_[ptnNo_];

		} //end of for ptnNo_
		liscale = chiLi1_->liscale + chiLi2_->liscale + scaling;

	} // end of else

	isCmped_ = 1;
}

//-----------------------------------------------------------
//-----------------------------------------------------------

DMat20 proChangeCube1_[MAX_NUM_RATE];
DMat20 proChangeCube2_[MAX_NUM_RATE];

//-------------------------------------------------------------------
//compute the likelihood for all ptnNo, rateNo, stateNo at this node
void LiNd::cmpLiGammaRate () {
	//we do not need to reCmp again
	if (isCmped_ == 1)
		return;


	opt_urtree.nCmpLi_ ++;
	//then, compute the likelihood of this internal node
	int nRate_ = myrate.getNRate ();
	int nState_ = mymodel.getNState ();
	int rate_state = nRate_ * nState_;
	int nPtn_ = ptnlist.getNPtn ();

	if (isEx_ == 1) { //the case of external case
		
		liscale = 0.0;

#ifdef _OPENMP
		#pragma omp parallel for  
#endif
		for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
			int ndStateNo_ = ptnlist.getBase(ptnNo_, id_);

			int start_addr = ptnNo_ * rate_state;

			if (ndStateNo_ < nState_) {
				memset(liNdCube_ + start_addr, 0, sizeof(LDOUBLE) * rate_state);
				for (int rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++)
					liNdCube_[start_addr + rateNo_ * nState_ + ndStateNo_] = 1.0;
			} else
				for (int rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
					PVec lipart = liNdCube_ + (start_addr + rateNo_ * nState_);

					for (int stateNo_ = 0; stateNo_ < nState_; stateNo_ ++)
						lipart[stateNo_] = 1.0;

					if (ndStateNo_ != BS_UNKNOWN && mymodel.getDataType () == NUCLEOTIDE)
						treatAmbiguousDNA(ndStateNo_, lipart);
				} //end of for rateNo
		}
	} else { //the case of internal node

		//first compute the likelihood for all childrent
		int chiNdNo1_, chiBrNo1_, chiNdNo2_, chiBrNo2_;

		opt_urtree.inNdArr_[id_].get2RemNeiNdBr (parNdNo_, chiNdNo1_, chiBrNo1_, chiNdNo2_, chiBrNo2_);
		double chiLen1_ = opt_urtree.getBrLen (chiBrNo1_);
		double chiLen2_ = opt_urtree.getBrLen (chiBrNo2_);

		LiNd *chiLi1_;
		chiLi1_ = &opt_urtree.liBrArr_[chiBrNo1_].getLiNd (chiNdNo1_);
		if (!chiLi1_->isCmped_) chiLi1_->cmpLiGammaRate ();

		LiNd *chiLi2_;
		chiLi2_ = &opt_urtree.liBrArr_[chiBrNo2_].getLiNd (chiNdNo2_);
		if (!chiLi2_->isCmped_) chiLi2_->cmpLiGammaRate ();

		int rateNo_;
		for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
			if (myrate.isNsSyHeterogenous()) {
				mymodel.cmpProbChange (chiLen1_, rateNo_, proChangeCube1_[rateNo_]);
				mymodel.cmpProbChange (chiLen2_, rateNo_, proChangeCube2_[rateNo_]);
			} else {
				double rate_ = myrate.getRate (rateNo_);
				mymodel.cmpProbChange (chiLen1_ * rate_, proChangeCube1_[rateNo_]);
				mymodel.cmpProbChange (chiLen2_ * rate_, proChangeCube2_[rateNo_]);
			}
		}
		
		// initialize liscale = sum of childrens liscale
		double scaling = 0.0;

#ifdef _OPENMP
		#pragma omp parallel for private(rateNo_) reduction(+: scaling)
#endif
		for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {

			int start_addr = ptnNo_ * rate_state;
			bool should_scale = false;

			for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
				int stateNo_;
				int local_addr = start_addr + rateNo_ * nState_;
				PVec lipart = liNdCube_ + local_addr;
				PVec chi1_lipart = chiLi1_->liNdCube_ + local_addr;
				PVec chi2_lipart = chiLi2_->liNdCube_ + local_addr;

				for (stateNo_ = 0; stateNo_ < nState_; stateNo_ ++) {

					int chiStateNo_;
					LDOUBLE chiLiBr1_ = 0.0;
					LDOUBLE chiLiBr2_ = 0.0;
					/*if (myrate.use_invar_site && rateNo_ == nRate_ - 1) {
						chiLiBr1_ += chi1_lipart[stateNo_] *
						             proChangeCube1_[rateNo_][stateNo_][stateNo_];
						chiLiBr2_ += chi2_lipart[stateNo_] * 
							proChangeCube2_[rateNo_][stateNo_][stateNo_];
					} else*/
					for (chiStateNo_ = 0; chiStateNo_ < nState_; chiStateNo_ ++) {
						chiLiBr1_ += chi1_lipart[chiStateNo_] *
						             proChangeCube1_[rateNo_][stateNo_][chiStateNo_];
						chiLiBr2_ += chi2_lipart[chiStateNo_] * 
							proChangeCube2_[rateNo_][stateNo_][chiStateNo_];
					}

					lipart[stateNo_] = chiLiBr1_ * chiLiBr2_;
					if (lipart[stateNo_] < SCALING_THRESHOLD && lipart[stateNo_] > 0.0) 
						should_scale = true;

				}//end of for stateNo_

			} // end of for rateNo_

			if (!should_scale) continue;
			/* now do the scaling */

			PVec lipart = liNdCube_ + start_addr;
			int loopit;
			LDOUBLE limax = lipart[0];
			for (loopit = 1; loopit < rate_state; loopit ++) 
				if (limax < lipart[loopit]) 
					limax = lipart[loopit];
			
			for (loopit = 0; loopit < rate_state; loopit ++) 
				lipart[loopit] /= limax;

			double ptn_scaling = log(limax);
			if (ptn_logl) ptn_logl[ptnNo_] += ptn_scaling;
			scaling += ptn_scaling * ptnlist.weightArr_[ptnNo_];

		} //end of for ptnNo_
		liscale = chiLi1_->liscale + chiLi2_->liscale + scaling;


	} // end of else

	isCmped_ = 1;
}

inline bool LiNd::isSiteSpec() {
	return (myrate.getType () == SITE_SPECIFIC && myrate.isOptedSpecificRate ());
}


//compute the likelihood for all ptnNo, rateNo, baseNo at this node
void LiNd::cmpLi () {
	RATE_TYPE rateType_ = myrate.getType ();
	if (myrate.isNsSyHeterogenous() || rateType_ == GAMMA)
		cmpLiGammaRate();
	else
		cmpLiUniformRate ();
}

//--------------------------------------------------------------------
//create the external descendant nodes no list of this node
void LiNd::createExDesNd (Vec<int> &exDesNdNoLs) {
	if (isExDesNdCreated_ == 0) {
		if (isEx_ == 1) {
			exDesNdNoLs_.set (1, 0);
			exDesNdNoLs_ += id_;
		} else {
			int chiBrNo1_, chiBrNo2_;
			opt_urtree.inNdArr_[id_].get2RemBr (parBrNo_, chiBrNo1_, chiBrNo2_);

			exDesNdNoLs_.set (alignment.getNSeq(), 0);

			Vec<int> chiExDesNdNoLs1_;
			opt_urtree.liBrArr_[chiBrNo1_].getRemLiNd (id_).createExDesNd (chiExDesNdNoLs1_);
			exDesNdNoLs_ += chiExDesNdNoLs1_;

			Vec<int> chiExDesNdNoLs2_;
			opt_urtree.liBrArr_[chiBrNo2_].getRemLiNd (id_).createExDesNd (chiExDesNdNoLs2_);
			exDesNdNoLs_ += chiExDesNdNoLs2_;
		}
	}


	exDesNdNoLs = exDesNdNoLs_;
	isExDesNdCreated_ = 1;
}

//--------------------------------------------------------------------
//create the external descendant nodes no list of this node
void LiNd::createExDesNd () {
	createExDesNd (exDesNdNoLs_);
}

//--------------------------------------------------------------------
//set memory for this class
void LiNd::setLimit () {}

//--------------------------------------------------------------------
//release the memory of this class
void LiNd::release () {
	exDesNdNoLs_.release ();
	Utl::delMem (liNdCube_);
}

//-------------------------------------------------------------------
//the destructor function
LiNd::~LiNd () {
	release ();
}
