Commit | Line | Data |
---|---|---|
c5946158 BA |
1 | #' @title R6 class with agghoo functions fit() and predict(). |
2 | #' | |
3 | #' @description | |
4 | #' Class encapsulating the methods to run to obtain the best predictor | |
5 | #' from the list of models (see 'Model' class). | |
6 | #' | |
d9a139b5 BA |
7 | #' @importFrom R6 R6Class |
8 | #' | |
c5946158 | 9 | #' @export |
cca5f1c6 | 10 | AgghooCV <- R6::R6Class("AgghooCV", |
c5946158 | 11 | public = list( |
cca5f1c6 | 12 | #' @description Create a new AgghooCV object. |
c5946158 BA |
13 | #' @param data Matrix or data.frame |
14 | #' @param target Vector of targets (generally numeric or factor) | |
15 | #' @param task "regression" or "classification" | |
16 | #' @param gmodel Generic model returning a predictive function | |
504afaad | 17 | #' @param loss Function assessing the error of a prediction |
afa67660 | 18 | initialize = function(data, target, task, gmodel, loss) { |
c5946158 BA |
19 | private$data <- data |
20 | private$target <- target | |
21 | private$task <- task | |
22 | private$gmodel <- gmodel | |
504afaad | 23 | private$loss <- loss |
c5946158 BA |
24 | }, |
25 | #' @description Fit an agghoo model. | |
43a6578d BA |
26 | #' @param CV List describing cross-validation to run. Slots: \cr |
27 | #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC) \cr | |
28 | #' - V: number of runs (default: 10) \cr | |
c5946158 | 29 | #' - test_size: percentage of data in the test dataset, for MC |
43a6578d | 30 | #' (irrelevant for V-fold). Default: 0.2. \cr |
c5946158 | 31 | #' - shuffle: wether or not to shuffle data before V-fold. |
afa67660 BA |
32 | #' Irrelevant for Monte-Carlo; default: TRUE \cr |
33 | #' Default (if NULL): type="MC", V=10, test_size=0.2 | |
34 | fit = function(CV = NULL) { | |
35 | CV <- checkCV(CV) | |
c5946158 | 36 | n <- nrow(private$data) |
504afaad | 37 | shuffle_inds <- NULL |
c5946158 BA |
38 | if (CV$type == "vfold" && CV$shuffle) |
39 | shuffle_inds <- sample(n, n) | |
504afaad BA |
40 | # Result: list of V predictive models (+ parameters for info) |
41 | private$pmodels <- list() | |
42 | for (v in seq_len(CV$V)) { | |
43 | # Prepare train / test data and target, from full dataset. | |
44 | # dataHO: "data Hold-Out" etc. | |
afa67660 BA |
45 | test_indices <- get_testIndices(n, CV, v, shuffle_inds) |
46 | d <- splitTrainTest(private$data, private$target, test_indices) | |
504afaad BA |
47 | best_model <- NULL |
48 | best_error <- Inf | |
49 | for (p in seq_len(private$gmodel$nmodels)) { | |
afa67660 BA |
50 | model_pred <- private$gmodel$get(d$dataTrain, d$targetTrain, p) |
51 | prediction <- model_pred(d$dataTest) | |
52 | error <- private$loss(prediction, d$targetTest) | |
504afaad BA |
53 | if (error <= best_error) { |
54 | newModel <- list(model=model_pred, param=private$gmodel$getParam(p)) | |
55 | if (error == best_error) | |
56 | best_model[[length(best_model)+1]] <- newModel | |
57 | else { | |
58 | best_model <- list(newModel) | |
59 | best_error <- error | |
60 | } | |
c5946158 BA |
61 | } |
62 | } | |
504afaad BA |
63 | # Choose a model at random in case of ex-aequos |
64 | private$pmodels[[v]] <- best_model[[ sample(length(best_model),1) ]] | |
c5946158 BA |
65 | } |
66 | }, | |
67 | #' @description Predict an agghoo model (after calling fit()) | |
68 | #' @param X Matrix or data.frame to predict | |
504afaad | 69 | predict = function(X) { |
d9a139b5 BA |
70 | if (!is.matrix(X) && !is.data.frame(X)) |
71 | stop("X: matrix or data.frame") | |
504afaad | 72 | if (!is.list(private$pmodels)) { |
c5946158 | 73 | print("Please call $fit() method first") |
504afaad | 74 | return (invisible(NULL)) |
c5946158 | 75 | } |
504afaad | 76 | V <- length(private$pmodels) |
afa67660 BA |
77 | oneLineX <- X[1,] |
78 | if (is.matrix(X)) | |
79 | # HACK: R behaves differently with data frames and matrices. | |
80 | oneLineX <- t(as.matrix(oneLineX)) | |
7b5193cd | 81 | if (length(private$pmodels[[1]]$model(oneLineX)) >= 2) |
504afaad BA |
82 | # Soft classification: |
83 | return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V) | |
c5946158 | 84 | n <- nrow(X) |
504afaad BA |
85 | all_predictions <- as.data.frame(matrix(nrow=n, ncol=V)) |
86 | for (v in 1:V) | |
87 | all_predictions[,v] <- private$pmodels[[v]]$model(X) | |
c5946158 | 88 | if (private$task == "regression") |
504afaad | 89 | # Easy case: just average each row |
1fdc3c34 | 90 | return (rowMeans(all_predictions)) |
504afaad BA |
91 | # "Hard" classification: |
92 | apply(all_predictions, 1, function(row) { | |
93 | t <- table(row) | |
94 | # Next lines in case of ties (broken at random) | |
95 | tmax <- max(t) | |
96 | sample( names(t)[which(t == tmax)], 1 ) | |
97 | }) | |
98 | }, | |
99 | #' @description Return the list of V best parameters (after calling fit()) | |
100 | getParams = function() { | |
101 | lapply(private$pmodels, function(m) m$param) | |
c5946158 BA |
102 | } |
103 | ), | |
104 | private = list( | |
d9a139b5 BA |
105 | data = NULL, |
106 | target = NULL, | |
107 | task = NULL, | |
108 | gmodel = NULL, | |
504afaad | 109 | loss = NULL, |
afa67660 | 110 | pmodels = NULL |
c5946158 BA |
111 | ) |
112 | ) |