/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2021 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

#ifdef HAVE_CONFIG_H
#include "config.h"  // restrict
#endif

#include <cstdio>
#include <cstring>
#include <cmath>
#include <cassert>

#include "compare.h"
#include "varray.h"
#include "constants.h"
#include "vertical_interp.h"

constexpr double SCALEHEIGHT = -7000.;
constexpr double SCALESLP = 101325.0;

void
height_to_pressure(double *phlev, const double *hlev, const long nphlev)
{
  for (long k = 0; k < nphlev; k++)
    {
      /*
        unitsel == 1 : hlev[k] is given in meters
        unitsel == 2 : hlev[k] is given in kilometers
        height_to_pressure needs meters (MKSC-standard)
      */
      phlev[k] = SCALESLP * std::exp(hlev[k] / SCALEHEIGHT);
    }
}

void
pressure_to_height(double *hlev, const double *plev, const long nphlev)
{
  for (long k = 0; k < nphlev; k++)
    {
      hlev[k] = std::log(plev[k] / SCALESLP) * SCALEHEIGHT;
    }
}

template <typename T>
void
vct_to_hybrid_pressure(T *restrict fullp, T *halfp, const double *restrict vct, const T *restrict ps, long nhlev, long ngp)
{
  assert(ps != nullptr);

  auto halfpres = halfp;
  for (long lh = 0; lh < nhlev; lh++)
    {
      const auto zp = vct[lh];
      const auto ze = vct[lh + nhlev + 1];
      for (long i = 0; i < ngp; i++) halfpres[i] = zp + ze * ps[i];
      halfpres += ngp;
    }
  array_copy(ngp, ps, halfpres);

  if (fullp)
    {
      halfpres = halfp;
      for (long i = 0; i < ngp * nhlev; i++) fullp[i] = 0.5 * (halfpres[i] + halfpres[i + ngp]);
    }
}

template void vct_to_hybrid_pressure(float *fullp, float *halfp, const double *vct, const float *ps, long nhlev, long ngp);
template void vct_to_hybrid_pressure(double *fullp, double *halfp, const double *vct, const double *ps, long nhlev, long ngp);

void
extrapolate_P(double *restrict slp, const double *restrict halfp, const double *restrict fullp, const double *restrict geop,
              const double *restrict temp, long ngp)
{
  constexpr double zlapse = 0.0065;
  const double zrg = 1.0 / PlanetGrav;

  for (long j = 0; j < ngp; ++j)
    {
      if (geop[j] < 0.0001 && geop[j] > -0.0001)
        slp[j] = halfp[j];
      else
        {
          double alpha = PlanetRD * zlapse * zrg;
          double tstar = (1.0 + alpha * (halfp[j] / fullp[j] - 1.0)) * temp[j];

          if (tstar < 255.0) tstar = 0.5 * (255.0 + tstar);

          double tmsl = tstar + zlapse * zrg * geop[j];
          if (tmsl > 290.5 && tstar > 290.5)
            {
              tstar = 0.5 * (290.5 + tstar);
              tmsl = tstar;
            }

          if (tmsl - tstar < 0.000001 && tstar - tmsl < 0.000001)
            alpha = 0.0;
          else if (geop[j] > 0.0001 || geop[j] < -0.0001)
            alpha = PlanetRD * (tmsl - tstar) / geop[j];

          const double zprt = geop[j] / (PlanetRD * tstar);
          const double zprtal = zprt * alpha;
          slp[j] = halfp[j] * std::exp(zprt * (1.0 - zprtal * (0.5 - zprtal / 3.0)));
        }
    }
}

static inline double
extrapolate_T(double pres, double halfp, double fullp, double geop, double temp)
{
  double peval = 0.0;
  constexpr double zlapse = 0.0065;
  const double zrg = 1.0 / PlanetGrav;
  double tstar = (1.0 + zlapse * PlanetRD * zrg * (halfp / fullp - 1.0)) * temp;
  const double ztsz = tstar;
  const double z1 = tstar + zlapse * zrg * geop;

  if (tstar < 255.0) tstar = 0.5 * (255.0 + tstar);

  double ztmsl = tstar + zlapse * zrg * geop;

  if (ztmsl > 290.5 && tstar > 290.5)
    {
      tstar = 0.5 * (290.5 + tstar);
      // ztmsl = tstar;
    }

  // if (ztmsl > 290.5 && tstar <= 290.5) ztmsl = 290.5;

  if (pres <= halfp)
    {
      peval = ((halfp - pres) * temp + (pres - fullp) * tstar) / (halfp - fullp);
    }
  else
    {
      ztmsl = z1;
      tstar = ztsz;
      const double zhts = geop * zrg;

      if (zhts > 2000. && z1 > 298.)
        {
          ztmsl = 298.;
          if (zhts < 2500.) ztmsl = 0.002 * ((2500. - zhts) * z1 + (zhts - 2000.) * ztmsl);
        }

      double zalph;
      if ((ztmsl - tstar) < 0.000001)
        zalph = 0.;
      else if (geop > 0.0001 || geop < -0.0001)
        zalph = PlanetRD * (ztmsl - tstar) / geop;
      else
        zalph = PlanetRD * zlapse * zrg;

      const double zalp = zalph * std::log(pres / halfp);
      peval = tstar * (1.0 + zalp * (1.0 + zalp * (0.5 + 0.16666666667 * zalp)));
    }

  return peval;
}

static inline double
extrapolate_Z(double pres, double halfp, double fullp, double geop, double temp)
{
  constexpr double zlapse = 0.0065;
  constexpr double ztlim = 290.5;
  const double zrg = 1.0 / PlanetGrav;
  double alpha = PlanetRD * zlapse * zrg;
  double tstar = (1.0 + alpha * (halfp / fullp - 1.0)) * temp;

  if (tstar < 255.0) tstar = 0.5 * (255.0 + tstar);

  double tmsl = tstar + zlapse * zrg * geop;

  if (tmsl > ztlim && tstar > ztlim)
    {
      tstar = 0.5 * (ztlim + tstar);
      tmsl = tstar;
    }

  if (tmsl > ztlim && tstar <= ztlim) tmsl = ztlim;

  if (tmsl - tstar < 0.000001 && tstar - tmsl < 0.000001)
    alpha = 0.0;
  else if (geop > 0.0001 || geop < -0.0001)
    alpha = PlanetRD * (tmsl - tstar) / geop;

  const double zalp = std::log(pres / halfp);
  const double zalpal = zalp * alpha;

  return (geop - PlanetRD * tstar * zalp * (1.0 + zalpal * (0.5 + zalpal / 6.0))) * zrg;
}

template <typename T>
void
vertical_interp_T(const T *restrict geop, const T *restrict gt, T *pt, const T *restrict fullp, const T *restrict halfp,
                  const int *vertIndex, const double *restrict plev, long nplev, long ngp, long nhlev, double missval)
{
#ifdef _OPENMP
#pragma omp parallel for default(none) shared(geop, gt, pt, fullp, halfp, vertIndex, plev, nplev, ngp, nhlev, missval)
#endif
  for (long lp = 0; lp < nplev; lp++)
    {
      long nl, nh;
      const auto pres = plev[lp];
      const int *restrict vertIndexLev = vertIndex + lp * ngp;
      auto ptl = pt + lp * ngp;

      for (long i = 0; i < ngp; i++)
        {
          nl = vertIndexLev[i];
          if (nl < 0)
            ptl[i] = missval;
          else
            {
              if (nl > nhlev - 2)
                {
                  ptl[i] = extrapolate_T(pres, halfp[nhlev * ngp + i], fullp[(nhlev - 1) * ngp + i], geop[i],
                                         gt[(nhlev - 1) * ngp + i]);
                }
              else
                {
                  nh = nl + 1;
                  ptl[i] = gt[nl * ngp + i]
                           + (pres - fullp[nl * ngp + i]) * (gt[nh * ngp + i] - gt[nl * ngp + i])
                                 / (fullp[nh * ngp + i] - fullp[nl * ngp + i]);
                }
            }
        }
    }
}

// Explicit instantiation
template void vertical_interp_T(const float *restrict geop, const float *restrict gt, float *pt, const float *restrict fullp,
                                const float *restrict halfp, const int *vertIndex, const double *restrict plev, long nplev,
                                long ngp, long nhlev, double missval);
template void vertical_interp_T(const double *restrict geop, const double *restrict gt, double *pt, const double *restrict fullp,
                                const double *restrict halfp, const int *vertIndex, const double *restrict plev, long nplev,
                                long ngp, long nhlev, double missval);

template <typename T>
void
vertical_interp_Z(const T *restrict geop, const T *restrict gz, T *pz, const T *restrict fullp, const T *restrict halfp,
                  const int *vertIndex, const T *restrict gt, const double *restrict plev, long nplev, long ngp, long nhlev,
                  double missval)
{
  assert(geop != nullptr);
  assert(gz != nullptr);
  assert(pz != nullptr);
  assert(fullp != nullptr);
  assert(halfp != nullptr);

#ifdef _OPENMP
#pragma omp parallel for default(none) shared(geop, gz, pz, fullp, halfp, vertIndex, gt, plev, nplev, ngp, nhlev, missval)
#endif
  for (long lp = 0; lp < nplev; lp++)
    {
      long nl, nh;
      const auto pres = plev[lp];
      const int *restrict vertIndexLev = vertIndex + lp * ngp;
      auto pzl = pz + lp * ngp;

      for (long i = 0; i < ngp; i++)
        {
          nl = vertIndexLev[i];
          if (nl < 0)
            pzl[i] = missval;
          else
            {
              if (pres > halfp[(nl + 1) * ngp + i]) nl++;

              if (nl > nhlev - 1)
                {
                  pzl[i] = extrapolate_Z(pres, halfp[nhlev * ngp + i], fullp[(nhlev - 1) * ngp + i], geop[i],
                                         gt[(nhlev - 1) * ngp + i]);
                }
              else
                {
                  nh = nl + 1;
                  pzl[i] = gz[nl * ngp + i]
                           + (pres - halfp[nl * ngp + i]) * (gz[nh * ngp + i] - gz[nl * ngp + i])
                                 / (halfp[nh * ngp + i] - halfp[nl * ngp + i]);
                }
            }
        }
    }
}

// Explicit instantiation
template void vertical_interp_Z(const float *restrict geop, const float *restrict gz, float *pz, const float *restrict fullp,
                                const float *restrict halfp, const int *vertIndex, const float *restrict gt,
                                const double *restrict plev, long nplev, long ngp, long nhlev, double missval);
template void vertical_interp_Z(const double *restrict geop, const double *restrict gz, double *pz, const double *restrict fullp,
                                const double *restrict halfp, const int *vertIndex, const double *restrict gt,
                                const double *restrict plev, long nplev, long ngp, long nhlev, double missval);

template <typename T>
static inline double
vertical_interp_X_kernel(const T *restrict gt, const T *restrict hyb_press, long nl, double pres, long ngp, long nhlev,
                         double missval)
{
  const auto nh = nl + ngp;
  return (nl < 0) ? missval
                  : ((nh >= ngp * nhlev) ? gt[nl]
                                         : gt[nl] + (pres - hyb_press[nl]) * (gt[nh] - gt[nl]) / (hyb_press[nh] - hyb_press[nl]));
}

template <typename T>
void
vertical_interp_X(const T *restrict gt, T *pt, const T *hyb_press, const int *vertIndex, const double *restrict plev, long nplev,
                  long ngp, long nhlev, double missval)
{
  if (nplev > 3)
    {
#ifdef _OPENMP
#pragma omp parallel for default(none) shared(gt, pt, hyb_press, vertIndex, plev, nplev, ngp, nhlev, missval)
#endif
      for (long lp = 0; lp < nplev; lp++)
        {
          auto pres = plev[lp];
          const int *restrict vertIndexLev = vertIndex + lp * ngp;
          auto ptl = pt + lp * ngp;
          for (long i = 0; i < ngp; i++)
            {
              ptl[i] = vertical_interp_X_kernel(gt, hyb_press, vertIndexLev[i] * ngp + i, pres, ngp, nhlev, missval);
            }
        }
    }
  else
    {
      for (long lp = 0; lp < nplev; lp++)
        {
          auto pres = plev[lp];
          const int *restrict vertIndexLev = vertIndex + lp * ngp;
          auto *restrict ptl = pt + lp * ngp;
#ifdef _OPENMP
#pragma omp parallel for default(none) shared(gt, ptl, hyb_press, vertIndexLev, pres, ngp, nhlev, missval)
#endif
          for (long i = 0; i < ngp; i++)
            {
              ptl[i] = vertical_interp_X_kernel(gt, hyb_press, vertIndexLev[i] * ngp + i, pres, ngp, nhlev, missval);
            }
        }
    }
}

// Explicit instantiation
template void vertical_interp_X(const float *restrict gt, float *pt, const float *hyb_press, const int *vertIndex,
                                const double *restrict plev, long nplev, long ngp, long nhlev, double missval);
template void vertical_interp_X(const double *restrict gt, double *pt, const double *hyb_press, const int *vertIndex,
                                const double *restrict plev, long nplev, long ngp, long nhlev, double missval);

template <typename T>
void
gen_vert_index(int *vertIndex, const double *restrict plev, const T *restrict fullp, long ngp, long nplev, long nhlev,
               bool lreverse)
{
  varray_fill(ngp * nplev, vertIndex, 0);

#ifdef _OPENMP
#pragma omp parallel for default(none) shared(vertIndex, plev, fullp, ngp, nplev, nhlev, lreverse)
#endif
  for (long lp = 0; lp < nplev; lp++)
    {
      const T pres = plev[lp];
      auto *restrict vertIndexLev = vertIndex + lp * ngp;
      for (long lh = 0; lh < nhlev; lh++)
        {
          const auto *restrict fullpx = fullp + lh * ngp;
          if (lreverse)
            {
              for (long i = 0; i < ngp; i++)
                {
                  if (pres < fullpx[i]) vertIndexLev[i] = static_cast<int>(lh);
                }
            }
          else
            {
              for (long i = 0; i < ngp; i++)
                {
                  if (pres > fullpx[i]) vertIndexLev[i] = static_cast<int>(lh);
                }
            }
        }
    }
}

// Explicit instantiation
template void gen_vert_index(int *vertIndex, const double *plev, const float *fullp, long ngp, long nplev, long nhlev,
                             bool lreverse);
template void gen_vert_index(int *vertIndex, const double *plev, const double *fullp, long ngp, long nplev, long nhlev,
                             bool lreverse);

template <typename T>
void
gen_vert_index_mv(int *vertIndex, const double *restrict plev, long ngp, long nplev, const T *restrict psProg,
                  size_t *restrict pnmiss, bool lreverse)
{
#ifdef _OPENMP
#pragma omp parallel for default(none) shared(vertIndex, plev, ngp, nplev, psProg, pnmiss, lreverse)
#endif
  for (long lp = 0; lp < nplev; lp++)
    {
      pnmiss[lp] = 0;
      const T pres = plev[lp];
      auto *restrict vertIndexLev = vertIndex + lp * ngp;

      if (lreverse)
        {
          for (long i = 0; i < ngp; i++)
            {
              if (pres < psProg[i])
                {
                  vertIndexLev[i] = -1;
                  pnmiss[lp]++;
                }
            }
        }
      else
        {
          for (long i = 0; i < ngp; i++)
            {
              if (pres > psProg[i])
                {
                  vertIndexLev[i] = -1;
                  pnmiss[lp]++;
                }
            }
        }
    }
}

// Explicit instantiation
template void gen_vert_index_mv(int *vertIndex, const double *plev, long ngp, long nplev, const float *psProg, size_t *pnmiss,
                                bool lreverse);
template void gen_vert_index_mv(int *vertIndex, const double *plev, long ngp, long nplev, const double *psProg, size_t *pnmiss,
                                bool lreverse);
