X-Git-Url: https://git.auder.net/?a=blobdiff_plain;f=R%2Fagghoo.R;h=f3bc74093823dc66d96eaf05299ed72c59ba8876;hb=d9a139b51ee2e71e13d67cb9d530834b15058617;hp=528df2a546b383a9ebb280b37f1b7a495480d4a0;hpb=d09659f0e609bc8c1a6c390329d8f2d3b3ac5b24;p=agghoo.git diff --git a/R/agghoo.R b/R/agghoo.R index 528df2a..f3bc740 100644 --- a/R/agghoo.R +++ b/R/agghoo.R @@ -16,9 +16,9 @@ #' Default: see R6::Model. #' @param quality A function assessing the quality of a prediction. #' Arguments are y1 and y2 (comparing a prediction to known values). -#' Default: see R6::Agghoo. +#' Default: see R6::AgghooCV. #' -#' @return An R6::Agghoo object. +#' @return An R6::AgghooCV object. #' #' @examples #' # Regression: @@ -31,34 +31,32 @@ #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1)) #' #' @export -agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = NA) { +agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL) { # Args check: if (!is.data.frame(data) && !is.matrix(data)) stop("data: data.frame or matrix") - if (nrow(data) <= 1 || any(dim(data) == 0)) - stop("data: non-empty, >= 2 rows") if (!is.numeric(target) && !is.factor(target) && !is.character(target)) stop("target: numeric, factor or character vector") - if (!is.na(task)) + if (!is.null(task)) task = match.arg(task, c("classification", "regression")) if (is.character(gmodel)) - gmodel <- match.arg("knn", "ppr", "rf") - else if (!is.na(gmodel) && !is.function(gmodel)) + gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) + else if (!is.null(gmodel) && !is.function(gmodel)) # No further checks here: fingers crossed :) stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") if (is.numeric(params) || is.character(params)) params <- as.list(params) - if (!is.na(params) && !is.list(params)) + if (!is.list(params) && !is.null(params)) stop("params: numerical, character, or list (passed to model)") - if (!is.na(gmodel) && !is.character(gmodel) && is.na(params)) + if (is.function(gmodel) && !is.list(params)) stop("params must be provided when using a custom model") - if (is.na(gmodel) && !is.na(params)) - stop("model must be provided when using custom params") - if (!is.na(quality) && !is.function(quality)) + if (is.list(params) && is.null(gmodel)) + stop("model (or family) must be provided when using custom params") + if (!is.null(quality) && !is.function(quality)) # No more checks here as well... TODO:? stop("quality: function(y1, y2) --> Real") - if (is.na(task)) { + if (is.null(task)) { if (is.numeric(target)) task = "regression" else @@ -66,8 +64,8 @@ agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = } # Build Model object (= list of parameterized models) model <- Model$new(data, target, task, gmodel, params) - # Return Agghoo object, to run and predict - Agghoo$new(data, target, task, model, quality) + # Return AgghooCV object, to run and predict + AgghooCV$new(data, target, task, model, quality) } #' compareToStandard @@ -76,10 +74,10 @@ agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = #' (TODO: extended, in another file, more tests - when faster code). #' #' @export -compareToStandard <- function(df, t_idx, task = NA, rseed = -1) { +compareToStandard <- function(df, t_idx, task = NULL, rseed = -1) { if (rseed >= 0) set.seed(rseed) - if (is.na(task)) + if (is.null(task)) task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification") n <- nrow(df) test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )