/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Vertint    ap2pl           Model air pressure level to pressure level interpolation
*/

#include <cdi.h>

#include "cdo_options.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "field_vinterp.h"
#include "stdnametable.h"
#include "util_string.h"
#include "const.h"
#include "cdo_zaxis.h"
#include "param_conversion.h"


static bool
is_height_axis(int zaxisID)
{
  bool isHeight = false;
  if (zaxisInqType(zaxisID) == ZAXIS_REFERENCE)
    {
      char units[CDI_MAX_NAME], stdname[CDI_MAX_NAME];
      int length = CDI_MAX_NAME;
      cdiInqKeyString(zaxisID, CDI_GLOBAL, CDI_KEY_UNITS, units, &length);
      length = CDI_MAX_NAME;
      cdiInqKeyString(zaxisID, CDI_GLOBAL, CDI_KEY_STDNAME, stdname, &length);
      if (cdo_cmpstr(stdname, "height") && *units == 0) isHeight = true;
    }
  return isHeight;
}

static void
change_height_zaxis(int nhlev, int vlistID1, int vlistID2, int zaxisID2)
{
  const auto nzaxis = vlistNzaxis(vlistID1);
  for (int iz = 0; iz < nzaxis; ++iz)
    {
      const auto zaxisID = vlistZaxis(vlistID1, iz);
      const auto nlevel = zaxisInqSize(zaxisID);
      if ((nlevel == nhlev || nlevel == (nhlev+1)) && is_height_axis(zaxisID)) vlistChangeZaxisIndex(vlistID2, iz, zaxisID2);
    }
}

template <typename T>
static void
calc_half_press(size_t gridsize, size_t nhlevf, const Varray<T> &full_press, size_t nhlevh, Varray<T> &half_press)
{
  for (size_t i = 0; i < gridsize; i++) half_press[i] = 0;
#ifdef _OPENMP
#pragma omp parallel for default(none) shared(nhlevf, gridsize, full_press, half_press)
#endif
  for (size_t k = 1; k < nhlevf; k++)
    {
      const auto full_press_km1 = &full_press[(k - 1) * gridsize];
      const auto full_press_k = &full_press[k * gridsize];
      auto half_press_k = &half_press[k * gridsize];
      for (size_t i = 0; i < gridsize; i++) half_press_k[i] = 0.5 * (full_press_km1[i] + full_press_k[i]);
    }
  for (size_t i = 0; i < gridsize; i++) half_press[(nhlevh - 1) * gridsize + i] = full_press[(nhlevf - 1) * gridsize + i];
}

static void
calc_half_press(const Field3D &full_press, Field3D &half_press)
{
  if (full_press.memType == MemType::Float)
    calc_half_press(full_press.gridsize, full_press.nlevels, full_press.vec_f, half_press.nlevels, half_press.vec_f);
  else
    calc_half_press(full_press.gridsize, full_press.nlevels, full_press.vec_d, half_press.nlevels, half_press.vec_d);
}

void *
Vertintap(void *process)
{
  enum
  {
    func_pl,
    func_hl
  };
  enum
  {
    type_lin,
    type_log
  };
  int varID, levelID;
  int nhlev = 0, nhlevf = 0, nhlevh = 0;
  int apressID = -1, dpressID = -1;
  int psID = -1;
  char stdname[CDI_MAX_NAME];
  bool extrapolate = false;

  cdoInitialize(process);

  // clang-format off
  const auto AP2PL     = cdoOperatorAdd("ap2pl",     func_pl, type_lin, "pressure levels in pascal");
  const auto AP2PLX    = cdoOperatorAdd("ap2plx",    func_pl, type_lin, "pressure levels in pascal");
  const auto AP2HL     = cdoOperatorAdd("ap2hl",     func_hl, type_lin, "height levels in meter");
  const auto AP2HLX    = cdoOperatorAdd("ap2hlx",    func_hl, type_lin, "height levels in meter");
  const auto AP2PL_LP  = cdoOperatorAdd("ap2pl_lp",  func_pl, type_log, "pressure levels in pascal");
  const auto AP2PLX_LP = cdoOperatorAdd("ap2plx_lp", func_pl, type_log, "pressure levels in pascal");
  // clang-format on

  const auto operatorID = cdoOperatorID();
  const bool useHeightLevel = cdoOperatorF1(operatorID) == func_hl;
  const bool useLogType = cdoOperatorF2(operatorID) == type_log;

  if (operatorID == AP2PL || operatorID == AP2HL || operatorID == AP2PL_LP)
    {
      const auto envstr = getenv("EXTRAPOLATE");
      if (envstr && isdigit((int) envstr[0]))
        {
          if (atoi(envstr) == 1) extrapolate = true;
          if (extrapolate) cdoPrint("Extrapolation of missing values enabled!");
        }
    }
  else if (operatorID == AP2PLX || operatorID == AP2HLX || operatorID == AP2PLX_LP)
    {
      extrapolate = true;
    }

  operatorInputArg(cdoOperatorEnter(operatorID));

  std::vector<double> plev;
  if (operatorArgc() == 1 && cdoOperatorArgv(0) == "default")
    {
      if (useHeightLevel)
        plev = { 10, 50, 100, 500, 1000, 5000, 10000, 15000, 20000, 25000, 30000 };
      else
        plev = { 100000, 92500, 85000, 70000, 60000, 50000, 40000, 30000, 25000, 20000, 15000, 10000, 7000, 5000, 3000, 2000, 1000 };
    }
  else
    {
      plev = cdoArgvToFlt(cdoGetOperArgv());
    }

  int nplev = plev.size();

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  const auto gridsize = vlist_check_gridsize(vlistID1);

  const auto zaxistype = useHeightLevel ? ZAXIS_HEIGHT : ZAXIS_PRESSURE;
  const auto zaxisIDp = zaxisCreate(zaxistype, nplev);
  zaxisDefLevels(zaxisIDp, plev.data());

  VarList varList1;
  varListInit(varList1, vlistID1);
  varListSetUniqueMemtype(varList1);
  const auto memtype = varList1[0].memType;

  const auto nvars = vlistNvars(vlistID1);

  for (varID = 0; varID < nvars; varID++)
    {
      int length = CDI_MAX_NAME;
      cdiInqKeyString(vlistID1, varID, CDI_KEY_STDNAME, stdname, &length);
      cstrToLowerCase(stdname);

      // clang-format off
      if      (cdo_cmpstr(stdname, var_stdname(surface_air_pressure))) psID = varID;
      else if (cdo_cmpstr(stdname, var_stdname(air_pressure)))         apressID = varID;
      else if (cdo_cmpstr(stdname, var_stdname(pressure_thickness)))   dpressID = varID;
      // clang-format on
    }

  if (Options::cdoVerbose)
    {
      cdoPrint("Found:");
      // clang-format off
      if (-1 != psID)     cdoPrint("  %s -> %s", var_stdname(surface_air_pressure), varList1[psID].name);
      if (-1 != apressID) cdoPrint("  %s -> %s", var_stdname(air_pressure), varList1[apressID].name);
      if (-1 != dpressID) cdoPrint("  %s -> %s", var_stdname(pressure_thickness), varList1[dpressID].name);
      // clang-format on
    }

  if (apressID == -1) cdoAbort("%s not found!", var_stdname(air_pressure));

  int zaxisIDh = -1;
  const auto nzaxis = vlistNzaxis(vlistID1);
  for (int i = 0; i < nzaxis; i++)
    {
      const auto zaxisID = vlistZaxis(vlistID1, i);
      if (zaxisID == varList1[apressID].zaxisID)
        {
          bool mono_level = true;
          const auto nlevels = zaxisInqSize(zaxisID);

          if (nlevels > 1 && is_height_axis(zaxisID))
            {
              Varray<double> level(nlevels);
              cdoZaxisInqLevels(zaxisID, &level[0]);
              int l;
              for (l = 0; l < nlevels; l++)
                {
                  if ((l + 1) != (int) (level[l] + 0.5)) break;
                }
              if (l == nlevels) mono_level = true;
            }

          if (nlevels > 1 && is_height_axis(zaxisID) && mono_level)
            {
              zaxisIDh = zaxisID;
              nhlev = nlevels;
              nhlevf = nhlev;
              nhlevh = nhlevf + 1;

              break;
            }
        }
    }

  change_height_zaxis(nhlev, vlistID1, vlistID2, zaxisIDp);

  VarList varList2;
  varListInit(varList2, vlistID2);
  varListSetMemtype(varList2, memtype);

  std::vector<bool> vars(nvars), varinterp(nvars);
  std::vector<std::vector<size_t>> varnmiss(nvars);
  Field3DVector vardata1(nvars), vardata2(nvars);

  const auto maxlev = nhlevh > nplev ? nhlevh : nplev;

  std::vector<size_t> pnmiss;
  if (!extrapolate) pnmiss.resize(nplev);

  // check levels
  if (zaxisIDh != -1)
    {
      const auto nlev = zaxisInqSize(zaxisIDh);
      if (nlev != nhlev) cdoAbort("Internal error, wrong number of height level!");
    }

  std::vector<int> vert_index;
  Field ps_prog;
  Field3D full_press, half_press;
  if (zaxisIDh != -1 && gridsize > 0)
    {
      vert_index.resize(gridsize * nplev);

      CdoVar var3Df, var3Dh;
      var3Df.gridsize = gridsize;
      var3Df.nlevels = nhlevf;
      var3Df.memType = memtype;
      full_press.init(var3Df);

      var3Dh.gridsize = gridsize;
      var3Dh.nlevels = nhlevh;
      var3Dh.memType = memtype;
      half_press.init(var3Dh);
    }
  else
    cdoWarning("No 3D variable with generalized height level found!");

  if (useHeightLevel)
    {
      Varray<double> phlev(nplev);
      height2pressure(phlev.data(), plev.data(), nplev);

      if (Options::cdoVerbose)
        for (int i = 0; i < nplev; ++i) cdoPrint("level = %d   height = %g   pressure = %g", i + 1, plev[i], phlev[i]);

      plev = phlev;
    }

  if (useLogType)
    for (int k = 0; k < nplev; k++) plev[k] = std::log(plev[k]);

  for (varID = 0; varID < nvars; varID++)
    {
      const auto gridID = varList1[varID].gridID;
      const auto zaxisID = varList1[varID].zaxisID;
      const auto nlevels = varList1[varID].nlevels;

      if (gridInqType(gridID) == GRID_SPECTRAL) cdoAbort("Spectral data unsupported!");

      vardata1[varID].init(varList1[varID]);

      varinterp[varID]
          = (zaxisID == zaxisIDh || (is_height_axis(zaxisID) && zaxisIDh != -1 && (nlevels == nhlevh || nlevels == nhlevf)));

      if (varinterp[varID])
        {
          varnmiss[varID].resize(maxlev, 0);
          vardata2[varID].init(varList2[varID]);
        }
      else
        {
          if (is_height_axis(zaxisID) && zaxisIDh != -1 && nlevels > 1)
            cdoWarning("Parameter %d has wrong number of levels, skipped! (name=%s nlevel=%d)", varID + 1, varList1[varID].name, nlevels);

          varnmiss[varID].resize(nlevels);
        }
    }

  if (zaxisIDh != -1 && psID == -1)
    {
      if (dpressID != -1)
        cdoWarning("Surface pressure not found - set to vertical sum of %s!", var_stdname(pressure_thickness));
      else
        cdoWarning("Surface pressure not found - set to lower bound of %s!", var_stdname(air_pressure));
    }

  for (varID = 0; varID < nvars; ++varID)
    {
      if (varinterp[varID] && varList1[varID].timetype == TIME_CONSTANT) vlistDefVarTimetype(vlistID2, varID, TIME_VARYING);
    }

  const auto streamID2 = cdoOpenWrite(1);

  cdoDefVlist(streamID2, vlistID2);

  int tsID = 0;
  while (true)
    {
      const auto nrecs = cdoStreamInqTimestep(streamID1, tsID);
      if (nrecs == 0) break;

      for (varID = 0; varID < nvars; ++varID)
        {
          vars[varID] = false;
          const auto nlevels = varList1[varID].nlevels;
          for (levelID = 0; levelID < nlevels; levelID++) varnmiss[varID][levelID] = 0;
        }

      taxisCopyTimestep(taxisID2, taxisID1);
      cdoDefTimestep(streamID2, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          cdoInqRecord(streamID1, &varID, &levelID);
          cdoReadRecord(streamID1, vardata1[varID], levelID, &varnmiss[varID][levelID]);
          vars[varID] = true;
        }

      for (varID = 0; varID < nvars; varID++)
        if (varinterp[varID]) vars[varID] = true;

      if (zaxisIDh != -1)
        {
          if (psID != -1)
            {
              ps_prog.init(varList1[psID]);
              fieldCopy(vardata1[psID], ps_prog);
            }
          else if (dpressID != -1)
            {
              ps_prog.init(varList1[dpressID]);
              fieldFill(ps_prog, 0);
              for (int k = 0; k < nhlevf; ++k) fieldAdd(ps_prog, vardata1[dpressID], k);
            }
          else
            {
              ps_prog.init(varList1[apressID]);
              fieldCopy(vardata1[apressID], nhlevf - 1, ps_prog);
            }

          // check range of ps_prog
          const auto mm = fieldMinMax(ps_prog);
          if (mm.min < MIN_PS || mm.max > MAX_PS)
            cdoWarning("Surface pressure out of range (min=%g max=%g)!", mm.min, mm.max);

          fieldCopy(vardata1[apressID], full_press);

          calc_half_press(full_press, half_press);

          if (useLogType)
            {
              if (memtype == MemType::Float)
                {
                  for (size_t i = 0; i < gridsize; i++) ps_prog.vec_f[i] = std::log(ps_prog.vec_f[i]);
                  for (size_t ki = 0; ki < nhlevh * gridsize; ki++) half_press.vec_f[ki] = std::log(half_press.vec_f[ki]);
                  for (size_t ki = 0; ki < nhlevf * gridsize; ki++) full_press.vec_f[ki] = std::log(full_press.vec_f[ki]);
                }
              else
                {
                  for (size_t i = 0; i < gridsize; i++) ps_prog.vec_d[i] = std::log(ps_prog.vec_d[i]);
                  for (size_t ki = 0; ki < nhlevh * gridsize; ki++) half_press.vec_d[ki] = std::log(half_press.vec_d[ki]);
                  for (size_t ki = 0; ki < nhlevf * gridsize; ki++) full_press.vec_d[ki] = std::log(full_press.vec_d[ki]);
                }
            }

          genind(vert_index, plev, full_press, gridsize);
          if (!extrapolate) genindmiss(vert_index, plev, gridsize, ps_prog, pnmiss);
        }

      for (varID = 0; varID < nvars; varID++)
        {
          if (vars[varID])
            {
              if (tsID > 0 && varList1[varID].timetype == TIME_CONSTANT) continue;

              if (varinterp[varID])
                {
                  const auto nlevels = varList1[varID].nlevels;
                  if (nlevels != nhlevf && nlevels != nhlevh)
                    cdoAbort("Number of generalized height level differ from full/half level (param=%s)!", varList1[varID].name);

                  for (levelID = 0; levelID < nlevels; levelID++)
                    {
                      if (varnmiss[varID][levelID]) cdoAbort("Missing values unsupported for this operator!");
                    }

                  vertical_interp_X(nlevels, full_press, half_press, vardata1[varID], vardata2[varID], vert_index, plev, gridsize);

                  if (!extrapolate) varrayCopy(nplev, pnmiss, varnmiss[varID]);
                }

              for (levelID = 0; levelID < varList2[varID].nlevels; levelID++)
                {
                  cdoDefRecord(streamID2, varID, levelID);
                  cdoWriteRecord(streamID2, varinterp[varID] ? vardata2[varID] : vardata1[varID], levelID, varnmiss[varID][levelID]);
                }
            }
        }

      tsID++;
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
