# helper: use edgeR to filter low counts genes and keep genes in markers
# use before DE
filterGenes <- function(dge, group_col, filter = c(10, 10), normalize = TRUE,
                        markers = NULL, gene_id = "SYMBOL") {
  if (!is.null(markers)) {
    stopifnot("Please provide a vector of gene symbols!" = is.vector(markers))
    markers <- AnnotationDbi::select(
      org.Hs.eg.db::org.Hs.eg.db,
      markers,
      gene_id, "SYMBOL"
    )
  }

  if (normalize == TRUE) {
    keep <- edgeR::filterByExpr(dge$counts,
      group = dge$samples[[group_col]],
      min.count = filter[1],
      large.n = filter[2]
    ) |
      (rownames(dge) %in% markers[[gene_id]])
  } else {
    keep <- (Matrix::rowSums(dge$counts > filter[1]) > filter[2]) |
      (rownames(dge) %in% markers[[gene_id]])
  }

  return(keep)
}

# helper: make design and apply voom, lmfit and treat
#' return DGEList containing vfit by limma::voom (if normalize = TRUE) and
#' tfit by limma::treat
#'
#' @inheritParams de_analysis
#' @return A DGEList containing vfit and tfit
voom_fit_treat <- function(dge,
                           group_col,
                           target_group,
                           normalize = TRUE,
                           group = FALSE,
                           lfc = 0,
                           p = 0.05,
                           batch = NULL,
                           summary = TRUE,
                           ...) {
  stopifnot(
    "Please provide column names as batch!" =
      is.null(batch) | is.vector(batch)
  )

  stopifnot(
    is.logical(normalize),
    is.logical(summary),
    is.logical(group),
    is.numeric(lfc),
    is.numeric(p)
  )

  ## group samples into binary groups if group = TRUE,
  ## otherwise use as multiple groups asis
  if (group == TRUE) {
    dge$samples$group <- ifelse(grepl(target_group, dge$samples[[group_col]]),
      make.names(target_group),
      "Others"
    )
  } else {
    dge$samples$group <- ifelse(grepl(target_group, dge$samples[[group_col]]),
      make.names(target_group),
      make.names(dge$samples[[group_col]])
    )
  }

  ## make design
  form <- formula(paste(c("~0", "group", batch), collapse = "+"))
  design <- model.matrix(form, dge$samples)
  colnames(design) <- sub("group", "", colnames(design))
  rownames(design) <- colnames(dge)

  ## make contrast matrix
  if ("contrast_mat" %in% names(list(...))) {
    contrast.mat <- list(...)[["contrast_mat"]]
    ## check contrast.mat validity
    stopifnot(
      "contrast.mat must be a matrix!" = is.matrix(contrast.mat),
      "contrast.mat levels/rownames don't match design matrix!" =
        identical(rownames(contrast.mat), colnames(design))
    )
  } else {
    contrast.mat <- limma::makeContrasts(
      contrasts = c(sprintf(
        "%s-%s",
        make.names(target_group),
        make.names(setdiff(dge$samples$group, make.names(target_group)))
      )),
      ## target_group vs all the rest respectively
      ## if group = TRUE, it's target_group vs Others
      levels = design
    )
  }


  ## voom fit if data is raw counts data
  if (normalize == TRUE) {
    vfit <- limma::voom(dge, design = design)
    dge$vfit <- vfit
  } else {
    dge$vfit <- dge$counts
  }

  ## linear regression fit
  fit <- limma::lmFit(dge$vfit, design = design)
  tfit <- limma::treat(limma::contrasts.fit(fit, contrasts = contrast.mat),
    lfc = lfc, trend = !normalize
  )
  ## summarize the total number of DEGs
  if (summary == TRUE) {
    show(summary(limma::decideTests(tfit, lfc = lfc, p.value = p)))
  }

  dge$tfit <- tfit

  return(dge)
}

# helper: return DEGs UP and DOWN list based on Rank Product
#' return DEGs UP and DOWN list based on Rank Product
#'
#' @param tfit MArrayLM object generated by [limma::treat()]
#' @param lfc num, cutoff of logFC for DE analysis
#' @param p num, cutoff of p value for DE analysis
#' @param assemble 'intersect' or 'union', whether to select intersected or
#'                  union genes of different comparisons, default 'intersect'
#' @param Rank character, the variable for ranking DEGs, can be 'logFC',
#'             'adj.P.Val'..., default 'adj.P.Val'
#' @param nperm num, permutation runs of simulating the distribution
#' @param thres num, cutoff for rank product permutation test if
#'              feature_selection = "rankproduct", default 0.05
#' @param keep.top NULL or num, whether to keep top n DEGs of specific
#'                 comparison
#' @param keep.group NULL or pattern, specify the top DEGs of which comparison
#'                   or group to be kept
#' @param ... omitted
#' @return A list of "UP" and "DOWN" genes
DEGs_RP <- function(tfit, lfc = NULL, p = 0.05, assemble = "intersect",
                    Rank = "adj.P.Val", nperm = 1e5, thres = 0.05,
                    keep.top = NULL, keep.group = NULL,
                    ...) {
  if (is.null(lfc)) lfc <- tfit$treat.lfc

  stopifnot(
    is.character(assemble), is.character(Rank),
    is.numeric(lfc), is.numeric(p), is.numeric(nperm)
  )

  DEG <- list()
  UPs <- list()
  DWs <- list()
  DEGs <- list()
  for (i in seq_len(ncol(tfit))) {
    DEG[[i]] <- na.omit(limma::topTreat(tfit,
      coef = i,
      number = Inf,
      sort.by = "none"
    ))
    UPs[[i]] <- subset(DEG[[i]], logFC > lfc & adj.P.Val < p)
    DWs[[i]] <- subset(DEG[[i]], logFC < -lfc & adj.P.Val < p)
    DEG[[i]]$lfc <- DEG[[i]]$logFC
    DEG[[i]]$logFC <- -abs(DEG[[i]]$logFC)
  }

  ## product of rank distribution for UPs
  genes <- Reduce(f = assemble, lapply(UPs, rownames))
  genes <- intersect(genes, Reduce(intersect, lapply(DEG, rownames)))
  if (length(genes) == 0) {
    DEGs[["UP"]] <- genes
  } else {
    # if(!is.null(rand)) set.seed(rand)
    up_dist <- lapply(UPs, function(x) {
      sample.int(length(genes), nperm, replace = TRUE)
    })
    pr_up_dist <- rowSums(log10(do.call(cbind, up_dist)))
    up_pr <- lapply(DEG, function(x) rank(x[genes, Rank])) |>
      do.call(what = cbind) |>
      log10() |>
      rowSums()
    names(up_pr) <- genes
    up_pr <- up_pr[up_pr < quantile(pr_up_dist, thres)] ## keep top rank genes
    DEGs[["UP"]] <- names(sort(up_pr))
  }


  ## PR distribution for DWs
  genes <- Reduce(f = assemble, lapply(DWs, rownames))
  genes <- intersect(genes, Reduce(intersect, lapply(DEG, rownames)))
  if (length(genes) == 0) {
    DEGs[["DOWN"]] <- genes
  } else {
    # if(!is.null(rand)) set.seed(rand)
    dw_dist <- lapply(DWs, function(x) {
      sample.int(length(genes), nperm, replace = TRUE)
    })
    pr_dw_dist <- rowSums(log10(do.call(cbind, dw_dist)))
    dw_pr <- lapply(DEG, function(x) rank(x[genes, Rank])) |>
      do.call(what = cbind) |>
      log10() |>
      rowSums()
    names(dw_pr) <- genes
    dw_pr <- dw_pr[dw_pr < quantile(pr_dw_dist, thres)] ## keep top rank genes
    DEGs[["DOWN"]] <- names(sort(dw_pr))
  }


  ## keep the top DEGs in specified comparison even if they didn't pass RP test
  if (!is.null(keep.top)) {
    if (length(which(grepl(keep.group, colnames(tfit)))) < 1) {
      stop("Please specify at least one valid comparison for keep.group!")
    }
    ## get top n UP DEGs for specified comparison
    tmp <- lapply(which(grepl(keep.group, colnames(tfit))), \(i) {
      tmp <- dplyr::arrange(DEG[[i]], !!sym(Rank))
      tmp <- rownames(tmp)[tmp$lfc > 0][seq_len(keep.top)]
    })
    DEGs[["UP"]] <- Reduce(union, c(list(DEGs[["UP"]]), tmp))
    ## get top n DOWN DEGs for specified comparison
    tmp <- lapply(which(grepl(keep.group, colnames(tfit))), \(i) {
      tmp <- dplyr::arrange(DEG[[i]], !!sym(Rank))
      tmp <- rownames(tmp)[tmp$lfc < 0][seq_len(keep.top)]
    })
    DEGs[["DOWN"]] <- Reduce(union, c(list(DEGs[["DOWN"]]), tmp))
  }

  return(DEGs)
}


# helper: return DEGs UP and DOWN list based on intersection or union
#' return DEGs UP and DOWN list based on intersection or union of comparisons
#'
#' @inheritParams DEGs_RP
#' @return A list of "UP" and "DOWN" genes
DEGs_Group <- function(tfit, lfc = NULL, p = 0.05,
                       assemble = "intersect", Rank = "adj.P.Val",
                       keep.top = NULL, keep.group = NULL,
                       ...) {
  if (is.null(lfc)) lfc <- tfit$treat.lfc

  stopifnot(
    is.character(assemble), is.character(Rank),
    is.numeric(lfc), is.numeric(p)
  )

  ## screen genes with p value < cutoff
  DEG <- list()
  UPs <- list()
  DWs <- list()
  DEGs <- list()
  for (i in seq_len(ncol(tfit))) {
    DEG[[i]] <- na.omit(limma::topTreat(tfit, coef = i, number = Inf))
    UPs[[i]] <- subset(DEG[[i]], logFC > lfc & adj.P.Val < p)
    DWs[[i]] <- subset(DEG[[i]], logFC < -lfc & adj.P.Val < p)
    DEG[[i]]$lfc <- DEG[[i]]$logFC
    DEG[[i]]$logFC <- -abs(DEG[[i]]$logFC)
  }
  DEGs[["UP"]] <- Reduce(f = assemble, lapply(UPs, rownames))
  DEGs[["DOWN"]] <- Reduce(f = assemble, lapply(DWs, rownames))

  o <- do.call(
    function(...) apply(cbind(...), 1, mean, na.rm = TRUE),
    lapply(UPs, function(x, g) rank(x[g, Rank]),
      g = DEGs[["UP"]]
    )
  )
  o <- order(o)
  DEGs[["UP"]] <- DEGs[["UP"]][o]

  o <- do.call(
    function(...) apply(cbind(...), 1, mean, na.rm = TRUE),
    lapply(DWs, function(x, g) rank(x[g, Rank]),
      g = DEGs[["DOWN"]]
    )
  )
  o <- order(o)
  DEGs[["DOWN"]] <- DEGs[["DOWN"]][o]

  ## keep the top DEGs in specified comparison even if they didn't pass RP test
  if (!is.null(keep.top) && assemble == "intersect") {
    if (length(which(grepl(keep.group, colnames(tfit)))) < 1) {
      stop("Please specify at least one valid comparison for keep.group!")
    }
    ## get top n UP DEGs for specified comparison
    tmp <- lapply(which(grepl(keep.group, colnames(tfit))), \(i) {
      tmp <- dplyr::arrange(DEG[[i]], !!sym(Rank))
      tmp <- rownames(tmp)[tmp$lfc > 0][seq_len(keep.top)]
    })
    DEGs[["UP"]] <- Reduce(union, c(list(DEGs[["UP"]]), tmp))
    ## get top n DOWN DEGs for specified comparison
    tmp <- lapply(which(grepl(keep.group, colnames(tfit))), \(i) {
      tmp <- dplyr::arrange(DEG[[i]], !!sym(Rank))
      tmp <- rownames(tmp)[tmp$lfc < 0][seq_len(keep.top)]
    })
    DEGs[["DOWN"]] <- Reduce(union, c(list(DEGs[["DOWN"]]), tmp))
  }

  return(DEGs)
}

utils::globalVariables(c("logFC", "adj.P.Val"))
