library(distfreereg)

### Grouping

set.seed(20250610)
# Define a data frame with repeated covariate values.
df <- as.data.frame(rbind(
  c(1,3),# 1:  1, 1
  c(2,3),# 2:  3, 2
  c(2,4),# 3:  4, 2
  c(1,4),# 4:  2, 1
  c(1,4),# 5:  2, 1
  c(1,4),# 6:  2, 1
  c(1,3),# 7:  1, 1
  c(2,4),# 8:  4, 2
  c(2,4),# 9:  4, 2
  c(2,3) # 10: 3, 2
))

n <- nrow(df)
df <- cbind(data.frame(Y = rnorm(n)), df)
names(df) <- c("y", "x1", "x2")

form <- y ~ x1 + x2
m <- lm(formula = form, data = df)

dfr_form_lm_natural <- distfreereg(test_mean = m, ordering = "natural",
                                   group = TRUE, B = 10)

dfr_form_lm_natural
stopifnot(identical(dfr_form_lm_natural[["res_order"]],
                    as.integer(c(1, 3, 4, 2, 2, 2, 1, 4, 4, 3))))

# Verify that epsp from grouped dfr is correctly obtained from transformed residuals
groupit_natural_1 <- ((sum(dfr_form_lm_natural$residuals$transformed[c(1,7)])/sqrt(2))/sqrt(4))
stopifnot(all.equal(groupit_natural_1, dfr_form_lm_natural$epsp[1]))

groupit_natural_2 <- groupit_natural_1 + ((sum(dfr_form_lm_natural$residuals$transformed[4:6])/sqrt(3))/sqrt(4))
stopifnot(all.equal(groupit_natural_2, dfr_form_lm_natural$epsp[2]))

groupit_natural_3 <- groupit_natural_2 + ((sum(dfr_form_lm_natural$residuals$transformed[c(2,10)])/sqrt(2))/sqrt(4))
stopifnot(all.equal(groupit_natural_3, dfr_form_lm_natural$epsp[3]))

groupit_natural_4 <- groupit_natural_3 + ((sum(dfr_form_lm_natural$residuals$transformed[c(3,8,9)])/sqrt(3))/sqrt(4))
stopifnot(all.equal(groupit_natural_4, dfr_form_lm_natural$epsp[4]))


# Now test it when grouping by a subset of columns.

dfr_form_lm_x1 <- distfreereg(test_mean = m, ordering = list("x1"),
                              group = TRUE, B = 10)

dfr_form_lm_x1
stopifnot(identical(dfr_form_lm_x1[["res_order"]],
                    as.integer(c(1, 2, 2, 1, 1, 1, 1, 2, 2, 2))))

groupit_x1_1 <- ((sum(dfr_form_lm_x1$residuals$transformed[c(1,4,5,6,7)])/sqrt(5))/sqrt(2))
stopifnot(all.equal(groupit_x1_1, dfr_form_lm_x1$epsp[1]))

groupit_x1_2 <- groupit_x1_1 + ((sum(dfr_form_lm_x1$residuals$transformed[c(2,3,8,9,10)])/sqrt(5))/sqrt(2))
stopifnot(all.equal(groupit_x1_2, dfr_form_lm_x1$epsp[2]))
