X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=R%2FR6_AgghooCV.R;h=dba42b4edc450a2000c11929504a8b33be12182c;hp=4ceb2899bc2e71fcb1d029308af2a2a84718e082;hb=d9a139b51ee2e71e13d67cb9d530834b15058617;hpb=cca5f1c67bd622fb7bc1279dfe4c3336d1446efd diff --git a/R/R6_AgghooCV.R b/R/R6_AgghooCV.R index 4ceb289..dba42b4 100644 --- a/R/R6_AgghooCV.R +++ b/R/R6_AgghooCV.R @@ -4,6 +4,8 @@ #' Class encapsulating the methods to run to obtain the best predictor #' from the list of models (see 'Model' class). #' +#' @importFrom R6 R6Class +#' #' @export AgghooCV <- R6::R6Class("AgghooCV", public = list( @@ -14,12 +16,12 @@ AgghooCV <- R6::R6Class("AgghooCV", #' @param gmodel Generic model returning a predictive function #' @param quality Function assessing the quality of a prediction; #' quality(y1, y2) --> real number - initialize = function(data, target, task, gmodel, quality = NA) { + initialize = function(data, target, task, gmodel, quality = NULL) { private$data <- data private$target <- target private$task <- task private$gmodel <- gmodel - if (is.na(quality)) { + if (is.null(quality)) { quality <- function(y1, y2) { # NOTE: if classif output is a probability matrix, adapt. if (task == "classification") @@ -87,7 +89,9 @@ AgghooCV <- R6::R6Class("AgghooCV", #' @param weight "uniform" (default) or "quality" to weight votes or #' average models performances (TODO: bad idea?!) predict = function(X, weight="uniform") { - if (!is.list(private$run_res) || is.na(private$run_res)) { + if (!is.matrix(X) && !is.data.frame(X)) + stop("X: matrix or data.frame") + if (!is.list(private$run_res)) { print("Please call $fit() method first") return } @@ -146,12 +150,12 @@ AgghooCV <- R6::R6Class("AgghooCV", } ), private = list( - data = NA, - target = NA, - task = NA, - gmodel = NA, - quality = NA, - run_res = NA, + data = NULL, + target = NULL, + task = NULL, + gmodel = NULL, + quality = NULL, + run_res = NULL, get_testIndices = function(CV, v, n, shuffle_inds) { if (CV$type == "vfold") { first_index = round((v-1) * n / CV$V) + 1 @@ -175,6 +179,11 @@ AgghooCV <- R6::R6Class("AgghooCV", testX <- private$data[test_indices,] targetHO <- private$target[-test_indices] testY <- private$target[test_indices] + # R will cast 1-dim matrices into vectors: + if (!is.matrix(dataHO) && !is.data.frame(dataHO)) + dataHO <- as.matrix(dataHO) + if (!is.matrix(testX) && !is.data.frame(testX)) + testX <- as.matrix(testX) if (p >= 1) # Standard CV: one model at a time return (getOnePerf(p)$perf)