X-Git-Url: https://git.auder.net/?a=blobdiff_plain;f=R%2FcompareTo.R;h=28cb711d3e07bb287068e375c8ea86888fc26dec;hb=HEAD;hp=eb372dc0e43c57da9336c778d42295239b6d6bd6;hpb=a78bd1c09d34868df1cf11664bfc5cef4a9384d6;p=agghoo.git diff --git a/R/compareTo.R b/R/compareTo.R index eb372dc..0eb517c 100644 --- a/R/compareTo.R +++ b/R/compareTo.R @@ -12,7 +12,7 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds) gmodel <- agghoo::Model$new(data, target, task, gmodel, params) best_error <- Inf - best_model <- NULL + best_p <- NULL for (p in seq_len(gmodel$nmodels)) { error <- Reduce('+', lapply(seq_len(CV$V), function(v) { testIdx <- list_testinds[[v]] @@ -22,17 +22,55 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { loss(prediction, d$targetTest) }) ) if (error <= best_error) { - newModel <- list(model=gmodel$get(data, target, p), - param=gmodel$getParam(p)) if (error == best_error) - best_model[[length(best_model)+1]] <- newModel + best_p[[length(best_p)+1]] <- p else { - best_model <- list(newModel) + best_p <- list(p) best_error <- error } } } - best_model[[ sample(length(best_model), 1) ]] + chosenP <- best_p[[ sample(length(best_p), 1) ]] + list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) +} + +#' CVvoting_core +#' +#' "voting" cross-validation method, added here as an example. +#' Parameters are described in ?agghoo and ?AgghooCV +CVvoting_core <- function(data, target, task, gmodel, params, loss, CV) { + CV <- checkCV(CV) + n <- nrow(data) + shuffle_inds <- NULL + if (CV$type == "vfold" && CV$shuffle) + shuffle_inds <- sample(n, n) + gmodel <- agghoo::Model$new(data, target, task, gmodel, params) + bestP <- rep(0, gmodel$nmodels) + for (v in seq_len(CV$V)) { + test_indices <- get_testIndices(n, CV, v, shuffle_inds) + d <- splitTrainTest(data, target, test_indices) + best_p <- NULL + best_error <- Inf + for (p in seq_len(gmodel$nmodels)) { + model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p) + prediction <- model_pred(d$dataTest) + error <- loss(prediction, d$targetTest) + if (error <= best_error) { + if (error == best_error) + best_p[[length(best_p)+1]] <- p + else { + best_p <- list(p) + best_error <- error + } + } + } + for (p in best_p) + bestP[p] <- bestP[p] + 1 + } + # Choose a param at random in case of ex-aequos: + maxP <- max(bestP) + chosenP <- sample(which(bestP == maxP), 1) + list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) } #' standardCV_run @@ -40,8 +78,6 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { #' Run and eval the standard cross-validation procedure. #' Parameters are rather explicit except "floss", which corresponds to the #' "final" loss function, applied to compute the error on testing dataset. -#' -#' @export standardCV_run <- function( dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... ) { @@ -61,13 +97,35 @@ standardCV_run <- function( invisible(err) } +#' CVvoting_run +#' +#' Run and eval the voting cross-validation procedure. +#' Parameters are rather explicit except "floss", which corresponds to the +#' "final" loss function, applied to compute the error on testing dataset. +CVvoting_run <- function( + dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... +) { + args <- list(...) + task <- checkTask(args$task, targetTrain) + modPar <- checkModPar(args$gmodel, args$params) + loss <- checkLoss(args$loss, task) + CV <- checkCV(args$CV) + s <- CVvoting_core( + dataTrain, targetTrain, task, modPar$gmodel, modPar$params, loss, CV) + if (verbose) + print(paste( "Parameter:", s$param )) + p <- s$model(dataTest) + err <- floss(p, targetTest) + if (verbose) + print(paste("error CV:", err)) + invisible(err) +} + #' agghoo_run #' #' Run and eval the agghoo procedure. #' Parameters are rather explicit except "floss", which corresponds to the #' "final" loss function, applied to compute the error on testing dataset. -#' -#' @export agghoo_run <- function( dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... ) { @@ -145,7 +203,6 @@ compareTo <- function( compareMulti <- function( data, target, method_s, N=100, nc=NA, floss=NULL, verbose=TRUE, ... ) { - require(parallel) if (is.na(nc)) nc <- parallel::detectCores()