From: Benjamin Auder <benjamin.auder@somewhere>
Date: Fri, 11 Jun 2021 16:42:56 +0000 (+0200)
Subject: Improve details in test/compareToCV.R
X-Git-Url: https://git.auder.net/assets/css/DESCRIPTION?a=commitdiff_plain;h=e86bf24de23aabec7a1176b8a1d09ee3fda216e3;p=agghoo.git

Improve details in test/compareToCV.R
---

diff --git a/test/README b/test/README
index 722eed7..987e10f 100644
--- a/test/README
+++ b/test/README
@@ -5,3 +5,13 @@ source("compareToCV.R")
 
 # rseed: >= 0 for reproducibility.
 compareToCV(data, target_column_index, rseed = -1)
+
+# Average over N runs:
+
+> compareMulti(iris, 5, N=100)
+[1] "error agghoo vs. cross-validation:"
+[1] 0.04266667 0.04566667
+
+> compareMulti(PimaIndiansDiabetes, 9, N=100)
+[1] "error agghoo vs. cross-validation:"
+[1] 0.2579221 0.2645455
diff --git a/test/compareToCV.R b/test/compareToCV.R
index 255c70c..d3b9011 100644
--- a/test/compareToCV.R
+++ b/test/compareToCV.R
@@ -92,14 +92,14 @@ standardCV <- function(data, target, task = NULL, gmodel = NULL, params = NULL,
   best_model[[ sample(length(best_model), 1) ]]
 }
 
-compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE) {
+compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE, ...) {
   if (rseed >= 0)
     set.seed(rseed)
   if (is.null(task))
     task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification")
   n <- nrow(df)
   test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
-  a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task)
+  a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task, ...)
   a$fit()
   if (verbose) {
     print("Parameters:")
@@ -112,7 +112,7 @@ compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE) {
   if (verbose)
     print(paste("error agghoo:", err_a))
   # Compare with standard cross-validation:
-  s <- standardCV(df[-test_indices,-t_idx], df[-test_indices,t_idx], task)
+  s <- standardCV(df[-test_indices,-t_idx], df[-test_indices,t_idx], task, ...)
   if (verbose)
     print(paste( "Parameter:", s$param ))
   ps <- s$model(df[test_indices,-t_idx])
@@ -121,16 +121,16 @@ compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE) {
                   mean(abs(ps - df[test_indices,t_idx])))
   if (verbose)
     print(paste("error CV:", err_s))
-  c(err_a, err_s)
+  invisible(c(err_a, err_s))
 }
 
 library(parallel)
-compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA) {
+compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA, ...) {
   if (is.na(nc))
     nc <- detectCores()
   errors <- mclapply(1:N,
                      function(n) {
-                       compareToCV(df, t_idx, task, n, verbose=FALSE) },
+                       compareToCV(df, t_idx, task, n, verbose=FALSE, ...) },
                      mc.cores = nc)
   print("error agghoo vs. cross-validation:")
   Reduce('+', errors) / N