Reorganize code - unfinished: some functions not exported yet
[agghoo.git] / R / R6_AgghooCV.R
1 #' @title R6 class with agghoo functions fit() and predict().
2 #'
3 #' @description
4 #' Class encapsulating the methods to run to obtain the best predictor
5 #' from the list of models (see 'Model' class).
6 #'
7 #' @importFrom R6 R6Class
8 #'
9 #' @export
10 AgghooCV <- R6::R6Class("AgghooCV",
11 public = list(
12 #' @description Create a new AgghooCV object.
13 #' @param data Matrix or data.frame
14 #' @param target Vector of targets (generally numeric or factor)
15 #' @param task "regression" or "classification"
16 #' @param gmodel Generic model returning a predictive function
17 #' @param loss Function assessing the error of a prediction
18 initialize = function(data, target, task, gmodel, loss) {
19 private$data <- data
20 private$target <- target
21 private$task <- task
22 private$gmodel <- gmodel
23 private$loss <- loss
24 },
25 #' @description Fit an agghoo model.
26 #' @param CV List describing cross-validation to run. Slots: \cr
27 #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC) \cr
28 #' - V: number of runs (default: 10) \cr
29 #' - test_size: percentage of data in the test dataset, for MC
30 #' (irrelevant for V-fold). Default: 0.2. \cr
31 #' - shuffle: wether or not to shuffle data before V-fold.
32 #' Irrelevant for Monte-Carlo; default: TRUE \cr
33 #' Default (if NULL): type="MC", V=10, test_size=0.2
34 fit = function(CV = NULL) {
35 CV <- checkCV(CV)
36 n <- nrow(private$data)
37 shuffle_inds <- NULL
38 if (CV$type == "vfold" && CV$shuffle)
39 shuffle_inds <- sample(n, n)
40 # Result: list of V predictive models (+ parameters for info)
41 private$pmodels <- list()
42 for (v in seq_len(CV$V)) {
43 # Prepare train / test data and target, from full dataset.
44 # dataHO: "data Hold-Out" etc.
45 test_indices <- get_testIndices(n, CV, v, shuffle_inds)
46 d <- splitTrainTest(private$data, private$target, test_indices)
47 best_model <- NULL
48 best_error <- Inf
49 for (p in seq_len(private$gmodel$nmodels)) {
50 model_pred <- private$gmodel$get(d$dataTrain, d$targetTrain, p)
51 prediction <- model_pred(d$dataTest)
52 error <- private$loss(prediction, d$targetTest)
53 if (error <= best_error) {
54 newModel <- list(model=model_pred, param=private$gmodel$getParam(p))
55 if (error == best_error)
56 best_model[[length(best_model)+1]] <- newModel
57 else {
58 best_model <- list(newModel)
59 best_error <- error
60 }
61 }
62 }
63 # Choose a model at random in case of ex-aequos
64 private$pmodels[[v]] <- best_model[[ sample(length(best_model),1) ]]
65 }
66 },
67 #' @description Predict an agghoo model (after calling fit())
68 #' @param X Matrix or data.frame to predict
69 predict = function(X) {
70 if (!is.matrix(X) && !is.data.frame(X))
71 stop("X: matrix or data.frame")
72 if (!is.list(private$pmodels)) {
73 print("Please call $fit() method first")
74 return (invisible(NULL))
75 }
76 V <- length(private$pmodels)
77 oneLineX <- X[1,]
78 if (is.matrix(X))
79 # HACK: R behaves differently with data frames and matrices.
80 oneLineX <- t(as.matrix(oneLineX))
81 if (length(private$pmodels[[1]]$model(oneLineX)) >= 2)
82 # Soft classification:
83 return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V)
84 n <- nrow(X)
85 all_predictions <- as.data.frame(matrix(nrow=n, ncol=V))
86 for (v in 1:V)
87 all_predictions[,v] <- private$pmodels[[v]]$model(X)
88 if (private$task == "regression")
89 # Easy case: just average each row
90 return (rowMeans(all_predictions))
91 # "Hard" classification:
92 apply(all_predictions, 1, function(row) {
93 t <- table(row)
94 # Next lines in case of ties (broken at random)
95 tmax <- max(t)
96 sample( names(t)[which(t == tmax)], 1 )
97 })
98 },
99 #' @description Return the list of V best parameters (after calling fit())
100 getParams = function() {
101 lapply(private$pmodels, function(m) m$param)
102 }
103 ),
104 private = list(
105 data = NULL,
106 target = NULL,
107 task = NULL,
108 gmodel = NULL,
109 loss = NULL,
110 pmodels = NULL
111 )
112 )