Update in progress - unfinished
[agghoo.git] / R / R6_AgghooCV.R
CommitLineData
c5946158
BA
1#' @title R6 class with agghoo functions fit() and predict().
2#'
3#' @description
4#' Class encapsulating the methods to run to obtain the best predictor
5#' from the list of models (see 'Model' class).
6#'
d9a139b5
BA
7#' @importFrom R6 R6Class
8#'
c5946158 9#' @export
cca5f1c6 10AgghooCV <- R6::R6Class("AgghooCV",
c5946158 11 public = list(
cca5f1c6 12 #' @description Create a new AgghooCV object.
c5946158
BA
13 #' @param data Matrix or data.frame
14 #' @param target Vector of targets (generally numeric or factor)
15 #' @param task "regression" or "classification"
16 #' @param gmodel Generic model returning a predictive function
504afaad
BA
17 #' @param loss Function assessing the error of a prediction
18 initialize = function(data, target, task, gmodel, loss = NULL) {
c5946158
BA
19 private$data <- data
20 private$target <- target
21 private$task <- task
22 private$gmodel <- gmodel
504afaad
BA
23 if (is.null(loss))
24 loss <- private$defaultLoss
25 private$loss <- loss
c5946158
BA
26 },
27 #' @description Fit an agghoo model.
43a6578d
BA
28 #' @param CV List describing cross-validation to run. Slots: \cr
29 #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC) \cr
30 #' - V: number of runs (default: 10) \cr
c5946158 31 #' - test_size: percentage of data in the test dataset, for MC
43a6578d 32 #' (irrelevant for V-fold). Default: 0.2. \cr
c5946158
BA
33 #' - shuffle: wether or not to shuffle data before V-fold.
34 #' Irrelevant for Monte-Carlo; default: TRUE
c5946158
BA
35 fit = function(
36 CV = list(type = "MC",
37 V = 10,
38 test_size = 0.2,
504afaad 39 shuffle = TRUE)
c5946158
BA
40 ) {
41 if (!is.list(CV))
42 stop("CV: list of type, V, [test_size], [shuffle]")
43 n <- nrow(private$data)
504afaad 44 shuffle_inds <- NULL
c5946158
BA
45 if (CV$type == "vfold" && CV$shuffle)
46 shuffle_inds <- sample(n, n)
504afaad
BA
47 # Result: list of V predictive models (+ parameters for info)
48 private$pmodels <- list()
49 for (v in seq_len(CV$V)) {
50 # Prepare train / test data and target, from full dataset.
51 # dataHO: "data Hold-Out" etc.
52 test_indices <- private$get_testIndices(CV, v, n, shuffle_inds)
53 dataHO <- private$data[-test_indices,]
54 testX <- private$data[test_indices,]
55 targetHO <- private$target[-test_indices]
56 testY <- private$target[test_indices]
57 # [HACK] R will cast 1-dim matrices into vectors:
58 if (!is.matrix(dataHO) && !is.data.frame(dataHO))
59 dataHO <- as.matrix(dataHO)
60 if (!is.matrix(testX) && !is.data.frame(testX))
61 testX <- as.matrix(testX)
62 best_model <- NULL
63 best_error <- Inf
64 for (p in seq_len(private$gmodel$nmodels)) {
65 model_pred <- private$gmodel$get(dataHO, targetHO, p)
66 prediction <- model_pred(testX)
67 error <- private$loss(prediction, testY)
68 if (error <= best_error) {
69 newModel <- list(model=model_pred, param=private$gmodel$getParam(p))
70 if (error == best_error)
71 best_model[[length(best_model)+1]] <- newModel
72 else {
73 best_model <- list(newModel)
74 best_error <- error
75 }
c5946158
BA
76 }
77 }
504afaad
BA
78 # Choose a model at random in case of ex-aequos
79 private$pmodels[[v]] <- best_model[[ sample(length(best_model),1) ]]
c5946158
BA
80 }
81 },
82 #' @description Predict an agghoo model (after calling fit())
83 #' @param X Matrix or data.frame to predict
504afaad 84 predict = function(X) {
d9a139b5
BA
85 if (!is.matrix(X) && !is.data.frame(X))
86 stop("X: matrix or data.frame")
504afaad 87 if (!is.list(private$pmodels)) {
c5946158 88 print("Please call $fit() method first")
504afaad 89 return (invisible(NULL))
c5946158 90 }
504afaad 91 V <- length(private$pmodels)
c152ea66 92 oneLineX <- t(as.matrix(X[1,]))
7b5193cd 93 if (length(private$pmodels[[1]]$model(oneLineX)) >= 2)
504afaad
BA
94 # Soft classification:
95 return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V)
c5946158 96 n <- nrow(X)
504afaad
BA
97 all_predictions <- as.data.frame(matrix(nrow=n, ncol=V))
98 for (v in 1:V)
99 all_predictions[,v] <- private$pmodels[[v]]$model(X)
c5946158 100 if (private$task == "regression")
504afaad 101 # Easy case: just average each row
1fdc3c34 102 return (rowMeans(all_predictions))
504afaad
BA
103 # "Hard" classification:
104 apply(all_predictions, 1, function(row) {
105 t <- table(row)
106 # Next lines in case of ties (broken at random)
107 tmax <- max(t)
108 sample( names(t)[which(t == tmax)], 1 )
109 })
110 },
111 #' @description Return the list of V best parameters (after calling fit())
112 getParams = function() {
113 lapply(private$pmodels, function(m) m$param)
c5946158
BA
114 }
115 ),
116 private = list(
d9a139b5
BA
117 data = NULL,
118 target = NULL,
119 task = NULL,
120 gmodel = NULL,
504afaad
BA
121 loss = NULL,
122 pmodels = NULL,
c5946158
BA
123 get_testIndices = function(CV, v, n, shuffle_inds) {
124 if (CV$type == "vfold") {
504afaad 125 # Slice indices (optionnally shuffled)
c5946158
BA
126 first_index = round((v-1) * n / CV$V) + 1
127 last_index = round(v * n / CV$V)
128 test_indices = first_index:last_index
504afaad 129 if (!is.null(shuffle_inds))
c5946158
BA
130 test_indices <- shuffle_inds[test_indices]
131 }
132 else
504afaad 133 # Monte-Carlo cross-validation
c5946158
BA
134 test_indices = sample(n, round(n * CV$test_size))
135 test_indices
136 },
504afaad
BA
137 defaultLoss = function(y1, y2) {
138 if (private$task == "classification") {
139 if (is.null(dim(y1)))
140 # Standard case: "hard" classification
141 mean(y1 != y2)
142 else {
143 # "Soft" classification: predict() outputs a probability matrix
144 # In this case "target" could be in matrix form.
145 if (!is.null(dim(y2)))
146 mean(rowSums(abs(y1 - y2)))
147 else {
148 # Or not: y2 is a "factor".
149 y2 <- as.character(y2)
150 # NOTE: the user should provide target in matrix form because
151 # matching y2 with columns is rather inefficient!
152 names <- colnames(y1)
153 positions <- list()
154 for (idx in seq_along(names))
155 positions[[ names[idx] ]] <- idx
156 mean(vapply(
157 seq_along(y2),
158 function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])),
159 0))
160 }
c5946158
BA
161 }
162 }
504afaad
BA
163 else
164 # Regression
165 mean(abs(y1 - y2))
c5946158
BA
166 }
167 )
168)