From 7733758e823f6e783f965f5c7f7f80a1d4e5df3a Mon Sep 17 00:00:00 2001 From: Benjamin Auder Date: Thu, 24 Jun 2021 12:27:33 +0200 Subject: [PATCH] Add and debug CV-voting --- DESCRIPTION | 4 +- NAMESPACE | 1 + R/A_NAMESPACE.R | 3 ++ R/compareTo.R | 100 ++++++++++++++++++++++++++++---------------- man/CVvoting_run.Rd | 13 ++++++ 5 files changed, 82 insertions(+), 39 deletions(-) create mode 100644 man/CVvoting_run.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 5d1842a..5e85d59 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -27,10 +27,10 @@ URL: https://git.auder.net/?p=agghoo.git License: MIT + file LICENSE RoxygenNote: 7.1.1 Collate: + 'compareTo.R' 'agghoo.R' 'R6_AgghooCV.R' 'R6_Model.R' - 'A_NAMESPACE.R' 'checks.R' - 'compareTo.R' 'utils.R' + 'A_NAMESPACE.R' diff --git a/NAMESPACE b/NAMESPACE index 1a5f2a0..74d8bd5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,7 @@ # Generated by roxygen2: do not edit by hand export(AgghooCV) +export(CVvoting_run) export(Model) export(agghoo) export(agghoo_run) diff --git a/R/A_NAMESPACE.R b/R/A_NAMESPACE.R index 6747346..0466833 100644 --- a/R/A_NAMESPACE.R +++ b/R/A_NAMESPACE.R @@ -1,4 +1,7 @@ +#' @include utils.R +#' @include checks.R #' @include R6_Model.R #' @include R6_AgghooCV.R #' @include agghoo.R +#' @include compareTo.R NULL diff --git a/R/compareTo.R b/R/compareTo.R index e6bf2b2..28cb711 100644 --- a/R/compareTo.R +++ b/R/compareTo.R @@ -1,3 +1,39 @@ +#' standardCV_core +#' +#' Cross-validation method, added here as an example. +#' Parameters are described in ?agghoo and ?AgghooCV +standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { + n <- nrow(data) + shuffle_inds <- NULL + if (CV$type == "vfold" && CV$shuffle) + shuffle_inds <- sample(n, n) + list_testinds <- list() + for (v in seq_len(CV$V)) + list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds) + gmodel <- agghoo::Model$new(data, target, task, gmodel, params) + best_error <- Inf + best_p <- NULL + for (p in seq_len(gmodel$nmodels)) { + error <- Reduce('+', lapply(seq_len(CV$V), function(v) { + testIdx <- list_testinds[[v]] + d <- splitTrainTest(data, target, testIdx) + model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p) + prediction <- model_pred(d$dataTest) + loss(prediction, d$targetTest) + }) ) + if (error <= best_error) { + if (error == best_error) + best_p[[length(best_p)+1]] <- p + else { + best_p <- list(p) + best_error <- error + } + } + } + chosenP <- best_p[[ sample(length(best_p), 1) ]] + list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) +} + #' CVvoting_core #' #' "voting" cross-validation method, added here as an example. @@ -8,8 +44,8 @@ CVvoting_core <- function(data, target, task, gmodel, params, loss, CV) { shuffle_inds <- NULL if (CV$type == "vfold" && CV$shuffle) shuffle_inds <- sample(n, n) - bestP <- rep(0, gmodel$nmodels) gmodel <- agghoo::Model$new(data, target, task, gmodel, params) + bestP <- rep(0, gmodel$nmodels) for (v in seq_len(CV$V)) { test_indices <- get_testIndices(n, CV, v, shuffle_inds) d <- splitTrainTest(data, target, test_indices) @@ -37,42 +73,6 @@ CVvoting_core <- function(data, target, task, gmodel, params, loss, CV) { list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) } -#' standardCV_core -#' -#' Cross-validation method, added here as an example. -#' Parameters are described in ?agghoo and ?AgghooCV -standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { - n <- nrow(data) - shuffle_inds <- NULL - if (CV$type == "vfold" && CV$shuffle) - shuffle_inds <- sample(n, n) - list_testinds <- list() - for (v in seq_len(CV$V)) - list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds) - gmodel <- agghoo::Model$new(data, target, task, gmodel, params) - best_error <- Inf - best_p <- NULL - for (p in seq_len(gmodel$nmodels)) { - error <- Reduce('+', lapply(seq_len(CV$V), function(v) { - testIdx <- list_testinds[[v]] - d <- splitTrainTest(data, target, testIdx) - model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p) - prediction <- model_pred(d$dataTest) - loss(prediction, d$targetTest) - }) ) - if (error <= best_error) { - if (error == best_error) - best_p[[length(best_p)+1]] <- p - else { - best_p <- list(p) - best_error <- error - } - } - } - chosenP <- best_p[[ sample(length(best_p), 1) ]] - list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) -} - #' standardCV_run #' #' Run and eval the standard cross-validation procedure. @@ -99,6 +99,32 @@ standardCV_run <- function( invisible(err) } +#' CVvoting_run +#' +#' Run and eval the voting cross-validation procedure. +#' Parameters are rather explicit except "floss", which corresponds to the +#' "final" loss function, applied to compute the error on testing dataset. +#' +#' @export +CVvoting_run <- function( + dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ... +) { + args <- list(...) + task <- checkTask(args$task, targetTrain) + modPar <- checkModPar(args$gmodel, args$params) + loss <- checkLoss(args$loss, task) + CV <- checkCV(args$CV) + s <- CVvoting_core( + dataTrain, targetTrain, task, modPar$gmodel, modPar$params, loss, CV) + if (verbose) + print(paste( "Parameter:", s$param )) + p <- s$model(dataTest) + err <- floss(p, targetTest) + if (verbose) + print(paste("error CV:", err)) + invisible(err) +} + #' agghoo_run #' #' Run and eval the agghoo procedure. diff --git a/man/CVvoting_run.Rd b/man/CVvoting_run.Rd new file mode 100644 index 0000000..9aad2fe --- /dev/null +++ b/man/CVvoting_run.Rd @@ -0,0 +1,13 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compareTo.R +\name{CVvoting_run} +\alias{CVvoting_run} +\title{CVvoting_run} +\usage{ +CVvoting_run(dataTrain, dataTest, targetTrain, targetTest, floss, verbose, ...) +} +\description{ +Run and eval the voting cross-validation procedure. +Parameters are rather explicit except "floss", which corresponds to the +"final" loss function, applied to compute the error on testing dataset. +} -- 2.44.0