92d061f0a13790fc8ec5460b6619f4f60b63a51b
[agghoo.git] / R / agghoo.R
1 #' agghoo
2 #'
3 #' Run the agghoo procedure. (...)
4 #'
5 #' @param data Data frame or matrix containing the data in lines.
6 #' @param target The target values to predict. Generally a vector.
7 #' @param task "classification" or "regression". Default:
8 #' regression if target is numerical, classification otherwise.
9 #' @param gmodel A "generic model", which is a function returning a predict
10 #' function (taking X as only argument) from the tuple
11 #' (dataHO, targetHO, param), where 'HO' stands for 'Hold-Out',
12 #' referring to cross-validation. Cross-validation is run on an array
13 #' of 'param's. See params argument. Default: see R6::Model.
14 #' @param params A list of parameters. Often, one list cell is just a
15 #' numerical value, but in general it could be of any type.
16 #' Default: see R6::Model.
17 #' @param quality A function assessing the quality of a prediction.
18 #' Arguments are y1 and y2 (comparing a prediction to known values).
19 #' Default: see R6::AgghooCV.
20 #'
21 #' @return An R6::AgghooCV object.
22 #'
23 #' @examples
24 #' # Regression:
25 #' a_reg <- agghoo(iris[,-c(2,5)], iris[,2])
26 #' a_reg$fit()
27 #' pr <- a_reg$predict(iris[,-c(2,5)] + rnorm(450, sd=0.1))
28 #' # Classification
29 #' a_cla <- agghoo(iris[,-5], iris[,5])
30 #' a_cla$fit(mode="standard")
31 #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1))
32 #'
33 #' @export
34 agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = NA) {
35 # Args check:
36 if (!is.data.frame(data) && !is.matrix(data))
37 stop("data: data.frame or matrix")
38 if (nrow(data) <= 1 || any(dim(data) == 0))
39 stop("data: non-empty, >= 2 rows")
40 if (!is.numeric(target) && !is.factor(target) && !is.character(target))
41 stop("target: numeric, factor or character vector")
42 if (!is.na(task))
43 task = match.arg(task, c("classification", "regression"))
44 if (is.character(gmodel))
45 gmodel <- match.arg("knn", "ppr", "rf")
46 else if (!is.na(gmodel) && !is.function(gmodel))
47 # No further checks here: fingers crossed :)
48 stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y")
49 if (is.numeric(params) || is.character(params))
50 params <- as.list(params)
51 if (!is.na(params) && !is.list(params))
52 stop("params: numerical, character, or list (passed to model)")
53 if (!is.na(gmodel) && !is.character(gmodel) && is.na(params))
54 stop("params must be provided when using a custom model")
55 if (is.na(gmodel) && !is.na(params))
56 stop("model must be provided when using custom params")
57 if (!is.na(quality) && !is.function(quality))
58 # No more checks here as well... TODO:?
59 stop("quality: function(y1, y2) --> Real")
60
61 if (is.na(task)) {
62 if (is.numeric(target))
63 task = "regression"
64 else
65 task = "classification"
66 }
67 # Build Model object (= list of parameterized models)
68 model <- Model$new(data, target, task, gmodel, params)
69 # Return AgghooCV object, to run and predict
70 AgghooCV$new(data, target, task, model, quality)
71 }
72
73 #' compareToStandard
74 #'
75 #' Temporary function to compare agghoo to CV
76 #' (TODO: extended, in another file, more tests - when faster code).
77 #'
78 #' @export
79 compareToStandard <- function(df, t_idx, task = NA, rseed = -1) {
80 if (rseed >= 0)
81 set.seed(rseed)
82 if (is.na(task))
83 task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification")
84 n <- nrow(df)
85 test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
86 a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task)
87 a$fit(mode="agghoo") #default mode
88 pa <- a$predict(df[test_indices,-t_idx])
89 print(paste("error agghoo",
90 ifelse(task == "classification",
91 mean(p != df[test_indices,t_idx]),
92 mean(abs(pa - df[test_indices,t_idx])))))
93 # Compare with standard cross-validation:
94 a$fit(mode="standard")
95 ps <- a$predict(df[test_indices,-t_idx])
96 print(paste("error CV",
97 ifelse(task == "classification",
98 mean(ps != df[test_indices,t_idx]),
99 mean(abs(ps - df[test_indices,t_idx])))))
100 }