/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Diff       diff            Compare two datasets
*/

#include <map>
#include <algorithm>

#include <cdi.h>

#include "functs.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "param_conversion.h"
#include "mpmo_color.h"
#include "cdo_options.h"
#include "printinfo.h"
#include "pmlist.h"
#include "cdo_zaxis.h"


static void
varListSynchronizeMemtype(VarList &varList1, VarList &varList2, const std::map<int, int> &mapOfVarIDs)
{
  const int nvars = varList1.size();
  for (int varID1 = 0; varID1 < nvars; varID1++)
    {
      const auto it = mapOfVarIDs.find(varID1);
      if (it == mapOfVarIDs.end()) continue;

      const auto varID2 = it->second;

      if (varList1[varID1].memType == MemType::Float && varList2[varID2].memType == MemType::Double)
        varList1[varID1].memType = MemType::Double;
      else if (varList1[varID1].memType == MemType::Double && varList2[varID2].memType == MemType::Float)
        varList2[varID2].memType = MemType::Double;
    }
}

template <typename T>
static void
diff_kernel(bool hasMissvals, const T v1, const T v2, const T missval1, const T missval2,
            size_t &ndiff, bool &dsgn, bool &zero, double &absm, double &relm)
{
  const auto v1isnan = std::isnan(v1);
  const auto v2isnan = std::isnan(v2);
  const auto v1ismissval = hasMissvals ? DBL_IS_EQUAL(v1, missval1) : false;
  const auto v2ismissval = hasMissvals ? DBL_IS_EQUAL(v2, missval2) : false;
  if ((v1isnan && !v2isnan) || (!v1isnan && v2isnan))
    {
      ndiff++;
      relm = 1.0;
    }
  else if (hasMissvals == false || (!v1ismissval && !v2ismissval))
    {
      const double absdiff = std::fabs(v1 - v2);
      if (absdiff > 0.0) ndiff++;

      absm = std::max(absm, absdiff);

      const auto vv = v1 * v2;
      if (vv < 0.0)
        dsgn = true;
      else if (IS_EQUAL(vv, 0.0))
        zero = true;
      else
        relm = std::max(relm, absdiff / std::max(std::fabs(v1), std::fabs(v2)));
    }
  else if ((v1ismissval && !v2ismissval) || (!v1ismissval && v2ismissval))
    {
      ndiff++;
      relm = 1.0;
    }
}

static void
diff_kernel(size_t i, bool hasMissvals, const Field &field1, const Field &field2, size_t &ndiff, bool &dsgn, bool &zero, double &absm, double &relm)
{
  if (field1.memType == MemType::Float)
    diff_kernel(hasMissvals, field1.vec_f[i], field2.vec_f[i], (float)field1.missval, (float)field2.missval, ndiff, dsgn, zero, absm, relm);
  else
    diff_kernel(hasMissvals, field1.vec_d[i], field2.vec_d[i], field1.missval, field2.missval, ndiff, dsgn, zero, absm, relm);
}

static void
use_real_part(const size_t gridsize, Field &field)
{
  if (field.memType == MemType::Float)
    for (size_t i = 0; i < gridsize; ++i) field.vec_f[i] = field.vec_f[i * 2];
  else
    for (size_t i = 0; i < gridsize; ++i) field.vec_d[i] = field.vec_d[i * 2];
}

static void
diffGetParameter(double &abslim, double &abslim2, double &rellim, int &mapflag, int &maxcount)
{
  const auto pargc = operatorArgc();
  if (pargc)
    {
      const auto pargv = cdoGetOperArgv();

      KVList kvlist;
      kvlist.name = "DIFF";
      if (kvlist.parseArguments(pargc, pargv) != 0) cdoAbort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdoAbort("Too many values for parameter key >%s<!", key.c_str());
          if (kv.nvalues < 1) cdoAbort("Missing value for parameter key >%s<!", key.c_str());
          const auto &value = kv.values[0];

          // clang-format off
          if      (key == "abslim")   abslim = parameter2double(value);
          else if (key == "abslim2")  abslim2 = parameter2double(value);
          else if (key == "rellim")   rellim = parameter2double(value);
          else if (key == "maxcount") maxcount = parameter2int(value);
          else if (key == "names")
            {
              if      (value == "left")      mapflag = 1;
              else if (value == "right")     mapflag = 2;
              else if (value == "intersect") mapflag = 3;
              else cdoAbort("Invalid value for key >%s< (names=<left/right/intersect>)", key.c_str(), value.c_str());
            }
          else cdoAbort("Invalid parameter key >%s<!", key.c_str());
          // clang-format on
        }
    }
}

void *
Diff(void *process)
{
  bool lhead = true;
  int nrecs, nrecs2;
  int varID1, varID2 = -1;
  int levelID;
  int ndrec = 0, nd2rec = 0, ngrec = 0;
  char paramstr[32];

  cdoInitialize(process);

  // clang-format off
  const auto DIFF  = cdoOperatorAdd("diff",  0, 0, nullptr);
  const auto DIFFP = cdoOperatorAdd("diffp", 0, 0, nullptr);
  const auto DIFFN = cdoOperatorAdd("diffn", 0, 0, nullptr);
  const auto DIFFC = cdoOperatorAdd("diffc", 0, 0, nullptr);
  // clang-format on

  const auto operatorID = cdoOperatorID();

  int mapflag = 0, maxcount = 0;
  double abslim = 0.0, abslim2 = 1.e-3, rellim = 1.0;
  diffGetParameter(abslim, abslim2, rellim, mapflag, maxcount);

  if (rellim < -1.e33 || rellim > 1.e+33) cdoAbort("Rel. limit out of range!");
  if (abslim < -1.e33 || abslim > 1.e+33) cdoAbort("Abs. limit out of range!");
  if (abslim2 < -1.e33 || abslim2 > 1.e+33) cdoAbort("Abs2. limit out of range!");

  const auto streamID1 = cdoOpenRead(0);
  const auto streamID2 = cdoOpenRead(1);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = cdoStreamInqVlist(streamID2);

  const auto nvars = vlistNvars(vlistID1);
  std::map<int, int> mapOfVarIDs;

  if (mapflag == 0)
    {
      vlistCompare(vlistID1, vlistID2, CMP_ALL);
      for (int varID = 0; varID < nvars; ++varID) mapOfVarIDs[varID] = varID;
    }
  else
    {
      vlistMap(vlistID1, vlistID2, CMP_ALL, mapflag, mapOfVarIDs);
    }

  VarList varList1, varList2;
  varListInit(varList1, vlistID1);
  varListInit(varList2, vlistID2);
  varListSynchronizeMemtype(varList1, varList2, mapOfVarIDs);

  Field field1, field2;

  const auto taxisID = vlistInqTaxis(vlistID1);

  int indg = 0;
  int tsID = 0;
  while (true)
    {
      bool lstop = false;

      nrecs = cdoStreamInqTimestep(streamID1, tsID);
      const auto vdateString = dateToString(taxisInqVdate(taxisID));
      const auto vtimeString = timeToString(taxisInqVtime(taxisID));

      nrecs2 = cdoStreamInqTimestep(streamID2, tsID);

      if (nrecs == 0 || nrecs2 == 0) break;

      int recID2next = 0;

      for (int recID = 0; recID < nrecs; recID++)
        {
          cdoInqRecord(streamID1, &varID1, &levelID);

          auto it = mapOfVarIDs.find(varID1);
          if (it == mapOfVarIDs.end())
            {
              if (mapflag == 2 || mapflag == 3) continue;
              cdoAbort("Internal problem: varID1=%d not found!", varID1);
            }

          for (; recID2next < nrecs2; ++recID2next)
            {
              cdoInqRecord(streamID2, &varID2, &levelID);
              if (it->second == varID2)
                {
                  ++recID2next;
                  break;
                }
            }

          if (it->second != varID2 && recID2next == nrecs2)
            cdoAbort("Internal problem: varID2=%d not found in second stream!", it->second);

          indg += 1;

          const auto gridsize = varList1[varID1].gridsize;

          // checkrel = gridInqType(gridID) != GRID_SPECTRAL;
          const auto checkrel = true;

          cdiParamToString(varList1[varID1].param, paramstr, sizeof(paramstr));

          field1.init(varList1[varID1]);
          cdoReadRecord(streamID1, field1);
          if (varList1[varID1].nwpv == CDI_COMP) use_real_part(gridsize, field1);

          field2.init(varList2[varID2]);
          cdoReadRecord(streamID2, field2);
          if (varList2[varID2].nwpv == CDI_COMP) use_real_part(gridsize, field2);

          const auto hasMissvals = (field1.nmiss || field2.nmiss);
          size_t ndiff = 0;
          bool dsgn = false, zero = false;
          double absm = 0.0, relm = 0.0;

          for (size_t i = 0; i < gridsize; i++)
            {
              diff_kernel(i, hasMissvals, field1, field2, ndiff, dsgn, zero, absm, relm);
            }

          if (!Options::silentMode || Options::cdoVerbose)
            {
              if (absm > abslim || (checkrel && relm >= rellim) || Options::cdoVerbose)
                {
                  if (lhead)
                    {
                      lhead = false;

                      fprintf(stdout, "               Date     Time   Level Gridsize    Miss ");
                      fprintf(stdout, "   Diff ");
                      fprintf(stdout, ": S Z  Max_Absdiff Max_Reldiff : ");

                      if (operatorID == DIFFN)
                        fprintf(stdout, "Parameter name");
                      else if (operatorID == DIFF || operatorID == DIFFP)
                        fprintf(stdout, "Parameter ID");
                      else if (operatorID == DIFFC)
                        fprintf(stdout, "Code number");

                      fprintf(stdout, "\n");
                    }

                  fprintf(stdout, "%6d ", indg);
                  fprintf(stdout, ":");

                  set_text_color(stdout, MAGENTA);
                  fprintf(stdout, "%s %s ", vdateString.c_str(), vtimeString.c_str());
                  reset_text_color(stdout);
                  set_text_color(stdout, GREEN);
                  fprintf(stdout, "%7g ", cdoZaxisInqLevel(varList1[varID1].zaxisID, levelID));
                  fprintf(stdout, "%8zu %7zu ", gridsize, std::max(field1.nmiss, field2.nmiss));
                  fprintf(stdout, "%7zu ", ndiff);
                  reset_text_color(stdout);

                  fprintf(stdout, ":");
                  fprintf(stdout, " %c %c ", dsgn ? 'T' : 'F', zero ? 'T' : 'F');
                  set_text_color(stdout, BLUE);
                  fprintf(stdout, "%#12.5g%#12.5g", absm, relm);
                  reset_text_color(stdout);
                  fprintf(stdout, " : ");

                  set_text_color(stdout, BRIGHT, GREEN);
                  if (operatorID == DIFFN)
                    fprintf(stdout, "%-11s", varList1[varID1].name);
                  else if (operatorID == DIFF || operatorID == DIFFP)
                    fprintf(stdout, "%-11s", paramstr);
                  else if (operatorID == DIFFC)
                    fprintf(stdout, "%4d", varList1[varID1].code);
                  reset_text_color(stdout);

                  fprintf(stdout, "\n");
                }
            }

          ngrec++;
          if (absm > abslim || (checkrel && relm >= rellim)) ndrec++;
          if (absm > abslim2 || (checkrel && relm >= rellim)) nd2rec++;

          if (maxcount > 0 && ndrec >= maxcount)
            {
              lstop = true;
              break;
            }
        }

      if (lstop) break;

      tsID++;
    }

  if (ndrec > 0)
    {
      Options::cdoExitStatus = 1;

      set_text_color(stdout, BRIGHT, RED);
      fprintf(stdout, "  %d of %d records differ", ndrec, ngrec);
      reset_text_color(stdout);
      fprintf(stdout, "\n");

      if (ndrec != nd2rec && abslim < abslim2) fprintf(stdout, "  %d of %d records differ more than %g\n", nd2rec, ngrec, abslim2);
      //  fprintf(stdout, "  %d of %d records differ more then one thousandth\n", nprec, ngrec);
    }

  if (nrecs == 0 && nrecs2 > 0) cdoWarning("stream2 has more time steps than stream1!");
  if (nrecs > 0 && nrecs2 == 0) cdoWarning("stream1 has more time steps than stream2!");

  cdoStreamClose(streamID1);
  cdoStreamClose(streamID2);

  cdoFinish();

  return nullptr;
}
