/* Copyright (C) 2008 Xavier Pujol.

This file is part of the fplll Library.

The fplll Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The fplll Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the fplll Library; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */

#include "tools.h"

/* Applies Minkowski's theorem: ||sol|| <= gamma_d . det^(2/d)
   (sol = squared length of the shortest vector)
   Input: rdiag
   Output: result */

void minkowskiTheorem(Float& result, const FloatVect& rdiag) {
  unsigned int d = rdiag.size();
  Float logdetsq, rtmp1;
  logdetsq.set(0.0);

  // Computes log((det L) ^ 2) (avoids overflow with log)
  for (unsigned int i = 0; i < d; i++) {
    rtmp1.log(rdiag[i]);
    logdetsq.add(logdetsq, rtmp1);
  }

  rtmp1.set(d);
  logdetsq.div(logdetsq, rtmp1);    // log((det L) ^ 2) / d
  result.exponential(logdetsq);     // (det L) ^ (-2/d)
  rtmp1.set(hermite_constant[d]);
  result.mul(result, rtmp1);        // gamma[d] * (det L) ^ (-2/d)
}

/* Computes the volume of a d-dimensional hypersphere
   of radius 1 = pi^(d/2)/Gamma(1+d/2)
   Input: d
   Output: volume */

static void sphereVolume(Float& volume, int d) {
  Float rtmp1;
  volume.set(pow(PI, (double)(d / 2)));

  if (d % 2 == 0)
    for (int i = 1; i <= d / 2; i++) {
      rtmp1.set((unsigned int) i);
      volume.div(volume, rtmp1);
    }
  else
    for (int i = 0; i <= d / 2; i++) {
      rtmp1.set(2.0 / (double)(2 * i + 1));
      volume.mul(volume, rtmp1);
    }
}

/* Estimates the cost of the enumeration
   Input: A, rdiag, dmax
   Output: cost */

void costEstimate(Float& cost, const Float& A, const FloatVect& rdiag, int dmax) {
  int i, j, k, imax = 0, jmax = 0;

  const double cost_factor = 20.0;
  Float det, maximum, rtmp1;
  maximum.set(-1.0);

  for (i = 0; i < dmax; i++) {
    for (j = i; j < dmax; j++) {
      det.set(1.0);
      for (k = i; k <= j; k++) {
        rtmp1.div(A, rdiag[k]);
        det.mul(det, rtmp1);
      }
      det.sqrt(det);
      sphereVolume(rtmp1, j - i + 1);
      det.mul(det, rtmp1);

      if (maximum <= det) {
        imax = i;
        jmax = j;
        maximum.set(det);
      }
    }
  }

  rtmp1.set(cost_factor);
  cost.mul(maximum, rtmp1);
}

void gramSchmidt(const FloatMatrix& b, FloatMatrix& mu,
  FloatVect& rdiag, Float& basisMin) {

  int i, j, k, d = b.GetNumRows(), n = b.GetNumCols();
  FloatMatrix r(d, d);
  Float coeff;

  BOUND_CHECK(mu.GetNumRows() == d && mu.GetNumCols() == d,
    "(gramSchmidt) Invalid matrix size");
  if ((int) rdiag.size() != d) rdiag.resize(d);

  for (i = 0; i < d; i++) {
    for (j = 0; j <= i; j++) {
      coeff.set(0.0);
      for (k = 0; k < n; k++)
        coeff.addmul(b(i, k), b(j, k));
      if (i == j && (i == 0 || coeff <= basisMin))
        basisMin.set(coeff);
      for (k = 0; k < j; k++)
        coeff.submul(r(j, k), mu(i, k));
      r(i, j).set(coeff);
      mu(i, j).div(coeff, r(j, j));
    }
    rdiag[i].set(r(i, i));
  }
}

void gramSchmidt(const FloatMatrix& b, FloatMatrix& mu, FloatVect& rdiag) {
  Float basisMin;
  gramSchmidt(b, mu, rdiag, basisMin);
}

void getGSCoords(const FloatMatrix& matrix, const FloatMatrix& mu,
  const FloatVect& rdiag, const FloatVect& v, FloatVect& vcoord) {

  int d = matrix.GetNumRows();

  if (d != (int) vcoord.size()) vcoord.resize(d);
  BOUND_CHECK(mu.GetNumRows() == d && mu.GetNumCols() == d &&
    d == (int) v.size() && d == (int) rdiag.size(),
    "(getGSCoords) Incompatible sizes");

  for (int i = 0; i < d; i++) {
    vcoord[i].set(0.0);
    for (int j = 0; j < (int) v.size(); j++)
      vcoord[i].addmul(v[j], matrix(i, j));
    for (int j = 0; j < i; j++)
      vcoord[i].submul(mu(i, j), vcoord[j]);
  }
  for (int i = 0; i < d; i++)
    vcoord[i].div(vcoord[i], rdiag[i]);
}

void babai(const FloatMatrix& matrix, const FloatMatrix& mu,
  const FloatVect& rdiag, const FloatVect& target, FloatVect& targetcoord) {

  int d = matrix.GetNumRows();
  getGSCoords(matrix, mu, rdiag, target, targetcoord);
  for (int i = d - 1; i >= 0; i--) {
    targetcoord[i].rnd(targetcoord[i]);
    for (int j = 0; j < i; j++)
      targetcoord[j].submul(mu(i, j), targetcoord[i]);
  }
}
