/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Smooth        smooth          Smooth grid points
      Smooth        smooth9         9 point smoothing
*/

#include <cdi.h>

#include "process_int.h"
#include "param_conversion.h"
#include "cdo_wtime.h"
#include <mpim_grid.h>
#include "constants.h"  // planet radius
#include "pmlist.h"
#include "cdo_options.h"
#include "progress.h"
#include "cimdOmp.h"
#include "grid_point_search.h"

enum
{
  FORM_LINEAR
};

static const char *Form[] = { "linear" };

struct SmoothPoint
{
  double arc_radius = 0.0;
  double radius = 1.0;
  double weight0 = 0.25;
  double weightR = 0.25;
  size_t maxpoints = SIZE_MAX;
  int form = FORM_LINEAR;
};

template <typename T>
static size_t
smooth(int gridID, T missval, const Varray<T> &array1, Varray<T> &array2, const SmoothPoint &spoint)
{
  auto gridID0 = gridID;
  auto gridsize = gridInqSize(gridID);
  auto numNeighbors = spoint.maxpoints;
  if (numNeighbors > gridsize) numNeighbors = gridsize;

  Varray<uint8_t> mask(gridsize);
  for (size_t i = 0; i < gridsize; ++i) mask[i] = !DBL_IS_EQUAL(array1[i], missval);

  gridID = generate_full_point_grid(gridID);

  if (!gridHasCoordinates(gridID)) cdoAbort("Cell center coordinates missing!");

  Varray<double> xvals(gridsize), yvals(gridsize);
  gridInqXvals(gridID, xvals.data());
  gridInqYvals(gridID, yvals.data());

  // Convert lat/lon units if required
  cdo_grid_to_radian(gridID, CDI_XAXIS, gridsize, xvals.data(), "grid center lon");
  cdo_grid_to_radian(gridID, CDI_YAXIS, gridsize, yvals.data(), "grid center lat");

  std::vector<knnWeightsType> knnWeights;
  for (int i = 0; i < Threading::ompNumThreads; ++i) knnWeights.push_back(knnWeightsType(numNeighbors));

  auto start = Options::cdoVerbose ? cdo_get_wtime() : 0.0;

  bool xIsCyclic = false;
  size_t dims[2] = { gridsize, 0 };
  GridPointSearch gps;
  gridPointSearchCreate(gps, xIsCyclic, dims, gridsize, xvals, yvals);

  if (spoint.arc_radius > 0.0)
    gridPointSearchSetArcRadius(gps, spoint.arc_radius);
  else
    gridPointSearchSetChordRadius(gps, spoint.radius);

  if (Options::cdoVerbose) cdoPrint("Point search created: %.2f seconds", cdo_get_wtime() - start);

  if (Options::cdoVerbose) progress::init();

  start = Options::cdoVerbose ? cdo_get_wtime() : 0;

  size_t naddsMin = gridsize, naddsMax = 0;
  size_t nmissx = 0;
  double findex = 0;

#ifdef HAVE_OPENMP4
#pragma omp parallel for schedule(dynamic) default(none) reduction(+ : nmissx) reduction(min : naddsMin) reduction(max : naddsMax) \
  shared(findex, Options::cdoVerbose, knnWeights, spoint, mask, array1, array2, xvals, yvals, gps, gridsize, missval)
#endif
  for (size_t i = 0; i < gridsize; ++i)
    {
      const auto ompthID = cdo_omp_get_thread_num();

#ifdef _OPENMP
#pragma omp atomic
#endif
      findex++;
      if (Options::cdoVerbose && cdo_omp_get_thread_num() == 0) progress::update(0, 1, findex / gridsize);

      gridSearchPointSmooth(gps, xvals[i], yvals[i], knnWeights[ompthID]);

      // Compute weights based on inverse distance if mask is false, eliminate those points
      const auto nadds = knnWeights[ompthID].computeWeights(mask, spoint.radius, spoint.weight0, spoint.weightR);
      naddsMin = std::min(naddsMin, nadds);
      naddsMax = std::max(naddsMax, nadds);
      if (nadds)
        {
          array2[i] = knnWeights[ompthID].arrayWeightsSum(array1);
        }
      else
        {
          nmissx++;
          array2[i] = missval;
        }
    }

  progress::update(0, 1, 1);

  if (Options::cdoVerbose) cdoPrint("Point search nearest: %.2f seconds", cdo_get_wtime() - start);
  if (Options::cdoVerbose) cdoPrint("Min/Max points found: %zu/%zu", naddsMin, naddsMax);

  gridPointSearchDelete(gps);

  if (gridID0 != gridID) gridDestroy(gridID);

  return nmissx;
}

static void
smooth(const Field &field1, Field &field2, const SmoothPoint &spoint)
{
  if (field1.memType == MemType::Float)
    field2.nmiss = smooth(field1.grid, (float)field1.missval, field1.vec_f, field2.vec_f, spoint);
  else
    field2.nmiss = smooth(field1.grid, field1.missval, field1.vec_d, field2.vec_d, spoint);
}

template <typename T>
static inline void
smooth9_sum(size_t ij, const std::vector<uint8_t> &mask, double sfac, const Varray<T> &array, double &avg, double &divavg)
{
  if (mask[ij])
    {
      avg += sfac * array[ij];
      divavg += sfac;
    }
}

template <typename T>
static size_t
smooth9(int gridID, T missval, const Varray<T> &array1, Varray<T> &array2)
{
  const auto gridsize = gridInqSize(gridID);
  const auto nlon = gridInqXsize(gridID);
  const auto nlat = gridInqYsize(gridID);
  const auto grid_is_cyclic = gridIsCircular(gridID);

  std::vector<uint8_t> mask(gridsize);

  for (size_t i = 0; i < gridsize; ++i) mask[i] = !DBL_IS_EQUAL(missval, array1[i]);

  size_t nmiss = 0;
  for (size_t i = 0; i < nlat; i++)
    {
      for (size_t j = 0; j < nlon; j++)
        {
          double avg = 0, divavg = 0;

          if ((i == 0) || (j == 0) || (i == (nlat - 1)) || (j == (nlon - 1)))
            {
              const auto ij = j + nlon * i;
              if (mask[ij])
                {
                  avg += array1[ij];
                  divavg += 1;
                  // upper left corner
                  if ((i != 0) && (j != 0))
                    smooth9_sum(((i - 1) * nlon) + j - 1, mask, 0.3, array1, avg, divavg);
                  else if (i != 0 && grid_is_cyclic)
                    smooth9_sum((i - 1) * nlon + j - 1 + nlon, mask, 0.3, array1, avg, divavg);

                  // upper cell
                  if (i != 0) smooth9_sum(((i - 1) * nlon) + j, mask, 0.5, array1, avg, divavg);

                  // upper right corner
                  if ((i != 0) && (j != (nlon - 1)))
                    smooth9_sum(((i - 1) * nlon) + j + 1, mask, 0.3, array1, avg, divavg);
                  else if ((i != 0) && grid_is_cyclic)
                    smooth9_sum((i - 1) * nlon + j + 1 - nlon, mask, 0.3, array1, avg, divavg);

                  // left cell
                  if (j != 0)
                    smooth9_sum(i * nlon + j - 1, mask, 0.5, array1, avg, divavg);
                  else if (grid_is_cyclic)
                    smooth9_sum(i * nlon - 1 + nlon, mask, 0.5, array1, avg, divavg);

                  // right cell
                  if (j != (nlon - 1))
                    smooth9_sum((i * nlon) + j + 1, mask, 0.5, array1, avg, divavg);
                  else if (grid_is_cyclic)
                    smooth9_sum(i * nlon + j + 1 - nlon, mask, 0.5, array1, avg, divavg);

                  // lower left corner
                  if (mask[ij] && ((i != (nlat - 1)) && (j != 0)))
                    smooth9_sum(((i + 1) * nlon + j - 1), mask, 0.3, array1, avg, divavg);
                  else if ((i != (nlat - 1)) && grid_is_cyclic)
                    smooth9_sum((i + 1) * nlon - 1 + nlon, mask, 0.3, array1, avg, divavg);

                  // lower cell
                  if (i != (nlat - 1)) smooth9_sum(((i + 1) * nlon) + j, mask, 0.5, array1, avg, divavg);

                  // lower right corner
                  if ((i != (nlat - 1)) && (j != (nlon - 1)))
                    smooth9_sum(((i + 1) * nlon) + j + 1, mask, 0.3, array1, avg, divavg);
                  else if ((i != (nlat - 1)) && grid_is_cyclic)
                    smooth9_sum(((i + 1) * nlon) + j + 1 - nlon, mask, 0.3, array1, avg, divavg);
                }
            }
          else if (mask[j + nlon * i])
            {
              avg += array1[j + nlon * i];
              divavg += 1;

              smooth9_sum(((i - 1) * nlon) + j - 1, mask, 0.3, array1, avg, divavg);
              smooth9_sum(((i - 1) * nlon) + j, mask, 0.5, array1, avg, divavg);
              smooth9_sum(((i - 1) * nlon) + j + 1, mask, 0.3, array1, avg, divavg);
              smooth9_sum(((i) *nlon) + j - 1, mask, 0.5, array1, avg, divavg);
              smooth9_sum((i * nlon) + j + 1, mask, 0.5, array1, avg, divavg);
              smooth9_sum(((i + 1) * nlon + j - 1), mask, 0.3, array1, avg, divavg);
              smooth9_sum(((i + 1) * nlon) + j, mask, 0.5, array1, avg, divavg);
              smooth9_sum(((i + 1) * nlon) + j + 1, mask, 0.3, array1, avg, divavg);
            }

          if (std::fabs(divavg) > 0)
            {
              array2[i * nlon + j] = avg / divavg;
            }
          else
            {
              array2[i * nlon + j] = missval;
              nmiss++;
            }
        }
    }

  return nmiss;
}

static void
smooth9(const Field &field1, Field &field2)
{
  if (field1.memType == MemType::Float)
    field2.nmiss = smooth9(field1.grid, (float)field1.missval, field1.vec_f, field2.vec_f);
  else
    field2.nmiss = smooth9(field1.grid, field1.missval, field1.vec_d, field2.vec_d);
}

double
radiusDegToKm(const double radiusInDeg)
{
  return radiusInDeg * (2.0 * PlanetRadius * M_PI) / (360.0 * 1000.0);
}

static int
convert_form(const std::string &formstr)
{
  int form = FORM_LINEAR;

  if (formstr == "linear")
    form = FORM_LINEAR;
  else
    cdoAbort("form=%s unsupported!", formstr.c_str());

  return form;
}

static void
smoothGetParameter(int &xnsmooth, SmoothPoint &spoint)
{
  const auto pargc = operatorArgc();
  if (pargc)
    {
      const auto pargv = cdoGetOperArgv();

      KVList kvlist;
      kvlist.name = "SMOOTH";
      if (kvlist.parseArguments(pargc, pargv) != 0) cdoAbort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdoAbort("Too many values for parameter key >%s<!", key.c_str());
          if (kv.nvalues < 1) cdoAbort("Missing value for parameter key >%s<!", key.c_str());
          const auto &value = kv.values[0];

          // clang-format off
          if      (key == "nsmooth")    xnsmooth = parameter2int(value);
          else if (key == "maxpoints")  spoint.maxpoints = parameter2sizet(value);
          else if (key == "weight0")    spoint.weight0 = parameter2double(value);
          else if (key == "weightR")    spoint.weightR = parameter2double(value);
          else if (key == "radius")     spoint.radius = radius_str_to_deg(value.c_str());
          else if (key == "arc_radius") spoint.arc_radius = radius_str_to_deg(value.c_str());
          else if (key == "form")       spoint.form = convert_form(value);
          else cdoAbort("Invalid parameter key >%s<!", key.c_str());
          // clang-format on
        }
    }
}

static void
check_radius_range(double radius, const char *name)
{
  if (radius < 0.0 || radius > 180.0) cdoAbort("%s=%g out of bounds (0-180 deg)!", name, radius);
}

void *
Smooth(void *process)
{
  int xnsmooth = 1;

  cdoInitialize(process);

  // clang-format off
  const auto SMOOTH  = cdoOperatorAdd("smooth",   0,   0, nullptr);
  const auto SMOOTH9 = cdoOperatorAdd("smooth9",  0,   0, nullptr);
  // clang-format on

  const auto operatorID = cdoOperatorID();

  SmoothPoint spoint;
  if (operatorID == SMOOTH) smoothGetParameter(xnsmooth, spoint);

  check_radius_range(spoint.radius, "radius");
  check_radius_range(spoint.arc_radius, "arc_radius");

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  VarList varList1;
  varListInit(varList1, vlistID1);

  const auto nvars = vlistNvars(vlistID1);
  std::vector<bool> varIDs(nvars, false);

  for (int varID = 0; varID < nvars; ++varID)
    {
      const auto gridID = varList1[varID].gridID;
      const auto gridtype = gridInqType(gridID);
      if (gridtype == GRID_GAUSSIAN || gridtype == GRID_LONLAT || gridtype == GRID_CURVILINEAR || gridtype == GRID_PROJECTION
          || (operatorID == SMOOTH9 && gridtype == GRID_GENERIC && gridInqXsize(gridID) > 0 && gridInqYsize(gridID) > 0))
        {
          varIDs[varID] = true;
        }
      else if (operatorID == SMOOTH && gridtype == GRID_UNSTRUCTURED)
        {
          varIDs[varID] = true;
        }
      else
        {
          cdoWarning("Unsupported grid for variable %s", varList1[varID].name);
        }
    }

  const auto gridsizemax = vlistGridsizeMax(vlistID1);
  if (gridsizemax < spoint.maxpoints) spoint.maxpoints = gridsizemax;
  if (Options::cdoVerbose && operatorID == SMOOTH)
    {
      const auto &sp = spoint;
      if (sp.arc_radius > 0.0)
        cdoPrint("nsmooth = %d, maxpoints = %zu, arc_radius = %gdeg(%gkm), form = %s, weight0 = %g, weightR = %g", xnsmooth,
                 sp.maxpoints, sp.radius, radiusDegToKm(sp.arc_radius), Form[sp.form], sp.weight0, sp.weightR);
      else
        cdoPrint("nsmooth = %d, maxpoints = %zu, radius = %gdeg(%gkm), form = %s, weight0 = %g, weightR = %g", xnsmooth,
                 sp.maxpoints, sp.radius, radiusDegToKm(sp.radius), Form[sp.form], sp.weight0, sp.weightR);
    }

  spoint.radius *= DEG2RAD;
  spoint.arc_radius *= DEG2RAD;

  Field field1, field2;

  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);

  int tsID = 0;
  while (true)
    {
      const auto nrecs = cdoStreamInqTimestep(streamID1, tsID);
      if (nrecs == 0) break;

      taxisCopyTimestep(taxisID2, taxisID1);
      cdoDefTimestep(streamID2, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          int varID, levelID;
          cdoInqRecord(streamID1, &varID, &levelID);
          field1.init(varList1[varID]);
          field2.init(varList1[varID]);
          cdoReadRecord(streamID1, field1);

          if (varIDs[varID])
            {
              for (int i = 0; i < xnsmooth; ++i)
                {
                  if (operatorID == SMOOTH)
                    smooth(field1, field2, spoint);
                  else if (operatorID == SMOOTH9)
                    smooth9(field1, field2);

                  fieldCopy(field2, field1);
                }
            }

          cdoDefRecord(streamID2, varID, levelID);
          cdoWriteRecord(streamID2, field1);
        }

      tsID++;
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
