Commit | Line | Data |
---|---|---|
c5946158 BA |
1 | #' @title R6 class with agghoo functions fit() and predict(). |
2 | #' | |
3 | #' @description | |
4 | #' Class encapsulating the methods to run to obtain the best predictor | |
5 | #' from the list of models (see 'Model' class). | |
6 | #' | |
d9a139b5 BA |
7 | #' @importFrom R6 R6Class |
8 | #' | |
c5946158 | 9 | #' @export |
cca5f1c6 | 10 | AgghooCV <- R6::R6Class("AgghooCV", |
c5946158 | 11 | public = list( |
cca5f1c6 | 12 | #' @description Create a new AgghooCV object. |
c5946158 BA |
13 | #' @param data Matrix or data.frame |
14 | #' @param target Vector of targets (generally numeric or factor) | |
15 | #' @param task "regression" or "classification" | |
16 | #' @param gmodel Generic model returning a predictive function | |
17 | #' @param quality Function assessing the quality of a prediction; | |
18 | #' quality(y1, y2) --> real number | |
d9a139b5 | 19 | initialize = function(data, target, task, gmodel, quality = NULL) { |
c5946158 BA |
20 | private$data <- data |
21 | private$target <- target | |
22 | private$task <- task | |
23 | private$gmodel <- gmodel | |
d9a139b5 | 24 | if (is.null(quality)) { |
c5946158 BA |
25 | quality <- function(y1, y2) { |
26 | # NOTE: if classif output is a probability matrix, adapt. | |
27 | if (task == "classification") | |
28 | mean(y1 == y2) | |
29 | else | |
30 | atan(1.0 / (mean(abs(y1 - y2) + 0.01))) #experimental... | |
31 | } | |
32 | } | |
33 | private$quality <- quality | |
34 | }, | |
35 | #' @description Fit an agghoo model. | |
36 | #' @param CV List describing cross-validation to run. Slots: | |
37 | #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC) | |
38 | #' - V: number of runs (default: 10) | |
39 | #' - test_size: percentage of data in the test dataset, for MC | |
40 | #' (irrelevant for V-fold). Default: 0.2. | |
41 | #' - shuffle: wether or not to shuffle data before V-fold. | |
42 | #' Irrelevant for Monte-Carlo; default: TRUE | |
43 | #' @param mode "agghoo" or "standard" (for usual cross-validation) | |
44 | fit = function( | |
45 | CV = list(type = "MC", | |
46 | V = 10, | |
47 | test_size = 0.2, | |
48 | shuffle = TRUE), | |
49 | mode="agghoo" | |
50 | ) { | |
51 | if (!is.list(CV)) | |
52 | stop("CV: list of type, V, [test_size], [shuffle]") | |
53 | n <- nrow(private$data) | |
54 | shuffle_inds <- NA | |
55 | if (CV$type == "vfold" && CV$shuffle) | |
56 | shuffle_inds <- sample(n, n) | |
57 | if (mode == "agghoo") { | |
58 | vperfs <- list() | |
59 | for (v in 1:CV$V) { | |
60 | test_indices <- private$get_testIndices(CV, v, n, shuffle_inds) | |
61 | vperf <- private$get_modelPerf(test_indices) | |
62 | vperfs[[v]] <- vperf | |
63 | } | |
64 | private$run_res <- vperfs | |
65 | } | |
66 | else { | |
67 | # Standard cross-validation | |
68 | best_index = 0 | |
69 | best_perf <- -1 | |
70 | for (p in 1:private$gmodel$nmodels) { | |
71 | tot_perf <- 0 | |
72 | for (v in 1:CV$V) { | |
73 | test_indices <- private$get_testIndices(CV, v, n, shuffle_inds) | |
74 | perf <- private$get_modelPerf(test_indices, p) | |
75 | tot_perf <- tot_perf + perf / CV$V | |
76 | } | |
77 | if (tot_perf > best_perf) { | |
78 | # TODO: if ex-aequos: models list + choose at random | |
79 | best_index <- p | |
80 | best_perf <- tot_perf | |
81 | } | |
82 | } | |
83 | best_model <- private$gmodel$get(private$data, private$target, best_index) | |
84 | private$run_res <- list( list(model=best_model, perf=best_perf) ) | |
85 | } | |
86 | }, | |
87 | #' @description Predict an agghoo model (after calling fit()) | |
88 | #' @param X Matrix or data.frame to predict | |
89 | #' @param weight "uniform" (default) or "quality" to weight votes or | |
90 | #' average models performances (TODO: bad idea?!) | |
91 | predict = function(X, weight="uniform") { | |
d9a139b5 BA |
92 | if (!is.matrix(X) && !is.data.frame(X)) |
93 | stop("X: matrix or data.frame") | |
94 | if (!is.list(private$run_res)) { | |
c5946158 BA |
95 | print("Please call $fit() method first") |
96 | return | |
97 | } | |
98 | V <- length(private$run_res) | |
99 | if (V == 1) | |
100 | # Standard CV: | |
101 | return (private$run_res[[1]]$model(X)) | |
102 | # Agghoo: | |
103 | if (weight == "uniform") | |
104 | weights <- rep(1 / V, V) | |
105 | else { | |
106 | perfs <- sapply(private$run_res, function(item) item$perf) | |
107 | perfs[perfs < 0] <- 0 #TODO: show a warning (with count of < 0...) | |
108 | total_weight <- sum(perfs) #TODO: error if total_weight == 0 | |
109 | weights <- perfs / total_weight | |
110 | } | |
111 | n <- nrow(X) | |
112 | # TODO: detect if output = probs matrix for classif (in this case, adapt?) | |
113 | # prediction agghoo "probabiliste" pour un nouveau x : | |
114 | # argMax({ predict(m_v, x), v in 1..V }) ... | |
115 | if (private$task == "classification") { | |
116 | votes <- as.list(rep(NA, n)) | |
117 | parse_numeric <- FALSE | |
118 | } | |
119 | else | |
120 | preds <- matrix(0, nrow=n, ncol=V) | |
121 | for (v in 1:V) { | |
122 | predictions <- private$run_res[[v]]$model(X) | |
123 | if (private$task == "regression") | |
124 | preds <- cbind(preds, weights[v] * predictions) | |
125 | else { | |
126 | if (!parse_numeric && is.numeric(predictions)) | |
127 | parse_numeric <- TRUE | |
128 | for (i in 1:n) { | |
129 | if (!is.list(votes[[i]])) | |
130 | votes[[i]] <- list() | |
131 | index <- as.character(predictions[i]) | |
132 | if (is.null(votes[[i]][[index]])) | |
133 | votes[[i]][[index]] <- 0 | |
134 | votes[[i]][[index]] <- votes[[i]][[index]] + weights[v] | |
135 | } | |
136 | } | |
137 | } | |
138 | if (private$task == "regression") | |
139 | return (rowSums(preds)) | |
140 | res <- c() | |
141 | for (i in 1:n) { | |
142 | # TODO: if ex-aequos, random choice... | |
143 | ind_max <- which.max(unlist(votes[[i]])) | |
144 | pred_class <- names(votes[[i]])[ind_max] | |
145 | if (parse_numeric) | |
146 | pred_class <- as.numeric(pred_class) | |
147 | res <- c(res, pred_class) | |
148 | } | |
149 | res | |
150 | } | |
151 | ), | |
152 | private = list( | |
d9a139b5 BA |
153 | data = NULL, |
154 | target = NULL, | |
155 | task = NULL, | |
156 | gmodel = NULL, | |
157 | quality = NULL, | |
158 | run_res = NULL, | |
c5946158 BA |
159 | get_testIndices = function(CV, v, n, shuffle_inds) { |
160 | if (CV$type == "vfold") { | |
161 | first_index = round((v-1) * n / CV$V) + 1 | |
162 | last_index = round(v * n / CV$V) | |
163 | test_indices = first_index:last_index | |
164 | if (CV$shuffle) | |
165 | test_indices <- shuffle_inds[test_indices] | |
166 | } | |
167 | else | |
168 | test_indices = sample(n, round(n * CV$test_size)) | |
169 | test_indices | |
170 | }, | |
171 | get_modelPerf = function(test_indices, p=0) { | |
172 | getOnePerf <- function(p) { | |
173 | model_pred <- private$gmodel$get(dataHO, targetHO, p) | |
174 | prediction <- model_pred(testX) | |
175 | perf <- private$quality(prediction, testY) | |
176 | list(model=model_pred, perf=perf) | |
177 | } | |
178 | dataHO <- private$data[-test_indices,] | |
179 | testX <- private$data[test_indices,] | |
180 | targetHO <- private$target[-test_indices] | |
181 | testY <- private$target[test_indices] | |
d9a139b5 BA |
182 | # R will cast 1-dim matrices into vectors: |
183 | if (!is.matrix(dataHO) && !is.data.frame(dataHO)) | |
184 | dataHO <- as.matrix(dataHO) | |
185 | if (!is.matrix(testX) && !is.data.frame(testX)) | |
186 | testX <- as.matrix(testX) | |
c5946158 BA |
187 | if (p >= 1) |
188 | # Standard CV: one model at a time | |
189 | return (getOnePerf(p)$perf) | |
190 | # Agghoo: loop on all models | |
191 | best_model = NULL | |
192 | best_perf <- -1 | |
193 | for (p in 1:private$gmodel$nmodels) { | |
194 | model_perf <- getOnePerf(p) | |
195 | if (model_perf$perf > best_perf) { | |
196 | # TODO: if ex-aequos: models list + choose at random | |
197 | best_model <- model_perf$model | |
198 | best_perf <- model_perf$perf | |
199 | } | |
200 | } | |
201 | list(model=best_model, perf=best_perf) | |
202 | } | |
203 | ) | |
204 | ) |