| 1 | defaultLoss_classif <- function(y1, y2) { |
| 2 | if (is.null(dim(y1))) |
| 3 | # Standard case: "hard" classification |
| 4 | mean(y1 != y2) |
| 5 | else { |
| 6 | # "Soft" classification: predict() outputs a probability matrix |
| 7 | # In this case "target" could be in matrix form. |
| 8 | if (!is.null(dim(y2))) |
| 9 | mean(rowSums(abs(y1 - y2))) |
| 10 | else { |
| 11 | # Or not: y2 is a "factor". |
| 12 | y2 <- as.character(y2) |
| 13 | # NOTE: the user should provide target in matrix form because |
| 14 | # matching y2 with columns is rather inefficient! |
| 15 | names <- colnames(y1) |
| 16 | positions <- list() |
| 17 | for (idx in seq_along(names)) |
| 18 | positions[[ names[idx] ]] <- idx |
| 19 | mean(vapply( |
| 20 | seq_along(y2), |
| 21 | function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])), |
| 22 | 0)) |
| 23 | } |
| 24 | } |
| 25 | } |
| 26 | |
| 27 | defaultLoss_regress <- function(y1, y2) { |
| 28 | mean(abs(y1 - y2)) |
| 29 | } |
| 30 | |
| 31 | # TODO: allow strings like "MSE", "abs" etc |
| 32 | checkLoss <- function(loss, task) { |
| 33 | if (!is.null(loss) && !is.function(loss)) |
| 34 | stop("loss: function(y1, y2) --> Real") |
| 35 | if (is.null(loss)) { |
| 36 | loss <- if (task == "classification") { |
| 37 | defaultLoss_classif |
| 38 | } else { |
| 39 | defaultLoss_regress |
| 40 | } |
| 41 | } |
| 42 | loss |
| 43 | } |
| 44 | |
| 45 | checkCV <- function(CV) { |
| 46 | if (is.null(CV)) |
| 47 | CV <- list(type="MC", V=10, test_size=0.2, shuffle=TRUE) |
| 48 | else { |
| 49 | if (!is.list(CV)) |
| 50 | stop("CV: list of type('MC'|'vfold'), V(integer, [test_size, shuffle]") |
| 51 | if (is.null(CV$type)) { |
| 52 | warning("CV$type not provided: set to MC") |
| 53 | CV$type <- "MC" |
| 54 | } |
| 55 | if (is.null(CV$V)) { |
| 56 | warning("CV$V not provided: set to 10") |
| 57 | CV$V <- 10 |
| 58 | } |
| 59 | if (CV$type == "MC" && is.null(CV$test_size)) |
| 60 | CV$test_size <- 0.2 |
| 61 | if (CV$type == "vfold" && is.null(CV$shuffle)) |
| 62 | CV$shuffle <- TRUE |
| 63 | } |
| 64 | CV |
| 65 | } |
| 66 | |
| 67 | checkDaTa <- function(data, target) { |
| 68 | if (!is.data.frame(data) && !is.matrix(data)) |
| 69 | stop("data: data.frame or matrix") |
| 70 | if (is.data.frame(target) || is.matrix(target)) { |
| 71 | if (!is.numeric(target)) |
| 72 | stop("multi-columns target must be a probability matrix") |
| 73 | if (nrow(target) != nrow(data) || ncol(target) == 1) |
| 74 | stop("target probability matrix does not match data size") |
| 75 | } |
| 76 | else if (!is.numeric(target) && !is.factor(target) && !is.character(target)) |
| 77 | stop("target: numeric, factor or character vector") |
| 78 | } |
| 79 | |
| 80 | checkTask <- function(task, target) { |
| 81 | if (!is.null(task)) |
| 82 | task <- match.arg(task, c("classification", "regression")) |
| 83 | task <- ifelse(is.numeric(target), "regression", "classification") |
| 84 | } |
| 85 | |
| 86 | checkModPar <- function(gmodel, params) { |
| 87 | if (is.character(gmodel)) |
| 88 | gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) |
| 89 | else if (!is.null(gmodel) && !is.function(gmodel)) |
| 90 | stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") |
| 91 | if (is.numeric(params) || is.character(params)) |
| 92 | params <- as.list(params) |
| 93 | if (!is.list(params) && !is.null(params)) |
| 94 | stop("params: numerical, character, or list (passed to model)") |
| 95 | if (is.function(gmodel) && !is.list(params)) |
| 96 | stop("params must be provided when using a custom model") |
| 97 | if (is.list(params) && is.null(gmodel)) |
| 98 | stop("model (or family) must be provided when using custom params") |
| 99 | list(gmodel=gmodel, params=params) |
| 100 | } |