X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=R%2Fagghoo.R;h=f3bc74093823dc66d96eaf05299ed72c59ba8876;hp=92d061f0a13790fc8ec5460b6619f4f60b63a51b;hb=d9a139b51ee2e71e13d67cb9d530834b15058617;hpb=cca5f1c67bd622fb7bc1279dfe4c3336d1446efd diff --git a/R/agghoo.R b/R/agghoo.R index 92d061f..f3bc740 100644 --- a/R/agghoo.R +++ b/R/agghoo.R @@ -31,34 +31,32 @@ #' pc <- a_cla$predict(iris[,-5] + rnorm(600, sd=0.1)) #' #' @export -agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = NA) { +agghoo <- function(data, target, task = NULL, gmodel = NULL, params = NULL, quality = NULL) { # Args check: if (!is.data.frame(data) && !is.matrix(data)) stop("data: data.frame or matrix") - if (nrow(data) <= 1 || any(dim(data) == 0)) - stop("data: non-empty, >= 2 rows") if (!is.numeric(target) && !is.factor(target) && !is.character(target)) stop("target: numeric, factor or character vector") - if (!is.na(task)) + if (!is.null(task)) task = match.arg(task, c("classification", "regression")) if (is.character(gmodel)) - gmodel <- match.arg("knn", "ppr", "rf") - else if (!is.na(gmodel) && !is.function(gmodel)) + gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) + else if (!is.null(gmodel) && !is.function(gmodel)) # No further checks here: fingers crossed :) stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") if (is.numeric(params) || is.character(params)) params <- as.list(params) - if (!is.na(params) && !is.list(params)) + if (!is.list(params) && !is.null(params)) stop("params: numerical, character, or list (passed to model)") - if (!is.na(gmodel) && !is.character(gmodel) && is.na(params)) + if (is.function(gmodel) && !is.list(params)) stop("params must be provided when using a custom model") - if (is.na(gmodel) && !is.na(params)) - stop("model must be provided when using custom params") - if (!is.na(quality) && !is.function(quality)) + if (is.list(params) && is.null(gmodel)) + stop("model (or family) must be provided when using custom params") + if (!is.null(quality) && !is.function(quality)) # No more checks here as well... TODO:? stop("quality: function(y1, y2) --> Real") - if (is.na(task)) { + if (is.null(task)) { if (is.numeric(target)) task = "regression" else @@ -76,10 +74,10 @@ agghoo <- function(data, target, task = NA, gmodel = NA, params = NA, quality = #' (TODO: extended, in another file, more tests - when faster code). #' #' @export -compareToStandard <- function(df, t_idx, task = NA, rseed = -1) { +compareToStandard <- function(df, t_idx, task = NULL, rseed = -1) { if (rseed >= 0) set.seed(rseed) - if (is.na(task)) + if (is.null(task)) task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification") n <- nrow(df) test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )