// #include <algorithm>
#include <vector>
#include <array>
#include <Rcpp.h>
#include "ramr.h"

// [[Rcpp::plugins(cpp20)]]
// [[Rcpp::plugins(openmp)]]

// This function prepares input data for further processing:
//   1) makes a copy of raw methylation values ('raw')
//   2) transposes raw values dropping NaNs (to 'out'),
//      optionally transforms to include {0;1} into (0,1),
//      counts 0s, 1s and other valid values ('coef', 'len')
//   3) arranges other vectors used in computations later.
// NB: <input.ranges> for rcpp_prepare_data must be sorted
//
// TODO:
//   [ ] more efficient access to S4Vectors with raw values
//   [x] OpenMP
//   [ ] cache-friendly (845 samples seriously suck on Mac)
//   [ ] ...

template<int transform>
Rcpp::List rcpp_prepare_data (Rcpp::IntegerVector &seqnames,                    // IntegerVector (factor) output of S4Vectors::runValue(Seqinfo::seqnames(<input.ranges>))
                              Rcpp::IntegerVector &seqrunlens,                  // IntegerVector output of S4Vectors::runLength(Seqinfo::seqnames(<input.ranges>))
                              Rcpp::IntegerVector &start,                       // IntegerVector output of BiocGenerics::start(<input.ranges>)
                              Rcpp::IntegerVector &strand,                      // IntegerVector (factor) output of S4Vectors::as.factor(BiocGenerics::strand(<input.ranges>))
                              Rcpp::DataFrame &mcols,                           // DataFrame output of as.data.frame(GenomicRanges::mcols(<input.ranges>), optional=TRUE)
                              Rcpp::DataFrame &coverage,                        // optional DataFrame with coverage data for binomial modelling of extremes {0;1}
                              double exclude_lower,                             // lower bound of range to exclude
                              double exclude_upper,                             // upper bound of range to exclude
                              Rcpp::IntegerVector &chunks)                      // start/end rows of chunks for parallel processing
{
  // consts
  const size_t ncol = mcols.ncol();                                             // number of columns (samples)
  const size_t nrow = mcols.nrow();                                             // number of rows (genomic loci)

  // containers
  T_int* chr = new T_int;                                                       // chromosomes
  T_int* pos = new T_int(start.begin(), start.end());                           // genomic positions
  T_int* str = new T_int(strand.begin(), strand.end());                         // genomic strands
  T_dbl* raw = new T_dbl;                                                       // flat vector with raw values from &mcols
  T_int* cov = new T_int;                                                       // optional flat vector with coverage values from &coverage
  T_dbl* out = new T_dbl;                                                       // vector to hold intermediate output values (e.g., transposed)
  T_int* len = new T_int;                                                       // lengths of &mcols rows minus number of NaNs
  T_dbl* coef = new T_dbl;                                                      // vector to hold per-row results (e.g., median, Q1, Q3, parameters of fitted distribution)
  T_int* thr = new T_int(chunks.begin(), chunks.end());                         // chunks of rows for multiple threads

  // fill 'chr' vector with seqname ids
  chr->reserve(nrow);                                                           // reserve space as required
  for (size_t i=0; i<(size_t)seqnames.size(); i++)
    chr->resize(chr->size()+seqrunlens[i], seqnames[i]);
  chr->shrink_to_fit();

  // fill 'raw' with values from &mcols
  raw->reserve(ncol*nrow);                                                      // reserve space as required
  for (size_t c=0; c<ncol; c++)
    raw->insert(raw->end(), ((Rcpp::NumericVector)mcols[c]).begin(), ((Rcpp::NumericVector)mcols[c]).end());
  raw->shrink_to_fit();

  // if given: fill 'cov' with values from &coverage
  if ((size_t)coverage.size()==ncol && (size_t)coverage.nrows()==nrow) {
    cov->reserve(ncol*nrow);                                                    // reserve space as required
    for (size_t c=0; c<ncol; c++)
      cov->insert(cov->end(), ((Rcpp::IntegerVector)coverage[c]).begin(), ((Rcpp::IntegerVector)coverage[c]).end());
    cov->shrink_to_fit();
  }

  // initialize 'len', 'coef', and 'out'
  len->resize(nrow);                                                            // init with 0
  coef->resize(nrow*NCOEF);                                                     // nrow times NCOEF to store them all continuously (all 0)
  out->resize(ncol*nrow, NA_REAL);                                              // init with NA_REAL - see if it breaks anything further. NB: default might be marginally faster

  // fast direct accessors
  const auto raw_data = raw->data();
  const auto out_data = out->data();
  const auto len_data = len->data();
  const auto coef_data = coef->data();

  // linear transformation as described in https://pubmed.ncbi.nlm.nih.gov/16594767/
  // squeezes {0;1} extremes within (0,1) bounds of beta distribution
  const double a = ((double)ncol - 1) / ncol;                                   // coefficient for linear transformation
  const double b = 0.5 / ncol;                                                  // coefficient for linear transformation

  // number of chunks/threads
  const size_t nthreads = thr->size() - 1;                                      // 'thr' always starts with 0 and ends with 'nrow'

#pragma omp parallel num_threads(nthreads)
{
  const size_t thr_num = omp_get_thread_num();                                  // thread ID
  const size_t row_from = thr->at(thr_num);                                     // start of row chunk
  const size_t row_to = thr->at(thr_num+1);                                     // end of row chunk

  // transpose 'raw' to 'out', counting 0/1, skipping NaNs; adjust 'len'
  // should be more computationally efficient and parallelizable
  // have to rewrite this to become cache-friendly, 845 samples seriously suck on Mac
  double *buf  = (double*) malloc(ncol * sizeof(double));                       // buffer to gather values from each column (mcols[r,])
  for (size_t r=row_from; r<row_to; r++) {
    const auto q = coef_data + r*NCOEF;                                         // pointer to coef NCOEF-element array
    size_t l = 0;                                                               // number of elements actually copied
    for (size_t c=0; c<ncol; c++) {                                             // column by column
      const auto raw_value = raw_data[r+nrow*c];                                // value to compare/transpose
      if (!std::isnan(raw_value)) {                                             // if is not NaN
        if (transform==0) {                                                     // use raw values == do not transform
          q[0] += isZero(raw_value);                                            // is it a 0?
          q[1] += isOne(raw_value);                                             // is it a 1?
          buf[l++] = raw_value;                                                 // gather it in the buffer; increase its length
        } else if (transform==1) {                                              // do a linear transformation of values
          buf[l++] = raw_value * a + b;                                         // transform, gather it in the buffer; increase its length
        }
      }
    }

    // median
    if (l>0) {                                                                  // if there are values in the buffer
      const size_t hl = l/2;                                                    // half length
      std::nth_element(buf, buf+hl, buf+l);                                     // order up to l/2-th
      q[2] = buf[hl];                                                           // median for odd l
      if ((l&1)==0) {                                                           // if l is even
        std::nth_element(buf, buf+hl-1, buf+hl);                                // order up to l/2-1-th
        q[2] = (q[2] + buf[hl-1])/2;                                            // median for even l
      }
      if (q[2]<exclude_lower || q[2]>exclude_upper){                            // if median is less that exclude_lower or greater than exclude_upper
        std::copy(buf, buf+l, out_data+ncol*r);                                 // copy 'buf' to 'out'
        len_data[r] = l;                                                        // adjust observed length, because otherwise it's 0 and we won't use this row in further analyses
      }
    }
  }
  free(buf);
}

  // wrap and return the results
  Rcpp::List res = Rcpp::List::create(                                          // final List
    Rcpp::Named("ncol") = ncol,                                                 // number of columns (samples)
    Rcpp::Named("nrow") = nrow,                                                 // number of rows (genomic loci)
    Rcpp::Named("seqnames") = seqnames,                                         // integer IDs of seqnames
    Rcpp::Named("seqrunlens") = seqrunlens,                                     // running lengths of seqnames
    Rcpp::Named("samples") = mcols.names()                                      // sample names
  );
  res.attr("strandlevels") = strand.attr("levels");                             // strand levels

  // pointers to containers
  Rcpp::XPtr<T_int> chr_xptr(chr, true);
  Rcpp::XPtr<T_int> pos_xptr(pos, true);
  Rcpp::XPtr<T_int> str_xptr(str, true);
  Rcpp::XPtr<T_dbl> raw_xptr(raw, true);
  Rcpp::XPtr<T_int> cov_xptr(cov, true);
  Rcpp::XPtr<T_dbl> out_xptr(out, true);
  Rcpp::XPtr<T_int> len_xptr(len, true);
  Rcpp::XPtr<T_dbl> coef_xptr(coef, true);
  Rcpp::XPtr<T_int> thr_xptr(thr, true);
  res.attr("chr_xptr") = chr_xptr;
  res.attr("pos_xptr") = pos_xptr;
  res.attr("str_xptr") = str_xptr;
  res.attr("raw_xptr") = raw_xptr;
  res.attr("cov_xptr") = cov_xptr;
  res.attr("out_xptr") = out_xptr;
  res.attr("len_xptr") = len_xptr;
  res.attr("coef_xptr") = coef_xptr;
  res.attr("thr_xptr") = thr_xptr;

  return(res);
}

// [[Rcpp::export]]
Rcpp::List rcpp_prepare_data_identity (
    Rcpp::IntegerVector &seqnames, Rcpp::IntegerVector &seqrunlens, Rcpp::IntegerVector &start, Rcpp::IntegerVector &strand,
    Rcpp::DataFrame &mcols, Rcpp::DataFrame &coverage, double exclude_lower, double exclude_upper, Rcpp::IntegerVector &chunks)
{
  return rcpp_prepare_data<0>(seqnames, seqrunlens, start, strand, mcols, coverage, exclude_lower, exclude_upper, chunks);
}

// [[Rcpp::export]]
Rcpp::List rcpp_prepare_data_linear (
    Rcpp::IntegerVector &seqnames, Rcpp::IntegerVector &seqrunlens, Rcpp::IntegerVector &start, Rcpp::IntegerVector &strand,
    Rcpp::DataFrame &mcols, Rcpp::DataFrame &coverage, double exclude_lower, double exclude_upper, Rcpp::IntegerVector &chunks)
{
  return rcpp_prepare_data<1>(seqnames, seqrunlens, start, strand, mcols, coverage, exclude_lower, exclude_upper, chunks);
}


// #############################################################################
// test code and sourcing don't work on OS X
/*** R

*/
// #############################################################################

