X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=R%2FcompareTo.R;fp=R%2FcompareTo.R;h=536d2eec8526504e0c4f7fcd883e54badea58f49;hp=00e90a9ea7ddaac9ea3a24326ceb21d5eb0e01d0;hb=17ea2f13e0c32c107db20677750bd7a98bb7e0f8;hpb=afa676609daba103e43d6d4654560ca4c1c9b38b diff --git a/R/compareTo.R b/R/compareTo.R index 00e90a9..536d2ee 100644 --- a/R/compareTo.R +++ b/R/compareTo.R @@ -1,3 +1,7 @@ +#' standardCV_core +#' +#' Cross-validation method, added here as an example. +#' Parameters are described in ?agghoo and ?AgghooCV standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { n <- nrow(data) shuffle_inds <- NULL @@ -28,17 +32,24 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { } } } -#browser() best_model[[ sample(length(best_model), 1) ]] } +#' standardCV_run +#' +#' Run and eval the standard cross-validation procedure. +#' Parameters are rather explicit except "floss", which corresponds to the +#' "final" loss function, applied to compute the error on testing dataset. +#' +#' @export standardCV_run <- function( - dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ... + dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... ) { args <- list(...) task <- checkTask(args$task, targetTrain) modPar <- checkModPar(args$gmodel, args$params) loss <- checkLoss(args$loss, task) + CV <- checkCV(args$CV) s <- standardCV_core( dataTrain, targetTrain, task, modPar$gmodel, modPar$params, loss, CV) if (verbose) @@ -50,10 +61,21 @@ standardCV_run <- function( invisible(err) } +#' agghoo_run +#' +#' Run and eval the agghoo procedure. +#' Parameters are rather explicit except "floss", which corresponds to the +#' "final" loss function, applied to compute the error on testing dataset. +#' +#' @export agghoo_run <- function( - dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ... + dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... ) { - a <- agghoo(dataTrain, targetTrain, ...) + args <- list(...) + CV <- checkCV(args$CV) + # Must remove CV arg, or agghoo will complain "error: unused arg" + args$CV <- NULL + a <- do.call(agghoo, c(list(data=dataTrain, target=targetTrain), args)) a$fit(CV) if (verbose) { print("Parameters:") @@ -66,7 +88,20 @@ agghoo_run <- function( invisible(err) } -# ... arguments passed to method_s (agghoo, standard CV or else) +#' compareTo +#' +#' Compare a list of learning methods (or run only one), on data/target. +#' +#' @param data Data matrix or data.frame +#' @param target Target vector (generally) +#' @param method_s Either a single function, or a list +#' (examples: agghoo_run, standardCV_run) +#' @param rseed Seed of the random generator (-1 means "random seed") +#' @param floss Loss function to compute the error on testing dataset. +#' @param verbose TRUE to request methods to be verbose. +#' @param ... arguments passed to method_s function(s) +#' +#' @export compareTo <- function( data, target, method_s, rseed=-1, floss=NULL, verbose=TRUE, ... ) { @@ -75,7 +110,6 @@ compareTo <- function( n <- nrow(data) test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) ) d <- splitTrainTest(data, target, test_indices) - CV <- checkCV(list(...)$CV) # Set error function to be used on model outputs (not in core method) task <- checkTask(list(...)$task, target) @@ -87,8 +121,7 @@ compareTo <- function( # Run (and compare) all methods: runOne <- function(o) { - o(d$dataTrain, d$dataTest, d$targetTrain, d$targetTest, - CV, floss, verbose, ...) + o(d$dataTrain, d$dataTest, d$targetTrain, d$targetTest, floss, verbose, ...) } errors <- c() if (is.list(method_s)) @@ -98,10 +131,19 @@ compareTo <- function( invisible(errors) } -# Run compareTo N times in parallel -# ... : additional args to be passed to method_s +#' compareMulti +#' +#' Run compareTo N times in parallel. +#' +#' @inheritParams compareTo +#' @param N Number of calls to method(s) +#' @param nc Number of cores. Set to parallel::detectCores() if undefined. +#' Set it to any value <=1 to say "no parallelism". +#' @param verbose TRUE to print task numbers and "Errors:" in the end. +#' +#' @export compareMulti <- function( - data, target, method_s, N=100, nc=NA, floss=NULL, ... + data, target, method_s, N=100, nc=NA, floss=NULL, verbose=TRUE, ... ) { require(parallel) if (is.na(nc)) @@ -109,7 +151,8 @@ compareMulti <- function( # "One" comparison for each method in method_s (list) compareOne <- function(n) { - print(n) + if (verbose) + print(n) compareTo(data, target, method_s, n, floss, verbose=FALSE, ...) } @@ -118,6 +161,29 @@ compareMulti <- function( } else { lapply(1:N, compareOne) } - print("Errors:") + if (verbose) + print("Errors:") Reduce('+', errors) / N } + +#' compareRange +#' +#' Run compareMulti on several values of the parameter V. +#' +#' @inheritParams compareMulti +#' @param V_range Values of V to be tested. +#' +#' @export +compareRange <- function( + data, target, method_s, N=100, nc=NA, floss=NULL, V_range=c(10,15,20,), ... +) { + args <- list(...) + # Avoid warnings if V is left unspecified: + CV <- suppressWarnings( checkCV(args$CV) ) + errors <- lapply(V_range, function(V) { + args$CV$V <- V + do.call(compareMulti, c(list(data=data, target=target, method_s=method_s, + N=N, nc=nc, floss=floss, verbose=F), args)) + }) + print(paste(V_range, errors)) +}