Commit | Line | Data |
---|---|---|
c5946158 BA |
1 | #' @title R6 class with agghoo functions fit() and predict(). |
2 | #' | |
3 | #' @description | |
4 | #' Class encapsulating the methods to run to obtain the best predictor | |
5 | #' from the list of models (see 'Model' class). | |
6 | #' | |
d9a139b5 BA |
7 | #' @importFrom R6 R6Class |
8 | #' | |
c5946158 | 9 | #' @export |
cca5f1c6 | 10 | AgghooCV <- R6::R6Class("AgghooCV", |
c5946158 | 11 | public = list( |
cca5f1c6 | 12 | #' @description Create a new AgghooCV object. |
c5946158 BA |
13 | #' @param data Matrix or data.frame |
14 | #' @param target Vector of targets (generally numeric or factor) | |
15 | #' @param task "regression" or "classification" | |
16 | #' @param gmodel Generic model returning a predictive function | |
504afaad BA |
17 | #' @param loss Function assessing the error of a prediction |
18 | initialize = function(data, target, task, gmodel, loss = NULL) { | |
c5946158 BA |
19 | private$data <- data |
20 | private$target <- target | |
21 | private$task <- task | |
22 | private$gmodel <- gmodel | |
504afaad BA |
23 | if (is.null(loss)) |
24 | loss <- private$defaultLoss | |
25 | private$loss <- loss | |
c5946158 BA |
26 | }, |
27 | #' @description Fit an agghoo model. | |
28 | #' @param CV List describing cross-validation to run. Slots: | |
29 | #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC) | |
30 | #' - V: number of runs (default: 10) | |
31 | #' - test_size: percentage of data in the test dataset, for MC | |
32 | #' (irrelevant for V-fold). Default: 0.2. | |
33 | #' - shuffle: wether or not to shuffle data before V-fold. | |
34 | #' Irrelevant for Monte-Carlo; default: TRUE | |
c5946158 BA |
35 | fit = function( |
36 | CV = list(type = "MC", | |
37 | V = 10, | |
38 | test_size = 0.2, | |
504afaad | 39 | shuffle = TRUE) |
c5946158 BA |
40 | ) { |
41 | if (!is.list(CV)) | |
42 | stop("CV: list of type, V, [test_size], [shuffle]") | |
43 | n <- nrow(private$data) | |
504afaad | 44 | shuffle_inds <- NULL |
c5946158 BA |
45 | if (CV$type == "vfold" && CV$shuffle) |
46 | shuffle_inds <- sample(n, n) | |
504afaad BA |
47 | # Result: list of V predictive models (+ parameters for info) |
48 | private$pmodels <- list() | |
49 | for (v in seq_len(CV$V)) { | |
50 | # Prepare train / test data and target, from full dataset. | |
51 | # dataHO: "data Hold-Out" etc. | |
52 | test_indices <- private$get_testIndices(CV, v, n, shuffle_inds) | |
53 | dataHO <- private$data[-test_indices,] | |
54 | testX <- private$data[test_indices,] | |
55 | targetHO <- private$target[-test_indices] | |
56 | testY <- private$target[test_indices] | |
57 | # [HACK] R will cast 1-dim matrices into vectors: | |
58 | if (!is.matrix(dataHO) && !is.data.frame(dataHO)) | |
59 | dataHO <- as.matrix(dataHO) | |
60 | if (!is.matrix(testX) && !is.data.frame(testX)) | |
61 | testX <- as.matrix(testX) | |
62 | best_model <- NULL | |
63 | best_error <- Inf | |
64 | for (p in seq_len(private$gmodel$nmodels)) { | |
65 | model_pred <- private$gmodel$get(dataHO, targetHO, p) | |
66 | prediction <- model_pred(testX) | |
67 | error <- private$loss(prediction, testY) | |
68 | if (error <= best_error) { | |
69 | newModel <- list(model=model_pred, param=private$gmodel$getParam(p)) | |
70 | if (error == best_error) | |
71 | best_model[[length(best_model)+1]] <- newModel | |
72 | else { | |
73 | best_model <- list(newModel) | |
74 | best_error <- error | |
75 | } | |
c5946158 BA |
76 | } |
77 | } | |
504afaad BA |
78 | # Choose a model at random in case of ex-aequos |
79 | private$pmodels[[v]] <- best_model[[ sample(length(best_model),1) ]] | |
c5946158 BA |
80 | } |
81 | }, | |
82 | #' @description Predict an agghoo model (after calling fit()) | |
83 | #' @param X Matrix or data.frame to predict | |
504afaad | 84 | predict = function(X) { |
d9a139b5 BA |
85 | if (!is.matrix(X) && !is.data.frame(X)) |
86 | stop("X: matrix or data.frame") | |
504afaad | 87 | if (!is.list(private$pmodels)) { |
c5946158 | 88 | print("Please call $fit() method first") |
504afaad | 89 | return (invisible(NULL)) |
c5946158 | 90 | } |
504afaad BA |
91 | V <- length(private$pmodels) |
92 | if (length(private$pmodels[[1]]$model(X[1,])) >= 2) | |
93 | # Soft classification: | |
94 | return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V) | |
c5946158 | 95 | n <- nrow(X) |
504afaad BA |
96 | all_predictions <- as.data.frame(matrix(nrow=n, ncol=V)) |
97 | for (v in 1:V) | |
98 | all_predictions[,v] <- private$pmodels[[v]]$model(X) | |
c5946158 | 99 | if (private$task == "regression") |
504afaad BA |
100 | # Easy case: just average each row |
101 | rowSums(all_predictions) | |
102 | # "Hard" classification: | |
103 | apply(all_predictions, 1, function(row) { | |
104 | t <- table(row) | |
105 | # Next lines in case of ties (broken at random) | |
106 | tmax <- max(t) | |
107 | sample( names(t)[which(t == tmax)], 1 ) | |
108 | }) | |
109 | }, | |
110 | #' @description Return the list of V best parameters (after calling fit()) | |
111 | getParams = function() { | |
112 | lapply(private$pmodels, function(m) m$param) | |
c5946158 BA |
113 | } |
114 | ), | |
115 | private = list( | |
d9a139b5 BA |
116 | data = NULL, |
117 | target = NULL, | |
118 | task = NULL, | |
119 | gmodel = NULL, | |
504afaad BA |
120 | loss = NULL, |
121 | pmodels = NULL, | |
c5946158 BA |
122 | get_testIndices = function(CV, v, n, shuffle_inds) { |
123 | if (CV$type == "vfold") { | |
504afaad | 124 | # Slice indices (optionnally shuffled) |
c5946158 BA |
125 | first_index = round((v-1) * n / CV$V) + 1 |
126 | last_index = round(v * n / CV$V) | |
127 | test_indices = first_index:last_index | |
504afaad | 128 | if (!is.null(shuffle_inds)) |
c5946158 BA |
129 | test_indices <- shuffle_inds[test_indices] |
130 | } | |
131 | else | |
504afaad | 132 | # Monte-Carlo cross-validation |
c5946158 BA |
133 | test_indices = sample(n, round(n * CV$test_size)) |
134 | test_indices | |
135 | }, | |
504afaad BA |
136 | defaultLoss = function(y1, y2) { |
137 | if (private$task == "classification") { | |
138 | if (is.null(dim(y1))) | |
139 | # Standard case: "hard" classification | |
140 | mean(y1 != y2) | |
141 | else { | |
142 | # "Soft" classification: predict() outputs a probability matrix | |
143 | # In this case "target" could be in matrix form. | |
144 | if (!is.null(dim(y2))) | |
145 | mean(rowSums(abs(y1 - y2))) | |
146 | else { | |
147 | # Or not: y2 is a "factor". | |
148 | y2 <- as.character(y2) | |
149 | # NOTE: the user should provide target in matrix form because | |
150 | # matching y2 with columns is rather inefficient! | |
151 | names <- colnames(y1) | |
152 | positions <- list() | |
153 | for (idx in seq_along(names)) | |
154 | positions[[ names[idx] ]] <- idx | |
155 | mean(vapply( | |
156 | seq_along(y2), | |
157 | function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])), | |
158 | 0)) | |
159 | } | |
c5946158 BA |
160 | } |
161 | } | |
504afaad BA |
162 | else |
163 | # Regression | |
164 | mean(abs(y1 - y2)) | |
c5946158 BA |
165 | } |
166 | ) | |
167 | ) |