Adjustments / fixes... And add knn for regression
[agghoo.git] / R / R6_AgghooCV.R
index 4ceb289..dba42b4 100644 (file)
@@ -4,6 +4,8 @@
 #' Class encapsulating the methods to run to obtain the best predictor
 #' from the list of models (see 'Model' class).
 #'
+#' @importFrom R6 R6Class
+#'
 #' @export
 AgghooCV <- R6::R6Class("AgghooCV",
   public = list(
@@ -14,12 +16,12 @@ AgghooCV <- R6::R6Class("AgghooCV",
     #' @param gmodel Generic model returning a predictive function
     #' @param quality Function assessing the quality of a prediction;
     #'                quality(y1, y2) --> real number
-    initialize = function(data, target, task, gmodel, quality = NA) {
+    initialize = function(data, target, task, gmodel, quality = NULL) {
       private$data <- data
       private$target <- target
       private$task <- task
       private$gmodel <- gmodel
-      if (is.na(quality)) {
+      if (is.null(quality)) {
         quality <- function(y1, y2) {
           # NOTE: if classif output is a probability matrix, adapt.
           if (task == "classification")
@@ -87,7 +89,9 @@ AgghooCV <- R6::R6Class("AgghooCV",
     #' @param weight "uniform" (default) or "quality" to weight votes or
     #'               average models performances (TODO: bad idea?!)
     predict = function(X, weight="uniform") {
-      if (!is.list(private$run_res) || is.na(private$run_res)) {
+      if (!is.matrix(X) && !is.data.frame(X))
+        stop("X: matrix or data.frame")
+      if (!is.list(private$run_res)) {
         print("Please call $fit() method first")
         return
       }
@@ -146,12 +150,12 @@ AgghooCV <- R6::R6Class("AgghooCV",
     }
   ),
   private = list(
-    data = NA,
-    target = NA,
-    task = NA,
-    gmodel = NA,
-    quality = NA,
-    run_res = NA,
+    data = NULL,
+    target = NULL,
+    task = NULL,
+    gmodel = NULL,
+    quality = NULL,
+    run_res = NULL,
     get_testIndices = function(CV, v, n, shuffle_inds) {
       if (CV$type == "vfold") {
         first_index = round((v-1) * n / CV$V) + 1
@@ -175,6 +179,11 @@ AgghooCV <- R6::R6Class("AgghooCV",
       testX <- private$data[test_indices,]
       targetHO <- private$target[-test_indices]
       testY <- private$target[test_indices]
+      # R will cast 1-dim matrices into vectors:
+      if (!is.matrix(dataHO) && !is.data.frame(dataHO))
+        dataHO <- as.matrix(dataHO)
+      if (!is.matrix(testX) && !is.data.frame(testX))
+        testX <- as.matrix(testX)
       if (p >= 1)
         # Standard CV: one model at a time
         return (getOnePerf(p)$perf)