X-Git-Url: https://git.auder.net/css/vendor/current/gitweb.js?a=blobdiff_plain;f=R%2FR6_AgghooCV.R;h=81ddbe1b33ce563ebb2b92eb41c92a4573629e0c;hb=c152ea666f61e36b095bf8b42ab99efe9eab2dba;hp=9cdf19ed5081aec44fc3e4ceb9a72dcb7db306b3;hpb=504afaadc783916dc126fb87ab9e067f302eb2c5;p=agghoo.git diff --git a/R/R6_AgghooCV.R b/R/R6_AgghooCV.R index 9cdf19e..81ddbe1 100644 --- a/R/R6_AgghooCV.R +++ b/R/R6_AgghooCV.R @@ -89,7 +89,8 @@ AgghooCV <- R6::R6Class("AgghooCV", return (invisible(NULL)) } V <- length(private$pmodels) - if (length(private$pmodels[[1]]$model(X[1,])) >= 2) + oneLineX <- t(as.matrix(X[1,])) + if (length(private$pmodels[[1]]$model(oneLineX)) >= 2) # Soft classification: return (Reduce("+", lapply(private$pmodels, function(m) m$model(X))) / V) n <- nrow(X) @@ -98,7 +99,7 @@ AgghooCV <- R6::R6Class("AgghooCV", all_predictions[,v] <- private$pmodels[[v]]$model(X) if (private$task == "regression") # Easy case: just average each row - rowSums(all_predictions) + return (rowMeans(all_predictions)) # "Hard" classification: apply(all_predictions, 1, function(row) { t <- table(row)