/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

        Timstat2        timcor      correlates two data files on the same grid
*/

#include <cdi.h>

#include "functs.h"
#include "process_int.h"
#include "cdo_vlist.h"

// correlation in time
static void
correlation_init(bool hasMissValues, size_t gridsize, const Varray<double> &x, const Varray<double> &y, double xmv, double ymv,
                 Varray<size_t> &nofvals, Varray<double> &work0, Varray<double> &work1, Varray<double> &work2, Varray<double> &work3, Varray<double> &work4)
{
  if (hasMissValues)
    {
#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
      for (size_t i = 0; i < gridsize; ++i)
        {
          if ((!DBL_IS_EQUAL(x[i], xmv)) && (!DBL_IS_EQUAL(y[i], ymv)))
            {
              work0[i] += x[i];
              work1[i] += y[i];
              work2[i] += x[i] * x[i];
              work3[i] += y[i] * y[i];
              work4[i] += x[i] * y[i];
              nofvals[i]++;
            }
        }
    }
  else
    {
#ifdef HAVE_OPENMP4
#pragma omp parallel for simd default(shared)
#elif _OPENMP
#pragma omp parallel for default(shared)
#endif
      for (size_t i = 0; i < gridsize; ++i)
        {
          work0[i] += x[i];
          work1[i] += y[i];
          work2[i] += x[i] * x[i];
          work3[i] += y[i] * y[i];
          work4[i] += x[i] * y[i];
          nofvals[i]++;
        }
    }
}

static size_t
correlation(size_t gridsize, double xmv, double ymv, const Varray<size_t> &nofvals, Varray<double> &work0, const Varray<double> &work1,
            const Varray<double> &work2, const Varray<double> &work3, const Varray<double> &work4)
{
  size_t nmiss = 0;

  for (size_t i = 0; i < gridsize; ++i)
    {
      const auto missval1 = xmv;
      const auto missval2 = ymv;
      double cor;
      const auto nvals = nofvals[i];
      if (nvals > 0)
        {
          const auto temp0 = MULMN(work0[i], work1[i]);
          const auto temp1 = SUBMN(work4[i], DIVMN(temp0, nvals));
          const auto temp2 = MULMN(work0[i], work0[i]);
          const auto temp3 = MULMN(work1[i], work1[i]);
          const auto temp4 = SUBMN(work2[i], DIVMN(temp2, nvals));
          const auto temp5 = SUBMN(work3[i], DIVMN(temp3, nvals));
          const auto temp6 = MULMN(temp4, temp5);

          cor = DIVMN(temp1, SQRTMN(temp6));
          cor = std::min(std::max(cor, -1.0), 1.0);

          if (DBL_IS_EQUAL(cor, xmv)) nmiss++;
        }
      else
        {
          nmiss++;
          cor = xmv;
        }

      work0[i] = cor;
    }

  return nmiss;
}

// covariance in time
static void
covariance_init(bool hasMissValues, size_t gridsize, const Varray<double> &x, const Varray<double> &y, double xmv, double ymv,
                Varray<size_t> &nofvals, Varray<double> &work0, Varray<double> &work1, Varray<double> &work2)
{
  if (hasMissValues)
    {
#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
      for (size_t i = 0; i < gridsize; ++i)
        {
          if ((!DBL_IS_EQUAL(x[i], xmv)) && (!DBL_IS_EQUAL(y[i], ymv)))
            {
              work0[i] += x[i];
              work1[i] += y[i];
              work2[i] += x[i] * y[i];
              nofvals[i]++;
            }
        }
    }
  else
    {
#ifdef HAVE_OPENMP4
#pragma omp parallel for simd default(shared)
#elif _OPENMP
#pragma omp parallel for default(shared)
#endif
      for (size_t i = 0; i < gridsize; ++i)
        {
          work0[i] += x[i];
          work1[i] += y[i];
          work2[i] += x[i] * y[i];
          nofvals[i]++;
        }
    }
}

static size_t
covariance(size_t gridsize, double xmv, double ymv, const Varray<size_t> &nofvals, Varray<double> &work0, const Varray<double> &work1, const Varray<double> &work2)
{
  size_t nmiss = 0;

  for (size_t i = 0; i < gridsize; ++i)
    {
      const auto missval1 = xmv;
      const auto missval2 = ymv;
      double covar;
      const auto nvals = nofvals[i];
      if (nvals > 0)
        {
          double dnvals = nvals;
          const auto temp = DIVMN(MULMN(work0[i], work1[i]), dnvals * dnvals);
          covar = SUBMN(DIVMN(work2[i], dnvals), temp);
          if (DBL_IS_EQUAL(covar, xmv)) nmiss++;
        }
      else
        {
          nmiss++;
          covar = xmv;
        }

      work0[i] = covar;
    }

  return nmiss;
}


// rms in time
static void
rmsd_init(size_t gridsize, const Varray<double> &x, const Varray<double> &y, double xmv, double ymv, Varray<size_t> &nofvals,
         Varray<double> &rmsd)
{
  for (size_t i = 0; i < gridsize; ++i)
    {
      if ((!DBL_IS_EQUAL(x[i], xmv)) && (!DBL_IS_EQUAL(y[i], ymv)))
        {
          rmsd[i] += ((x[i] - y[i]) * (x[i] - y[i]));
          nofvals[i]++;
        }
    }
}

static size_t
rmsd_compute(size_t gridsize, double missval, const Varray<size_t> &nofvals, Varray<double> &rmsd)
{
  size_t nmiss = 0;

  for (size_t i = 0; i < gridsize; ++i)
    {
      if (nofvals[i] > 0)
        {
          rmsd[i] = std::sqrt(rmsd[i] / (double)nofvals[i]);
        }
      else
        {
          nmiss++;
          rmsd[i] = missval;
        }
    }

  return nmiss;
}

void *
Timstat2(void *process)
{
  int64_t vdate = 0;
  int vtime = 0;

  cdoInitialize(process);

  // clang-format off
  cdoOperatorAdd("timcor",   func_cor,   5, nullptr);
  cdoOperatorAdd("timcovar", func_covar, 3, nullptr);
  cdoOperatorAdd("timrmsd",  func_rmsd,  1, nullptr);
  // clang-format on

  const auto operatorID = cdoOperatorID();
  const auto operfunc = cdoOperatorF1(operatorID);
  const auto nwork = cdoOperatorF2(operatorID);
  const auto timeIsConst = (operfunc == func_rmsd);

  operatorCheckArgc(0);

  const auto streamID1 = cdoOpenRead(0);
  const auto streamID2 = cdoOpenRead(1);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = cdoStreamInqVlist(streamID2);
  const auto vlistID3 = vlistDuplicate(vlistID1);

  vlistCompare(vlistID1, vlistID2, CMP_ALL);

  VarList varList1, varList2;
  varListInit(varList1, vlistID1);
  varListInit(varList2, vlistID2);

  const auto nvars = vlistNvars(vlistID1);
  const auto nrecs1 = vlistNrecs(vlistID1);
  std::vector<int> recVarID(nrecs1), recLevelID(nrecs1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  // const auto taxisID2 = vlistInqTaxis(vlistID2);
  const auto taxisID3 = taxisDuplicate(taxisID1);

  if (timeIsConst)
    for (int varID = 0; varID < nvars; ++varID)
      vlistDefVarTimetype(vlistID3, varID, TIME_CONSTANT);

  vlistDefTaxis(vlistID3, taxisID3);
  const auto streamID3 = cdoOpenWrite(2);
  cdoDefVlist(streamID3, vlistID3);

  const auto gridsizemax = vlistGridsizeMax(vlistID1);
  Varray<double> array1(gridsizemax), array2(gridsizemax);

  Varray4D<double> work(nvars);
  Varray3D<size_t> nofvals(nvars);

  for (int varID = 0; varID < nvars; varID++)
    {
      const auto gridsize = varList1[varID].gridsize;
      const auto nlevs = varList1[varID].nlevels;

      work[varID].resize(nlevs);
      nofvals[varID].resize(nlevs);

      for (int levelID = 0; levelID < nlevs; levelID++)
        {
          nofvals[varID][levelID].resize(gridsize, 0);
          work[varID][levelID].resize(nwork);
          for (int i = 0; i < nwork; i++) work[varID][levelID][i].resize(gridsize, 0.0);
        }
    }

  int tsID = 0;
  while (true)
    {
      const auto nrecs = cdoStreamInqTimestep(streamID1, tsID);
      if (nrecs == 0) break;

      vdate = taxisInqVdate(taxisID1);
      vtime = taxisInqVtime(taxisID1);

      const auto nrecs2 = cdoStreamInqTimestep(streamID2, tsID);
      if (nrecs != nrecs2) cdoWarning("Input streams have different number of records!");

      for (int recID = 0; recID < nrecs; recID++)
        {
          int varID, levelID;
          cdoInqRecord(streamID1, &varID, &levelID);
          cdoInqRecord(streamID2, &varID, &levelID);

          if (tsID == 0)
            {
              recVarID[recID] = varID;
              recLevelID[recID] = levelID;
            }

          const auto gridsize = varList1[varID].gridsize;
          const auto missval1 = varList1[varID].missval;
          const auto missval2 = varList2[varID].missval;

          size_t nmiss1 = 0, nmiss2 = 0;
          cdoReadRecord(streamID1, &array1[0], &nmiss1);
          cdoReadRecord(streamID2, &array2[0], &nmiss1);
          bool hasMissValues = (nmiss1 > 0 || nmiss2 > 0);

          auto &rwork = work[varID][levelID];
          auto &rnofvals = nofvals[varID][levelID];

          if (operfunc == func_cor)
            {
              correlation_init(hasMissValues, gridsize, array1, array2, missval1, missval2, rnofvals,
                               rwork[0], rwork[1], rwork[2], rwork[3], rwork[4]);
            }
          else if (operfunc == func_covar)
            {
              covariance_init(hasMissValues, gridsize, array1, array2, missval1, missval2, rnofvals,
                              rwork[0], rwork[1], rwork[2]);
            }
          else if (operfunc == func_rmsd)
            {
              rmsd_init(gridsize, array1, array2, missval1, missval2, rnofvals, rwork[0]);
            }
        }

      tsID++;
    }

  tsID = 0;
  taxisDefVdate(taxisID3, vdate);
  taxisDefVtime(taxisID3, vtime);
  cdoDefTimestep(streamID3, tsID);

  for (int recID = 0; recID < nrecs1; recID++)
    {
      const auto varID = recVarID[recID];
      const auto levelID = recLevelID[recID];

      const auto gridsize = varList1[varID].gridsize;
      const auto missval1 = varList1[varID].missval;
      const auto missval2 = varList2[varID].missval;

      auto &rwork = work[varID][levelID];
      auto &rnofvals = nofvals[varID][levelID];
      size_t nmiss = 0;

      if (operfunc == func_cor)
        {
          nmiss = correlation(gridsize, missval1, missval2, rnofvals, rwork[0], rwork[1], rwork[2], rwork[3], rwork[4]);
        }
      else if (operfunc == func_covar)
        {
          nmiss = covariance(gridsize, missval1, missval2, rnofvals, rwork[0], rwork[1], rwork[2]);
        }
      else if (operfunc == func_rmsd)
        {
          nmiss = rmsd_compute(gridsize, missval1, rnofvals, rwork[0]);
        }

      cdoDefRecord(streamID3, varID, levelID);
      cdoWriteRecord(streamID3, rwork[0].data(), nmiss);
    }

  cdoStreamClose(streamID3);
  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
