Update in progress - unfinished
[agghoo.git] / R / agghoo.R
CommitLineData
c5946158
BA
1#' agghoo
2#'
43a6578d 3#' Run the (core) agghoo procedure.
504afaad
BA
4#' Arguments specify the list of models, their parameters and the
5#' cross-validation settings, among others.
c5946158
BA
6#'
7#' @param data Data frame or matrix containing the data in lines.
504afaad
BA
8#' @param target The target values to predict. Generally a vector,
9#' but possibly a matrix in the case of "soft classification".
c5946158
BA
10#' @param task "classification" or "regression". Default:
11#' regression if target is numerical, classification otherwise.
12#' @param gmodel A "generic model", which is a function returning a predict
13#' function (taking X as only argument) from the tuple
14#' (dataHO, targetHO, param), where 'HO' stands for 'Hold-Out',
15#' referring to cross-validation. Cross-validation is run on an array
16#' of 'param's. See params argument. Default: see R6::Model.
17#' @param params A list of parameters. Often, one list cell is just a
18#' numerical value, but in general it could be of any type.
19#' Default: see R6::Model.
504afaad 20#' @param loss A function assessing the error of a prediction.
c5946158 21#' Arguments are y1 and y2 (comparing a prediction to known values).
504afaad 22#' loss(y1, y2) --> real number (error). Default: see R6::AgghooCV.
c5946158 23#'
504afaad
BA
24#' @return
25#' An R6::AgghooCV object o. Then, call o$fit() and finally o$predict(newData)
c5946158
BA
26#'
27#' @examples
28#' # Regression:
29#' a_reg <- agghoo(iris[,-c(2,5)], iris[,2])
30#' a_reg$fit()
31#' pr <- a_reg$predict(iris[,-c(2,5)] + rnorm(450, sd=0.1))
32#' # Classification
33#' a_cla <- agghoo(iris[,-5], iris[,5])
504afaad 34#' a_cla$fit()
c5946158
BA
35#' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1))
36#'
43a6578d
BA
37#' @seealso Function \code{\link{compareTo}}
38#'
504afaad
BA
39#' @references
40#' Guillaume Maillard, Sylvain Arlot, Matthieu Lerasle. "Aggregated hold-out".
41#' Journal of Machine Learning Research 22(20):1--55, 2021.
42#'
c5946158 43#' @export
504afaad 44agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, loss = NULL) {
c5946158
BA
45 # Args check:
46 if (!is.data.frame(data) && !is.matrix(data))
47 stop("data: data.frame or matrix")
504afaad
BA
48 if (is.data.frame(target) || is.matrix(target)) {
49 if (nrow(target) != nrow(data) || ncol(target) == 1)
50 stop("target probability matrix does not match data size")
51 }
52 else if (!is.numeric(target) && !is.factor(target) && !is.character(target))
c5946158 53 stop("target: numeric, factor or character vector")
d9a139b5 54 if (!is.null(task))
c5946158
BA
55 task = match.arg(task, c("classification", "regression"))
56 if (is.character(gmodel))
d9a139b5
BA
57 gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree"))
58 else if (!is.null(gmodel) && !is.function(gmodel))
c5946158
BA
59 # No further checks here: fingers crossed :)
60 stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y")
61 if (is.numeric(params) || is.character(params))
62 params <- as.list(params)
d9a139b5 63 if (!is.list(params) && !is.null(params))
c5946158 64 stop("params: numerical, character, or list (passed to model)")
d9a139b5 65 if (is.function(gmodel) && !is.list(params))
c5946158 66 stop("params must be provided when using a custom model")
d9a139b5
BA
67 if (is.list(params) && is.null(gmodel))
68 stop("model (or family) must be provided when using custom params")
504afaad 69 if (!is.null(loss) && !is.function(loss))
c5946158 70 # No more checks here as well... TODO:?
504afaad 71 stop("loss: function(y1, y2) --> Real")
c5946158 72
d9a139b5 73 if (is.null(task)) {
c5946158
BA
74 if (is.numeric(target))
75 task = "regression"
76 else
77 task = "classification"
78 }
79 # Build Model object (= list of parameterized models)
80 model <- Model$new(data, target, task, gmodel, params)
cca5f1c6 81 # Return AgghooCV object, to run and predict
504afaad 82 AgghooCV$new(data, target, task, model, loss)
c5946158 83}