Adjustments / fixes... And add knn for regression
authorBenjamin Auder <benjamin.auder@somewhere>
Tue, 8 Jun 2021 15:54:59 +0000 (17:54 +0200)
committerBenjamin Auder <benjamin.auder@somewhere>
Tue, 8 Jun 2021 15:54:59 +0000 (17:54 +0200)
DESCRIPTION
NAMESPACE
R/R6_AgghooCV.R
R/R6_Model.R
R/agghoo.R
man/AgghooCV.Rd
man/Model.Rd
man/agghoo.Rd
man/compareToStandard.Rd

index 7944819..2689a20 100644 (file)
@@ -20,7 +20,8 @@ Imports:
     R6,
     caret,
     rpart,
-    randomForest
+    randomForest,
+    FNN
 Suggests:
     roxygen2
 URL: https://git.auder.net/?p=agghoo.git
index 14fb7a3..63138fb 100644 (file)
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -4,3 +4,10 @@ export(AgghooCV)
 export(Model)
 export(agghoo)
 export(compareToStandard)
+importFrom(FNN,knn.reg)
+importFrom(R6,R6Class)
+importFrom(caret,var_seq)
+importFrom(class,knn)
+importFrom(randomForest,randomForest)
+importFrom(rpart,rpart)
+importFrom(stats,ppr)
index 4ceb289..dba42b4 100644 (file)
@@ -4,6 +4,8 @@
 #' Class encapsulating the methods to run to obtain the best predictor
 #' from the list of models (see 'Model' class).
 #'
+#' @importFrom R6 R6Class
+#'
 #' @export
 AgghooCV <- R6::R6Class("AgghooCV",
   public = list(
@@ -14,12 +16,12 @@ AgghooCV <- R6::R6Class("AgghooCV",
     #' @param gmodel Generic model returning a predictive function
     #' @param quality Function assessing the quality of a prediction;
     #'                quality(y1, y2) --> real number
-    initialize = function(data, target, task, gmodel, quality = NA) {
+    initialize = function(data, target, task, gmodel, quality = NULL) {
       private$data <- data
       private$target <- target
       private$task <- task
       private$gmodel <- gmodel
-      if (is.na(quality)) {
+      if (is.null(quality)) {
         quality <- function(y1, y2) {
           # NOTE: if classif output is a probability matrix, adapt.
           if (task == "classification")
@@ -87,7 +89,9 @@ AgghooCV <- R6::R6Class("AgghooCV",
     #' @param weight "uniform" (default) or "quality" to weight votes or
     #'               average models performances (TODO: bad idea?!)
     predict = function(X, weight="uniform") {
-      if (!is.list(private$run_res) || is.na(private$run_res)) {
+      if (!is.matrix(X) && !is.data.frame(X))
+        stop("X: matrix or data.frame")
+      if (!is.list(private$run_res)) {
         print("Please call $fit() method first")
         return
       }
@@ -146,12 +150,12 @@ AgghooCV <- R6::R6Class("AgghooCV",
     }
   ),
   private = list(
-    data = NA,
-    target = NA,
-    task = NA,
-    gmodel = NA,
-    quality = NA,
-    run_res = NA,
+    data = NULL,
+    target = NULL,
+    task = NULL,
+    gmodel = NULL,
+    quality = NULL,
+    run_res = NULL,
     get_testIndices = function(CV, v, n, shuffle_inds) {
       if (CV$type == "vfold") {
         first_index = round((v-1) * n / CV$V) + 1
@@ -175,6 +179,11 @@ AgghooCV <- R6::R6Class("AgghooCV",
       testX <- private$data[test_indices,]
       targetHO <- private$target[-test_indices]
       testY <- private$target[test_indices]
+      # R will cast 1-dim matrices into vectors:
+      if (!is.matrix(dataHO) && !is.data.frame(dataHO))
+        dataHO <- as.matrix(dataHO)
+      if (!is.matrix(testX) && !is.data.frame(testX))
+        testX <- as.matrix(testX)
       if (p >= 1)
         # Standard CV: one model at a time
         return (getOnePerf(p)$perf)
index 9d7fc70..8912cdb 100644 (file)
@@ -6,6 +6,13 @@
 #' Parameters for cross-validation are either provided or estimated.
 #' Model family can be chosen among "rf", "tree", "ppr" and "knn" for now.
 #'
+#' @importFrom FNN knn.reg
+#' @importFrom class knn
+#' @importFrom stats ppr
+#' @importFrom randomForest randomForest
+#' @importFrom rpart rpart
+#' @importFrom caret var_seq
+#'
 #' @export
 Model <- R6::R6Class("Model",
   public = list(
@@ -18,8 +25,8 @@ Model <- R6::R6Class("Model",
     #' @param gmodel Generic model returning a predictive function; chosen
     #'               automatically given data and target nature if not provided.
     #' @param params List of parameters for cross-validation (each defining a model)
-    initialize = function(data, target, task, gmodel = NA, params = NA) {
-      if (is.na(gmodel)) {
+    initialize = function(data, target, task, gmodel = NULL, params = NULL) {
+      if (is.null(gmodel)) {
         # (Generic) model not provided
         all_numeric <- is.numeric(as.matrix(data))
         if (!all_numeric)
@@ -30,7 +37,7 @@ Model <- R6::R6Class("Model",
           # Numerical data
           gmodel = ifelse(task == "regression", "ppr", "knn")
       }
-      if (is.na(params))
+      if (is.null(params))
         # Here, gmodel is a string (= its family),
         # because a custom model must be given with its parameters.
         params <- as.list(private$getParams(gmodel, data, target))
@@ -52,8 +59,8 @@ Model <- R6::R6Class("Model",
   ),
   private = list(
     # No need to expose model or parameters list
-    gmodel = NA,
-    params = NA,
+    gmodel = NULL,
+    params = NULL,
     # Main function: given a family, return a generic model, which in turn
     # will output a predictive model from data + target + params.
     getGmodel = function(family, task) {
@@ -62,7 +69,7 @@ Model <- R6::R6Class("Model",
           require(rpart)
           method <- ifelse(task == "classification", "class", "anova")
           df <- data.frame(cbind(dataHO, target=targetHO))
-          model <- rpart(target ~ ., df, method=method, control=list(cp=param))
+          model <- rpart::rpart(target ~ ., df, method=method, control=list(cp=param))
           function(X) predict(model, X)
         }
       }
@@ -82,9 +89,17 @@ Model <- R6::R6Class("Model",
         }
       }
       else if (family == "knn") {
-        function(dataHO, targetHO, param) {
-          require(class)
-          function(X) class::knn(dataHO, X, cl=targetHO, k=param)
+        if (task == "classification") {
+          function(dataHO, targetHO, param) {
+            require(class)
+            function(X) class::knn(dataHO, X, cl=targetHO, k=param)
+          }
+        }
+        else {
+          function(dataHO, targetHO, param) {
+            require(FNN)
+            function(X) FNN::knn.reg(dataHO, X, y=targetHO, k=param)$pred
+          }
         }
       }
     },
index 92d061f..f3bc740 100644 (file)
 #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1))
 #'
 #' @export
-agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = NA) {
+agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL) {
        # Args check:
   if (!is.data.frame(data) && !is.matrix(data))
     stop("data: data.frame or matrix")
-  if (nrow(data) <= 1 || any(dim(data) == 0))
-    stop("data: non-empty, >= 2 rows")
   if (!is.numeric(target) && !is.factor(target) && !is.character(target))
     stop("target: numeric, factor or character vector")
-  if (!is.na(task))
+  if (!is.null(task))
     task = match.arg(task, c("classification", "regression"))
   if (is.character(gmodel))
-    gmodel <- match.arg("knn", "ppr", "rf")
-  else if (!is.na(gmodel) && !is.function(gmodel))
+    gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree"))
+  else if (!is.null(gmodel) && !is.function(gmodel))
     # No further checks here: fingers crossed :)
     stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y")
   if (is.numeric(params) || is.character(params))
     params <- as.list(params)
-  if (!is.na(params) && !is.list(params))
+  if (!is.list(params) && !is.null(params))
     stop("params: numerical, character, or list (passed to model)")
-  if (!is.na(gmodel) && !is.character(gmodel) && is.na(params))
+  if (is.function(gmodel) && !is.list(params))
     stop("params must be provided when using a custom model")
-  if (is.na(gmodel) && !is.na(params))
-    stop("model must be provided when using custom params")
-  if (!is.na(quality) && !is.function(quality))
+  if (is.list(params) && is.null(gmodel))
+    stop("model (or family) must be provided when using custom params")
+  if (!is.null(quality) && !is.function(quality))
     # No more checks here as well... TODO:?
     stop("quality: function(y1, y2) --> Real")
 
-  if (is.na(task)) {
+  if (is.null(task)) {
     if (is.numeric(target))
       task = "regression"
     else
@@ -76,10 +74,10 @@ agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality =
 #' (TODO: extended, in another file, more tests - when faster code).
 #'
 #' @export
-compareToStandard <- function(df, t_idx, task = NA, rseed = -1) {
+compareToStandard <- function(df, t_idx, task = NULL, rseed = -1) {
   if (rseed >= 0)
     set.seed(rseed)
-  if (is.na(task))
+  if (is.null(task))
     task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification")
   n <- nrow(df)
   test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
index 4d4cf78..75d1ab6 100644 (file)
@@ -22,7 +22,7 @@ from the list of models (see 'Model' class).
 \subsection{Method \code{new()}}{
 Create a new AgghooCV object.
 \subsection{Usage}{
-\if{html}{\out{<div class="r">}}\preformatted{AgghooCV$new(data, target, task, gmodel, quality = NA)}\if{html}{\out{</div>}}
+\if{html}{\out{<div class="r">}}\preformatted{AgghooCV$new(data, target, task, gmodel, quality = NULL)}\if{html}{\out{</div>}}
 }
 
 \subsection{Arguments}{
index a16f8ae..32f3ca2 100644 (file)
@@ -30,7 +30,7 @@ Model family can be chosen among "rf", "tree", "ppr" and "knn" for now.
 \subsection{Method \code{new()}}{
 Create a new generic model.
 \subsection{Usage}{
-\if{html}{\out{<div class="r">}}\preformatted{Model$new(data, target, task, gmodel = NA, params = NA)}\if{html}{\out{</div>}}
+\if{html}{\out{<div class="r">}}\preformatted{Model$new(data, target, task, gmodel = NULL, params = NULL)}\if{html}{\out{</div>}}
 }
 
 \subsection{Arguments}{
index 21afe5a..69dafed 100644 (file)
@@ -4,7 +4,7 @@
 \alias{agghoo}
 \title{agghoo}
 \usage{
-agghoo(data, target, task = NA, gmodel = NA, params = NA, quality = NA)
+agghoo(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL)
 }
 \arguments{
 \item{data}{Data frame or matrix containing the data in lines.}
index 5787de9..742ecaa 100644 (file)
@@ -4,7 +4,7 @@
 \alias{compareToStandard}
 \title{compareToStandard}
 \usage{
-compareToStandard(df, t_idx, task = NA, rseed = -1)
+compareToStandard(df, t_idx, task = NULL, rseed = -1)
 }
 \description{
 Temporary function to compare agghoo to CV