X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=R%2Fagghoo.R;h=cac2cf11dd1897c356d610d894affe959e2358ab;hp=f3bc74093823dc66d96eaf05299ed72c59ba8876;hb=504afaadc783916dc126fb87ab9e067f302eb2c5;hpb=15f48abea9c631d52317ff157c8af0dd4c7a67d3 diff --git a/R/agghoo.R b/R/agghoo.R index f3bc740..cac2cf1 100644 --- a/R/agghoo.R +++ b/R/agghoo.R @@ -1,9 +1,12 @@ #' agghoo #' -#' Run the agghoo procedure. (...) +#' Run the agghoo procedure (or standard cross-validation). +#' Arguments specify the list of models, their parameters and the +#' cross-validation settings, among others. #' #' @param data Data frame or matrix containing the data in lines. -#' @param target The target values to predict. Generally a vector. +#' @param target The target values to predict. Generally a vector, +#' but possibly a matrix in the case of "soft classification". #' @param task "classification" or "regression". Default: #' regression if target is numerical, classification otherwise. #' @param gmodel A "generic model", which is a function returning a predict @@ -14,11 +17,12 @@ #' @param params A list of parameters. Often, one list cell is just a #' numerical value, but in general it could be of any type. #' Default: see R6::Model. -#' @param quality A function assessing the quality of a prediction. +#' @param loss A function assessing the error of a prediction. #' Arguments are y1 and y2 (comparing a prediction to known values). -#' Default: see R6::AgghooCV. +#' loss(y1, y2) --> real number (error). Default: see R6::AgghooCV. #' -#' @return An R6::AgghooCV object. +#' @return +#' An R6::AgghooCV object o. Then, call o$fit() and finally o$predict(newData) #' #' @examples #' # Regression: @@ -27,15 +31,23 @@ #' pr <- a_reg$predict(iris[,-c(2,5)] + rnorm(450, sd=0.1)) #' # Classification #' a_cla <- agghoo(iris[,-5], iris[,5]) -#' a_cla$fit(mode="standard") +#' a_cla$fit() #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1)) #' +#' @references +#' Guillaume Maillard, Sylvain Arlot, Matthieu Lerasle. "Aggregated hold-out". +#' Journal of Machine Learning Research 22(20):1--55, 2021. +#' #' @export -agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL) { +agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, loss = NULL) { # Args check: if (!is.data.frame(data) && !is.matrix(data)) stop("data: data.frame or matrix") - if (!is.numeric(target) && !is.factor(target) && !is.character(target)) + if (is.data.frame(target) || is.matrix(target)) { + if (nrow(target) != nrow(data) || ncol(target) == 1) + stop("target probability matrix does not match data size") + } + else if (!is.numeric(target) && !is.factor(target) && !is.character(target)) stop("target: numeric, factor or character vector") if (!is.null(task)) task = match.arg(task, c("classification", "regression")) @@ -52,9 +64,9 @@ agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, qual stop("params must be provided when using a custom model") if (is.list(params) && is.null(gmodel)) stop("model (or family) must be provided when using custom params") - if (!is.null(quality) && !is.function(quality)) + if (!is.null(loss) && !is.function(loss)) # No more checks here as well... TODO:? - stop("quality: function(y1, y2) --> Real") + stop("loss: function(y1, y2) --> Real") if (is.null(task)) { if (is.numeric(target)) @@ -65,34 +77,5 @@ agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, qual # Build Model object (= list of parameterized models) model <- Model$new(data, target, task, gmodel, params) # Return AgghooCV object, to run and predict - AgghooCV$new(data, target, task, model, quality) -} - -#' compareToStandard -#' -#' Temporary function to compare agghoo to CV -#' (TODO: extended, in another file, more tests - when faster code). -#' -#' @export -compareToStandard <- function(df, t_idx, task = NULL, rseed = -1) { - if (rseed >= 0) - set.seed(rseed) - if (is.null(task)) - task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification") - n <- nrow(df) - test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) ) - a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task) - a$fit(mode="agghoo") #default mode - pa <- a$predict(df[test_indices,-t_idx]) - print(paste("error agghoo", - ifelse(task == "classification", - mean(p != df[test_indices,t_idx]), - mean(abs(pa - df[test_indices,t_idx]))))) - # Compare with standard cross-validation: - a$fit(mode="standard") - ps <- a$predict(df[test_indices,-t_idx]) - print(paste("error CV", - ifelse(task == "classification", - mean(ps != df[test_indices,t_idx]), - mean(abs(ps - df[test_indices,t_idx]))))) + AgghooCV$new(data, target, task, model, loss) }