Fix regression
[agghoo.git] / R / R6_AgghooCV.R
CommitLineData
c5946158
BA
1#' @title R6 class with agghoo functions fit() and predict().
2#'
3#' @description
4#' Class encapsulating the methods to run to obtain the best predictor
5#' from the list of models (see 'Model' class).
6#'
d9a139b5
BA
7#' @importFrom R6 R6Class
8#'
c5946158 9#' @export
cca5f1c6 10AgghooCV <- R6::R6Class("AgghooCV",
c5946158 11 public = list(
cca5f1c6 12 #' @description Create a new AgghooCV object.
c5946158
BA
13 #' @param data Matrix or data.frame
14 #' @param target Vector of targets (generally numeric or factor)
15 #' @param task "regression" or "classification"
16 #' @param gmodel Generic model returning a predictive function
504afaad
BA
17 #' @param loss Function assessing the error of a prediction
18 initialize = function(data, target, task, gmodel, loss = NULL) {
c5946158
BA
19 private$data <- data
20 private$target <- target
21 private$task <- task
22 private$gmodel <- gmodel
504afaad
BA
23 if (is.null(loss))
24 loss <- private$defaultLoss
25 private$loss <- loss
c5946158
BA
26 },
27 #' @description Fit an agghoo model.
28 #' @param CV List describing cross-validation to run. Slots:
29 #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC)
30 #' - V: number of runs (default: 10)
31 #' - test_size: percentage of data in the test dataset, for MC
32 #' (irrelevant for V-fold). Default: 0.2.
33 #' - shuffle: wether or not to shuffle data before V-fold.
34 #' Irrelevant for Monte-Carlo; default: TRUE
c5946158
BA
35 fit = function(
36 CV = list(type = "MC",
37 V = 10,
38 test_size = 0.2,
504afaad 39 shuffle = TRUE)
c5946158
BA
40 ) {
41 if (!is.list(CV))
42 stop("CV: list of type, V, [test_size], [shuffle]")
43 n <- nrow(private$data)
504afaad 44 shuffle_inds <- NULL
c5946158
BA
45 if (CV$type == "vfold" && CV$shuffle)
46 shuffle_inds <- sample(n, n)
504afaad
BA
47 # Result: list of V predictive models (+ parameters for info)
48 private$pmodels <- list()
49 for (v in seq_len(CV$V)) {
50 # Prepare train / test data and target, from full dataset.
51 # dataHO: "data Hold-Out" etc.
52 test_indices <- private$get_testIndices(CV, v, n, shuffle_inds)
53 dataHO <- private$data[-test_indices,]
54 testX <- private$data[test_indices,]
55 targetHO <- private$target[-test_indices]
56 testY <- private$target[test_indices]
57 # [HACK] R will cast 1-dim matrices into vectors:
58 if (!is.matrix(dataHO) && !is.data.frame(dataHO))
59 dataHO <- as.matrix(dataHO)
60 if (!is.matrix(testX) && !is.data.frame(testX))
61 testX <- as.matrix(testX)
62 best_model <- NULL
63 best_error <- Inf
64 for (p in seq_len(private$gmodel$nmodels)) {
65 model_pred <- private$gmodel$get(dataHO, targetHO, p)
66 prediction <- model_pred(testX)
67 error <- private$loss(prediction, testY)
68 if (error <= best_error) {
69 newModel <- list(model=model_pred, param=private$gmodel$getParam(p))
70 if (error == best_error)
71 best_model[[length(best_model)+1]] <- newModel
72 else {
73 best_model <- list(newModel)
74 best_error <- error
75 }
c5946158
BA
76 }
77 }
504afaad
BA
78 # Choose a model at random in case of ex-aequos
79 private$pmodels[[v]] <- best_model[[ sample(length(best_model),1) ]]
c5946158
BA
80 }
81 },
82 #' @description Predict an agghoo model (after calling fit())
83 #' @param X Matrix or data.frame to predict
504afaad 84 predict = function(X) {
d9a139b5
BA
85 if (!is.matrix(X) && !is.data.frame(X))
86 stop("X: matrix or data.frame")
504afaad 87 if (!is.list(private$pmodels)) {
c5946158 88 print("Please call $fit() method first")
504afaad 89 return (invisible(NULL))
c5946158 90 }
504afaad
BA
91 V <- length(private$pmodels)
92 if (length(private$pmodels[[1]]$model(X[1,])) >= 2)
93 # Soft classification:
94 return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V)
c5946158 95 n <- nrow(X)
504afaad
BA
96 all_predictions <- as.data.frame(matrix(nrow=n, ncol=V))
97 for (v in 1:V)
98 all_predictions[,v] <- private$pmodels[[v]]$model(X)
c5946158 99 if (private$task == "regression")
504afaad 100 # Easy case: just average each row
1fdc3c34 101 return (rowMeans(all_predictions))
504afaad
BA
102 # "Hard" classification:
103 apply(all_predictions, 1, function(row) {
104 t <- table(row)
105 # Next lines in case of ties (broken at random)
106 tmax <- max(t)
107 sample( names(t)[which(t == tmax)], 1 )
108 })
109 },
110 #' @description Return the list of V best parameters (after calling fit())
111 getParams = function() {
112 lapply(private$pmodels, function(m) m$param)
c5946158
BA
113 }
114 ),
115 private = list(
d9a139b5
BA
116 data = NULL,
117 target = NULL,
118 task = NULL,
119 gmodel = NULL,
504afaad
BA
120 loss = NULL,
121 pmodels = NULL,
c5946158
BA
122 get_testIndices = function(CV, v, n, shuffle_inds) {
123 if (CV$type == "vfold") {
504afaad 124 # Slice indices (optionnally shuffled)
c5946158
BA
125 first_index = round((v-1) * n / CV$V) + 1
126 last_index = round(v * n / CV$V)
127 test_indices = first_index:last_index
504afaad 128 if (!is.null(shuffle_inds))
c5946158
BA
129 test_indices <- shuffle_inds[test_indices]
130 }
131 else
504afaad 132 # Monte-Carlo cross-validation
c5946158
BA
133 test_indices = sample(n, round(n * CV$test_size))
134 test_indices
135 },
504afaad
BA
136 defaultLoss = function(y1, y2) {
137 if (private$task == "classification") {
138 if (is.null(dim(y1)))
139 # Standard case: "hard" classification
140 mean(y1 != y2)
141 else {
142 # "Soft" classification: predict() outputs a probability matrix
143 # In this case "target" could be in matrix form.
144 if (!is.null(dim(y2)))
145 mean(rowSums(abs(y1 - y2)))
146 else {
147 # Or not: y2 is a "factor".
148 y2 <- as.character(y2)
149 # NOTE: the user should provide target in matrix form because
150 # matching y2 with columns is rather inefficient!
151 names <- colnames(y1)
152 positions <- list()
153 for (idx in seq_along(names))
154 positions[[ names[idx] ]] <- idx
155 mean(vapply(
156 seq_along(y2),
157 function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])),
158 0))
159 }
c5946158
BA
160 }
161 }
504afaad
BA
162 else
163 # Regression
164 mean(abs(y1 - y2))
c5946158
BA
165 }
166 )
167)