#' Count clonal mutations on one or several chromosomal copies
#'
#' This function counts the number of clonal mutations residing on a single or
#' multiple copies per genomic segment. Segments of equal copy number and
#' B-allele count are merged per chromosome.
#' @param nbObj combined SNV and CNV information as generated by
#' \code{\link{nbImport}}.
#' @param min.cn minimal copy number.
#' @param max.cn maximal copy number.
#' @param chromosomes the chromosomes to be evaluated.
#' @return a data table reporting the length of each segment, the number of
#' clonal mutations on all A-allele copies, the number of clonal mutations on
#' all B-allele copies and the total number of clonal mutations (including
#' clonal mutations on a single copy only).
#' @examples
#' snvs <- system.file("extdata", "NBE15",
#'     "snvs_NBE15_somatic_snvs_conf_8_to_10.vcf",
#'     package = "LACHESIS"
#' )
#' s_data <- readVCF(vcf = snvs, vcf.source = "dkfz")
#' aceseq_cn <- system.file("extdata", "NBE15",
#'     "NBE15_comb_pro_extra2.51_1.txt",
#'     package = "LACHESIS"
#' )
#' c_data <- readCNV(aceseq_cn)
#' nb <- nbImport(cnv = c_data, snv = s_data, purity = 1, ploidy = 2.51)
#' cl_muts <- clonalMutationCounter(nb)
#' @import data.table
#' @importFrom stats dbinom dmultinom
#' @export


clonalMutationCounter <- function(nbObj = NULL, min.cn = 1, max.cn = 4,
                                  chromosomes = seq_len(22)) {
    cn_end.y <- . <- cn_start.y <- t_vaf <- chrom <- TCN <- n_mut_total <-
        n_mut_total_subclonal <- A <- B <- n_mut_firstpeak <-
        n_mut_total_clonal <- NULL

    if (is.null(nbObj)) {
        stop("Please provide an nbObj, as generated by nbImport.")
    }

    if (max.cn <= min.cn) {
        stop("max.cn must be larger than min.cn")
    }

    # Initiate the count object for all genotypes present in the data and
    # fulfilling TCN >= min.cn & TCN <= max.cn
    data.table::setDT(x = nbObj, key = c("chrom", "TCN", "A", "B"))
    countObj <- unique(
        nbObj[chrom %in% chromosomes & TCN >= min.cn & TCN <=
            max.cn],
        by = data.table::key(nbObj)
    )
    # merge information from both objects
    countObj <- merge(countObj, unique(nbObj, by = c(
        data.table::key(countObj),
        "cn_start", "cn_end"
    )))
    # sum up the segment lengths, start and end position for each copy number
    # state per chromosome
    countObj <- countObj[, .(
        Seglength = sum(cn_end.y - cn_start.y),
        Start = min(cn_start.y),
        End = max(cn_end.y)
    ), by = eval(data.table::key(countObj))]

    # sum up for each copy number state the number of clonal variants

    splt.countObj <- split(countObj, by = c("chrom", "TCN", "A", "B"))

    splt.countObj <- lapply(splt.countObj, function(splt) {
        # A, B, TCN
        A <- as.numeric(as.character(splt[, "A"]))
        B <- as.numeric(as.character(splt[, "B"]))
        TCN <- as.numeric(as.character(splt[, "TCN"]))

        # Expected VAFs at lower and higher-order clonal peaks:
        clonal.vafs <- unlist(vapply(c(1, B, A), function(x) {
            purity <- attr(nbObj, "purity")
            x * purity / (purity * (A + B) + (1 - purity) * 2)
        }, numeric(1)))
        clonal.vafs <- unique(sort(clonal.vafs[clonal.vafs > 0]))

        # order of the clonal peaks
        clone.order <- sort(unique(c(1, B, A)))
        clone.order <- clone.order[clone.order > 0]

        # In order to avoid overestimation of the clonal peak due to subclonal
        # SVNs, we quantify the first-order clonal peak on its upper half only
        measured.muts <- nbObj[splt][t_vaf >= min(clonal.vafs)]
        excluded.muts <- nbObj[splt][t_vaf < min(clonal.vafs)]
        all.muts <- nbObj[splt]

        if (nrow(measured.muts) == 0) {
            tmp1 <- splt[, "chrom"]
            warning(sprintf(
                "No clonal VAFs for TCN = %s, A = %s on chromosome %s",
                TCN, A, tmp1
            ))
            rm(tmp1)

            splt$n_mut_A <- 0
            splt$n_mut_B <- 0
            splt$n_mut_total_clonal <- 0
            splt$n_mut_total_subclonal <- nrow(excluded.muts)
            splt$n_mut_total <- nrow(all.muts)

            return(splt)
        }

        # If there's only one mutation, assign it to its most likely state by
        # comparing its VAF with the clonal frequencies
        if (nrow(measured.muts) == 1) {
            which.order <- clone.order[which.min(measured.muts[, (clonal.vafs -
                t_vaf)^2])]
            if (which.order == 1) {
                n_mut <- 2 # first-order peak is quantified on the upper half
                # only, thus multiply by 2
                n_mut_subclonal <- nrow(excluded.muts) - 1
                n_mut_firstpeak <- n_mut
            } else {
                n_mut <- 1
                n_mut_subclonal <- nrow(excluded.muts)
                n_mut_firstpeak <- 0
            }
            if (which.order == A) {
                splt$n_mut_A <- n_mut
                splt$n_mut_B <- 0
            } else if (which.order == B) {
                splt$n_mut_A <- 0
                splt$n_mut_B <- n_mut
            } else {
                splt$n_mut_A <- 0
                splt$n_mut_B <- 0
            }
            splt$n_mut_total_clonal <- n_mut
            splt$n_mut_total_subclonal <- n_mut_subclonal
            splt$n_mut_total <- nrow(all.muts)
            splt$n_mut_firstpeak <- n_mut_firstpeak

            return(splt)
        }

        # On monosomic and heterozygous disomic regions, there is only one
        # clonal peak. Thus  assign all mutations to that peak
        if (TCN %in% c(1, 2) & A == 1) {
            n_mut <- nrow(measured.muts) * 2 # first-order peak is quantified
            # on the upper half only, thus
            # multiply by 2
            n_mut_firstpeak <- n_mut
            splt$n_mut_A <- n_mut / 2 # distribute mutations equally between A
            # and B allele
            if (B == 0) {
                splt$n_mut_B <- 0
            } else {
                splt$n_mut_B <- n_mut / 2
            }
            splt$n_mut_total_clonal <- n_mut
            splt$n_mut_total_subclonal <- nrow(excluded.muts) - n_mut / 2
            splt$n_mut_total <- nrow(all.muts)
            splt$n_mut_firstpeak <- n_mut_firstpeak

            return(splt)
        }

        # For the remaining cases, estimate the sizes of the clonal peaks using
        # a binomial mixture model.
        rel.clone.size <- .peak_estimate(
            measured.muts = measured.muts,
            clonal.vafs = clonal.vafs
        )

        if (A == B) { # distribute mutations equally to both alleles
            n_mut_A <- nrow(measured.muts) * rel.clone.size[which(clone.order
            == A)] / 2
            n_mut_B <- n_mut_A
            n_mut_total_clonal <- n_mut_A + n_mut_B + nrow(measured.muts) *
                rel.clone.size[which(clone.order == 1)] * 2
        } else if (B == 1) {
            n_mut_A <- nrow(measured.muts) * rel.clone.size[which(clone.order
            == A)]
            # first-order peak is quantified on the upper half only, thus
            # multiply by 2
            n_mut_B <- nrow(measured.muts) * rel.clone.size[which(clone.order
            == B)] * 2
            n_mut_total_clonal <- n_mut_A + n_mut_B
        } else if (B == 0) {
            n_mut_A <- nrow(measured.muts) * rel.clone.size[which(clone.order
            == A)]
            n_mut_B <- 0
            n_mut_total_clonal <- n_mut_A + n_mut_B + nrow(measured.muts) *
                rel.clone.size[which(clone.order == 1)] * 2
        } else {
            n_mut_A <- nrow(measured.muts) * rel.clone.size[which(clone.order
            == A)]
            n_mut_B <- nrow(measured.muts) * rel.clone.size[which(clone.order
            == B)]
            n_mut_total_clonal <- n_mut_A + n_mut_B + nrow(measured.muts) *
                rel.clone.size[which(clone.order == 1)] * 2
        }
        n_mut_subclonal <- nrow(excluded.muts) - nrow(measured.muts) *
            rel.clone.size[which(clone.order == 1)]
        n_mut_firstpeak <- nrow(measured.muts) * rel.clone.size[which(
            clone.order == 1
        )] * 2

        splt$n_mut_A <- n_mut_A
        splt$n_mut_B <- n_mut_B
        splt$n_mut_total_clonal <- n_mut_total_clonal
        splt$n_mut_total_subclonal <- n_mut_subclonal
        splt$n_mut_total <- nrow(all.muts)
        splt$n_mut_firstpeak <- n_mut_firstpeak

        return(splt)
    })

    splt.countObj <- data.table::rbindlist(
        l = splt.countObj, use.names = TRUE,
        fill = TRUE
    )

    splt.countObj[, `:=`(
        p_sc = ifelse(n_mut_total > 0, n_mut_total_subclonal / n_mut_total,
            NA_real_
        ),
        p_lc = ifelse(A != 1 & B != 1 & n_mut_total > 0,
            n_mut_firstpeak / n_mut_total, 0
        ),
        p_ec = ifelse((TCN > 2 | (TCN == 2 & A != B)) & n_mut_total > 0,
            (n_mut_total_clonal - n_mut_firstpeak) / n_mut_total, 0
        ),
        p_c = ifelse((A == 1 | B == 1) & n_mut_total > 0,
            n_mut_firstpeak / n_mut_total, 0
        )
    )]

    attr(splt.countObj, "purity") <- attr(nbObj, "purity")
    attr(splt.countObj, "ploidy") <- attr(nbObj, "ploidy")
    attr(splt.countObj, "ID") <- attr(nbObj, "ID")

    return(splt.countObj)
}


.peak_estimate <- function(measured.muts = NULL, clonal.vafs = NULL) {
    # Define the posterior probability of a mixing model with mixing factor p.
    # Scan p over 0 - 1 and select the most likely one.

    if (length(clonal.vafs) == 2) {
        p.priors <- seq(0, 1, 0.01)

        posteriors <- vapply(p.priors, function(p) {
            sum(apply(measured.muts, 1, function(x) {
                L <- dbinom(
                    x = as.numeric(as.character(x["t_alt_count"])),
                    size = as.numeric(as.character(x["t_depth"])),
                    prob = clonal.vafs
                )
                P <- L / sum(L)
                log(sum(c(p, 1 - p) * P))
            }))
        }, numeric(1))

        p.clones <- c(p.priors[which.max(posteriors)], 1 - p.priors[
            which.max(posteriors)
        ])
    } else if (length(clonal.vafs) == 3) {
        p.priors <- tidyr::crossing(seq(0, 1, 0.01), seq(0, 1, 0.01))
        # ensure probabilities add up to 1:
        p.priors <- p.priors[rowSums(p.priors) <= 1, ]

        posteriors <- apply(p.priors, 1, function(p) {
            sum(apply(measured.muts, 1, function(x) {
                L <- dbinom(
                    x = as.numeric(as.character(x["t_alt_count"])),
                    size = as.numeric(as.character(x["t_depth"])),
                    prob = clonal.vafs
                )
                P <- L / sum(L)
                log(sum(c(p[1], p[2], (1 - sum(p))) * P))
            }))
        })

        p.clones <- c(p.priors[which.max(posteriors), ], 1 -
            sum(p.priors[which.max(posteriors), ]))
    }

    return(unlist(p.clones))
}
