X-Git-Url: https://git.auder.net/?a=blobdiff_plain;f=R%2FcompareTo.R;h=28cb711d3e07bb287068e375c8ea86888fc26dec;hb=7733758e823f6e783f965f5c7f7f80a1d4e5df3a;hp=00e90a9ea7ddaac9ea3a24326ceb21d5eb0e01d0;hpb=afa676609daba103e43d6d4654560ca4c1c9b38b;p=agghoo.git diff --git a/R/compareTo.R b/R/compareTo.R index 00e90a9..28cb711 100644 --- a/R/compareTo.R +++ b/R/compareTo.R @@ -1,3 +1,7 @@ +#' standardCV_core +#' +#' Cross-validation method, added here as an example. +#' Parameters are described in ?agghoo and ?AgghooCV standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { n <- nrow(data) shuffle_inds <- NULL @@ -8,7 +12,7 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds) gmodel <- agghoo::Model$new(data, target, task, gmodel, params) best_error <- Inf - best_model <- NULL + best_p <- NULL for (p in seq_len(gmodel$nmodels)) { error <- Reduce('+', lapply(seq_len(CV$V), function(v) { testIdx <- list_testinds[[v]] @@ -18,27 +22,72 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { loss(prediction, d$targetTest) }) ) if (error <= best_error) { - newModel <- list(model=gmodel$get(data, target, p), - param=gmodel$getParam(p)) if (error == best_error) - best_model[[length(best_model)+1]] <- newModel + best_p[[length(best_p)+1]] <- p else { - best_model <- list(newModel) + best_p <- list(p) best_error <- error } } } -#browser() - best_model[[ sample(length(best_model), 1) ]] + chosenP <- best_p[[ sample(length(best_p), 1) ]] + list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) } +#' CVvoting_core +#' +#' "voting" cross-validation method, added here as an example. +#' Parameters are described in ?agghoo and ?AgghooCV +CVvoting_core <- function(data, target, task, gmodel, params, loss, CV) { + CV <- checkCV(CV) + n <- nrow(data) + shuffle_inds <- NULL + if (CV$type == "vfold" && CV$shuffle) + shuffle_inds <- sample(n, n) + gmodel <- agghoo::Model$new(data, target, task, gmodel, params) + bestP <- rep(0, gmodel$nmodels) + for (v in seq_len(CV$V)) { + test_indices <- get_testIndices(n, CV, v, shuffle_inds) + d <- splitTrainTest(data, target, test_indices) + best_p <- NULL + best_error <- Inf + for (p in seq_len(gmodel$nmodels)) { + model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p) + prediction <- model_pred(d$dataTest) + error <- loss(prediction, d$targetTest) + if (error <= best_error) { + if (error == best_error) + best_p[[length(best_p)+1]] <- p + else { + best_p <- list(p) + best_error <- error + } + } + } + for (p in best_p) + bestP[p] <- bestP[p] + 1 + } + # Choose a param at random in case of ex-aequos: + maxP <- max(bestP) + chosenP <- sample(which(bestP == maxP), 1) + list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) +} + +#' standardCV_run +#' +#' Run and eval the standard cross-validation procedure. +#' Parameters are rather explicit except "floss", which corresponds to the +#' "final" loss function, applied to compute the error on testing dataset. +#' +#' @export standardCV_run <- function( - dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ... + dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... ) { args <- list(...) task <- checkTask(args$task, targetTrain) modPar <- checkModPar(args$gmodel, args$params) loss <- checkLoss(args$loss, task) + CV <- checkCV(args$CV) s <- standardCV_core( dataTrain, targetTrain, task, modPar$gmodel, modPar$params, loss, CV) if (verbose) @@ -50,10 +99,47 @@ standardCV_run <- function( invisible(err) } +#' CVvoting_run +#' +#' Run and eval the voting cross-validation procedure. +#' Parameters are rather explicit except "floss", which corresponds to the +#' "final" loss function, applied to compute the error on testing dataset. +#' +#' @export +CVvoting_run <- function( + dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... +) { + args <- list(...) + task <- checkTask(args$task, targetTrain) + modPar <- checkModPar(args$gmodel, args$params) + loss <- checkLoss(args$loss, task) + CV <- checkCV(args$CV) + s <- CVvoting_core( + dataTrain, targetTrain, task, modPar$gmodel, modPar$params, loss, CV) + if (verbose) + print(paste( "Parameter:", s$param )) + p <- s$model(dataTest) + err <- floss(p, targetTest) + if (verbose) + print(paste("error CV:", err)) + invisible(err) +} + +#' agghoo_run +#' +#' Run and eval the agghoo procedure. +#' Parameters are rather explicit except "floss", which corresponds to the +#' "final" loss function, applied to compute the error on testing dataset. +#' +#' @export agghoo_run <- function( - dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ... + dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... ) { - a <- agghoo(dataTrain, targetTrain, ...) + args <- list(...) + CV <- checkCV(args$CV) + # Must remove CV arg, or agghoo will complain "error: unused arg" + args$CV <- NULL + a <- do.call(agghoo, c(list(data=dataTrain, target=targetTrain), args)) a$fit(CV) if (verbose) { print("Parameters:") @@ -66,7 +152,20 @@ agghoo_run <- function( invisible(err) } -# ... arguments passed to method_s (agghoo, standard CV or else) +#' compareTo +#' +#' Compare a list of learning methods (or run only one), on data/target. +#' +#' @param data Data matrix or data.frame +#' @param target Target vector (generally) +#' @param method_s Either a single function, or a list +#' (examples: agghoo_run, standardCV_run) +#' @param rseed Seed of the random generator (-1 means "random seed") +#' @param floss Loss function to compute the error on testing dataset. +#' @param verbose TRUE to request methods to be verbose. +#' @param ... arguments passed to method_s function(s) +#' +#' @export compareTo <- function( data, target, method_s, rseed=-1, floss=NULL, verbose=TRUE, ... ) { @@ -75,7 +174,6 @@ compareTo <- function( n <- nrow(data) test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) ) d <- splitTrainTest(data, target, test_indices) - CV <- checkCV(list(...)$CV) # Set error function to be used on model outputs (not in core method) task <- checkTask(list(...)$task, target) @@ -87,8 +185,7 @@ compareTo <- function( # Run (and compare) all methods: runOne <- function(o) { - o(d$dataTrain, d$dataTest, d$targetTrain, d$targetTest, - CV, floss, verbose, ...) + o(d$dataTrain, d$dataTest, d$targetTrain, d$targetTest, floss, verbose, ...) } errors <- c() if (is.list(method_s)) @@ -98,10 +195,19 @@ compareTo <- function( invisible(errors) } -# Run compareTo N times in parallel -# ... : additional args to be passed to method_s +#' compareMulti +#' +#' Run compareTo N times in parallel. +#' +#' @inheritParams compareTo +#' @param N Number of calls to method(s) +#' @param nc Number of cores. Set to parallel::detectCores() if undefined. +#' Set it to any value <=1 to say "no parallelism". +#' @param verbose TRUE to print task numbers and "Errors:" in the end. +#' +#' @export compareMulti <- function( - data, target, method_s, N=100, nc=NA, floss=NULL, ... + data, target, method_s, N=100, nc=NA, floss=NULL, verbose=TRUE, ... ) { require(parallel) if (is.na(nc)) @@ -109,7 +215,8 @@ compareMulti <- function( # "One" comparison for each method in method_s (list) compareOne <- function(n) { - print(n) + if (verbose) + print(n) compareTo(data, target, method_s, n, floss, verbose=FALSE, ...) } @@ -118,6 +225,29 @@ compareMulti <- function( } else { lapply(1:N, compareOne) } - print("Errors:") + if (verbose) + print("Errors:") Reduce('+', errors) / N } + +#' compareRange +#' +#' Run compareMulti on several values of the parameter V. +#' +#' @inheritParams compareMulti +#' @param V_range Values of V to be tested. +#' +#' @export +compareRange <- function( + data, target, method_s, N=100, nc=NA, floss=NULL, V_range=c(10,15,20), ... +) { + args <- list(...) + # Avoid warnings if V is left unspecified: + CV <- suppressWarnings( checkCV(args$CV) ) + errors <- lapply(V_range, function(V) { + args$CV$V <- V + do.call(compareMulti, c(list(data=data, target=target, method_s=method_s, + N=N, nc=nc, floss=floss, verbose=F), args)) + }) + print(paste(V_range, errors)) +}