| 1 | # Internal usage: check and fill arguments with default values. |
| 2 | |
| 3 | defaultLoss_classif <- function(y1, y2) { |
| 4 | if (is.null(dim(y1))) |
| 5 | # Standard case: "hard" classification |
| 6 | mean(y1 != y2) |
| 7 | else { |
| 8 | # "Soft" classification: predict() outputs a probability matrix |
| 9 | # In this case "target" could be in matrix form. |
| 10 | if (!is.null(dim(y2))) |
| 11 | mean(rowSums(abs(y1 - y2))) |
| 12 | else { |
| 13 | # Or not: y2 is a "factor". |
| 14 | y2 <- as.character(y2) |
| 15 | # NOTE: the user should provide target in matrix form because |
| 16 | # matching y2 with columns is rather inefficient! |
| 17 | names <- colnames(y1) |
| 18 | positions <- list() |
| 19 | for (idx in seq_along(names)) |
| 20 | positions[[ names[idx] ]] <- idx |
| 21 | mean(vapply( |
| 22 | seq_along(y2), |
| 23 | function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])), |
| 24 | 0)) |
| 25 | } |
| 26 | } |
| 27 | } |
| 28 | |
| 29 | defaultLoss_regress <- function(y1, y2) { |
| 30 | mean(abs(y1 - y2)) |
| 31 | } |
| 32 | |
| 33 | # TODO: allow strings like "MSE", "abs" etc |
| 34 | checkLoss <- function(loss, task) { |
| 35 | if (!is.null(loss) && !is.function(loss)) |
| 36 | stop("loss: function(y1, y2) --> Real") |
| 37 | if (is.null(loss)) { |
| 38 | loss <- if (task == "classification") { |
| 39 | defaultLoss_classif |
| 40 | } else { |
| 41 | defaultLoss_regress |
| 42 | } |
| 43 | } |
| 44 | loss |
| 45 | } |
| 46 | |
| 47 | checkCV <- function(CV) { |
| 48 | if (is.null(CV)) |
| 49 | CV <- list(type="MC", V=10, test_size=0.2, shuffle=TRUE) |
| 50 | else { |
| 51 | if (!is.list(CV)) |
| 52 | stop("CV: list of type('MC'|'vfold'), V(integer, [test_size, shuffle]") |
| 53 | if (is.null(CV$type)) { |
| 54 | warning("CV$type not provided: set to MC") |
| 55 | CV$type <- "MC" |
| 56 | } |
| 57 | if (is.null(CV$V)) { |
| 58 | warning("CV$V not provided: set to 10") |
| 59 | CV$V <- 10 |
| 60 | } |
| 61 | if (CV$type == "MC" && is.null(CV$test_size)) |
| 62 | CV$test_size <- 0.2 |
| 63 | if (CV$type == "vfold" && is.null(CV$shuffle)) |
| 64 | CV$shuffle <- TRUE |
| 65 | } |
| 66 | CV |
| 67 | } |
| 68 | |
| 69 | checkDaTa <- function(data, target) { |
| 70 | if (!is.data.frame(data) && !is.matrix(data)) |
| 71 | stop("data: data.frame or matrix") |
| 72 | if (is.data.frame(target) || is.matrix(target)) { |
| 73 | if (!is.numeric(target)) |
| 74 | stop("multi-columns target must be a probability matrix") |
| 75 | if (nrow(target) != nrow(data) || ncol(target) == 1) |
| 76 | stop("target probability matrix does not match data size") |
| 77 | } |
| 78 | else if (!is.numeric(target) && !is.factor(target) && !is.character(target)) |
| 79 | stop("target: numeric, factor or character vector") |
| 80 | } |
| 81 | |
| 82 | checkTask <- function(task, target) { |
| 83 | if (!is.null(task)) |
| 84 | task <- match.arg(task, c("classification", "regression")) |
| 85 | ifelse(is.numeric(target), "regression", "classification") |
| 86 | } |
| 87 | |
| 88 | checkModPar <- function(gmodel, params) { |
| 89 | if (is.character(gmodel)) |
| 90 | gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) |
| 91 | else if (!is.null(gmodel) && !is.function(gmodel)) |
| 92 | stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") |
| 93 | if (is.numeric(params) || is.character(params)) |
| 94 | params <- as.list(params) |
| 95 | if (!is.list(params) && !is.null(params)) |
| 96 | stop("params: numerical, character, or list (passed to model)") |
| 97 | if (is.function(gmodel) && !is.list(params)) |
| 98 | stop("params must be provided when using a custom model") |
| 99 | if (is.list(params) && is.null(gmodel)) |
| 100 | stop("model (or family) must be provided when using custom params") |
| 101 | list(gmodel=gmodel, params=params) |
| 102 | } |