X-Git-Url: https://git.auder.net/?a=blobdiff_plain;f=R%2Fagghoo.R;h=48ac741f2ec8caf8648ac16ecb8a84fd3d5b2c5a;hb=HEAD;hp=cac2cf11dd1897c356d610d894affe959e2358ab;hpb=504afaadc783916dc126fb87ab9e067f302eb2c5;p=agghoo.git diff --git a/R/agghoo.R b/R/agghoo.R index cac2cf1..48ac741 100644 --- a/R/agghoo.R +++ b/R/agghoo.R @@ -1,6 +1,6 @@ #' agghoo #' -#' Run the agghoo procedure (or standard cross-validation). +#' Run the (core) agghoo procedure. #' Arguments specify the list of models, their parameters and the #' cross-validation settings, among others. #' @@ -34,48 +34,25 @@ #' a_cla$fit() #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1)) #' +#' @seealso Function \code{\link{compareTo}} +#' #' @references #' Guillaume Maillard, Sylvain Arlot, Matthieu Lerasle. "Aggregated hold-out". #' Journal of Machine Learning Research 22(20):1--55, 2021. #' #' @export -agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, loss = NULL) { +agghoo <- function( + data, target, task = NULL, gmodel = NULL, params = NULL, loss = NULL +) { # Args check: - if (!is.data.frame(data) && !is.matrix(data)) - stop("data: data.frame or matrix") - if (is.data.frame(target) || is.matrix(target)) { - if (nrow(target) != nrow(data) || ncol(target) == 1) - stop("target probability matrix does not match data size") - } - else if (!is.numeric(target) && !is.factor(target) && !is.character(target)) - stop("target: numeric, factor or character vector") - if (!is.null(task)) - task = match.arg(task, c("classification", "regression")) - if (is.character(gmodel)) - gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) - else if (!is.null(gmodel) && !is.function(gmodel)) - # No further checks here: fingers crossed :) - stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") - if (is.numeric(params) || is.character(params)) - params <- as.list(params) - if (!is.list(params) && !is.null(params)) - stop("params: numerical, character, or list (passed to model)") - if (is.function(gmodel) && !is.list(params)) - stop("params must be provided when using a custom model") - if (is.list(params) && is.null(gmodel)) - stop("model (or family) must be provided when using custom params") - if (!is.null(loss) && !is.function(loss)) - # No more checks here as well... TODO:? - stop("loss: function(y1, y2) --> Real") + checkDaTa(data, target) + task <- checkTask(task, target) + modPar <- checkModPar(gmodel, params) + loss <- checkLoss(loss, task) - if (is.null(task)) { - if (is.numeric(target)) - task = "regression" - else - task = "classification" - } # Build Model object (= list of parameterized models) - model <- Model$new(data, target, task, gmodel, params) + model <- Model$new(data, target, task, modPar$gmodel, modPar$params) + # Return AgghooCV object, to run and predict AgghooCV$new(data, target, task, model, loss) }