Update in progress - unfinished
[agghoo.git] / R / R6_AgghooCV.R
index 9cdf19e..ed9aa5c 100644 (file)
@@ -25,11 +25,11 @@ AgghooCV <- R6::R6Class("AgghooCV",
       private$loss <- loss
     },
     #' @description Fit an agghoo model.
-    #' @param CV List describing cross-validation to run. Slots:
-    #'          - type: 'vfold' or 'MC' for Monte-Carlo (default: MC)
-    #'          - V: number of runs (default: 10)
+    #' @param CV List describing cross-validation to run. Slots: \cr
+    #'          - type: 'vfold' or 'MC' for Monte-Carlo (default: MC) \cr
+    #'          - V: number of runs (default: 10) \cr
     #'          - test_size: percentage of data in the test dataset, for MC
-    #'            (irrelevant for V-fold). Default: 0.2.
+    #'            (irrelevant for V-fold). Default: 0.2. \cr
     #'          - shuffle: wether or not to shuffle data before V-fold.
     #'            Irrelevant for Monte-Carlo; default: TRUE
     fit = function(
@@ -89,7 +89,8 @@ AgghooCV <- R6::R6Class("AgghooCV",
         return (invisible(NULL))
       }
       V <- length(private$pmodels)
-      if (length(private$pmodels[[1]]$model(X[1,])) >= 2)
+      oneLineX <- t(as.matrix(X[1,]))
+      if (length(private$pmodels[[1]]$model(oneLineX)) >= 2)
         # Soft classification:
         return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V)
       n <- nrow(X)
@@ -98,7 +99,7 @@ AgghooCV <- R6::R6Class("AgghooCV",
         all_predictions[,v] <- private$pmodels[[v]]$model(X)
       if (private$task == "regression")
         # Easy case: just average each row
-        rowSums(all_predictions)
+        return (rowMeans(all_predictions))
       # "Hard" classification:
       apply(all_predictions, 1, function(row) {
         t <- table(row)