From d9a139b51ee2e71e13d67cb9d530834b15058617 Mon Sep 17 00:00:00 2001 From: Benjamin Auder Date: Tue, 8 Jun 2021 17:54:59 +0200 Subject: [PATCH] Adjustments / fixes... And add knn for regression --- DESCRIPTION | 3 ++- NAMESPACE | 7 +++++++ R/R6_AgghooCV.R | 27 ++++++++++++++++++--------- R/R6_Model.R | 33 ++++++++++++++++++++++++--------- R/agghoo.R | 26 ++++++++++++-------------- man/AgghooCV.Rd | 2 +- man/Model.Rd | 2 +- man/agghoo.Rd | 2 +- man/compareToStandard.Rd | 2 +- 9 files changed, 67 insertions(+), 37 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7944819..2689a20 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -20,7 +20,8 @@ Imports: R6, caret, rpart, - randomForest + randomForest, + FNN Suggests: roxygen2 URL: https://git.auder.net/?p=agghoo.git diff --git a/NAMESPACE b/NAMESPACE index 14fb7a3..63138fb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,3 +4,10 @@ export(AgghooCV) export(Model) export(agghoo) export(compareToStandard) +importFrom(FNN,knn.reg) +importFrom(R6,R6Class) +importFrom(caret,var_seq) +importFrom(class,knn) +importFrom(randomForest,randomForest) +importFrom(rpart,rpart) +importFrom(stats,ppr) diff --git a/R/R6_AgghooCV.R b/R/R6_AgghooCV.R index 4ceb289..dba42b4 100644 --- a/R/R6_AgghooCV.R +++ b/R/R6_AgghooCV.R @@ -4,6 +4,8 @@ #' Class encapsulating the methods to run to obtain the best predictor #' from the list of models (see 'Model' class). #' +#' @importFrom R6 R6Class +#' #' @export AgghooCV <- R6::R6Class("AgghooCV", public = list( @@ -14,12 +16,12 @@ AgghooCV <- R6::R6Class("AgghooCV", #' @param gmodel Generic model returning a predictive function #' @param quality Function assessing the quality of a prediction; #' quality(y1, y2) --> real number - initialize = function(data, target, task, gmodel, quality = NA) { + initialize = function(data, target, task, gmodel, quality = NULL) { private$data <- data private$target <- target private$task <- task private$gmodel <- gmodel - if (is.na(quality)) { + if (is.null(quality)) { quality <- function(y1, y2) { # NOTE: if classif output is a probability matrix, adapt. if (task == "classification") @@ -87,7 +89,9 @@ AgghooCV <- R6::R6Class("AgghooCV", #' @param weight "uniform" (default) or "quality" to weight votes or #' average models performances (TODO: bad idea?!) predict = function(X, weight="uniform") { - if (!is.list(private$run_res) || is.na(private$run_res)) { + if (!is.matrix(X) && !is.data.frame(X)) + stop("X: matrix or data.frame") + if (!is.list(private$run_res)) { print("Please call $fit() method first") return } @@ -146,12 +150,12 @@ AgghooCV <- R6::R6Class("AgghooCV", } ), private = list( - data = NA, - target = NA, - task = NA, - gmodel = NA, - quality = NA, - run_res = NA, + data = NULL, + target = NULL, + task = NULL, + gmodel = NULL, + quality = NULL, + run_res = NULL, get_testIndices = function(CV, v, n, shuffle_inds) { if (CV$type == "vfold") { first_index = round((v-1) * n / CV$V) + 1 @@ -175,6 +179,11 @@ AgghooCV <- R6::R6Class("AgghooCV", testX <- private$data[test_indices,] targetHO <- private$target[-test_indices] testY <- private$target[test_indices] + # R will cast 1-dim matrices into vectors: + if (!is.matrix(dataHO) && !is.data.frame(dataHO)) + dataHO <- as.matrix(dataHO) + if (!is.matrix(testX) && !is.data.frame(testX)) + testX <- as.matrix(testX) if (p >= 1) # Standard CV: one model at a time return (getOnePerf(p)$perf) diff --git a/R/R6_Model.R b/R/R6_Model.R index 9d7fc70..8912cdb 100644 --- a/R/R6_Model.R +++ b/R/R6_Model.R @@ -6,6 +6,13 @@ #' Parameters for cross-validation are either provided or estimated. #' Model family can be chosen among "rf", "tree", "ppr" and "knn" for now. #' +#' @importFrom FNN knn.reg +#' @importFrom class knn +#' @importFrom stats ppr +#' @importFrom randomForest randomForest +#' @importFrom rpart rpart +#' @importFrom caret var_seq +#' #' @export Model <- R6::R6Class("Model", public = list( @@ -18,8 +25,8 @@ Model <- R6::R6Class("Model", #' @param gmodel Generic model returning a predictive function; chosen #' automatically given data and target nature if not provided. #' @param params List of parameters for cross-validation (each defining a model) - initialize = function(data, target, task, gmodel = NA, params = NA) { - if (is.na(gmodel)) { + initialize = function(data, target, task, gmodel = NULL, params = NULL) { + if (is.null(gmodel)) { # (Generic) model not provided all_numeric <- is.numeric(as.matrix(data)) if (!all_numeric) @@ -30,7 +37,7 @@ Model <- R6::R6Class("Model", # Numerical data gmodel = ifelse(task == "regression", "ppr", "knn") } - if (is.na(params)) + if (is.null(params)) # Here, gmodel is a string (= its family), # because a custom model must be given with its parameters. params <- as.list(private$getParams(gmodel, data, target)) @@ -52,8 +59,8 @@ Model <- R6::R6Class("Model", ), private = list( # No need to expose model or parameters list - gmodel = NA, - params = NA, + gmodel = NULL, + params = NULL, # Main function: given a family, return a generic model, which in turn # will output a predictive model from data + target + params. getGmodel = function(family, task) { @@ -62,7 +69,7 @@ Model <- R6::R6Class("Model", require(rpart) method <- ifelse(task == "classification", "class", "anova") df <- data.frame(cbind(dataHO, target=targetHO)) - model <- rpart(target ~ ., df, method=method, control=list(cp=param)) + model <- rpart::rpart(target ~ ., df, method=method, control=list(cp=param)) function(X) predict(model, X) } } @@ -82,9 +89,17 @@ Model <- R6::R6Class("Model", } } else if (family == "knn") { - function(dataHO, targetHO, param) { - require(class) - function(X) class::knn(dataHO, X, cl=targetHO, k=param) + if (task == "classification") { + function(dataHO, targetHO, param) { + require(class) + function(X) class::knn(dataHO, X, cl=targetHO, k=param) + } + } + else { + function(dataHO, targetHO, param) { + require(FNN) + function(X) FNN::knn.reg(dataHO, X, y=targetHO, k=param)$pred + } } } }, diff --git a/R/agghoo.R b/R/agghoo.R index 92d061f..f3bc740 100644 --- a/R/agghoo.R +++ b/R/agghoo.R @@ -31,34 +31,32 @@ #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1)) #' #' @export -agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = NA) { +agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL) { # Args check: if (!is.data.frame(data) && !is.matrix(data)) stop("data: data.frame or matrix") - if (nrow(data) <= 1 || any(dim(data) == 0)) - stop("data: non-empty, >= 2 rows") if (!is.numeric(target) && !is.factor(target) && !is.character(target)) stop("target: numeric, factor or character vector") - if (!is.na(task)) + if (!is.null(task)) task = match.arg(task, c("classification", "regression")) if (is.character(gmodel)) - gmodel <- match.arg("knn", "ppr", "rf") - else if (!is.na(gmodel) && !is.function(gmodel)) + gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) + else if (!is.null(gmodel) && !is.function(gmodel)) # No further checks here: fingers crossed :) stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") if (is.numeric(params) || is.character(params)) params <- as.list(params) - if (!is.na(params) && !is.list(params)) + if (!is.list(params) && !is.null(params)) stop("params: numerical, character, or list (passed to model)") - if (!is.na(gmodel) && !is.character(gmodel) && is.na(params)) + if (is.function(gmodel) && !is.list(params)) stop("params must be provided when using a custom model") - if (is.na(gmodel) && !is.na(params)) - stop("model must be provided when using custom params") - if (!is.na(quality) && !is.function(quality)) + if (is.list(params) && is.null(gmodel)) + stop("model (or family) must be provided when using custom params") + if (!is.null(quality) && !is.function(quality)) # No more checks here as well... TODO:? stop("quality: function(y1, y2) --> Real") - if (is.na(task)) { + if (is.null(task)) { if (is.numeric(target)) task = "regression" else @@ -76,10 +74,10 @@ agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = #' (TODO: extended, in another file, more tests - when faster code). #' #' @export -compareToStandard <- function(df, t_idx, task = NA, rseed = -1) { +compareToStandard <- function(df, t_idx, task = NULL, rseed = -1) { if (rseed >= 0) set.seed(rseed) - if (is.na(task)) + if (is.null(task)) task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification") n <- nrow(df) test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) ) diff --git a/man/AgghooCV.Rd b/man/AgghooCV.Rd index 4d4cf78..75d1ab6 100644 --- a/man/AgghooCV.Rd +++ b/man/AgghooCV.Rd @@ -22,7 +22,7 @@ from the list of models (see 'Model' class). \subsection{Method \code{new()}}{ Create a new AgghooCV object. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{AgghooCV$new(data, target, task, gmodel, quality = NA)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{AgghooCV$new(data, target, task, gmodel, quality = NULL)}\if{html}{\out{
}} } \subsection{Arguments}{ diff --git a/man/Model.Rd b/man/Model.Rd index a16f8ae..32f3ca2 100644 --- a/man/Model.Rd +++ b/man/Model.Rd @@ -30,7 +30,7 @@ Model family can be chosen among "rf", "tree", "ppr" and "knn" for now. \subsection{Method \code{new()}}{ Create a new generic model. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Model$new(data, target, task, gmodel = NA, params = NA)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Model$new(data, target, task, gmodel = NULL, params = NULL)}\if{html}{\out{
}} } \subsection{Arguments}{ diff --git a/man/agghoo.Rd b/man/agghoo.Rd index 21afe5a..69dafed 100644 --- a/man/agghoo.Rd +++ b/man/agghoo.Rd @@ -4,7 +4,7 @@ \alias{agghoo} \title{agghoo} \usage{ -agghoo(data, target, task = NA, gmodel = NA, params = NA, quality = NA) +agghoo(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL) } \arguments{ \item{data}{Data frame or matrix containing the data in lines.} diff --git a/man/compareToStandard.Rd b/man/compareToStandard.Rd index 5787de9..742ecaa 100644 --- a/man/compareToStandard.Rd +++ b/man/compareToStandard.Rd @@ -4,7 +4,7 @@ \alias{compareToStandard} \title{compareToStandard} \usage{ -compareToStandard(df, t_idx, task = NA, rseed = -1) +compareToStandard(df, t_idx, task = NULL, rseed = -1) } \description{ Temporary function to compare agghoo to CV -- 2.44.0