// SVG Package - C++ Implementation for Spatial Statistics
//
// Author: Zaoqu Liu
// License: MIT
//
// This file provides optimized C++ implementations of key statistical
// functions used in SVG detection, particularly Moran's I calculation.
//
// Features:
// - Armadillo linear algebra for efficient matrix operations
// - Vectorized computations where possible
// - Optimized memory access patterns
//
// Note: We intentionally avoid OpenMP to prevent conflicts with R's
// parallel package (mclapply). Parallelization should be done at R level.

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;
using namespace arma;

//' Fast Row-wise Moran's I Calculation
//'
//' Computes Moran's I statistic for multiple genes (rows) against a
//' precomputed spatial weights matrix.
//'
//' @param expr_matrix Numeric matrix (genes x spots)
//' @param W Spatial weights matrix (spots x spots)
//'
//' @return Numeric vector of Moran's I values for each gene
//'
//' @details
//' This function is optimized using Armadillo linear algebra library.
//' It computes: I = (n/W_sum) * (z' * W * z) / (z' * z)
//' where z = x - mean(x) for each gene.
//'
//' @keywords internal
// [[Rcpp::export]]
NumericVector moranI_cpp(const arma::mat& expr_matrix,
                         const arma::mat& W) {

    int n_genes = expr_matrix.n_rows;
    int n_spots = expr_matrix.n_cols;

    // Pre-compute weight matrix sum
    double W_sum = accu(W);

    // Output vector
    NumericVector I_values(n_genes);

    // Loop over genes
    for (int g = 0; g < n_genes; g++) {
        // Get expression vector for this gene
        rowvec x = expr_matrix.row(g);

        // Center the expression values
        double x_mean = mean(x);
        rowvec z = x - x_mean;

        // Compute variance
        double ss = dot(z, z);

        if (ss < 1e-10) {
            // Zero variance - undefined Moran's I
            I_values[g] = NA_REAL;
            continue;
        }

        // Compute spatial covariance using matrix multiplication
        // cv = sum_i sum_j w_ij * z_i * z_j = z' * W * z'
        double cv = as_scalar(z * W * z.t());

        // Moran's I
        double I_obs = (n_spots / W_sum) * (cv / ss);
        I_values[g] = I_obs;
    }

    return I_values;
}


//' Fast Moran's I with Full Statistics
//'
//' Computes Moran's I along with expected value and standard deviation
//' under the null hypothesis. Optimized with Armadillo vectorization.
//'
//' @param expr_matrix Numeric matrix (genes x spots)
//' @param W Spatial weights matrix (spots x spots), row-standardized
//'
//' @return List containing observed I, expected I, and standard deviation
//'
//' @details
//' This function pre-computes all weight matrix statistics once, then
//' loops over genes efficiently. The main bottleneck is the z * W * z'
//' computation which is O(n^2) per gene.
//'
//' @keywords internal
// [[Rcpp::export]]
List moranI_full_cpp(const arma::mat& expr_matrix,
                     const arma::mat& W) {

    int n_genes = expr_matrix.n_rows;
    int n_spots = expr_matrix.n_cols;
    double n = (double)n_spots;

    // Pre-compute weight matrix statistics (done once, O(n^2))
    double W_sum = accu(W);

    // S1 = 0.5 * sum((W + W')^2) - for variance calculation
    mat W_symm = W + W.t();
    double S1 = 0.5 * accu(W_symm % W_symm);

    // S2 = sum((rowSums(W) + colSums(W))^2)
    vec row_sums = sum(W, 1);
    vec col_sums = sum(W, 0).t();
    vec margin_sums = row_sums + col_sums;
    double S2 = accu(margin_sums % margin_sums);

    // Expected I under null (constant for all genes)
    double E_I = -1.0 / (n - 1.0);

    // Pre-compute common denominator terms
    double denom = (n - 1.0) * (n - 2.0) * (n - 3.0) * W_sum * W_sum;
    double W_sum_sq = W_sum * W_sum;

    // Output vectors
    NumericVector I_obs_vec(n_genes);
    NumericVector E_I_vec(n_genes, E_I);
    NumericVector sd_I_vec(n_genes);

    // Loop over genes (consider R-level parallelization for large gene sets)
    for (int g = 0; g < n_genes; g++) {
        rowvec x = expr_matrix.row(g);
        double x_mean = mean(x);
        rowvec z = x - x_mean;

        double m2 = dot(z, z);  // sum(z^2)

        if (m2 < 1e-10) {
            I_obs_vec[g] = NA_REAL;
            sd_I_vec[g] = NA_REAL;
            continue;
        }

        double m4 = dot(z % z, z % z);  // sum(z^4)

        // Moran's I - optimized: compute W * z' first, then dot with z
        vec Wz = W * z.t();
        double cv = dot(z, Wz.t());
        double I_obs = (n / W_sum) * (cv / m2);
        I_obs_vec[g] = I_obs;

        // Variance calculation (Cliff & Ord formula)
        double S3 = (m4 / n) / ((m2 / n) * (m2 / n));
        double S4 = (n * n - 3.0 * n + 3.0) * S1 - n * S2 + 3.0 * W_sum_sq;
        double S5 = (n * n - n) * S1 - 2.0 * n * S2 + 6.0 * W_sum_sq;

        double E_I2 = (n * S4 - S3 * S5) / denom;
        double V_I = E_I2 - E_I * E_I;

        sd_I_vec[g] = sqrt(std::max(V_I, 0.0));
    }

    return List::create(
        Named("observed") = I_obs_vec,
        Named("expected") = E_I_vec,
        Named("sd") = sd_I_vec
    );
}


//' Fast Binarization using K-means (k=2)
//'
//' Binarizes each row of an expression matrix using k-means clustering.
//'
//' @param expr_matrix Numeric matrix (genes x spots)
//' @param max_iter Maximum iterations for k-means
//'
//' @return Integer matrix of binary values (0/1)
//'
//' @keywords internal
// [[Rcpp::export]]
IntegerMatrix binarize_kmeans_cpp(const arma::mat& expr_matrix,
                                   int max_iter = 20) {

    int n_genes = expr_matrix.n_rows;
    int n_spots = expr_matrix.n_cols;

    IntegerMatrix result(n_genes, n_spots);

    for (int g = 0; g < n_genes; g++) {
        rowvec x = expr_matrix.row(g);

        // Initialize centers at min and max
        double c1 = x.min();
        double c2 = x.max();

        if (std::abs(c2 - c1) < 1e-10) {
            // Constant expression - all zeros
            for (int j = 0; j < n_spots; j++) {
                result(g, j) = 0;
            }
            continue;
        }

        uvec assignments(n_spots);

        // K-means iterations
        for (int iter = 0; iter < max_iter; iter++) {
            // Assignment step
            for (int j = 0; j < n_spots; j++) {
                double d1 = std::abs(x(j) - c1);
                double d2 = std::abs(x(j) - c2);
                assignments(j) = (d2 < d1) ? 1 : 0;
            }

            // Update step
            double sum1 = 0, sum2 = 0;
            int count1 = 0, count2 = 0;

            for (int j = 0; j < n_spots; j++) {
                if (assignments(j) == 0) {
                    sum1 += x(j);
                    count1++;
                } else {
                    sum2 += x(j);
                    count2++;
                }
            }

            double new_c1 = (count1 > 0) ? sum1 / count1 : c1;
            double new_c2 = (count2 > 0) ? sum2 / count2 : c2;

            // Check convergence
            if (std::abs(new_c1 - c1) < 1e-8 && std::abs(new_c2 - c2) < 1e-8) {
                break;
            }

            c1 = new_c1;
            c2 = new_c2;
        }

        // Assign to high cluster (1) if in cluster with higher mean
        int high_cluster = (c2 > c1) ? 1 : 0;

        for (int j = 0; j < n_spots; j++) {
            result(g, j) = (assignments(j) == (uword)high_cluster) ? 1 : 0;
        }
    }

    return result;
}


//' Fast Distance Matrix Computation
//'
//' Computes pairwise Euclidean distance matrix for spatial coordinates.
//'
//' @param coords Numeric matrix of coordinates (spots x 2)
//'
//' @return Symmetric distance matrix
//'
//' @keywords internal
// [[Rcpp::export]]
arma::mat dist_matrix_cpp(const arma::mat& coords) {

    int n = coords.n_rows;
    mat D(n, n, fill::zeros);

    for (int i = 0; i < n; i++) {
        for (int j = i + 1; j < n; j++) {
            double dx = coords(i, 0) - coords(j, 0);
            double dy = coords(i, 1) - coords(j, 1);
            double d = sqrt(dx * dx + dy * dy);

            D(i, j) = d;
            D(j, i) = d;
        }
    }

    return D;
}


//' Build KNN Adjacency Matrix
//'
//' Constructs K-nearest neighbor adjacency matrix.
//'
//' @param coords Numeric matrix of coordinates
//' @param k Number of neighbors
//'
//' @return Binary adjacency matrix
//'
//' @keywords internal
// [[Rcpp::export]]
arma::mat knn_adj_cpp(const arma::mat& coords, int k) {

    int n = coords.n_rows;
    mat D = dist_matrix_cpp(coords);
    mat adj(n, n, fill::zeros);

    for (int i = 0; i < n; i++) {
        vec dists = D.row(i).t();
        uvec sorted_idx = sort_index(dists);

        // Skip self (index 0 after sorting will be self with dist=0)
        for (int j = 1; j <= k && j < n; j++) {
            int neighbor_idx = sorted_idx(j);
            adj(i, neighbor_idx) = 1.0;
            adj(neighbor_idx, i) = 1.0;  // Make symmetric
        }
    }

    return adj;
}


//' Compute Fisher's Exact Test for Spatial Enrichment
//'
//' Fast computation of contingency table and Fisher test p-value
//' for binary spatial enrichment.
//'
//' @param bin_vec Binary vector for one gene
//' @param from_idx Vector of source indices for edges
//' @param to_idx Vector of target indices for edges
//'
//' @return List with contingency table counts, odds ratio, and p-value
//'
//' @keywords internal
// [[Rcpp::export]]
List fisher_spatial_cpp(const arma::ivec& bin_vec,
                        const arma::uvec& from_idx,
                        const arma::uvec& to_idx) {

    int n_edges = from_idx.n_elem;

    // Count contingency table entries
    int n00 = 0, n01 = 0, n10 = 0, n11 = 0;

    for (int e = 0; e < n_edges; e++) {
        int b_from = bin_vec(from_idx(e));
        int b_to = bin_vec(to_idx(e));

        if (b_from == 0 && b_to == 0) n00++;
        else if (b_from == 0 && b_to == 1) n01++;
        else if (b_from == 1 && b_to == 0) n10++;
        else n11++;
    }

    // Calculate odds ratio
    double OR = (n11 * n00) / std::max((double)(n10 * n01), 1e-10);

    // For p-value, we return the counts and let R compute Fisher test
    // (Fisher test implementation in C++ is complex)

    return List::create(
        Named("n00") = n00,
        Named("n01") = n01,
        Named("n10") = n10,
        Named("n11") = n11,
        Named("odds_ratio") = OR
    );
}
