X-Git-Url: https://git.auder.net/?a=blobdiff_plain;f=R%2Fagghoo.R;h=cac2cf11dd1897c356d610d894affe959e2358ab;hb=504afaadc783916dc126fb87ab9e067f302eb2c5;hp=92d061f0a13790fc8ec5460b6619f4f60b63a51b;hpb=cca5f1c67bd622fb7bc1279dfe4c3336d1446efd;p=agghoo.git diff --git a/R/agghoo.R b/R/agghoo.R index 92d061f..cac2cf1 100644 --- a/R/agghoo.R +++ b/R/agghoo.R @@ -1,9 +1,12 @@ #' agghoo #' -#' Run the agghoo procedure. (...) +#' Run the agghoo procedure (or standard cross-validation). +#' Arguments specify the list of models, their parameters and the +#' cross-validation settings, among others. #' #' @param data Data frame or matrix containing the data in lines. -#' @param target The target values to predict. Generally a vector. +#' @param target The target values to predict. Generally a vector, +#' but possibly a matrix in the case of "soft classification". #' @param task "classification" or "regression". Default: #' regression if target is numerical, classification otherwise. #' @param gmodel A "generic model", which is a function returning a predict @@ -14,11 +17,12 @@ #' @param params A list of parameters. Often, one list cell is just a #' numerical value, but in general it could be of any type. #' Default: see R6::Model. -#' @param quality A function assessing the quality of a prediction. +#' @param loss A function assessing the error of a prediction. #' Arguments are y1 and y2 (comparing a prediction to known values). -#' Default: see R6::AgghooCV. +#' loss(y1, y2) --> real number (error). Default: see R6::AgghooCV. #' -#' @return An R6::AgghooCV object. +#' @return +#' An R6::AgghooCV object o. Then, call o$fit() and finally o$predict(newData) #' #' @examples #' # Regression: @@ -27,38 +31,44 @@ #' pr <- a_reg$predict(iris[,-c(2,5)] + rnorm(450, sd=0.1)) #' # Classification #' a_cla <- agghoo(iris[,-5], iris[,5]) -#' a_cla$fit(mode="standard") +#' a_cla$fit() #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1)) #' +#' @references +#' Guillaume Maillard, Sylvain Arlot, Matthieu Lerasle. "Aggregated hold-out". +#' Journal of Machine Learning Research 22(20):1--55, 2021. +#' #' @export -agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = NA) { +agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, loss = NULL) { # Args check: if (!is.data.frame(data) && !is.matrix(data)) stop("data: data.frame or matrix") - if (nrow(data) <= 1 || any(dim(data) == 0)) - stop("data: non-empty, >= 2 rows") - if (!is.numeric(target) && !is.factor(target) && !is.character(target)) + if (is.data.frame(target) || is.matrix(target)) { + if (nrow(target) != nrow(data) || ncol(target) == 1) + stop("target probability matrix does not match data size") + } + else if (!is.numeric(target) && !is.factor(target) && !is.character(target)) stop("target: numeric, factor or character vector") - if (!is.na(task)) + if (!is.null(task)) task = match.arg(task, c("classification", "regression")) if (is.character(gmodel)) - gmodel <- match.arg("knn", "ppr", "rf") - else if (!is.na(gmodel) && !is.function(gmodel)) + gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) + else if (!is.null(gmodel) && !is.function(gmodel)) # No further checks here: fingers crossed :) stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") if (is.numeric(params) || is.character(params)) params <- as.list(params) - if (!is.na(params) && !is.list(params)) + if (!is.list(params) && !is.null(params)) stop("params: numerical, character, or list (passed to model)") - if (!is.na(gmodel) && !is.character(gmodel) && is.na(params)) + if (is.function(gmodel) && !is.list(params)) stop("params must be provided when using a custom model") - if (is.na(gmodel) && !is.na(params)) - stop("model must be provided when using custom params") - if (!is.na(quality) && !is.function(quality)) + if (is.list(params) && is.null(gmodel)) + stop("model (or family) must be provided when using custom params") + if (!is.null(loss) && !is.function(loss)) # No more checks here as well... TODO:? - stop("quality: function(y1, y2) --> Real") + stop("loss: function(y1, y2) --> Real") - if (is.na(task)) { + if (is.null(task)) { if (is.numeric(target)) task = "regression" else @@ -67,34 +77,5 @@ agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = # Build Model object (= list of parameterized models) model <- Model$new(data, target, task, gmodel, params) # Return AgghooCV object, to run and predict - AgghooCV$new(data, target, task, model, quality) -} - -#' compareToStandard -#' -#' Temporary function to compare agghoo to CV -#' (TODO: extended, in another file, more tests - when faster code). -#' -#' @export -compareToStandard <- function(df, t_idx, task = NA, rseed = -1) { - if (rseed >= 0) - set.seed(rseed) - if (is.na(task)) - task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification") - n <- nrow(df) - test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) ) - a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task) - a$fit(mode="agghoo") #default mode - pa <- a$predict(df[test_indices,-t_idx]) - print(paste("error agghoo", - ifelse(task == "classification", - mean(p != df[test_indices,t_idx]), - mean(abs(pa - df[test_indices,t_idx]))))) - # Compare with standard cross-validation: - a$fit(mode="standard") - ps <- a$predict(df[test_indices,-t_idx]) - print(paste("error CV", - ifelse(task == "classification", - mean(ps != df[test_indices,t_idx]), - mean(abs(ps - df[test_indices,t_idx]))))) + AgghooCV$new(data, target, task, model, loss) }