R (>= 3.5.0)
Imports:
R6,
- caret,
rpart,
randomForest,
FNN
export(standardCV_run)
importFrom(FNN,knn.reg)
importFrom(R6,R6Class)
-importFrom(caret,var_seq)
importFrom(class,knn)
-importFrom(randomForest,randomForest)
importFrom(rpart,rpart)
importFrom(stats,ppr)
#' @description Create a new AgghooCV object.
#' @param data Matrix or data.frame
#' @param target Vector of targets (generally numeric or factor)
- #' @param task "regression" or "classification"
+ #' @param task "regression" or "classification".
+ #' Default: classification if target not numeric.
#' @param gmodel Generic model returning a predictive function
+ #' Default: tree if mixed data, knn/ppr otherwise.
#' @param loss Function assessing the error of a prediction
+ #' Default: error rate or mean(abs(error)).
initialize = function(data, target, task, gmodel, loss) {
private$data <- data
private$target <- target
#' "Model" class, containing a (generic) learning function, which from
#' data + target [+ params] returns a prediction function X --> y.
#' Parameters for cross-validation are either provided or estimated.
-#' Model family can be chosen among "rf", "tree", "ppr" and "knn" for now.
+#' Model family can be chosen among "tree", "ppr" and "knn" for now.
#'
#' @importFrom FNN knn.reg
#' @importFrom class knn
#' @importFrom stats ppr
-#' @importFrom randomForest randomForest
#' @importFrom rpart rpart
-#' @importFrom caret var_seq
#'
#' @export
Model <- R6::R6Class("Model",
# (Generic) model not provided
all_numeric <- is.numeric(as.matrix(data))
if (!all_numeric)
- # At least one non-numeric column: use random forests or trees
- # TODO: 4 = arbitrary magic number...
- gmodel = ifelse(ncol(data) >= 4, "rf", "tree")
+ # At least one non-numeric column: use trees
+ gmodel = "tree"
else
# Numerical data
gmodel = ifelse(task == "regression", "ppr", "knn")
}
}
}
- else if (family == "rf") {
- function(dataHO, targetHO, param) {
- require(randomForest)
- if (task == "classification" && !is.factor(targetHO))
- targetHO <- as.factor(targetHO)
- model <- randomForest::randomForest(dataHO, targetHO, mtry=param)
- function(X) predict(model, X)
- }
- }
else if (family == "ppr") {
function(dataHO, targetHO, param) {
model <- stats::ppr(dataHO, targetHO, nterms=param)
step <- (length(cps) - 1) / 10
cps[unique(round(seq(1, length(cps), step)))]
}
- else if (family == "rf") {
- p <- ncol(data)
- # Use caret package to obtain the CV grid of mtry values
- require(caret)
- caret::var_seq(p, classification = (task == "classification"),
- len = min(10, p-1))
- }
else if (family == "ppr")
# This is nterms in ppr() function
1:10
+#' CVvoting_core
+#'
+#' "voting" cross-validation method, added here as an example.
+#' Parameters are described in ?agghoo and ?AgghooCV
+CVvoting_core <- function(data, target, task, gmodel, params, loss, CV) {
+ CV <- checkCV(CV)
+ n <- nrow(data)
+ shuffle_inds <- NULL
+ if (CV$type == "vfold" && CV$shuffle)
+ shuffle_inds <- sample(n, n)
+ bestP <- rep(0, gmodel$nmodels)
+ gmodel <- agghoo::Model$new(data, target, task, gmodel, params)
+ for (v in seq_len(CV$V)) {
+ test_indices <- get_testIndices(n, CV, v, shuffle_inds)
+ d <- splitTrainTest(data, target, test_indices)
+ best_p <- NULL
+ best_error <- Inf
+ for (p in seq_len(gmodel$nmodels)) {
+ model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p)
+ prediction <- model_pred(d$dataTest)
+ error <- loss(prediction, d$targetTest)
+ if (error <= best_error) {
+ if (error == best_error)
+ best_p[[length(best_p)+1]] <- p
+ else {
+ best_p <- list(p)
+ best_error <- error
+ }
+ }
+ }
+ for (p in best_p)
+ bestP[p] <- bestP[p] + 1
+ }
+ # Choose a param at random in case of ex-aequos:
+ maxP <- max(bestP)
+ chosenP <- sample(which(bestP == maxP), 1)
+ list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP))
+}
+
#' standardCV_core
#'
#' Cross-validation method, added here as an example.
list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds)
gmodel <- agghoo::Model$new(data, target, task, gmodel, params)
best_error <- Inf
- best_model <- NULL
+ best_p <- NULL
for (p in seq_len(gmodel$nmodels)) {
error <- Reduce('+', lapply(seq_len(CV$V), function(v) {
testIdx <- list_testinds[[v]]
loss(prediction, d$targetTest)
}) )
if (error <= best_error) {
- newModel <- list(model=gmodel$get(data, target, p),
- param=gmodel$getParam(p))
if (error == best_error)
- best_model[[length(best_model)+1]] <- newModel
+ best_p[[length(best_p)+1]] <- p
else {
- best_model <- list(newModel)
+ best_p <- list(p)
best_error <- error
}
}
}
- best_model[[ sample(length(best_model), 1) ]]
+ chosenP <- best_p[[ sample(length(best_p), 1) ]]
+ list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP))
}
#' standardCV_run
Support des valeurs manquantes (cf. mlbench::Ozone dataset)
-Compare with CV-voting (mode="voting" ?)
-Supprimer randomForest ? (Méthode déjà de type agrégation)
-(Remplacer par... ?)
+Méthode pour données mixtes ? (que tree actuellement)
\item{\code{target}}{Vector of targets (generally numeric or factor)}
-\item{\code{task}}{"regression" or "classification"}
+\item{\code{task}}{"regression" or "classification".
+Default: classification if target not numeric.}
-\item{\code{gmodel}}{Generic model returning a predictive function}
+\item{\code{gmodel}}{Generic model returning a predictive function
+Default: tree if mixed data, knn/ppr otherwise.}
-\item{\code{loss}}{Function assessing the error of a prediction}
+\item{\code{loss}}{Function assessing the error of a prediction
+Default: error rate or mean(abs(error)).}
}
\if{html}{\out{</div>}}
}
--- /dev/null
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/compareTo.R
+\name{CVvoting_core}
+\alias{CVvoting_core}
+\title{CVvoting_core}
+\usage{
+CVvoting_core(data, target, task, gmodel, params, loss, CV)
+}
+\description{
+"voting" cross-validation method, added here as an example.
+Parameters are described in ?agghoo and ?AgghooCV
+}
"Model" class, containing a (generic) learning function, which from
data + target [+ params] returns a prediction function X --> y.
Parameters for cross-validation are either provided or estimated.
-Model family can be chosen among "rf", "tree", "ppr" and "knn" for now.
+Model family can be chosen among "tree", "ppr" and "knn" for now.
}
\section{Public fields}{
\if{html}{\out{<div class="r6-fields">}}