Adjustments / fixes... And add knn for regression
[agghoo.git] / R / agghoo.R
CommitLineData
c5946158
BA
1#' agghoo
2#'
3#' Run the agghoo procedure. (...)
4#'
5#' @param data Data frame or matrix containing the data in lines.
6#' @param target The target values to predict. Generally a vector.
7#' @param task "classification" or "regression". Default:
8#' regression if target is numerical, classification otherwise.
9#' @param gmodel A "generic model", which is a function returning a predict
10#' function (taking X as only argument) from the tuple
11#' (dataHO, targetHO, param), where 'HO' stands for 'Hold-Out',
12#' referring to cross-validation. Cross-validation is run on an array
13#' of 'param's. See params argument. Default: see R6::Model.
14#' @param params A list of parameters. Often, one list cell is just a
15#' numerical value, but in general it could be of any type.
16#' Default: see R6::Model.
17#' @param quality A function assessing the quality of a prediction.
18#' Arguments are y1 and y2 (comparing a prediction to known values).
cca5f1c6 19#' Default: see R6::AgghooCV.
c5946158 20#'
cca5f1c6 21#' @return An R6::AgghooCV object.
c5946158
BA
22#'
23#' @examples
24#' # Regression:
25#' a_reg <- agghoo(iris[,-c(2,5)], iris[,2])
26#' a_reg$fit()
27#' pr <- a_reg$predict(iris[,-c(2,5)] + rnorm(450, sd=0.1))
28#' # Classification
29#' a_cla <- agghoo(iris[,-5], iris[,5])
30#' a_cla$fit(mode="standard")
31#' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1))
32#'
33#' @export
d9a139b5 34agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL) {
c5946158
BA
35 # Args check:
36 if (!is.data.frame(data) && !is.matrix(data))
37 stop("data: data.frame or matrix")
c5946158
BA
38 if (!is.numeric(target) && !is.factor(target) && !is.character(target))
39 stop("target: numeric, factor or character vector")
d9a139b5 40 if (!is.null(task))
c5946158
BA
41 task = match.arg(task, c("classification", "regression"))
42 if (is.character(gmodel))
d9a139b5
BA
43 gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree"))
44 else if (!is.null(gmodel) && !is.function(gmodel))
c5946158
BA
45 # No further checks here: fingers crossed :)
46 stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y")
47 if (is.numeric(params) || is.character(params))
48 params <- as.list(params)
d9a139b5 49 if (!is.list(params) && !is.null(params))
c5946158 50 stop("params: numerical, character, or list (passed to model)")
d9a139b5 51 if (is.function(gmodel) && !is.list(params))
c5946158 52 stop("params must be provided when using a custom model")
d9a139b5
BA
53 if (is.list(params) && is.null(gmodel))
54 stop("model (or family) must be provided when using custom params")
55 if (!is.null(quality) && !is.function(quality))
c5946158
BA
56 # No more checks here as well... TODO:?
57 stop("quality: function(y1, y2) --> Real")
58
d9a139b5 59 if (is.null(task)) {
c5946158
BA
60 if (is.numeric(target))
61 task = "regression"
62 else
63 task = "classification"
64 }
65 # Build Model object (= list of parameterized models)
66 model <- Model$new(data, target, task, gmodel, params)
cca5f1c6
BA
67 # Return AgghooCV object, to run and predict
68 AgghooCV$new(data, target, task, model, quality)
c5946158
BA
69}
70
71#' compareToStandard
72#'
73#' Temporary function to compare agghoo to CV
74#' (TODO: extended, in another file, more tests - when faster code).
75#'
76#' @export
d9a139b5 77compareToStandard <- function(df, t_idx, task = NULL, rseed = -1) {
c5946158
BA
78 if (rseed >= 0)
79 set.seed(rseed)
d9a139b5 80 if (is.null(task))
c5946158
BA
81 task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification")
82 n <- nrow(df)
83 test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
84 a <- agghoo(df[-test_indices,-t_idx], df[-test_indices,t_idx], task)
85 a$fit(mode="agghoo") #default mode
86 pa <- a$predict(df[test_indices,-t_idx])
87 print(paste("error agghoo",
88 ifelse(task == "classification",
89 mean(p != df[test_indices,t_idx]),
90 mean(abs(pa - df[test_indices,t_idx])))))
91 # Compare with standard cross-validation:
92 a$fit(mode="standard")
93 ps <- a$predict(df[test_indices,-t_idx])
94 print(paste("error CV",
95 ifelse(task == "classification",
96 mean(ps != df[test_indices,t_idx]),
97 mean(abs(ps - df[test_indices,t_idx])))))
98}