#' TSA Well Curves Plot
#'
#' Generates the individual curves for each well in the merged tsa data input.
#'     Options to create an average and standard deviation sd of the plot
#'     in addition to the individual curves. The average and sd will be
#'     smoothened by linear regression; see \code{\link{TSA_average}} for
#'     details.
#'
#' @importFrom magrittr %>%
#' @importFrom dplyr group_by
#'
#' @inheritParams TSA_average
#' @param show_Tm logical; \code{show_Tm = TRUE} by default. When TRUE, the Tm
#'     is displayed on the plot. When FALSE, the Tm is not added to the plot.
#' @param Tm_label_nudge numeric; \code{Tm_label_nudge = 7.5} the direction
#'     in the x direction to move the Tm label. This is used prevent the label
#'     from covering data. Ignored if \code{show_Tm = FALSE}.
#' @param show_average logical; \code{show_average = TRUE} by default.
#'     When TRUE, the average is and sd is plotted as
#'     generated by \code{\link{merge_TSA}}.
#' @param smooth logical; \code{smooth = TRUE} by default. When TRUE, linear
#'     regression by \code{\link[mgcv]{gam}} is used to make clean lines on
#'     the plot. See \code{\link{TSA_average}} for more details. When FALSE,
#'     individual points are plotted (slows down rendering).
#' @param plot_title,plot_subtitle characer string, NA by default.
#'     User-specified plots to overright automatic naming.
#' @param separate_legend logical; \code{separate_legend = TRUE} by default.
#'     When TRUE, the ggplot2 legend is separated from the TSA curve.
#'     This is to help with readability. One ggplot is returned when FALSE.
#' @param smoother character; one of \code{c("gam","beta","none")}.
#'     Passed to \code{\link{TSA_average}} to select the aggregate smoother:
#'     \code{"gam"} uses \pkg{mgcv}, \code{"beta"} uses a natural cubic spline
#'     with Beta(a,a) interior knots centered at Tm, and \code{"none"} uses the
#'     unsmoothed average. Default follows \code{TSA_average}.
#' @param beta_shape numeric; shape parameter \eqn{a} for the Beta(a,a) knot
#'     placement when \code{smoother = "beta"}. \code{beta_shape = 3} by default. Passed to \code{TSA_average}.
#' @param beta_n_knots integer or \code{NULL}; number of interior knots when
#'     \code{smoother = "beta"}. If \code{NULL}, uses \code{beta_knots_frac}.
#'     Passed to \code{TSA_average}.
#' @param beta_knots_frac numeric in (0,1); fraction of unique temperatures used
#'     as interior knots when \code{smoother = "beta"} and \code{beta_n_knots}
#'     is \code{NULL}. \code{beta_knots_frac = 0.008} by default. Passed to \code{TSA_average}.
#' @param use_natural logical; if TRUE (default) uses natural cubic spline
#'     basis for the beta method. Passed to \code{TSA_average}.
#'
#' @return By default, two ggplots are returned: one TSA curve and one legend
#'     key (when \code{separate_legend = TRUE}). When
#'     \code{separate_legend = FALSE}, a single ggplot is returned.
#'
#' @family TSA Plots
#' @seealso \code{\link{merge_TSA}} and \code{\link{normalize_fluorescence}}
#'     for preparing data. See \code{\link{TSA_average}} and
#'     \code{\link{get_legend}} for details on function parameters.
#' @examples
#' data("example_tsar_data")
#' check <- subset(example_tsar_data, condition_ID == "CA FL_PyxINE HCl")
#' TSA_wells_plot(check, y = "Fluorescence", smooth = TRUE, separate_legend = TRUE,
#'                smoother = "beta", beta_shape = 4, beta_knots_frac = 0.008)
#'
#' @export

TSA_wells_plot <- function(
    tsa_data,
    y = "RFU",
  show_Tm = TRUE,
    Tm_label_nudge = 7.5,
  show_average = TRUE,
    plot_title = NA,
    plot_subtitle = NA,
  smooth = TRUE,                 # lines for wells + smooth aggregate if TRUE
  separate_legend = TRUE,
    # NEW: second-stage smoother for the aggregate curve/ribbon
    smoother = c("gam", "beta", "none"),
    beta_shape = 3,
    beta_n_knots = NULL,
    beta_knots_frac = 0.12,
    use_natural = TRUE
) {
  smoother <- match.arg(smoother)
  y <- match.arg(y, choices = c("Fluorescence", "RFU"))
  # sanitize logical flags (accept logical or "TRUE"/"FALSE" strings)
  .to_bool <- function(x) { if (is.logical(x)) x else (tolower(as.character(x)) == "true") }
  show_Tm_flag      <- .to_bool(show_Tm)
  show_average_flag <- .to_bool(show_average)
  smooth_flag       <- .to_bool(smooth)
  sep_flag          <- .to_bool(separate_legend)
  Tm_label_nudge_num <- suppressWarnings(as.numeric(Tm_label_nudge))
  if (!is.finite(Tm_label_nudge_num)) Tm_label_nudge_num <- 0
  
  if (!"well_ID" %in% names(tsa_data) || !"condition_ID" %in% names(tsa_data)) {
    stop("tsa_data must be a data frame merged by merge_TSA() or merge_norm()")
  }
  
  # Base plot per-well: lines if smooth, otherwise points
  if (y == "Fluorescence") {
    TSA_curve <- ggplot(tsa_data, aes(x = Temperature, y = Fluorescence))
    tm_height <- max(tsa_data$Fluorescence, na.rm = TRUE) / 4
  } else { # RFU
    if (!"RFU" %in% names(tsa_data)) {
      stop("RFU column not found. Run normalize_fluorescence() or choose y='Fluorescence'.")
    }
    TSA_curve <- ggplot(tsa_data, aes(x = Temperature, y = RFU))
    tm_height <- 0.4
  }
  
  if (isTRUE(smooth_flag)) {
    TSA_curve <- TSA_curve + geom_line(aes(color = well_ID), alpha = 0.95)
  } else {
    TSA_curve <- TSA_curve + geom_point(aes(color = well_ID), alpha = 0.5)
  }
  
  TSA_curve <- TSA_curve + theme_bw()
  
  # Optional Tm line/label
  if (isTRUE(show_Tm_flag)) {
    # protect against vector of Tm; pick the first (single condition expected)
    avg_tm <- suppressWarnings(as.numeric(TSA_Tms(tsa_data)$Avg_Tm[1]))
    if (is.finite(avg_tm)) {
      TSA_curve <- TSA_curve +
        geom_vline(xintercept = avg_tm, linetype = "dashed", color = "#BC9595") +
        annotate(
          "label",
          x = avg_tm + Tm_label_nudge_num,
          y = tm_height,
          label = paste0("Tm=", round(avg_tm, 2), "C")
        )
    }
  }
  
  # Titles
  if (is.na(plot_title)) {
    if (TSA_proteins(tsa_data = tsa_data, n = TRUE) == 1) {
      title <- paste0("Thermal Profile of ", TSA_proteins(tsa_data = tsa_data))
    } else {
      title <- "Thermal Profile"
    }
  } else {
    title <- plot_title
  }
  
  subtitle <- if (is.na(plot_subtitle)) {
    paste(c("With:", TSA_ligands(tsa_data = tsa_data)), collapse = " ")
  } else {
    plot_subtitle
  }
  
  TSA_curve <- TSA_curve + labs(title = title, subtitle = subtitle)
  
  # Legend handling
  if (isTRUE(sep_flag)) {
    legend_plot <- get_legend(TSA_curve)
    TSA_curve <- TSA_curve + theme(legend.position = "none")
    TSA_return <- list(TSA_curve, legend_plot)
  } else {
    TSA_return <- TSA_curve
  }
  
  # Aggregate average curve + ribbon
  if (isTRUE(show_average_flag)) {
    # We use 'smooth' to decide whether to create smoothed columns in TSA_average,
    # and 'smoother' to choose the method (none/gam/beta) for those columns.
    tsa_average_df <- TSA_average(
      tsa_data = tsa_data,
      y = y,
      avg_smooth = smooth,
      sd_smooth  = smooth,
      smoother   = smoother,
      beta_shape = beta_shape,
      beta_n_knots = beta_n_knots,
      beta_knots_frac = beta_knots_frac,
      use_natural = use_natural
    )
    
    # Safe columns for plotting (works for any smoother selection)
  tsa_average_df$ymin_plot <- if (smooth_flag && "sd_min_smooth" %in% names(tsa_average_df))
      tsa_average_df$sd_min_smooth else tsa_average_df$sd_min
  tsa_average_df$ymax_plot <- if (smooth_flag && "sd_max_smooth" %in% names(tsa_average_df))
      tsa_average_df$sd_max_smooth else tsa_average_df$sd_max
  tsa_average_df$y_avg_plot <- if (smooth_flag && "avg_smooth" %in% names(tsa_average_df))
      tsa_average_df$avg_smooth else tsa_average_df$average

  # choose correct base plot depending on class (ggplot is also a list)
  base_plot <- if (inherits(TSA_return, "ggplot")) TSA_return else TSA_return[[1]]

    base_plot <- base_plot +
      geom_ribbon(
        inherit.aes = FALSE,
        data = tsa_average_df,
        aes(x = Temperature, ymin = ymin_plot, ymax = ymax_plot),
        alpha = 0.4
      ) +
      geom_line(
        inherit.aes = FALSE,
        linetype = "dotdash",
        data = tsa_average_df,
        aes(x = Temperature, y = y_avg_plot)
      )
    
    # Put back into return container
    if (isTRUE(separate_legend)) {
      TSA_return[[1]] <- base_plot
    } else {
      TSA_return <- base_plot
    }
  }
  
  return(TSA_return)
}

# small helper for null-coalescing (if not already defined somewhere in your codebase)
`%||%` <- function(a, b) if (!is.null(a)) a else b


#' TSA Box Plot
#'
#' Generates a box and whiskers plot for each condition specified. This
#'     is used to compare Tm values between the data set.
#'     See \code{\link{Tm_difference}} for details.
#'
#' @inheritParams TSA_average
#' @inheritParams TSA_wells_plot
#' @param control_condition Either a condition_ID or NA; NA by default.
#'     When a valid Condition ID is provided, a vertical line appears at the
#'     average Tm for the specified condition. When NA, this is skipped.
#' @param color_by character string, either c("Ligand", "Protein").
#'     The condition category to color the boxes within the box
#'     plot for comparison. This is represented in the legend.
#'     Set to NA to skip.
#' @param label_by character string, either c("Ligand", "Protein").
#'     The condition category to group the boxes within the box
#'     plot. This is represented in the axis. Set to NA to skip.
#' @return by default, two ggplots are returned: one TSA curve and one key.
#'     When \code{separate_legend = FALSE} one ggplot is returned.
#' @family TSA Plots
#' @seealso \code{\link{merge_TSA}}
#'     for preparing data. See \code{\link{Tm_difference}} and
#'     \code{\link{get_legend}} for details on function parameters.
#' @examples
#' data("example_tsar_data")
#' TSA_boxplot(example_tsar_data,
#'     color_by = "Protein",
#'     label_by = "Ligand", separate_legend = FALSE
#' )
#' @export

TSA_boxplot <- function(
    tsa_data,
    control_condition = NA, # Either a condition_ID or NA
    color_by = "Protein", # To skip, set as NA.
    label_by = "Ligand", # If not "Ligand" or "Protein", default order is used
    separate_legend = TRUE # Logical
    ) {

    color_by <- match.arg(color_by, choices = c("Protein", "Ligand", "NA"))
    label_by <- match.arg(label_by, choices = c("Protein", "Ligand", "NA"))

    plot_data <- TSA_Tms(
        analysis_data = tsa_data,
        condition_average = FALSE
    )
    plot_data <- plot_data[!is.na(plot_data$Tm), ]

    # If changing condition is not specified, calculate which varies more
    if (!color_by %in% c("Ligand", "Protein")) {
        n_proteins <- TSA_proteins(plot_data, n = TRUE)
        n_ligands <- TSA_ligands(plot_data, n = TRUE)
        if (n_ligands > n_proteins) {
            color_by <- "Ligand"
        } else {
            color_by <- "Protein"
        }
    }

    if (color_by == "Ligand") {
        tsa_plot <- ggplot(
            plot_data,
            aes(
                x = condition_ID,
                y = Tm,
                color = Ligand,
                label = well_ID
            )
        ) +
            geom_boxplot(alpha = 0.25) +
            geom_point(shape = 1) +
            scale_color_discrete(unique(tsa_data$Ligand),
                name = "Ligand"
            )
    }
    if (color_by == "Protein") {
        tsa_plot <- ggplot(
            plot_data,
            aes(
                x = condition_ID,
                y = Tm,
                color = Protein,
                label = well_ID
            )
        ) +
            geom_boxplot(alpha = 0.25) +
            geom_point(shape = 1) +
            scale_color_discrete(unique(tsa_data$Protein),
                name = "Protein"
            )
    }

    if (label_by == "Ligand") {
        tsa_plot <- tsa_plot +
            scale_x_discrete(
                breaks = plot_data$condition_ID,
                labels = plot_data$Ligand,
                name = label_by
            )
    }
    if (label_by == "Protein") {
        tsa_plot <- tsa_plot +
            scale_x_discrete(
                breaks = plot_data$condition_ID,
                labels = plot_data$Protein,
                name = label_by
            )
    }


    if (!is.na(control_condition)) {
        if (control_condition %in% condition_IDs(tsa_data)) {
            ctrl_Tm <-
                TSA_Tms(analysis_data = tsa_data[tsa_data$condition_ID ==
                    control_condition, ])
            ctrl_Tm <- ctrl_Tm$Avg_Tm
            tsa_plot <- tsa_plot +
                geom_hline(
                    yintercept = ctrl_Tm,
                    color = "grey",
                    linetype = 2,
                    alpha = 0.9
                )
        } else {
            stop("condition_ID assigned to control_condition is not
                found in the TSA data. Use condition_IDs(tsa_data) to get
                 the list of acceptable IDs")
        }
    }

    tsa_plot <- tsa_plot +
        coord_flip() +
        theme_bw() +
        labs(y = expression("T"["m"] ~ "(" * degree * "C)"))

    if (separate_legend) {
        legend_plot <- get_legend(tsa_plot)
        tsa_plot <- tsa_plot + theme(legend.position = "none")
        plot_list <- list(tsa_plot, legend_plot)
        return(plot_list)
    } else {
        tsa_plot <- tsa_plot
        return(tsa_plot)
    }
}

#' Compare TSA curves to control
#'
#'
#' Generate a number of plots based on the input data to compare the average
#' and standard deviation (sd) of each unique condition to a specified
#' control condition. To see all conditions use \code{condition_IDs(tsa_data)}.
#'
#' @inheritParams TSA_average
#' @inheritParams Tm_difference
#' @param show_Tm logical; \code{show_Tm = TRUE} by default. When TRUE, the Tm
#'     is displayed on the plot. When FALSE, the Tm is not added to the plot.
#' @param title_by character string; c("ligand", "protein", "both").
#'     Automatically names the plots by the specified condition category.
#' @param digits integer; the number of decimal places to round for change in Tm
#'     calculations displayed in the subtitle of each plot.
#' @param smoother character; one of \code{c("gam","beta","none")}.
#'     Passed to \code{\link{TSA_average}} to select the aggregate smoother:
#'     \code{"gam"} uses \pkg{mgcv}, \code{"beta"} uses a natural cubic spline
#'     with Beta(a,a) interior knots centered at Tm, and \code{"none"} uses the
#'     unsmoothed average. Default follows \code{TSA_average}.
#' @param beta_shape numeric; shape parameter \eqn{a} for the Beta(a,a) knot
#'     placement when \code{smoother = "beta"}. \code{beta_shape = 3} by default. Passed to \code{TSA_average}.
#' @param beta_n_knots integer or \code{NULL}; number of interior knots when
#'     \code{smoother = "beta"}. If \code{NULL}, uses \code{beta_knots_frac}.
#'     Passed to \code{TSA_average}.
#' @param beta_knots_frac numeric in (0,1); fraction of unique temperatures used
#'     as interior knots when \code{smoother = "beta"} and \code{beta_n_knots}
#'     is \code{NULL}. \code{beta_knots_frac = 0.008} by default. Passed to \code{TSA_average}.
#' @param use_natural logical; if TRUE (default) uses natural cubic spline
#'     basis for the beta method. Passed to \code{TSA_average}.
#'     
#' @return A \emph{named list} of ggplot objects. One plot per non-control
#'     condition is included, each overlaid against the control. A final plot for
#'     the control alone is appended and named \code{"Control: <control_condition>"}. 
#'     
#' @family TSA Plots
#' @seealso \code{\link{merge_TSA}} and \code{\link{normalize_fluorescence}}
#'     for preparing data. See \code{\link{TSA_average}} and
#'     \code{\link{get_legend}} for details on function parameters.
#'     See \code{\link{TSA_wells_plot}} for individual curves of the averaged
#'     conditions shown.
#' @examples
#' data("example_tsar_data")
#' TSA_compare_plot(example_tsar_data,
#'     y = "Fluorescence",
#'     control_condition = "CA FL_DMSO",
#'     smoother = "beta",
#'     beta_shape = 4, 
#'     beta_knots_frac = 0.008
#' )
#' @export
#' 
TSA_compare_plot <- function(
    tsa_data,
    control_condition,
    y = "Fluorescence",
    show_Tm = TRUE,
    title_by = "both",
    digits = 1,
    # new options:
    smoother   = c("gam", "beta", "none"),
    smooth_conditions = TRUE,
    beta_shape      = 3,
    beta_n_knots    = NULL,
    beta_knots_frac = 0.12,
    use_natural     = TRUE
) {

  smoother <- match.arg(smoother)
  y <- match.arg(y, choices = c("Fluorescence", "RFU"))

  if (!"well_ID" %in% names(tsa_data) || !"condition_ID" %in% names(tsa_data)) {
    stop("tsa_data must be a data frame merged by merge_TSA()")
  } else if (!control_condition %in% condition_IDs(tsa_data)) {
    stop("control_condition must be a value from tsa_data$condition_ID")
  }

  Tms_df <- TSA_Tms(tsa_data)
  control_avg <- Tms_df$Avg_Tm[Tms_df$condition_ID == control_condition]

  # --- control curve ---
  control_df <- tsa_data[tsa_data$condition_ID == control_condition, ]
  control_df <- TSA_average(
    tsa_data = control_df, y = y,
    avg_smooth = TRUE, sd_smooth = TRUE,
    smoother = smoother,
    beta_shape = beta_shape, beta_n_knots = beta_n_knots,
    beta_knots_frac = beta_knots_frac, use_natural = use_natural
  )

  # pick safe columns for control plotting
  control_df$y_plot   <- if ("avg_smooth"    %in% names(control_df)) control_df$avg_smooth    else control_df$average
  control_df$ymin_rib <- if ("sd_min_smooth" %in% names(control_df)) control_df$sd_min_smooth else control_df$sd_min
  control_df$ymax_rib <- if ("sd_max_smooth" %in% names(control_df)) control_df$sd_max_smooth else control_df$sd_max

  control_curve <- ggplot(
    control_df,
    aes(x = Temperature, y = y_plot)
  ) +
    geom_ribbon(
      aes(ymin = ymin_rib, ymax = ymax_rib),
      alpha = 0.4, fill = "black"
    ) +
    geom_line(linetype = "dotdash", color = "black")

  # palette
  colfunc <- grDevices::colorRampPalette(c("red", "blue"))
  col_vect <- colfunc(condition_IDs(tsa_data, n = TRUE))

  Tm_difference_DF <- Tm_difference(
    tsa_data = tsa_data,
    control_condition = control_condition
  )

  curve_list <- as.list(rep(NA, condition_IDs(tsa_data, n = TRUE)))

  for (i in seq_len(condition_IDs(tsa_data, n = TRUE))) {
    condition_i <- condition_IDs(tsa_data)[i]
    tm_avg_i <- Tms_df$Avg_Tm[Tms_df$condition_ID == condition_i]

    Tm_diff_i <- Tm_difference_DF$delta_Tm[
      Tm_difference_DF$condition_ID == condition_i
    ]
    Tm_diff_i <- round(Tm_diff_i, digits = digits)

    title_i <- condition_i
    subtitle_i <- paste("Tm = ", Tm_diff_i, "C", sep = "")
    ctrl_subtitle <- control_condition

    if (title_by == "ligand") {
      title_i <- Tm_difference_DF$Ligand[Tm_difference_DF$condition_ID == condition_i]
      ctrl_subtitle <- Tm_difference_DF$Ligand[Tm_difference_DF$condition_ID == control_condition]
    }
    if (title_by == "protein") {
      title_i <- Tm_difference_DF$Protein[Tm_difference_DF$condition_ID == condition_i]
      ctrl_subtitle <- Tm_difference_DF$Protein[Tm_difference_DF$condition_ID == control_condition]
    }
    if (title_by == "both") {
      title_i <- paste(
        Tm_difference_DF$Protein[Tm_difference_DF$condition_ID == condition_i],
        " + ",
        Tm_difference_DF$Ligand[Tm_difference_DF$condition_ID == condition_i],
        sep = ""
      )
      ctrl_subtitle <- paste(
        Tm_difference_DF$Protein[Tm_difference_DF$condition_ID == control_condition],
        " + ",
        Tm_difference_DF$Ligand[Tm_difference_DF$condition_ID == control_condition],
        sep = ""
      )
    }

    cond_df_i <- tsa_data[tsa_data$condition_ID == condition_i, ]
    cond_df_i <- TSA_average(
      tsa_data = cond_df_i, y = y,
      avg_smooth = smooth_conditions, sd_smooth = smooth_conditions,
      smoother = smoother,
      beta_shape = beta_shape, beta_n_knots = beta_n_knots,
      beta_knots_frac = beta_knots_frac, use_natural = use_natural
    )

    # pick safe columns for condition plotting
    cond_df_i$y_line   <- if (smooth_conditions && "avg_smooth"    %in% names(cond_df_i)) cond_df_i$avg_smooth    else cond_df_i$average
    cond_df_i$ymin_rib <- if (smooth_conditions && "sd_min_smooth" %in% names(cond_df_i)) cond_df_i$sd_min_smooth else cond_df_i$sd_min
    cond_df_i$ymax_rib <- if (smooth_conditions && "sd_max_smooth" %in% names(cond_df_i)) cond_df_i$sd_max_smooth else cond_df_i$sd_max

    diff_curve_i <- control_curve +
      geom_ribbon(
        data = cond_df_i,
        aes(x = Temperature, ymin = ymin_rib, ymax = ymax_rib),
        alpha = 0.4, fill = col_vect[i],
        inherit.aes = FALSE
      ) +
      geom_line(
        data = cond_df_i,
        aes(x = Temperature, y = y_line),
        color = col_vect[i],
        inherit.aes = FALSE
      )

    # --- FIX: restore BOTH vertical Tm lines in comparison plots (control + condition) ---
    if (isTRUE(show_Tm)) {
      ctrl_tm <- suppressWarnings(as.numeric(control_avg[1]))
      cond_tm <- suppressWarnings(as.numeric(tm_avg_i[1]))

      if (is.finite(ctrl_tm)) {
        diff_curve_i <- diff_curve_i +
          geom_vline(xintercept = ctrl_tm, linetype = "dashed", color = "black")
      }
      if (is.finite(cond_tm)) {
        diff_curve_i <- diff_curve_i +
          geom_vline(xintercept = cond_tm, linetype = "dashed", color = col_vect[i])
      }
    }

    diff_curve_i <- diff_curve_i +
      theme_bw() +
      labs(title = title_i, subtitle = subtitle_i)

    curve_list[[i]] <- diff_curve_i
    names(curve_list)[i] <- condition_i
  }

  # -- add control plot to list
  control_curve <- control_curve +
    labs(title = "Control", subtitle = ctrl_subtitle) +
    theme_bw()

  if (isTRUE(show_Tm)) {
    ctrl_tm <- suppressWarnings(as.numeric(control_avg[1]))
    if (is.finite(ctrl_tm)) {
      control_curve <- control_curve +
        geom_vline(xintercept = ctrl_tm, linetype = "dashed", color = "#BC9595")
    }
  }

  curve_list[names(curve_list) == control_condition] <- NULL
  curve_list[[length(curve_list) + 1]] <- control_curve
  names(curve_list)[length(curve_list)] <- paste("Control: ", control_condition, sep = "")

  return(curve_list)
}
