#include "scalablebayesm.h"
#include <RcppArmadillo.h>

using namespace Rcpp;
using namespace arma;

//prototypes
double llmnl(arma::vec const& beta, arma::vec const& y, arma::mat const& X);
arma::vec llmnlParallel(arma::mat const& beta, arma::vec const& y, arma::mat const& X);

//[[Rcpp::export]]
Rcpp::List rheteroMnlIndepMetrop_rcpp_loop(Rcpp::List const& Data, arma::mat const& draws, Rcpp::List const& Mcmc) {
  int R, k, N, keep, mkeep;
  keep = Mcmc["keep"];
  R = draws.n_rows;
  k = draws.n_cols;
  
  Rcpp::List lgtdata = Data["lgtdata"];
  N = lgtdata.length();
  int nunits = N; // Better than calling length() in each iteration
  arma::cube betadraw(nunits, k, floor((R) / keep));
  
  arma::rowvec betainit = arma::mean(draws, 0);
  arma::mat currentdraw(N, k, arma::fill::zeros); // Preallocate memory
  arma::vec logc(N), logp(N);
  arma::mat betac = arma::repmat(betainit, N, 1);
  
  for (int i = 0; i < N; i++) {
    Rcpp::List lgtdatai = lgtdata[i];
    arma::vec y = lgtdatai["y"];
    arma::mat X = lgtdatai["X"];
    logc(i) = llmnl(trans(betainit), y, X);
  }
  
  for (int rep = 0; rep < R; rep++) {
    arma::rowvec betap = draws.row(rep);
    for (int i = 0; i < N; i++) {
      Rcpp::List lgtdatai = lgtdata[i];
      arma::vec y = lgtdatai["y"];
      arma::mat X = lgtdatai["X"];
      logp(i) = llmnl(trans(betap), y, X);
      double ratio = (logp(i) - logc(i) > 0) ? 1 : exp(logp(i) - logc(i));
      double unif = Rcpp::runif(1)[0];
      if (unif < ratio) {
        currentdraw.row(i) = betap;
        betac.row(i) = betap;
        logc(i) = logp(i);
      } else {
        currentdraw.row(i) = betac.row(i);
      }
    }
    if ((rep + 1) > 0 && (rep + 1) % keep == 0) {
      mkeep = (rep + 1) / keep;
      betadraw.slice(mkeep - 1) = currentdraw.rows(0, nunits - 1);
    }
  }
  
  return Rcpp::List::create(Rcpp::Named("betadraw") = betadraw);
}

