From 7b5193cdf5eb7041710c52368764feeacbb36a7c Mon Sep 17 00:00:00 2001
From: Benjamin Auder <benjamin.auder@somewhere>
Date: Mon, 14 Jun 2021 16:57:25 +0200
Subject: [PATCH] Fix agghoo for tree / rpart

---
 R/R6_AgghooCV.R    |  3 ++-
 R/R6_Model.R       | 21 +++++++++++++--------
 test/compareToCV.R | 24 ++++++++++++++++--------
 3 files changed, 31 insertions(+), 17 deletions(-)

diff --git a/R/R6_AgghooCV.R b/R/R6_AgghooCV.R
index 2734d69..c555641 100644
--- a/R/R6_AgghooCV.R
+++ b/R/R6_AgghooCV.R
@@ -89,7 +89,8 @@ AgghooCV <- R6::R6Class("AgghooCV",
         return (invisible(NULL))
       }
       V <- length(private$pmodels)
-      if (length(private$pmodels[[1]]$model(X[1,])) >= 2)
+      oneLineX <- as.data.frame(t(as.matrix(X[1,])))
+      if (length(private$pmodels[[1]]$model(oneLineX)) >= 2)
         # Soft classification:
         return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V)
       n <- nrow(X)
diff --git a/R/R6_Model.R b/R/R6_Model.R
index 8fc2324..96f892d 100644
--- a/R/R6_Model.R
+++ b/R/R6_Model.R
@@ -73,9 +73,15 @@ Model <- R6::R6Class("Model",
         function(dataHO, targetHO, param) {
           require(rpart)
           method <- ifelse(task == "classification", "class", "anova")
+          if (is.null(colnames(dataHO)))
+            colnames(dataHO) <- paste0("V", 1:ncol(dataHO))
           df <- data.frame(cbind(dataHO, target=targetHO))
           model <- rpart::rpart(target ~ ., df, method=method, control=list(cp=param))
-          function(X) predict(model, X)
+          function(X) {
+            if (is.null(colnames(X)))
+              colnames(X) <- paste0("V", 1:ncol(X))
+            predict(model, as.data.frame(X))
+          }
         }
       }
       else if (family == "rf") {
@@ -115,18 +121,17 @@ Model <- R6::R6Class("Model",
         require(rpart)
         df <- data.frame(cbind(data, target=target))
         ctrl <- list(
+          cp = 0,
           minsplit = 2,
           minbucket = 1,
-          maxcompete = 0,
-          maxsurrogate = 0,
-          usesurrogate = 0,
-          xval = 0,
-          surrogatestyle = 0,
-          maxdepth = 30)
+          xval = 0)
         r <- rpart(target ~ ., df, method="class", control=ctrl)
         cps <- r$cptable[-1,1]
-        if (length(cps) <= 11)
+        if (length(cps) <= 11) {
+          if (length(cps == 0))
+            stop("No cross-validation possible: select another model")
           return (cps)
+        }
         step <- (length(cps) - 1) / 10
         cps[unique(round(seq(1, length(cps), step)))]
       }
diff --git a/test/compareToCV.R b/test/compareToCV.R
index a124dd2..276749b 100644
--- a/test/compareToCV.R
+++ b/test/compareToCV.R
@@ -101,23 +101,26 @@ compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE, ...) {
     task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification")
   n <- nrow(df)
   test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
-  a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task, ...)
+  data <- as.matrix(df[-test_indices,-t_idx])
+  target <- df[-test_indices,t_idx]
+  test <- as.matrix(df[test_indices,-t_idx])
+  a <- agghoo(data, target, task, ...)
   a$fit()
   if (verbose) {
     print("Parameters:")
     print(unlist(a$getParams()))
   }
-  pa <- a$predict(df[test_indices,-t_idx])
+  pa <- a$predict(test)
   err_a <- ifelse(task == "classification",
                   mean(pa != df[test_indices,t_idx]),
                   mean(abs(pa - df[test_indices,t_idx])))
   if (verbose)
     print(paste("error agghoo:", err_a))
   # Compare with standard cross-validation:
-  s <- standardCV(df[-test_indices,-t_idx], df[-test_indices,t_idx], task, ...)
+  s <- standardCV(data, target, task, ...)
   if (verbose)
     print(paste( "Parameter:", s$param ))
-  ps <- s$model(df[test_indices,-t_idx])
+  ps <- s$model(test)
   err_s <- ifelse(task == "classification",
                   mean(ps != df[test_indices,t_idx]),
                   mean(abs(ps - df[test_indices,t_idx])))
@@ -130,10 +133,15 @@ library(parallel)
 compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA, ...) {
   if (is.na(nc))
     nc <- detectCores()
-  errors <- mclapply(1:N,
-                     function(n) {
-                       compareToCV(df, t_idx, task, n, verbose=FALSE, ...) },
-                     mc.cores = nc)
+  compareOne <- function(n) {
+    print(n)
+    compareToCV(df, t_idx, task, n, verbose=FALSE, ...)
+  }
+  errors <- if (nc >= 2) {
+    mclapply(1:N, compareOne, mc.cores = nc)
+  } else {
+    lapply(1:N, compareOne)
+  }
   print("error agghoo vs. cross-validation:")
   Reduce('+', errors) / N
 }
-- 
2.44.0