# -----------------------------------------------------------------------------
# analyze_mm_categorical.R - Monte-Carlo summary for categorical moderation ---
# -----------------------------------------------------------------------------
# This reworked version avoids the earlier "unused arguments" error by matching
# helper signatures exactly.  It replicates the original public interface so
# no downstream code or tests need to change.
# -----------------------------------------------------------------------------

#' Summarise Monte-Carlo draws for categorical moderators
#'
#' Internal engine used by .make_moderation() when the moderator *W* is
#' categorical. It converts a matrix of MC draws into conditional indirect
#' effects, contrasts, path-level statistics, and overall effects for each
#' moderator group.
#'
#' @param mc_result  A numeric matrix *or* a semmcci
#'                   object whose thetahatstar element is such a matrix.
#' @param prepared_data Data frame returned by PrepareData() (or equivalent)
#'                      containing pre-processed variables and a W_info
#'                      attribute with group information.
#' @param MP       Character vector of path names that are moderated (e.g.,
#'                 "a1", "b1", "d1", "cp", "b_1_2").  If NULL no moderation
#'                 is applied.
#' @param ci_level Numeric scalar in (0, 1) giving the confidence level for
#'                 percentile intervals (default 0.95).
#' @param digits   Integer; number of decimal places to keep when rounding.
#' @param debug    Logical; if TRUE prints progress messages.
#'
#' @return A list with components identical to the original implementation but
#'         built entirely from English messages and calling .cat_* helpers.
#' @keywords internal
#' @noRd
analyze_mm_categorical <- function(mc_result, prepared_data,
                                   MP       = NULL,
                                   ci_level = 0.95,
                                   digits   = 8,
                                   debug    = FALSE) {

  # ---- local message helper --------------------------------------------
  msg <- function(fmt, ...) if (debug) message(sprintf(fmt, ...))

  # ---- 0. Helper aliases ------------------------------------------------
  make_ci_names      <- .cat_make_ci_names
  summarize_vec      <- .cat_summarize_vec
  add_sig            <- .cat_add_sig
  pack_df            <- .cat_pack_df
  get_indirect_paths <- .cat_get_indirect_paths
  get_mod_prefix     <- .cat_get_mod_prefix

  # Wrapper that satisfies .cat_apply_mod() signature exactly
  apply_mod <- function(base, group) {
    .cat_apply_mod(
      theta         = theta,
      prepared_data = prepared_data,
      base          = base,
      group         = group,
      MP            = MP,
      grp_var       = grp_var
    )
  }

  # ---- 1. Validate & basic objects -------------------------------------
  theta <- as.matrix(mc_result)
  if (nrow(theta) == 0L) stop("Sampling matrix is empty.")

  Winfo <- attr(prepared_data, "W_info")
  if (is.null(Winfo)) stop("prepared_data is missing W_info attribute.")
  grp_var <- if (!is.null(Winfo$factor_name)) Winfo$factor_name else Winfo$raw
  groups  <- sort(unique(prepared_data[[grp_var]]))

  msg("Moderator groups: %s", paste(groups, collapse = ", "))
  msg("theta dim: %d * %d", nrow(theta), ncol(theta))

  # ---- 2. Parse paths ---------------------------------------------------
  pars_all  <- colnames(theta)
  paths_all <- get_indirect_paths(pars_all)
  paths     <- if (length(MP))
    Filter(function(p) any(p$coefs %in% MP), paths_all)
  else paths_all

  msg("indirect paths (all / filtered): %d / %d",
      length(paths_all), length(paths))

  # ---- 3. Containers ----------------------------------------------------
  cond_IE  <- IE_ct       <- list()
  path_lv  <- path_ct     <- list()
  tot_ind_vals <- setNames(vector("list", length(groups)), groups)
  for (g in groups) tot_ind_vals[[g]] <- numeric(nrow(theta))

  # ---- 4. Each indirect path -------------------------------------------
  for (pth in paths) {
    vals_g <- list()
    for (g in groups) {
      comp_list <- lapply(pth$coefs, function(cn) apply_mod(cn, g))
      vals_g[[g]] <- Reduce(`*`, comp_list)
      tot_ind_vals[[g]] <- tot_ind_vals[[g]] + vals_g[[g]]
    }
    ## conditional IE rows
    for (g in groups)
      cond_IE[[length(cond_IE) + 1L]] <- data.frame(
        IE    = pth$path_name,
        Group = g,
        summarize_vec(vals_g[[g]]),
        check.names = FALSE
      )
    ## contrasts
    for (pr in utils::combn(groups, 2, simplify = FALSE)) {
      diffv <- vals_g[[pr[2]]] - vals_g[[pr[1]]]
      IE_ct[[length(IE_ct) + 1L]] <- data.frame(
        IE       = pth$path_name,
        Contrast = paste(pr[2], "-", pr[1]),
        summarize_vec(diffv),
        check.names = FALSE
      )
    }
  }

  # ---- 5. Path-level coefficients --------------------------------------
  for (base in MP) {
    for (g in groups)
      path_lv[[length(path_lv) + 1L]] <- data.frame(
        Path  = base,
        Group = g,
        summarize_vec(apply_mod(base, g)),
        check.names = FALSE
      )
    for (pr in utils::combn(groups, 2, simplify = FALSE)) {
      diffv <- apply_mod(base, pr[2]) - apply_mod(base, pr[1])
      path_ct[[length(path_ct) + 1L]] <- data.frame(
        Path     = base,
        Contrast = paste(pr[2], "-", pr[1]),
        summarize_vec(diffv),
        check.names = FALSE
      )
    }
  }

  # ---- 6. Total indirect & total effects --------------------------------
  dir_vals     <- lapply(groups, function(g) apply_mod("cp", g))
  names(dir_vals) <- groups
  tot_eff_vals <- Map(`+`, dir_vals, tot_ind_vals)

  cond_overall <- overall_ct <- list()
  for (g in groups) {
    cond_overall[[length(cond_overall) + 1L]] <- data.frame(
      Effect = "total_indirect", Group = g,
      summarize_vec(tot_ind_vals[[g]]), check.names = FALSE)
    cond_overall[[length(cond_overall) + 1L]] <- data.frame(
      Effect = "total_effect", Group = g,
      summarize_vec(tot_eff_vals[[g]]), check.names = FALSE)
  }
  for (pr in utils::combn(groups, 2, simplify = FALSE)) {
    hi <- pr[2]; lo <- pr[1]; lab <- paste(hi, "-", lo)
    overall_ct[[length(overall_ct) + 1L]] <- data.frame(
      Effect = "total_indirect", Contrast = lab,
      summarize_vec(tot_ind_vals[[hi]] - tot_ind_vals[[lo]]),
      check.names = FALSE)
    overall_ct[[length(overall_ct) + 1L]] <- data.frame(
      Effect = "total_effect", Contrast = lab,
      summarize_vec(tot_eff_vals[[hi]] - tot_eff_vals[[lo]]),
      check.names = FALSE)
  }

  cond_overall <- pack_df(cond_overall, "conditional_overall")
  overall_ct   <- pack_df(overall_ct,   "overall_contrasts")

  ## order rows for readability
  if (!is.null(cond_overall))
    cond_overall <- cond_overall[
      order(match(cond_overall$Effect,
                  c("total_indirect", "total_effect"))), ]
  if (!is.null(overall_ct))
    overall_ct   <- overall_ct[
      order(match(overall_ct$Effect,
                  c("total_indirect", "total_effect"))), ]

  # ---- 7. Return --------------------------------------------------------
  list(
    type                 = "categorical",
    conditional_IE       = pack_df(cond_IE,  "conditional_IE"),
    IE_contrasts         = pack_df(IE_ct,    "IE_contrasts"),
    conditional_overall  = cond_overall,
    overall_contrasts    = overall_ct,
    extra = list(
      path_levels    = pack_df(path_lv, "path_levels"),
      path_contrasts = pack_df(path_ct, "path_contrasts")
    )
  )
}



# ===== Helper functions =======================================================

#' Make CI column names like "2.5%CI.Lo" / "97.5%CI.Up"
#' @keywords internal
#' @noRd
.cat_make_ci_names <- function(ci) {
  lo <- paste0(formatC((1 - ci) / 2 * 100, format = "f", digits = 1), "%CI.Lo")
  up <- paste0(formatC((1 + ci) / 2 * 100, format = "f", digits = 1), "%CI.Up")
  c(lo, up)
}

#' Summarise a numeric vector (mean, SD, CI)
#' @keywords internal
#' @noRd
.cat_summarize_vec <- function(v, ci_level = .95, digits = 8) {
  qs  <- stats::quantile(v, c((1 - ci_level)/2, (1 + ci_level)/2), names = FALSE)
  out <- data.frame(Estimate = mean(v), SE = stats::sd(v), check.names = FALSE)
  ci_names <- .cat_make_ci_names(ci_level)
  out[[ci_names[1]]] <- qs[1]
  out[[ci_names[2]]] <- qs[2]
  out[] <- lapply(out, round, digits)
  out
}

#' Add significance flag
#' @keywords internal
#' @noRd
.cat_add_sig <- function(df, ci_level = .95) {
  ci_names <- .cat_make_ci_names(ci_level)
  if (all(ci_names %in% names(df)))
    df$Sig <- ifelse(df[[ci_names[1]]] * df[[ci_names[2]]] > 0, "*", "")
  df
}

#' Combine list of data frames
#' @keywords internal
#' @noRd
.cat_pack_df <- function(lst, lbl, ci_level = .95) {
  if (!length(lst)) return(NULL)
  .cat_add_sig(do.call(rbind, lst), ci_level = ci_level)
}

#' Parse coefficient names into possible indirect paths
#' @keywords internal
#' @noRd
.cat_get_indirect_paths <- function(col_names) {
  a  <- grep("^a_?\\d+$",       col_names, value = TRUE)  # a1 / a_1
  b  <- grep("^b_?\\d+$",       col_names, value = TRUE)  # b1 / b_1
  bn <- grep("^b_\\d+_\\d+$",   col_names, value = TRUE) # b_1_2

  edges <- data.frame(src = character(),
                      tgt = character(),
                      label = character())
  for (ai in a) {
    mi <- sub("^a_?", "", ai)
    edges <- rbind(edges, data.frame(src = "X",
                                     tgt = paste0("M", mi, "diff"),
                                     label = ai))
  }
  for (bi in b) {
    mi <- sub("^b_?", "", bi)
    edges <- rbind(edges, data.frame(src = paste0("M", mi, "diff"),
                                     tgt = "Y",
                                     label = bi))
  }
  for (bij in bn) {
    idx <- unlist(regmatches(bij, gregexpr("\\d+", bij)))
    edges <- rbind(edges, data.frame(src = paste0("M", idx[1], "diff"),
                                     tgt = paste0("M", idx[2], "diff"),
                                     label = bij))
  }
  graph <- split(edges, edges$src)

  dfs <- function(path) {
    last <- tail(path, 1)
    if (last == "Y") return(list(path))
    if (!last %in% names(graph)) return(list())
    out <- list()
    for (tgt in graph[[last]]$tgt) {
      if (tgt %in% path) next                    # avoid cycles
      out <- c(out, dfs(c(path, tgt)))
    }
    out
  }
  raw_paths <- dfs("X")
  if (!length(raw_paths)) return(list())

  unique(lapply(raw_paths, function(nodes) {
    coefs     <- mediators <- character()
    for (i in seq_len(length(nodes) - 1)) {
      edge  <- edges[edges$src == nodes[i] & edges$tgt == nodes[i + 1], ]
      coefs <- c(coefs, edge$label[1])
      if (grepl("^M\\d+diff$", nodes[i]))
        mediators <- c(mediators, sub("^M(\\d+)diff$", "\\1", nodes[i]))
    }
    list(
      path_name = paste0("indirect_", paste(mediators, collapse = "_")),
      coefs     = coefs,
      mediators = paste(mediators, collapse = " ")
    )
  }))
}

#' Derive moderation-coefficient prefix
#' @keywords internal
#' @noRd
.cat_get_mod_prefix <- function(base) {
  if (base == "cp")                         return("cpw")
  if (grepl("^a_?\\d+$", base))             return(paste0("aw", sub("^a_?", "", base)))
  if (grepl("^b_?\\d+$", base))             return(paste0("bw", sub("^b_?", "", base)))
  if (grepl("^d_?\\d+$", base))             return(paste0("dw", sub("^d_?", "", base)))
  if (grepl("^b_\\d+_\\d+$", base))         return(paste0("bw_", sub("^b_", "", base)))
  if (grepl("^d_\\d+_\\d+$", base))         return(paste0("dw_", sub("^d_", "", base)))
  base
}

#' Apply moderation for a specific group
#' @keywords internal
#' @noRd
.cat_apply_mod <- function(theta, prepared_data, base, group, MP, grp_var) {
  base_vec <- theta[, base, drop = TRUE]
  if (length(MP) && base %in% MP) {
    prefix <- .cat_get_mod_prefix(base)
    mods   <- grep(paste0("^", prefix, "_W\\d+$"),
                   colnames(theta), value = TRUE)
    for (m in mods) {
      dm   <- sub(".*_(W\\d+)$", "\\1", m)  # e.g., W1
      wval <- unique(prepared_data[[dm]][prepared_data[[grp_var]] == group])
      stopifnot(length(wval) == 1)
      base_vec <- base_vec + theta[, m] * wval
    }
  }
  base_vec
}

#' Fix names like 2.5.CI.Lo --> 2.5%CI.Lo
#' @keywords internal
#' @noRd
fix_pct_names <- function(df) {
  if (is.null(df) || !is.data.frame(df)) return(df)
  names(df) <- sub("^X([0-9.]+)(\\.CI\\.(Lo|Up))",
                   "\\1%CI.\\2", names(df))
  names(df) <- sub("([0-9]+)\\.CI\\.(Lo|Up)",
                   "\\1%CI.\\2", names(df))
  df
}
