Preparing for CRAN upload
[agghoo.git] / R / R6_AgghooCV.R
CommitLineData
c5946158
BA
1#' @title R6 class with agghoo functions fit() and predict().
2#'
3#' @description
4#' Class encapsulating the methods to run to obtain the best predictor
5#' from the list of models (see 'Model' class).
6#'
d9a139b5
BA
7#' @importFrom R6 R6Class
8#'
c5946158 9#' @export
cca5f1c6 10AgghooCV <- R6::R6Class("AgghooCV",
c5946158 11 public = list(
cca5f1c6 12 #' @description Create a new AgghooCV object.
c5946158
BA
13 #' @param data Matrix or data.frame
14 #' @param target Vector of targets (generally numeric or factor)
a7ec4f8a
BA
15 #' @param task "regression" or "classification".
16 #' Default: classification if target not numeric.
c5946158 17 #' @param gmodel Generic model returning a predictive function
a7ec4f8a 18 #' Default: tree if mixed data, knn/ppr otherwise.
504afaad 19 #' @param loss Function assessing the error of a prediction
a7ec4f8a 20 #' Default: error rate or mean(abs(error)).
afa67660 21 initialize = function(data, target, task, gmodel, loss) {
c5946158
BA
22 private$data <- data
23 private$target <- target
24 private$task <- task
25 private$gmodel <- gmodel
504afaad 26 private$loss <- loss
c5946158
BA
27 },
28 #' @description Fit an agghoo model.
43a6578d
BA
29 #' @param CV List describing cross-validation to run. Slots: \cr
30 #' - type: 'vfold' or 'MC' for Monte-Carlo (default: MC) \cr
31 #' - V: number of runs (default: 10) \cr
c5946158 32 #' - test_size: percentage of data in the test dataset, for MC
43a6578d 33 #' (irrelevant for V-fold). Default: 0.2. \cr
c5946158 34 #' - shuffle: wether or not to shuffle data before V-fold.
afa67660
BA
35 #' Irrelevant for Monte-Carlo; default: TRUE \cr
36 #' Default (if NULL): type="MC", V=10, test_size=0.2
37 fit = function(CV = NULL) {
38 CV <- checkCV(CV)
c5946158 39 n <- nrow(private$data)
504afaad 40 shuffle_inds <- NULL
c5946158
BA
41 if (CV$type == "vfold" && CV$shuffle)
42 shuffle_inds <- sample(n, n)
504afaad
BA
43 # Result: list of V predictive models (+ parameters for info)
44 private$pmodels <- list()
45 for (v in seq_len(CV$V)) {
46 # Prepare train / test data and target, from full dataset.
47 # dataHO: "data Hold-Out" etc.
afa67660
BA
48 test_indices <- get_testIndices(n, CV, v, shuffle_inds)
49 d <- splitTrainTest(private$data, private$target, test_indices)
504afaad
BA
50 best_model <- NULL
51 best_error <- Inf
52 for (p in seq_len(private$gmodel$nmodels)) {
afa67660
BA
53 model_pred <- private$gmodel$get(d$dataTrain, d$targetTrain, p)
54 prediction <- model_pred(d$dataTest)
55 error <- private$loss(prediction, d$targetTest)
504afaad
BA
56 if (error <= best_error) {
57 newModel <- list(model=model_pred, param=private$gmodel$getParam(p))
58 if (error == best_error)
59 best_model[[length(best_model)+1]] <- newModel
60 else {
61 best_model <- list(newModel)
62 best_error <- error
63 }
c5946158
BA
64 }
65 }
504afaad
BA
66 # Choose a model at random in case of ex-aequos
67 private$pmodels[[v]] <- best_model[[ sample(length(best_model),1) ]]
c5946158
BA
68 }
69 },
70 #' @description Predict an agghoo model (after calling fit())
71 #' @param X Matrix or data.frame to predict
504afaad 72 predict = function(X) {
d9a139b5
BA
73 if (!is.matrix(X) && !is.data.frame(X))
74 stop("X: matrix or data.frame")
504afaad 75 if (!is.list(private$pmodels)) {
c5946158 76 print("Please call $fit() method first")
504afaad 77 return (invisible(NULL))
c5946158 78 }
504afaad 79 V <- length(private$pmodels)
afa67660
BA
80 oneLineX <- X[1,]
81 if (is.matrix(X))
82 # HACK: R behaves differently with data frames and matrices.
83 oneLineX <- t(as.matrix(oneLineX))
7b5193cd 84 if (length(private$pmodels[[1]]$model(oneLineX)) >= 2)
504afaad
BA
85 # Soft classification:
86 return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V)
c5946158 87 n <- nrow(X)
504afaad
BA
88 all_predictions <- as.data.frame(matrix(nrow=n, ncol=V))
89 for (v in 1:V)
90 all_predictions[,v] <- private$pmodels[[v]]$model(X)
c5946158 91 if (private$task == "regression")
504afaad 92 # Easy case: just average each row
1fdc3c34 93 return (rowMeans(all_predictions))
504afaad
BA
94 # "Hard" classification:
95 apply(all_predictions, 1, function(row) {
96 t <- table(row)
97 # Next lines in case of ties (broken at random)
98 tmax <- max(t)
99 sample( names(t)[which(t == tmax)], 1 )
100 })
101 },
102 #' @description Return the list of V best parameters (after calling fit())
103 getParams = function() {
104 lapply(private$pmodels, function(m) m$param)
c5946158
BA
105 }
106 ),
107 private = list(
d9a139b5
BA
108 data = NULL,
109 target = NULL,
110 task = NULL,
111 gmodel = NULL,
504afaad 112 loss = NULL,
afa67660 113 pmodels = NULL
c5946158
BA
114 )
115)