From: Benjamin Auder Date: Fri, 11 Jun 2021 16:42:56 +0000 (+0200) Subject: Improve details in test/compareToCV.R X-Git-Url: https://git.auder.net/variants/Chakart/doc/html/current/common.css?a=commitdiff_plain;h=e86bf24de23aabec7a1176b8a1d09ee3fda216e3;p=agghoo.git Improve details in test/compareToCV.R --- diff --git a/test/README b/test/README index 722eed7..987e10f 100644 --- a/test/README +++ b/test/README @@ -5,3 +5,13 @@ source("compareToCV.R") # rseed: >= 0 for reproducibility. compareToCV(data, target_column_index, rseed = -1) + +# Average over N runs: + +> compareMulti(iris, 5, N=100) +[1] "error agghoo vs. cross-validation:" +[1] 0.04266667 0.04566667 + +> compareMulti(PimaIndiansDiabetes, 9, N=100) +[1] "error agghoo vs. cross-validation:" +[1] 0.2579221 0.2645455 diff --git a/test/compareToCV.R b/test/compareToCV.R index 255c70c..d3b9011 100644 --- a/test/compareToCV.R +++ b/test/compareToCV.R @@ -92,14 +92,14 @@ standardCV <- function(data, target, task = NULL, gmodel = NULL, params = NULL, best_model[[ sample(length(best_model), 1) ]] } -compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE) { +compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE, ...) { if (rseed >= 0) set.seed(rseed) if (is.null(task)) task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification") n <- nrow(df) test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) ) - a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task) + a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task, ...) a$fit() if (verbose) { print("Parameters:") @@ -112,7 +112,7 @@ compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE) { if (verbose) print(paste("error agghoo:", err_a)) # Compare with standard cross-validation: - s <- standardCV(df[-test_indices,-t_idx], df[-test_indices,t_idx], task) + s <- standardCV(df[-test_indices,-t_idx], df[-test_indices,t_idx], task, ...) if (verbose) print(paste( "Parameter:", s$param )) ps <- s$model(df[test_indices,-t_idx]) @@ -121,16 +121,16 @@ compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE) { mean(abs(ps - df[test_indices,t_idx]))) if (verbose) print(paste("error CV:", err_s)) - c(err_a, err_s) + invisible(c(err_a, err_s)) } library(parallel) -compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA) { +compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA, ...) { if (is.na(nc)) nc <- detectCores() errors <- mclapply(1:N, function(n) { - compareToCV(df, t_idx, task, n, verbose=FALSE) }, + compareToCV(df, t_idx, task, n, verbose=FALSE, ...) }, mc.cores = nc) print("error agghoo vs. cross-validation:") Reduce('+', errors) / N