/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

#include <cdi.h>

#include "cdo_math.h"
#include "remap.h"
#include <mpim_grid.h>
#include "cdo_output.h" 

static size_t
fill_src_indices(bool isCyclic, long nx, long ny, long ii, long jj, long k, size_t *psrcIndices)
{
  k /= 2;

  auto j0 = jj - k;
  auto jn = jj + k;
  auto i0 = ii - k;
  auto in = ii + k;
  if (j0 < 0) j0 = 0;
  if (jn >= ny) jn = ny - 1;
  if ((in - i0) > nx)
    {
      i0 = 0;
      in = nx - 1;
    }

  size_t numIndices = 0;

  for (long j = j0; j <= jn; ++j)
    for (long i = i0; i <= in; ++i)
      {
        auto ix = i;
        if (isCyclic && ix < 0) ix += nx;
        if (isCyclic && ix >= nx) ix -= nx;
        if (ix >= 0 && ix < nx && j < ny) psrcIndices[numIndices++] = j * nx + ix;
      }

  return numIndices;
}

static void
store_distance_healpix(GridPointSearch &gps, double plon, double plat, knnWeightsType &knnWeights, size_t numIndices, size_t *indices, double *srcLons, double *srcLats)
{
  double xyz[3], query_pt[3];
  gcLLtoXYZ(plon, plat, query_pt);
  auto sqrSearchRadius = cdo::sqr(gps.searchRadius);

  for (size_t i = 0; i < numIndices; ++i)
    {
      auto index = indices[i];
      auto lon = srcLons[i];
      auto lat = srcLats[i];

      gcLLtoXYZ(lon, lat, xyz);

      // Find distance to this point
      double sqrDist = (float) squareDistance(query_pt, xyz);
      if (sqrDist <= sqrSearchRadius)
        {
          // Store the index and distance if this is one of the smallest so far
          knnWeights.storeDistance(index, std::sqrt(sqrDist));
        }
    }

  knnWeights.checkDistance();
}

static void
store_distance_reg2d(GridPointSearch &gps, double plon, double plat, knnWeightsType &knnWeights, size_t nx, size_t numIndices,
                     size_t *psrcIndices)
{
  const auto &coslon = gps.coslon;
  const auto &sinlon = gps.sinlon;
  const auto &coslat = gps.coslat;
  const auto &sinlat = gps.sinlat;

  double xyz[3], query_pt[3];
  gcLLtoXYZ(plon, plat, query_pt);
  auto sqrSearchRadius = cdo::sqr(gps.searchRadius);

  for (size_t i = 0; i < numIndices; ++i)
    {
      auto index = psrcIndices[i];
      auto iy = index / nx;
      auto ix = index - iy * nx;

      xyz[0] = coslat[iy] * coslon[ix];
      xyz[1] = coslat[iy] * sinlon[ix];
      xyz[2] = sinlat[iy];
      // Find distance to this point
      double sqrDist = (float) squareDistance(query_pt, xyz);
      if (sqrDist <= sqrSearchRadius)
        {
          // Store the index and distance if this is one of the smallest so far
          knnWeights.storeDistance(index, std::sqrt(sqrDist));
        }
    }

  knnWeights.checkDistance();
}

static void
gridSearchPointHealpix(GridPointSearch &gps, double plon, double plat, knnWeightsType &knnWeights)
{
  /*
    Input variables:

      plat : latitude  of the search point
      plon : longitude of the search point

    Output variables:

      knnWeights.m_indices[numNeighbors] : index of each of the closest points
      knnWeights.m_dist[numNeighbors] : distance to each of the closest points
  */
  auto numNeighbors = knnWeights.maxNeighbors();
  if (numNeighbors > 9) cdo_abort("Max number of neighbors is 9!");

  // Initialize distance and index arrays
  knnWeights.initIndices();
  knnWeights.initDist();

  auto index = hp_lonlat_to_index(gps.order, gps.nside, plon, plat);

  int64_t neighbours[8];
  hp_get_neighbours(gps.order, gps.nside, index, neighbours);

  size_t indices[9];
  indices[0] = index;
  size_t numWeights = 1;
  for (int i = 0; i < 8; ++i)
    if (neighbours[i] >= 0) indices[numWeights++] = neighbours[i];

  double srcLons[9], srcLats[9];
  for (size_t i = 0; i < numWeights; ++i)
    hp_index_to_lonlat(gps.order, gps.nside, indices[i], &srcLons[i], &srcLats[i]);
    
  store_distance_healpix(gps, plon, plat, knnWeights, numWeights, indices, srcLons, srcLats);
}

// This routine finds the closest numNeighbor points to a search point and computes a distance to each of the neighbors
static void
gridSearchPointReg2d(GridPointSearch &gps, double plon, double plat, knnWeightsType &knnWeights)
{
  /*
    Input variables:

      plat : latitude  of the search point
      plon : longitude of the search point

    Output variables:

      knnWeights.m_indices[numNeighbors] : index of each of the closest points
      knnWeights.m_dist[numNeighbors] : distance to each of the closest points
  */
  auto numNeighbors = knnWeights.maxNeighbors();
  auto &nbrIndices = knnWeights.m_indices;
  auto &nbrDistance = knnWeights.m_dist;

  // Initialize distance and index arrays
  knnWeights.initIndices();
  knnWeights.initDist();

  const auto &src_center_lon = gps.reg2d_center_lon;
  const auto &src_center_lat = gps.reg2d_center_lat;

  long nx = gps.dims[0];
  long ny = gps.dims[1];
  size_t nxm = gps.isCyclic ? nx + 1 : nx;

  if (plon < src_center_lon[0]) plon += PI2;
  if (plon > src_center_lon[nxm - 1]) plon -= PI2;

  size_t ii, jj;
  auto lfound = rect_grid_search(ii, jj, plon, plat, nxm, ny, src_center_lon, src_center_lat);
  if (lfound)
    {
      if (gps.isCyclic && ii == (nxm - 1)) ii = 0;

      constexpr size_t MAX_SEARCH_CELLS = 25;
      size_t srcIndices[MAX_SEARCH_CELLS];
      size_t *psrcIndices = srcIndices;

      size_t k;
      for (k = 3; k < 10000; k += 2)
        if (numNeighbors <= (size_t) (k - 2) * (k - 2)) break;

      std::vector<size_t> tmpIndices;
      if ((k * k) > MAX_SEARCH_CELLS)
        {
          tmpIndices.resize(k * k);
          psrcIndices = tmpIndices.data();
        }

      auto numIndices = fill_src_indices(gps.isCyclic, nx, ny, ii, jj, k, psrcIndices);

      store_distance_reg2d(gps, plon, plat, knnWeights, nx, numIndices, psrcIndices);
    }
  else if (gps.extrapolate)
    {
      int searchResult = 0;

      if (numNeighbors < 4)
        {
          size_t nbrIndices4[4];
          double nbrDistance4[4];
          for (size_t n = 0; n < numNeighbors; ++n) nbrIndices4[n] = SIZE_MAX;
          searchResult = grid_search_square_reg_2d_NN(nx, ny, nbrIndices4, nbrDistance4, plat, plon, src_center_lat, src_center_lon);
          if (searchResult < 0)
            {
              for (size_t n = 0; n < numNeighbors; ++n) nbrIndices[n] = nbrIndices4[n];
              for (size_t n = 0; n < numNeighbors; ++n) nbrDistance[n] = nbrDistance4[n];
            }
        }
      else
        {
          searchResult
              = grid_search_square_reg_2d_NN(nx, ny, nbrIndices.data(), nbrDistance.data(), plat, plon, src_center_lat, src_center_lon);
        }

      if (searchResult >= 0)
        for (size_t n = 0; n < numNeighbors; ++n) nbrIndices[n] = SIZE_MAX;
    }
}

void
grid_search_point(GridPointSearch &gps, double plon, double plat, knnWeightsType &knnWeights)
{
  /*
    Input variables:

      plat : latitude  of the search point
      plon : longitude of the search point

    Output variables:

      knnWeights.m_indices[numNeighbors] : index of each of the closest points
      knnWeights.m_dist[numNeighbors] : distance to each of the closest points
  */
  auto numNeighbors = knnWeights.maxNeighbors();

  // check some more points if distance is the same use the smaller index
  auto ndist = (numNeighbors > 8) ? numNeighbors + 8 : numNeighbors * 2;
  if (ndist > gps.n) ndist = gps.n;

  if (knnWeights.m_tmpIndices.empty()) knnWeights.m_tmpIndices.resize(ndist);
  if (knnWeights.m_tmpDist.empty()) knnWeights.m_tmpDist.resize(ndist);
  auto &indices = knnWeights.m_tmpIndices;
  auto &dist = knnWeights.m_tmpDist;

  size_t numIndices = 0;
  if (numNeighbors == 1)
    numIndices = grid_point_search_nearest(gps, plon, plat, indices.data(), dist.data());
  else
    numIndices = grid_point_search_qnearest(gps, plon, plat, ndist, indices.data(), dist.data());

  ndist = numIndices;
  if (ndist < numNeighbors) numNeighbors = ndist;

  // Initialize distance and index arrays
  knnWeights.initIndices();
  knnWeights.initDist();
  for (size_t i = 0; i < ndist; ++i) knnWeights.storeDistance(indices[i], dist[i], numNeighbors);

  knnWeights.checkDistance();
}

void
grid_search_point_smooth(GridPointSearch &gps, double plon, double plat, knnWeightsType &knnWeights)
{
  /*
    Input variables:

      plat : latitude  of the search point
      plon : longitude of the search point

    Output variables:

      knnWeights.m_indices[numNeighbors] : index of each of the closest points
      knnWeights.m_dist[numNeighbors] : distance to each of the closest points
  */
  auto numNeighbors = knnWeights.maxNeighbors();
  auto checkDistance = (numNeighbors <= 32);

  // check some more points if distance is the same use the smaller index
  auto ndist = checkDistance ? ((numNeighbors > 8) ? numNeighbors + 8 : numNeighbors * 2) : numNeighbors;
  if (ndist > gps.n) ndist = gps.n;

  if (knnWeights.m_tmpIndices.empty()) knnWeights.m_tmpIndices.resize(ndist);
  if (knnWeights.m_tmpDist.empty()) knnWeights.m_tmpDist.resize(ndist);
  auto &indices = knnWeights.m_tmpIndices;
  auto &dist = knnWeights.m_tmpDist;

  size_t numIndices = 0;
  if (numNeighbors == 1)
    numIndices = grid_point_search_nearest(gps, plon, plat, indices.data(), dist.data());
  else
    numIndices = grid_point_search_qnearest(gps, plon, plat, ndist, indices.data(), dist.data());

  ndist = numIndices;

  if (checkDistance)
    {
      if (ndist < numNeighbors) numNeighbors = ndist;

      // Initialize distance and index arrays
      knnWeights.initIndices(numNeighbors);
      knnWeights.initDist(numNeighbors);
      for (size_t i = 0; i < ndist; ++i) knnWeights.storeDistance(indices[i], dist[i], numNeighbors);
    }
  else
    {
      knnWeights.m_numNeighbors = ndist;
      for (size_t i = 0; i < ndist; ++i) knnWeights.m_indices[i] = indices[i];
      for (size_t i = 0; i < ndist; ++i) knnWeights.m_dist[i] = dist[i];
    }

  knnWeights.checkDistance();
}

void
remap_search_points(RemapSearch &rsearch, const LonLatPoint &llpoint, knnWeightsType &knnWeights)
{
  if (rsearch.srcGrid->type == RemapGridType::HealPix)
    gridSearchPointHealpix(rsearch.gps, llpoint.lon, llpoint.lat, knnWeights);
  else if (rsearch.srcGrid->type == RemapGridType::Reg2D)
    gridSearchPointReg2d(rsearch.gps, llpoint.lon, llpoint.lat, knnWeights);
  else
    grid_search_point(rsearch.gps, llpoint.lon, llpoint.lat, knnWeights);
}

static int
gridSearchSquareCurv2d(GridPointSearch &gps, RemapGrid *rgrid, size_t (&srcIndices)[4], double (&srcLats)[4], double (&srcLons)[4],
                       double plat, double plon)
{
  /*
    Input variables:

      plat : latitude  of the search point
      plon : longitude of the search point

    Output variables:

      srcIndices[4] :  index of each corner point enclosing P
      srcLats[4]    :  latitudes  of the four corner points
      srcLons[4]    :  longitudes of the four corner points
  */
  int searchResult = 0;

  for (int n = 0; n < 4; ++n) srcIndices[n] = 0;

  double dist = 0.0;
  size_t index = 0;
  size_t numIndices = grid_point_search_nearest(gps, plon, plat, &index, &dist);
  if (numIndices > 0)
    {
      auto nx = rgrid->dims[0];
      auto ny = rgrid->dims[1];

      for (int k = 0; k < 4; ++k)
        {
          // Determine neighbor index
          auto j = index / nx;
          auto i = index - j * nx;
          if (k == 0 || k == 2) i = (i > 0) ? i - 1 : rgrid->isCyclic ? nx - 1 : 0;
          if (k == 0 || k == 1) j = (j > 0) ? j - 1 : 0;
          if (point_in_quad(rgrid->isCyclic, nx, ny, i, j, srcIndices, srcLons, srcLats, plon, plat, rgrid->cell_center_lon.data(),
                            rgrid->cell_center_lat.data()))
            {
              searchResult = 1;
              return searchResult;
            }
        }
    }

  /*
    If no cell found, point is likely either in a box that straddles either pole or is outside the grid.
    Fall back to a distance-weighted average of the four closest points. Go ahead and compute weights here,
    but store in srcLats and return -index to prevent the parent routine from computing bilinear weights.
  */
  if (!rgrid->doExtrapolate) return searchResult;

  size_t ndist = 4;
  numIndices = grid_point_search_qnearest(gps, plon, plat, ndist, srcIndices, srcLats);
  if (numIndices == 4)
    {
      for (int i = 0; i < 4; ++i) srcLats[i] = 1.0 / (srcLats[i] + TINY);
      double distance = 0.0;
      for (int i = 0; i < 4; ++i) distance += srcLats[i];
      for (int i = 0; i < 4; ++i) srcLats[i] /= distance;
      searchResult = -1;
    }

  return searchResult;
}

int
remap_search_square(RemapSearch &rsearch, const LonLatPoint &llpoint, size_t (&srcIndices)[4], double (&srcLats)[4],
                    double (&srcLons)[4])
{
  if (rsearch.srcGrid->type == RemapGridType::Reg2D)
    return grid_search_square_reg_2d(rsearch.srcGrid, srcIndices, srcLats, srcLons, llpoint.lat, llpoint.lon);
  else if (rsearch.gps.in_use)
    return gridSearchSquareCurv2d(rsearch.gps, rsearch.srcGrid, srcIndices, srcLats, srcLons, llpoint.lat, llpoint.lon);
  else
    return grid_search_square_curv_2d_scrip(rsearch.srcGrid, srcIndices, srcLats, srcLons, llpoint.lat, llpoint.lon, rsearch.srcBins);
}
