X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=agghoo.Rcheck%2F00_pkg_src%2Fagghoo%2FR%2Fagghoo.R;fp=agghoo.Rcheck%2F00_pkg_src%2Fagghoo%2FR%2Fagghoo.R;h=48ac741f2ec8caf8648ac16ecb8a84fd3d5b2c5a;hp=0000000000000000000000000000000000000000;hb=97f16440280a40a49c4898a75942e374880bfca3;hpb=3b8affec63125c3816a3d15f0f49776dc14867a2 diff --git a/agghoo.Rcheck/00_pkg_src/agghoo/R/agghoo.R b/agghoo.Rcheck/00_pkg_src/agghoo/R/agghoo.R new file mode 100644 index 0000000..48ac741 --- /dev/null +++ b/agghoo.Rcheck/00_pkg_src/agghoo/R/agghoo.R @@ -0,0 +1,58 @@ +#' agghoo +#' +#' Run the (core) agghoo procedure. +#' Arguments specify the list of models, their parameters and the +#' cross-validation settings, among others. +#' +#' @param data Data frame or matrix containing the data in lines. +#' @param target The target values to predict. Generally a vector, +#' but possibly a matrix in the case of "soft classification". +#' @param task "classification" or "regression". Default: +#' regression if target is numerical, classification otherwise. +#' @param gmodel A "generic model", which is a function returning a predict +#' function (taking X as only argument) from the tuple +#' (dataHO, targetHO, param), where 'HO' stands for 'Hold-Out', +#' referring to cross-validation. Cross-validation is run on an array +#' of 'param's. See params argument. Default: see R6::Model. +#' @param params A list of parameters. Often, one list cell is just a +#' numerical value, but in general it could be of any type. +#' Default: see R6::Model. +#' @param loss A function assessing the error of a prediction. +#' Arguments are y1 and y2 (comparing a prediction to known values). +#' loss(y1, y2) --> real number (error). Default: see R6::AgghooCV. +#' +#' @return +#' An R6::AgghooCV object o. Then, call o$fit() and finally o$predict(newData) +#' +#' @examples +#' # Regression: +#' a_reg <- agghoo(iris[,-c(2,5)], iris[,2]) +#' a_reg$fit() +#' pr <- a_reg$predict(iris[,-c(2,5)] + rnorm(450, sd=0.1)) +#' # Classification +#' a_cla <- agghoo(iris[,-5], iris[,5]) +#' a_cla$fit() +#' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1)) +#' +#' @seealso Function \code{\link{compareTo}} +#' +#' @references +#' Guillaume Maillard, Sylvain Arlot, Matthieu Lerasle. "Aggregated hold-out". +#' Journal of Machine Learning Research 22(20):1--55, 2021. +#' +#' @export +agghoo <- function( + data, target, task = NULL, gmodel = NULL, params = NULL, loss = NULL +) { + # Args check: + checkDaTa(data, target) + task <- checkTask(task, target) + modPar <- checkModPar(gmodel, params) + loss <- checkLoss(loss, task) + + # Build Model object (= list of parameterized models) + model <- Model$new(data, target, task, modPar$gmodel, modPar$params) + + # Return AgghooCV object, to run and predict + AgghooCV$new(data, target, task, model, loss) +}