/* Copyright (C) 2008 Xavier Pujol.

This file is part of the fplll Library.

The fplll Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The fplll Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the fplll Library; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */

#include "solver.h"
#include "enumerate.h"
#include "topenum.h"
#include "tools.h"

static int checkRDiag(FloatVect& rdiag) {
  Float ratio, maxRatio;
  maxRatio.set(2.0);

  if (rdiag.size() <= 5) return 0;
  for (unsigned int i = 0; i < rdiag.size() - 1; i++) {
    ratio.div(rdiag[i], rdiag[i + 1]);
    if (ratio >= maxRatio) {
      cerr << "Warning: matrix is not LLL-reduced or vectors are not linearly independant" << endl
           << "Warning: enumeration might be longer than expected" << endl;
      return 1;
    }
  }
  return 0;
}

void Solver::solveCVP(const IntMatrix& intMatrix,
  const IntVect& intTarget, IntVect& solCoord) {

  int d = intMatrix.GetNumRows(), n = intMatrix.GetNumCols(), prec = 53;
  double minprec, ulp;

  STD_CHECK(d <= n, "number of vectors > size of the vectors");

  // Sets the floating-point precision
  minprec = 2.0 + (log((double) d) - log(SLV_EPS) +
            d * log(SLV_RHO)) / log(2.0);
  if (precision == AUTOMATIC_PRECISION)
    prec = max(prec, (int) floor(minprec) + 1 + 10);
  else if (precision != DEF_PRECISION)
    prec = max(prec, precision);
  Float::setprec(static_cast<int>(prec));
  ulp = pow(2.0, -prec);

  INFO("precision=" << prec << " minprec=" << minprec);
  STD_CHECK(evaluatorType == evalFast || prec >= minprec,
          "(Solver) Not enough precision, try with precision>="
          << ceil(minprec));

  // Allocates space for vectors and matrices in constructors
  IntVect intNewTarget;
  FloatMatrix floatMatrix(d, n), mu(d, d);
  FloatVect rdiag, target, targetCoord;
  Float basisMin, maxDist, invNormFactor, normFactor;
  Integer itmp1;
  Float ftmp1;

  // Computes the Gram-Shmidt orthogonalization in floating-point
  for (int i = 0; i < d; i++)
    for (int j = 0; j < n; j++)
      floatMatrix(i, j).set_z(intMatrix(i, j));
  gramSchmidt(floatMatrix, mu, rdiag, basisMin);
  checkRDiag(rdiag);
  zeroVect(solCoord, d);

  if (intTarget.empty()) {
    // Computes a bound for the enumeration
    // disabled - no error analysis yet
    /* minkowskiTheorem(maxDist, rdiag);
    if (basisMin < maxDist) maxDist.set(basisMin); */
    ftmp1.set(1.0001);
    maxDist.mul(basisMin, ftmp1);
  }
  else {
    // Applies Babai's algorithm
    FloatVect babaiSol;
    intNewTarget = intTarget;
    target.resize(n);

    for (int loopIdx = 0;; loopIdx++) {
      if (loopIdx >= 0x100 && ((loopIdx & (loopIdx - 1)) == 0))
        cerr << "Warning: possible infinite loop in Babai's algorithm" << endl;

      for (int i = 0; i < n; i++) {
        target[i].set_z(intNewTarget[i]);
      }
      babai(floatMatrix, mu, rdiag, target, babaiSol);
      int idx;
      for (idx = 0; idx < d && babaiSol[idx] >= -1 && babaiSol[idx] <= 1; idx++);
      if (idx == d) break;

      for (int i = 0; i < d; i++) {
        itmp1.set_f(babaiSol[i]);
        solCoord[i].add(solCoord[i], itmp1);
        for (int j = 0; j < n; j++)
          intNewTarget[j].submul(itmp1, intMatrix(i, j));
      }
    }
    INFO("BabaiSol=" << solCoord);
    getGSCoords(floatMatrix, mu, rdiag, target, targetCoord);

    /* Computes a very large bound to make the algorithm work
       until the first solution is found */
    maxDist.set(0.0);
    for (int i = 1; i < d; i++)
      maxDist.add(maxDist, rdiag[i]);
  }

  invNormFactor.set(1.0);
  normFactor.set(1.0);

  Enumerator enumerator(mu, rdiag, targetCoord, maxVolume, minLevel);
  Evaluator* evaluator;

  switch (evaluatorType) {
    case evalFast:
      evaluator = new Evaluator(mu, rdiag, targetCoord);
      break;
    case evalSmart:
      evaluator = new SmartEvaluator(mu, rdiag, targetCoord);
      break;
    case evalExact:
      evaluator = new ExactEvaluator(intMatrix, intTarget, normFactor,
        mu, rdiag, targetCoord);
      break;
    default:
      STD_CHECK(false, "(Solver) Invalid evaluator type " << evaluatorType);
      break;
  }

  if (prec >= minprec) {
    evaluator->inputErrorDefined = true;
    for (int i = 0; i < d; i++) {
      ftmp1.set(d * 16.0 * pow(SLV_RHO, i) * ulp);
      ftmp1.mul(ftmp1, rdiag[i]);
      evaluator->maxDRdiag[i].set(ftmp1);
      ftmp1.set(d * 64.0 * pow(SLV_RHO, i) * ulp);
      evaluator->maxDMu[i].set(ftmp1);
    }
  }

  // Main loop of the enumeration
  while (enumerator.enumNext(maxDist)) {
    INFO("Subtree=" << enumerator.getSubTree());
    if (verbose) cerr << enumerator.getSubTree();

    evaluator->newSolFlag = false;

    /* Enumerates short vectors only in enumerator.getSubTree()
       (about maxVolume iterations or less) */
    enumerate(mu, rdiag, maxDist, *evaluator,
      targetCoord, enumerator.getSubTree());

    if (verbose) {
      ftmp1.set(evaluator->lastPartialDist);
      ftmp1.mul(ftmp1, invNormFactor);
      cerr << "\r                                                \r";
      if (evaluator->newSolFlag)
        cerr << "Solution norm^2=" << ftmp1
             << " value=" << evaluator->solCoord << endl;
    }
  }

  INTERNAL_CHECK(!evaluator->solCoord.empty(), "(Solver) No solution found");
  if (intTarget.empty()) {
    validMaxError = evaluator->getMaxError(ftmp1);
    maxError = ftmp1.get_d(GMP_RNDU);
  }
  else {
    validMaxError = false;
    maxError = -1.0;
  }
  certifiedSol = evaluator->certifiedSol();
  for (int i = 0; i < d; i++) {
    itmp1.set_f(evaluator->solCoord[i]);
    solCoord[i].add(solCoord[i], itmp1);
  }

  delete evaluator;
}
