/***************************************************************************
 *                                                                         *
 *                  (begin: Feb 20 2003)                                   *
 *                                                                         *
 *   Parallel IQPNNI - Important Quartet Puzzle with NNI                   *
 *                                                                         *
 *   Copyright (C) 2005 by Le Sy Vinh, Bui Quang Minh, Arndt von Haeseler  *
 *   Copyright (C) 2003-2004 by Le Sy Vinh, Arndt von Haeseler             *
 *   {vinh,minh}@cs.uni-duesseldorf.de                                     *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

/***************************************************************************
                          model.cpp  -  description
                             -------------------
    begin                : Wed Mar 5 2003
    copyright            : (C) 2003 by 
    email                : vinh@cs.uni-duesseldorf.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include <iostream>
#include "model.h"
#include "mat.h"
#include "constant.h"
#include "ali.h"
#include "dmat20.h"
#include "rate.h"

Model mymodel;

extern int isContinuous;


//the constructor function of this class
Model::Model () {
	model_ = NULL;
	initialized_tsTv = false;
}


//====================================
void Model::setDataType (DATA_TYPE dataType) {
	dataType_ = dataType;
}

//====================================
DATA_TYPE Model::getDataType () {
	return dataType_;
}

bool Model::isCodonAnalysisNext() {
	return codon_model != UNDEF_MODEL;
}

//====================================
void Model::setStateFrq (double stateFrqA, double stateFrqC, double stateFrqG, double stateFrqT) {
	stateFrqArr_[BASE_A] = stateFrqA;
	stateFrqArr_[BASE_C] = stateFrqC;
	stateFrqArr_[BASE_G] = stateFrqG;
	stateFrqArr_[BASE_T] = stateFrqT;
}

void Model::setBaseFrqType (PAM_TYPE baseFrqType) {
	baseFrqType_ = baseFrqType;
}

//============================================
int Model::getNState () {
	return nState_;
}

//============================================
MODEL Model::getType () {
	return modelType_;
}

PAM_TYPE Model::getTsTvRatioType () {
	return tsTvRatioType_;
}

PAM_TYPE Model::getPyPuRatioType () {
	return pyPuRatioType_;
}

PAM_TYPE Model::getGenPamType () {
	return genPamType_;
}

PAM_TYPE Model::getBaseFrqType () {
	return baseFrqType_;
}

void Model::setNState (int nState) {
	nState_ = nState;
}

/**
	get the propotion of a class
*/
//inline double Model::getClassProb(int classNo) {
//}

/**
	get the ratio parm of a class
*/
double Model::getClassRatio(int classNo) {
	return ((CodonNY98*)model_) -> nsSyRatioVec[classNo];
}

//set the model of evolutional change
void Model::setModelType (const MODEL modelType) {
	modelType_ = modelType;

	if (modelType_ == HKY85 || modelType_ == TN93 || modelType_ == GTR) {
		if (modelType_ == HKY85)
			model_ = new HKY85M;
		else
			if (modelType_ == TN93)
				model_ = new TN93M;
			else
				model_ = new GTRM;

		nState_ = NUM_BASE;
	} else if (isCodonModel(modelType)) {
		nState_ = NUM_CODON;
		switch (modelType) {
		case Codon_GY94: model_ = new CodonModel; break;
		case Codon_YN98: model_ = new CodonYN98; break;
		case Codon_Pedersen98: model_ = new CodonPedersen98; break;
		case Codon_GTR: model_ = new CodonGTR; break;
		case Codon_PosRate: model_ = new CodonPosRate; break;
		case Codon_NY98:
			model_ = new CodonNY98(myrate.nsSy_classes, myrate.nsSy_ratio_type, 
				myrate.nsSy_ratio_val, myrate.nsSy_prob_val);
			break;
		default: break;
		}
	} else {
		model_ = new AminoM;
		nState_ = NUM_AMINO;
	}

	model_->setModel (modelType_);
}

//----------------------------------------------------------------------------------------------------------------------------------------
int Model::isEstedPam () {
	return ( (modelType_ == HKY85 && tsTvRatioType_ == ESTIMATE) ||
	         (modelType_ == TN93 && tsTvRatioType_ == ESTIMATE) ||
	         (modelType_ == TN93 && pyPuRatioType_ == ESTIMATE) ||
	         (modelType_ == GTR &&  genPamType_ == ESTIMATE) ||
	         (isCodonModel(modelType_) ) );

}


//----------------------------------------------------------------------------------------------------------------------------------------
//set the tsTvRatio
void Model::setTsTvRatio (const double tsTvRatio) {
	tsTvRatio_ = tsTvRatio;
}

//----------------------------------------------------------------------------------------------------------------------------------------
//set the tsTvRatio
void Model::setTsTvRatioType (const PAM_TYPE tsTvRatioType) {
	tsTvRatioType_ = tsTvRatioType;
}

//----------------------------------------------------------------------------------------------------------------------------------------
//set the pyPuRatio
void Model::setPyPuRatio (const double pyPuRatio) {
	pyPuRatio_ = pyPuRatio;
}

/*
void Model::setVParm (const double vp) {
	v_parm = vp;
}
*/

//----------------------------------------------------------------------------------------------------------------------------------------
//set the pyPuRatioType
void Model::setPyPuRatioType (const PAM_TYPE pyPuRatioType) {
	pyPuRatioType_ = pyPuRatioType;
}



//----------------------------------------------------------------------------------------------------------------------------------------
void Model::setGenPam (const double tsAG, const double tsCT,
                       const double tvAC, const double tvAT, const double tvCG, const double tvGT) {
	tsAG_  = tsAG;
	tsCT_  = tsCT;
	tvAC_  = tvAC;
	tvAT_  = tvAT;
	tvCG_  = tvCG;
	tvGT_  = tvGT;
}

//----------------------------------------------------------------------------------------------------------------------------------------
void Model::setGenPamType (const PAM_TYPE genPamType) {
	genPamType_ = genPamType;
}


//--------------------------------------------------------------------

//init all necessary things for using this class
void Model::init () {
	if (dataType_ == NUCLEOTIDE) {
		if (baseFrqType_ == ESTIMATE)
			alignment.estimateStateFrq (stateFrqArr_);

		model_->setStateFrq (stateFrqArr_);

		if (isContinuous == 0) 
			if (tsTvRatioType_ == ESTIMATE)
				tsTvRatio_ = 4.0;
		model_->setTsTvRatio(tsTvRatio_);

		if (isContinuous == 0) 
			if (pyPuRatioType_ == ESTIMATE)
				pyPuRatio_ = 1.0;
		model_->setPyPuRatio (pyPuRatio_);

		if (isContinuous == 0)
			if (genPamType_ == ESTIMATE && modelType_ == GTR)
				alignment.estimateGenPam (tsAG_, tsCT_, tvAC_, tvAT_, tvCG_, tvGT_);

		model_->setGenPam (tsAG_, tsCT_, tvAC_, tvAT_, tvCG_, tvGT_);

	} else if (dataType_ == CODON) {
		alignment.estimateCodonFrq(stateFrqArr_);
		model_->setStateFrq(stateFrqArr_);
		//cout << "tsTv = " << tsTvRatio_ << endl;
		if (isContinuous == 0)
			if (getTsTvRatioType() == ESTIMATE && tsTvRatio_ < 0)
				tsTvRatio_ = 4.0;
		model_->setTsTvRatio(tsTvRatio_);
		//v_parm = 50.0;
		//model_->setVParm(v_parm);

	}  // end of if dataType_


	model_->init ();

	if (dataType_ == AMINO_ACID)
		model_->getStateFrq (stateFrqArr_);
}

//================================================================
double Model::getTsTvRatio () {
	return model_->getTsTvRatio ();
}

//================================================================
double Model::getPyPuRatio () {
	return model_->getPyPuRatio ();
}

//================================================================
void Model::getGenPam (double &tsAG, double & tsCT,
                       double & tvAC, double & tvAT, double & tvCG, double & tvGT) {

	tsAG  = tsAG_;
	tsCT  = tsCT_;
	tvAC  = tvAC_;
	tvAT  = tvAT_;
	tvCG  = tvCG_;
	tvGT  = tvGT_;
}

//--------------------------------------------------------------------
/*optimize all parameter evolving to sequence evolution
1. tsTvRatio
3. general model parameters
*/

int Model::optPam () {
	if (modelType_ == TN93)
		return model_->optPam(tsTvRatioType_, pyPuRatioType_);
	else if (isEstedPam())
		return model_->optPam();
	else
		return 0;
}


//--------------------------------------------------------------------
//compute the probability mat
void Model::cmpProbChange (const double brLen, DMat20 &probMat) {
	model_->cmpProbChange (brLen, probMat);
}

//compute the probability mat, for codon-based Nielsen Yang 98 Model
void Model::cmpProbChange (const double brLen, const int classNo, DMat20 &probMat) {
	model_->cmpProbChange (brLen, classNo, probMat);
}

//--------------------------------------------------------------------
/*compute the probability of changing from stateNo1 into stateNo2
after a period of brLen / subRate_
*/

double Model::cmpProbChange (const int stateNo1, const int stateNo2, const double brLen) {
	return model_->cmpProbChange (stateNo1, stateNo2, brLen);
}


void Model::cmpProbChangeDerivatives (const double brLen, DMat20 &probMat, DMat20 &derv1, DMat20 &derv2) {
	model_->cmpProbChangeDerivatives (brLen, probMat, derv1, derv2);
}

void Model::cmpProbChangeDerivatives (const double brLen, int classNo, DMat20 &probMat,
                                      DMat20 &derv1, DMat20 &derv2) {
	model_->cmpProbChangeDerivatives (brLen, classNo, probMat, derv1, derv2);
}



//--------------------------------------------------------------------
//return the state frequence
void Model::getStateFrq (DVec20 &stateFrqArr) {
	for (int count_ = 0; count_ < nState_; count_ ++)
		stateFrqArr[count_] = stateFrqArr_[count_];
}

//--------------------------------------------------------------------
//return the type of model

MODEL Model::getModelType () {
	return modelType_;
}



//--------------------------------------------------------------------
//return the state frequence
double Model::getStateFrq (const int stateNo) {
	return stateFrqArr_[stateNo];
}

/**
	set the iteration step
*/
void Model::setPamStep(int step) {
	nPamStep = step;
}

/**
	get the iteration step
*/
int Model::getPamStep() {
	return nPamStep;
}


//--------------------------------------------------------------------
//release the memory of this class

void Model::release () {
	if (model_ != NULL) {
		delete model_;
		model_ = NULL;
	}
}

//--------------------------------------------------------------------
//the destruction function of this class

Model::~Model () {
	//	std::cout << "This is destructor of model class " << endl;
	release();
}
