#include "RcppArmadillo.h"
#include <algorithm>
#include <cmath>

static inline double trap(double x, double a, double b){
  if(a == b){
    if(x > a){
      return 1;
    }
    return 0;
  }
  double result = std::clamp((x-a) / (b-a), 0.0, 1.0);
  return result;
}


//' make_swc
//'
//' Creates soil water content index from input precipitation,
//' potential evapotranspiration, field capacity and wilting point.
//'
//' This function is called internally by \code{get_input()}.
//'
//' @param pet matrix with ntimesteps rows, nsites columns, containing potential evapotranspiration [kg*m^(-2)*s^(-1)]
//' @param rain matrix with ntimesteps rows, nsites columns, containing precipitation [kg*m^(-2)*s^(-1)]
//' @param wp vector of length nsites containing the wilting point [cm^3/cm^3]
//' @param fc vector of length nsites containing the field capacity [cm^3/cm^3]
//' @param seconds number of seconds in a time step
//' @param iterations number of times to run through the data, default is 3
//' @return matrix with ntimesteps rows, nsites columns, containing a soil water content index scaled from 0-100 
//' @export
// [[Rcpp::export]]
arma::mat make_swc(arma::mat rain, arma::mat pet, Rcpp::NumericVector fc, Rcpp::NumericVector wp, int seconds, int iterations) {
  if(rain.n_cols != pet.n_cols ||
     rain.n_rows != pet.n_rows){
    Rcpp::stop("make_swc: rain and pet matrices do not have equal dimensions");
  }
  if(rain.n_cols != fc.length()){
    Rcpp::stop("make_swc: Field capacity site count does not line up with rain site count");
  }
  if(rain.n_cols != wp.length()){
    Rcpp::stop("make_swc: Wilting point site count does not line up with rain site count");
  }
  size_t ntimesteps = rain.n_rows;
  size_t nsites = rain.n_cols;
  size_t niterations = static_cast<size_t>(iterations);
  arma::mat swc = arma::mat(ntimesteps, nsites);
  for(size_t s = 0; s < nsites; s++){
    swc(ntimesteps-1, s) = fc[s]*1000;
  }
  for(size_t it = 0; it < niterations; it++){
    for(size_t t = 0; t < ntimesteps; t++){
      for(size_t site = 0; site < nsites; site++){
        double site_fc = fc[site]*1000;
        double site_wp = wp[site]*1000;
        size_t prev = (t == 0) ? ntimesteps-1 : t-1;
        double prev_swc = swc(prev, site);
        double Kswc = trap(prev_swc, site_wp, site_fc);
        double curr_swc = prev_swc + rain(t, site) * seconds - Kswc * pet(t, site) * seconds;
        curr_swc = std::clamp(curr_swc, site_wp, site_fc);
        swc(t, site) = curr_swc;
      }
    }
  }
  for(size_t t = 0; t < ntimesteps; t++){
    for(size_t site = 0; site < nsites; site++){
      double site_fc = fc[site]*1000;
      double site_wp = wp[site]*1000;
      swc(t, site) = std::round(trap(swc(t,site), site_wp, site_fc) * 100);
    }
  }
  return swc;
}
