#' quantbayes plotting utilities
#'
#' @description
#' Produces diagnostic plots: global density, overlay density, evidence matrix,
#' p_hat, and theta credible intervals.
#'
#' @param res Result from quant_es_core.
#' @param x_matrix Evidence matrix used for the run.
#' @param top_n Number of variants for matrix and summary plots.
#' @param top_overlay Number of top variants used in overlay density.
#' @param highlight_points Optional list of highlighted variants.
#' @param palette10 Colour palette for overlay density lines.
#' @param palette20 Colour palette for p_hat plot.
#'
#' @return A list of ggplot objects.
#'
#' @importFrom grDevices colorRampPalette
#' @importFrom stats density rbeta
#' @importFrom dplyr %>% arrange desc slice_head pull group_by summarise filter mutate
#' @importFrom tidyr pivot_longer
#' @importFrom tibble tibble rownames_to_column
#' @import ggplot2
#'
#' @export
quant_es_plots <- function(
    res,
    x_matrix,
    top_n = 20,
    top_overlay = 10,
    highlight_points = NULL,
    palette10 = grDevices::colorRampPalette(
      c("#2f4356", "#656d87", "#f1e1d4", "#ffbf00", "#ee4035")
    )(10),
    palette20 = grDevices::colorRampPalette(
      c("#656d87", "#2f4356")
    )(20)
) {
  
  variants <- res$variants
  global <- res$global
  
  # Global density plot
  p_global <- ggplot2::ggplot(variants, ggplot2::aes(theta_mean)) +
    ggplot2::geom_density(fill = "#2f4356", alpha = 0.8, colour = "black") +
    ggplot2::theme_bw() +
    ggplot2::labs(
      title = "Global posterior theta distribution",
      subtitle = sprintf(
        "mean=%0.3f median=%0.3f CrI=%0.3f to %0.3f",
        global$mean_theta, global$median_theta,
        global$lower_theta, global$upper_theta
      ),
      x = "theta",
      y = "density"
    )
  
  # Top variant draws
  top_ids <- variants %>%
    dplyr::arrange(dplyr::desc(theta_mean)) %>%
    dplyr::slice_head(n = top_overlay) %>%
    dplyr::pull(variant_id)
  
  theta_draw_df <- purrr::map_dfr(top_ids, function(id) {
    row <- variants[variants$variant_id == id, , drop = FALSE]
    tibble::tibble(
      variant_id = id,
      theta = stats::rbeta(2000, 1 + row$k, 1 + (row$m - row$k))
    )
  })
  
  # Make sure palette10 is long enough for the overlay variants
  n_overlay <- length(unique(theta_draw_df$variant_id))
  if (length(palette10) < n_overlay) {
    palette10 <- rep(palette10, length.out = n_overlay)
  }
  
  # Peak detection
  get_peak <- function(theta, id) {
    d <- stats::density(theta)
    tibble::tibble(
      variant_id = id,
      peak_x = d$x[which.max(d$y)],
      peak_y = max(d$y)
    )
  }
  
  peak_top <- theta_draw_df %>%
    dplyr::group_by(variant_id) %>%
    dplyr::summarise(
      peak_x = {
        d <- stats::density(theta)
        d$x[which.max(d$y)]
      },
      peak_y = {
        d <- stats::density(theta)
        max(d$y)
      },
      .groups = "drop"
    )
  
  # Highlight helper
  get_peak_for_variant <- function(id) {
    if (is.null(id) || is.na(id) || !id %in% variants$variant_id) return(NULL)
    if (id %in% peak_top$variant_id) {
      return(dplyr::filter(peak_top, variant_id == id))
    }
    row <- variants[variants$variant_id == id, , drop = FALSE]
    th <- stats::rbeta(2000, 1 + row$k, 1 + (row$m - row$k))
    get_peak(th, id)
  }
  
  if (is.null(highlight_points) || length(highlight_points) == 0) {
    highlight_df <- tibble::tibble(
      variant_id = character(),
      peak_x = numeric(),
      peak_y = numeric(),
      highlight_fill = character(),
      highlight_size = numeric()
    )
  } else {
    highlight_df <- purrr::map_dfr(highlight_points, function(h) {
      peak <- get_peak_for_variant(h$id)
      if (is.null(peak) || nrow(peak) == 0) return(NULL)
      peak$highlight_fill <- h$colour
      peak$highlight_size <- h$size
      peak
    })
  }
  
  # Overlay plot
  p_overlay <- ggplot2::ggplot() +
    ggplot2::geom_density(
      data = variants,
      ggplot2::aes(theta_mean),
      fill = "#2f4356",
      colour = "black",
      alpha = 1
    ) +
    ggplot2::geom_density(
      data = theta_draw_df,
      ggplot2::aes(theta, colour = variant_id),
      size = 0.9,
      alpha = 0.55
    ) +
    ggplot2::scale_colour_manual(values = palette10) +
    ggplot2::theme_bw() +
    ggplot2::labs(
      title = "Posterior theta distribution with top candidates overlaid",
      subtitle = paste0("Top ", top_overlay, " variants"),
      x = "theta",
      y = "density"
    )
  
  # Add highlighted peaks with their own fill colours, not tied to the colour scale
  if (nrow(highlight_df) > 0) {
    p_overlay <- p_overlay +
      ggplot2::geom_point(
        data = highlight_df,
        ggplot2::aes(
          x = peak_x,
          y = peak_y,
          fill = highlight_fill,
          size = highlight_size
        ),
        inherit.aes = FALSE,
        shape = 21,
        colour = "black"
      ) +
      ggplot2::scale_fill_identity(
        guide = "legend",
        breaks = unique(highlight_df$highlight_fill),
        labels = unique(highlight_df$variant_id)
      )
  }
  
  # Evidence matrix
  evid_rank <- variants %>%
    dplyr::arrange(dplyr::desc(k)) %>%
    dplyr::slice_head(n = top_n) %>%
    dplyr::pull(variant_id)
  
  mat_long <- x_matrix[evid_rank, , drop = FALSE] %>%
    as.data.frame() %>%
    tibble::rownames_to_column("variant_id") %>%
    tidyr::pivot_longer(-variant_id, names_to = "rule", values_to = "value")
  
  p_matrix <- ggplot2::ggplot(
    mat_long,
    ggplot2::aes(rule, factor(variant_id, levels = rev(evid_rank)), fill = factor(value))
  ) +
    ggplot2::geom_tile(colour = "black") +
    ggplot2::scale_fill_manual(
      values = c("0" = "#ee4035", "1" = "#53c000", "NA" = "#ffbf00"),
      na.value = "#ffbf00"
    ) +
    ggplot2::theme_bw() +
    ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, vjust = 0.5)) +
    ggplot2::labs(
      title = "Evidence matrix",
      subtitle = paste0("Top ", top_n, " variants"),
      x = "evidence rule",
      y = "variant"
    )
  
  # p_hat plot
  df_p <- variants %>%
    dplyr::mutate(p_hat = k / m)
  
  top_p <- df_p %>%
    dplyr::arrange(dplyr::desc(p_hat)) %>%
    dplyr::slice_head(n = top_n)
  
  top_p$variant_id <- factor(top_p$variant_id, levels = rev(top_p$variant_id))
  
  # Make sure palette20 is long enough for the p_hat bars
  n_p <- length(unique(top_p$variant_id))
  if (length(palette20) < n_p) {
    palette20 <- rep(palette20, length.out = n_p)
  }
  
  p_p_hat <- ggplot2::ggplot(top_p, ggplot2::aes(variant_id, p_hat, fill = variant_id)) +
    ggplot2::geom_col(colour = "black") +
    ggplot2::scale_fill_manual(values = palette20) +
    ggplot2::coord_flip() +
    ggplot2::theme_bw() +
    ggplot2::labs(
      title = "Observed evidence proportion",
      subtitle = paste0("Top ", top_n, " variants"),
      x = "variant",
      y = "p_hat"
    ) +
    ggplot2::guides(fill = "none")
  
  # Credible intervals
  top_ci <- variants %>%
    dplyr::arrange(dplyr::desc(theta_mean)) %>%
    dplyr::slice_head(n = top_n)
  
  top_ci$variant_id <- factor(top_ci$variant_id, levels = rev(top_ci$variant_id))
  
  p_theta_ci <- ggplot2::ggplot(top_ci, ggplot2::aes(variant_id, theta_mean)) +
    ggplot2::geom_pointrange(
      ggplot2::aes(ymin = theta_lower, ymax = theta_upper),
      colour = "#2f4356"
    ) +
    ggplot2::coord_flip() +
    ggplot2::theme_bw() +
    ggplot2::labs(
      title = "Posterior theta credible intervals",
      subtitle = paste0("Top ", top_n, " variants"),
      x = "variant",
      y = "theta"
    )
  
  list(
    p_global = p_global,
    p_overlay = p_overlay,
    p_matrix = p_matrix,
    p_p_hat = p_p_hat,
    p_theta_ci = p_theta_ci
  )
}
