/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Vertstat   vertrange       Vertical range
      Vertstat   vertmin         Vertical minimum
      Vertstat   vertmax         Vertical maximum
      Vertstat   vertsum         Vertical sum
      Vertstat   vertint         Vertical integral
      Vertstat   vertmean        Vertical mean
      Vertstat   vertavg         Vertical average
      Vertstat   vertvar         Vertical variance
      Vertstat   vertvar1        Vertical variance [Normalize by (n-1)]
      Vertstat   vertstd         Vertical standard deviation
      Vertstat   vertstd1        Vertical standard deviation [Normalize by (n-1)]
*/

#include <cdi.h>

#include "cdo_options.h"
#include "functs.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "cdo_zaxis.h"
#include "param_conversion.h"
#include "pmlist.h"
#include "cdi_lockedIO.h"

#define IS_SURFACE_LEVEL(zaxisID) (zaxisInqType(zaxisID) == ZAXIS_SURFACE && zaxisInqSize(zaxisID) == 1)

int
getSurfaceID(int vlistID)
{
  int surfID = -1;

  const auto nzaxis = vlistNzaxis(vlistID);
  for (int index = 0; index < nzaxis; ++index)
    {
      const auto zaxisID = vlistZaxis(vlistID, index);
      if (IS_SURFACE_LEVEL(zaxisID))
        {
          surfID = vlistZaxis(vlistID, index);
          break;
        }
    }

  if (surfID == -1) surfID = zaxisFromName("surface");

  return surfID;
}

static void
setSurfaceID(const int vlistID, const int surfID)
{
  const auto nzaxis = vlistNzaxis(vlistID);
  for (int index = 0; index < nzaxis; ++index)
    {
      const auto zaxisID = vlistZaxis(vlistID, index);
      if (zaxisID != surfID || !IS_SURFACE_LEVEL(zaxisID)) vlistChangeZaxisIndex(vlistID, index, surfID);
    }
}

static void
vertstatGetParameter(bool *weights, bool *genbounds)
{
  const auto pargc = operatorArgc();
  if (pargc)
    {
      const auto pargv = cdoGetOperArgv();

      KVList kvlist;
      kvlist.name = "VERTSTAT";
      if (kvlist.parseArguments(pargc, pargv) != 0) cdoAbort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdoAbort("Too many values for parameter key >%s<!", key.c_str());
          if (kv.nvalues < 1) cdoAbort("Missing value for parameter key >%s<!", key.c_str());
          const auto &value = kv.values[0];

          // clang-format off
          if      (key == "weights")   *weights = parameter2bool(value);
          else if (key == "genbounds") *genbounds = parameter2bool(value);
          else cdoAbort("Invalid parameter key >%s<!", key.c_str());
          // clang-format on
        }
    }
}

void *
Vertstat(void *process)
{
  struct VertInfo
  {
    int zaxisID;
    int status;
    int numlevel;
    Varray<double> thickness;
    Varray<double> weights;
  };

  cdoInitialize(process);

  // clang-format off
                 cdoOperatorAdd("vertrange", func_range, 0, nullptr);
                 cdoOperatorAdd("vertmin",   func_min,   0, nullptr);
                 cdoOperatorAdd("vertmax",   func_max,   0, nullptr);
                 cdoOperatorAdd("vertsum",   func_sum,   0, nullptr);
  int VERTINT  = cdoOperatorAdd("vertint",   func_sum,   1, nullptr);
                 cdoOperatorAdd("vertmean",  func_mean,  1, nullptr);
                 cdoOperatorAdd("vertavg",   func_avg,   1, nullptr);
                 cdoOperatorAdd("vertvar",   func_var,   1, nullptr);
                 cdoOperatorAdd("vertvar1",  func_var1,  1, nullptr);
                 cdoOperatorAdd("vertstd",   func_std,   1, nullptr);
                 cdoOperatorAdd("vertstd1",  func_std1,  1, nullptr);
  // clang-format on

  const auto operatorID  = cdoOperatorID();
  const auto operfunc    = cdoOperatorF1(operatorID);
  const bool needWeights = cdoOperatorF2(operatorID);

  const auto lrange  = (operfunc == func_range);
  const auto lmean   = (operfunc == func_mean || operfunc == func_avg);
  const auto lstd    = (operfunc == func_std || operfunc == func_std1);
  const auto lvarstd = (lstd || operfunc == func_var || operfunc == func_var1);
  const int  divisor = (operfunc == func_std1 || operfunc == func_var1);

  auto vfarstdvar_func = lstd ? vfarstd : vfarvar;
  auto vfarcstdvar_func = lstd ? vfarcstd : vfarcvar;

  // int applyWeights = lmean;

  const auto streamID1 = cdoOpenRead(0);
  const auto vlistID1 = cdoStreamInqVlist(streamID1);

  vlistClearFlag(vlistID1);
  const auto nvars = vlistNvars(vlistID1);
  for (int varID = 0; varID < nvars; varID++) vlistDefFlag(vlistID1, varID, 0, true);

  const auto vlistID2 = vlistCreate();
  cdoVlistCopyFlag(vlistID2, vlistID1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  const auto surfID = getSurfaceID(vlistID1);
  setSurfaceID(vlistID2, surfID);

  const auto nzaxis = vlistNzaxis(vlistID1);
  std::vector<VertInfo> vert(nzaxis);
  if (needWeights)
    {
      auto useweights = true;
      auto genbounds = false;
      if (needWeights) vertstatGetParameter(&useweights, &genbounds);

      if (!useweights)
        {
          genbounds = false;
          cdoPrint("Using constant vertical weights!");
        }

      for (int index = 0; index < nzaxis; ++index)
        {
          const auto zaxisID = vlistZaxis(vlistID1, index);
          const auto nlev = zaxisInqSize(zaxisID);
          vert[index].numlevel = 0;
          vert[index].status = 0;
          vert[index].zaxisID = zaxisID;
          // if (nlev > 1)
          {
            vert[index].numlevel = nlev;
            vert[index].thickness.resize(nlev);
            vert[index].weights.resize(nlev);
            vert[index].status = getLayerThickness(useweights, genbounds, index, zaxisID, nlev, vert[index].thickness.data(),
                                                   vert[index].weights.data());
          }
          if (!useweights) vert[index].status = 3;
        }
    }

  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);

  VarList varList;
  varListInit(varList, vlistID1);

  Field field;

  FieldVector vars1(nvars), samp1(nvars), vars2;
  if (lvarstd || lrange) vars2.resize(nvars);

  for (int varID = 0; varID < nvars; varID++)
    {
      samp1[varID].grid = varList[varID].gridID;
      samp1[varID].missval = varList[varID].missval;
      samp1[varID].memType = MemType::Double;
      vars1[varID].grid = varList[varID].gridID;
      vars1[varID].missval = varList[varID].missval;
      vars1[varID].memType = MemType::Double;
      vars1[varID].resize(varList[varID].gridsize);
      if (lvarstd || lrange)
        {
          vars2[varID].grid = varList[varID].gridID;
          vars2[varID].missval = varList[varID].missval;
          vars2[varID].memType = MemType::Double;
          vars2[varID].resize(varList[varID].gridsize);
        }
    }

  int tsID = 0;
  while (true)
    {
      const auto nrecs = cdoStreamInqTimestep(streamID1, tsID);
      if (nrecs == 0) break;

      taxisCopyTimestep(taxisID2, taxisID1);
      cdoDefTimestep(streamID2, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          int varID, levelID;
          cdoInqRecord(streamID1, &varID, &levelID);

          auto &rsamp1 = samp1[varID];
          auto &rvars1 = vars1[varID];

          rvars1.nsamp++;
          if (lrange) vars2[varID].nsamp++;

          const auto gridsize = varList[varID].gridsize;
          const auto zaxisID = varList[varID].zaxisID;
          const auto nlev = varList[varID].nlevels;

          auto layer_weight = 1.0;
          auto layer_thickness = 1.0;
          if (needWeights)
            {
              for (int index = 0; index < nzaxis; ++index)
                if (vert[index].zaxisID == zaxisID)
                  {
                    if (vert[index].status == 0 && tsID == 0 && levelID == 0 && nlev > 1)
                      {
                        cdoWarning("Layer bounds not available, using constant vertical weights for variable %s!",
                                   varList[varID].name);
                      }
                    else
                      {
                        layer_weight = vert[index].weights[levelID];
                        layer_thickness = vert[index].thickness[levelID];
                      }

                    break;
                  }
            }

          if (levelID == 0)
            {
              cdoReadRecord(streamID1, rvars1);
              if (lrange)
                {
                  vars2[varID].nmiss = rvars1.nmiss;
                  vars2[varID].vec_d = rvars1.vec_d;
                }

              if (operatorID == VERTINT && IS_NOT_EQUAL(layer_thickness, 1.0)) vfarcmul(rvars1, layer_thickness);
              if (lmean && IS_NOT_EQUAL(layer_weight, 1.0)) vfarcmul(rvars1, layer_weight);

              if (lvarstd)
                {
                  if (IS_NOT_EQUAL(layer_weight, 1.0))
                    {
                      vfarmoqw(vars2[varID], rvars1, layer_weight);
                      vfarcmul(rvars1, layer_weight);
                    }
                  else
                    {
                      vfarmoq(vars2[varID], rvars1);
                    }
                }

              if (rvars1.nmiss || !rsamp1.empty() || needWeights)
                {
                  if (rsamp1.empty()) rsamp1.resize(gridsize);

                  for (size_t i = 0; i < gridsize; i++)
                    rsamp1.vec_d[i] = (DBL_IS_EQUAL(rvars1.vec_d[i], rvars1.missval)) ? 0.0 : layer_weight;
                }
            }
          else
            {
              field.init(varList[varID]);
              cdoReadRecord(streamID1, field);

              if (operatorID == VERTINT && IS_NOT_EQUAL(layer_thickness, 1.0)) vfarcmul(field, layer_thickness);
              if (lmean && IS_NOT_EQUAL(layer_weight, 1.0)) vfarcmul(field, layer_weight);

              if (field.nmiss || !rsamp1.empty())
                {
                  if (rsamp1.empty()) rsamp1.resize(gridsize, rvars1.nsamp);

                  if (field.memType == MemType::Float)
                    {
                      for (size_t i = 0; i < gridsize; i++)
                        if (!DBL_IS_EQUAL(field.vec_f[i], (float)rvars1.missval)) rsamp1.vec_d[i] += layer_weight;
                    }
                  else
                    {
                      for (size_t i = 0; i < gridsize; i++)
                        if (!DBL_IS_EQUAL(field.vec_d[i], rvars1.missval)) rsamp1.vec_d[i] += layer_weight;
                    }
                }

              if (lvarstd)
                {
                  if (IS_NOT_EQUAL(layer_weight, 1.0))
                    {
                      vfarsumqw(vars2[varID], field, layer_weight);
                      vfarsumw(rvars1, field, layer_weight);
                    }
                  else
                    {
                      vfarsumsumq(rvars1, vars2[varID], field);
                    }
                }
              else if (lrange)
                {
                  vfarmaxmin(rvars1, vars2[varID], field);
                }
              else
                {
                  vfarfun(rvars1, field, operfunc);
                }
            }
        }

      for (int varID = 0; varID < nvars; varID++)
        {
          const auto &rsamp1 = samp1[varID];
          auto &rvars1 = vars1[varID];

          if (rvars1.nsamp)
            {
              if (lmean)
                {
                  if (!rsamp1.empty())
                    vfardiv(rvars1, rsamp1);
                  else
                    vfarcdiv(rvars1, (double) rvars1.nsamp);
                }
              else if (lvarstd)
                {
                  if (!rsamp1.empty())
                    vfarstdvar_func(rvars1, vars2[varID], rsamp1, divisor);
                  else
                    vfarcstdvar_func(rvars1, vars2[varID], rvars1.nsamp, divisor);
                }
              else if (lrange)
                {
                  vfarsub(rvars1, vars2[varID]);
                }

              cdoDefRecord(streamID2, varID, 0);
              cdoWriteRecord(streamID2, rvars1);
              rvars1.nsamp = 0;
            }
        }

      tsID++;
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  vlistDestroy(vlistID2);

  cdoFinish();

  return nullptr;
}
