Fix typo
[agghoo.git] / R / R6_Agghoo.R
1 #' @title R6 class with agghoo functions fit() and predict().
2 #'
3 #' @description
4 #' Class encapsulating the methods to run to obtain the best predictor
5 #' from the list of models (see 'Model' class).
6 #'
7 #' @export
8 Agghoo <- R6::R6Class("Agghoo",
9 public = list(
10 #' @description Create a new Agghoo object.
11 #' @param data Matrix or data.frame
12 #' @param target Vector of targets (generally numeric or factor)
13 #' @param task "regression" or "classification"
14 #' @param gmodel Generic model returning a predictive function
15 #' @param quality Function assessing the quality of a prediction;
16 #' quality(y1, y2) --> real number
17 initialize = function(data, target, task, gmodel, quality = NA) {
18 private$data <- data
19 private$target <- target
20 private$task <- task
21 private$gmodel <- gmodel
22 if (is.na(quality)) {
23 quality <- function(y1, y2) {
24 # NOTE: if classif output is a probability matrix, adapt.
25 if (task == "classification")
26 mean(y1 == y2)
27 else
28 atan(1.0 / (mean(abs(y1 - y2) + 0.01))) #experimental...
29 }
30 }
31 private$quality <- quality
32 },
33 #' @description Fit an agghoo model.
34 #' @param CV List describing cross-validation to run. Slots:
35 #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC)
36 #' - V: number of runs (default: 10)
37 #' - test_size: percentage of data in the test dataset, for MC
38 #' (irrelevant for V-fold). Default: 0.2.
39 #' - shuffle: wether or not to shuffle data before V-fold.
40 #' Irrelevant for Monte-Carlo; default: TRUE
41 #' @param mode "agghoo" or "standard" (for usual cross-validation)
42 fit = function(
43 CV = list(type = "MC",
44 V = 10,
45 test_size = 0.2,
46 shuffle = TRUE),
47 mode="agghoo"
48 ) {
49 if (!is.list(CV))
50 stop("CV: list of type, V, [test_size], [shuffle]")
51 n <- nrow(private$data)
52 shuffle_inds <- NA
53 if (CV$type == "vfold" && CV$shuffle)
54 shuffle_inds <- sample(n, n)
55 if (mode == "agghoo") {
56 vperfs <- list()
57 for (v in 1:CV$V) {
58 test_indices <- private$get_testIndices(CV, v, n, shuffle_inds)
59 vperf <- private$get_modelPerf(test_indices)
60 vperfs[[v]] <- vperf
61 }
62 private$run_res <- vperfs
63 }
64 else {
65 # Standard cross-validation
66 best_index = 0
67 best_perf <- -1
68 for (p in 1:private$gmodel$nmodels) {
69 tot_perf <- 0
70 for (v in 1:CV$V) {
71 test_indices <- private$get_testIndices(CV, v, n, shuffle_inds)
72 perf <- private$get_modelPerf(test_indices, p)
73 tot_perf <- tot_perf + perf / CV$V
74 }
75 if (tot_perf > best_perf) {
76 # TODO: if ex-aequos: models list + choose at random
77 best_index <- p
78 best_perf <- tot_perf
79 }
80 }
81 best_model <- private$gmodel$get(private$data, private$target, best_index)
82 private$run_res <- list( list(model=best_model, perf=best_perf) )
83 }
84 },
85 #' @description Predict an agghoo model (after calling fit())
86 #' @param X Matrix or data.frame to predict
87 #' @param weight "uniform" (default) or "quality" to weight votes or
88 #' average models performances (TODO: bad idea?!)
89 predict = function(X, weight="uniform") {
90 if (!is.list(private$run_res) || is.na(private$run_res)) {
91 print("Please call $fit() method first")
92 return
93 }
94 V <- length(private$run_res)
95 if (V == 1)
96 # Standard CV:
97 return (private$run_res[[1]]$model(X))
98 # Agghoo:
99 if (weight == "uniform")
100 weights <- rep(1 / V, V)
101 else {
102 perfs <- sapply(private$run_res, function(item) item$perf)
103 perfs[perfs < 0] <- 0 #TODO: show a warning (with count of < 0...)
104 total_weight <- sum(perfs) #TODO: error if total_weight == 0
105 weights <- perfs / total_weight
106 }
107 n <- nrow(X)
108 # TODO: detect if output = probs matrix for classif (in this case, adapt?)
109 # prediction agghoo "probabiliste" pour un nouveau x :
110 # argMax({ predict(m_v, x), v in 1..V }) ...
111 if (private$task == "classification") {
112 votes <- as.list(rep(NA, n))
113 parse_numeric <- FALSE
114 }
115 else
116 preds <- matrix(0, nrow=n, ncol=V)
117 for (v in 1:V) {
118 predictions <- private$run_res[[v]]$model(X)
119 if (private$task == "regression")
120 preds <- cbind(preds, weights[v] * predictions)
121 else {
122 if (!parse_numeric && is.numeric(predictions))
123 parse_numeric <- TRUE
124 for (i in 1:n) {
125 if (!is.list(votes[[i]]))
126 votes[[i]] <- list()
127 index <- as.character(predictions[i])
128 if (is.null(votes[[i]][[index]]))
129 votes[[i]][[index]] <- 0
130 votes[[i]][[index]] <- votes[[i]][[index]] + weights[v]
131 }
132 }
133 }
134 if (private$task == "regression")
135 return (rowSums(preds))
136 res <- c()
137 for (i in 1:n) {
138 # TODO: if ex-aequos, random choice...
139 ind_max <- which.max(unlist(votes[[i]]))
140 pred_class <- names(votes[[i]])[ind_max]
141 if (parse_numeric)
142 pred_class <- as.numeric(pred_class)
143 res <- c(res, pred_class)
144 }
145 res
146 }
147 ),
148 private = list(
149 data = NA,
150 target = NA,
151 task = NA,
152 gmodel = NA,
153 quality = NA,
154 run_res = NA,
155 get_testIndices = function(CV, v, n, shuffle_inds) {
156 if (CV$type == "vfold") {
157 first_index = round((v-1) * n / CV$V) + 1
158 last_index = round(v * n / CV$V)
159 test_indices = first_index:last_index
160 if (CV$shuffle)
161 test_indices <- shuffle_inds[test_indices]
162 }
163 else
164 test_indices = sample(n, round(n * CV$test_size))
165 test_indices
166 },
167 get_modelPerf = function(test_indices, p=0) {
168 getOnePerf <- function(p) {
169 model_pred <- private$gmodel$get(dataHO, targetHO, p)
170 prediction <- model_pred(testX)
171 perf <- private$quality(prediction, testY)
172 list(model=model_pred, perf=perf)
173 }
174 dataHO <- private$data[-test_indices,]
175 testX <- private$data[test_indices,]
176 targetHO <- private$target[-test_indices]
177 testY <- private$target[test_indices]
178 if (p >= 1)
179 # Standard CV: one model at a time
180 return (getOnePerf(p)$perf)
181 # Agghoo: loop on all models
182 best_model = NULL
183 best_perf <- -1
184 for (p in 1:private$gmodel$nmodels) {
185 model_perf <- getOnePerf(p)
186 if (model_perf$perf > best_perf) {
187 # TODO: if ex-aequos: models list + choose at random
188 best_model <- model_perf$model
189 best_perf <- model_perf$perf
190 }
191 }
192 list(model=best_model, perf=best_perf)
193 }
194 )
195 )