}
}
}
+ else
+ mean(abs(y1 - y2))
}
}
best_model[[ sample(length(best_model), 1) ]]
}
-compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE) {
+compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE, ...) {
if (rseed >= 0)
set.seed(rseed)
if (is.null(task))
task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification")
n <- nrow(df)
test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
- a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task)
+ data <- as.matrix(df[-test_indices,-t_idx])
+ target <- df[-test_indices,t_idx]
+ test <- as.matrix(df[test_indices,-t_idx])
+ a <- agghoo(data, target, task, ...)
a$fit()
if (verbose) {
print("Parameters:")
print(unlist(a$getParams()))
}
- pa <- a$predict(df[test_indices,-t_idx])
+ pa <- a$predict(test)
err_a <- ifelse(task == "classification",
mean(pa != df[test_indices,t_idx]),
mean(abs(pa - df[test_indices,t_idx])))
if (verbose)
print(paste("error agghoo:", err_a))
# Compare with standard cross-validation:
- s <- standardCV(df[-test_indices,-t_idx], df[-test_indices,t_idx], task)
+ s <- standardCV(data, target, task, ...)
if (verbose)
print(paste( "Parameter:", s$param ))
- ps <- s$model(df[test_indices,-t_idx])
+ ps <- s$model(test)
err_s <- ifelse(task == "classification",
mean(ps != df[test_indices,t_idx]),
mean(abs(ps - df[test_indices,t_idx])))
if (verbose)
print(paste("error CV:", err_s))
- c(err_a, err_s)
+ invisible(c(err_a, err_s))
}
library(parallel)
-compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA) {
+compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA, ...) {
if (is.na(nc))
nc <- detectCores()
- errors <- mclapply(1:N,
- function(n) {
- compareToCV(df, t_idx, task, n, verbose=FALSE) },
- mc.cores = nc)
+ compareOne <- function(n) {
+ print(n)
+ compareToCV(df, t_idx, task, n, verbose=FALSE, ...)
+ }
+ errors <- if (nc >= 2) {
+ mclapply(1:N, compareOne, mc.cores = nc)
+ } else {
+ lapply(1:N, compareOne)
+ }
print("error agghoo vs. cross-validation:")
Reduce('+', errors) / N
}