48ac741f2ec8caf8648ac16ecb8a84fd3d5b2c5a
[agghoo.git] / R / agghoo.R
1 #' agghoo
2 #'
3 #' Run the (core) agghoo procedure.
4 #' Arguments specify the list of models, their parameters and the
5 #' cross-validation settings, among others.
6 #'
7 #' @param data Data frame or matrix containing the data in lines.
8 #' @param target The target values to predict. Generally a vector,
9 #' but possibly a matrix in the case of "soft classification".
10 #' @param task "classification" or "regression". Default:
11 #' regression if target is numerical, classification otherwise.
12 #' @param gmodel A "generic model", which is a function returning a predict
13 #' function (taking X as only argument) from the tuple
14 #' (dataHO, targetHO, param), where 'HO' stands for 'Hold-Out',
15 #' referring to cross-validation. Cross-validation is run on an array
16 #' of 'param's. See params argument. Default: see R6::Model.
17 #' @param params A list of parameters. Often, one list cell is just a
18 #' numerical value, but in general it could be of any type.
19 #' Default: see R6::Model.
20 #' @param loss A function assessing the error of a prediction.
21 #' Arguments are y1 and y2 (comparing a prediction to known values).
22 #' loss(y1, y2) --> real number (error). Default: see R6::AgghooCV.
23 #'
24 #' @return
25 #' An R6::AgghooCV object o. Then, call o$fit() and finally o$predict(newData)
26 #'
27 #' @examples
28 #' # Regression:
29 #' a_reg <- agghoo(iris[,-c(2,5)], iris[,2])
30 #' a_reg$fit()
31 #' pr <- a_reg$predict(iris[,-c(2,5)] + rnorm(450, sd=0.1))
32 #' # Classification
33 #' a_cla <- agghoo(iris[,-5], iris[,5])
34 #' a_cla$fit()
35 #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1))
36 #'
37 #' @seealso Function \code{\link{compareTo}}
38 #'
39 #' @references
40 #' Guillaume Maillard, Sylvain Arlot, Matthieu Lerasle. "Aggregated hold-out".
41 #' Journal of Machine Learning Research 22(20):1--55, 2021.
42 #'
43 #' @export
44 agghoo <- function(
45 data, target, task = NULL, gmodel = NULL, params = NULL, loss = NULL
46 ) {
47 # Args check:
48 checkDaTa(data, target)
49 task <- checkTask(task, target)
50 modPar <- checkModPar(gmodel, params)
51 loss <- checkLoss(loss, task)
52
53 # Build Model object (= list of parameterized models)
54 model <- Model$new(data, target, task, modPar$gmodel, modPar$params)
55
56 # Return AgghooCV object, to run and predict
57 AgghooCV$new(data, target, task, model, loss)
58 }