Commit | Line | Data |
---|---|---|
17ea2f13 BA |
1 | # Internal usage: check and fill arguments with default values. |
2 | ||
afa67660 BA |
3 | defaultLoss_classif <- function(y1, y2) { |
4 | if (is.null(dim(y1))) | |
5 | # Standard case: "hard" classification | |
6 | mean(y1 != y2) | |
7 | else { | |
8 | # "Soft" classification: predict() outputs a probability matrix | |
9 | # In this case "target" could be in matrix form. | |
10 | if (!is.null(dim(y2))) | |
11 | mean(rowSums(abs(y1 - y2))) | |
12 | else { | |
13 | # Or not: y2 is a "factor". | |
14 | y2 <- as.character(y2) | |
15 | # NOTE: the user should provide target in matrix form because | |
16 | # matching y2 with columns is rather inefficient! | |
17 | names <- colnames(y1) | |
18 | positions <- list() | |
19 | for (idx in seq_along(names)) | |
20 | positions[[ names[idx] ]] <- idx | |
21 | mean(vapply( | |
22 | seq_along(y2), | |
23 | function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])), | |
24 | 0)) | |
25 | } | |
26 | } | |
27 | } | |
28 | ||
29 | defaultLoss_regress <- function(y1, y2) { | |
30 | mean(abs(y1 - y2)) | |
31 | } | |
32 | ||
33 | # TODO: allow strings like "MSE", "abs" etc | |
34 | checkLoss <- function(loss, task) { | |
35 | if (!is.null(loss) && !is.function(loss)) | |
36 | stop("loss: function(y1, y2) --> Real") | |
37 | if (is.null(loss)) { | |
38 | loss <- if (task == "classification") { | |
39 | defaultLoss_classif | |
40 | } else { | |
41 | defaultLoss_regress | |
42 | } | |
43 | } | |
44 | loss | |
45 | } | |
46 | ||
47 | checkCV <- function(CV) { | |
48 | if (is.null(CV)) | |
49 | CV <- list(type="MC", V=10, test_size=0.2, shuffle=TRUE) | |
50 | else { | |
51 | if (!is.list(CV)) | |
52 | stop("CV: list of type('MC'|'vfold'), V(integer, [test_size, shuffle]") | |
53 | if (is.null(CV$type)) { | |
54 | warning("CV$type not provided: set to MC") | |
55 | CV$type <- "MC" | |
56 | } | |
57 | if (is.null(CV$V)) { | |
58 | warning("CV$V not provided: set to 10") | |
59 | CV$V <- 10 | |
60 | } | |
61 | if (CV$type == "MC" && is.null(CV$test_size)) | |
62 | CV$test_size <- 0.2 | |
63 | if (CV$type == "vfold" && is.null(CV$shuffle)) | |
64 | CV$shuffle <- TRUE | |
65 | } | |
66 | CV | |
67 | } | |
68 | ||
69 | checkDaTa <- function(data, target) { | |
70 | if (!is.data.frame(data) && !is.matrix(data)) | |
71 | stop("data: data.frame or matrix") | |
72 | if (is.data.frame(target) || is.matrix(target)) { | |
73 | if (!is.numeric(target)) | |
74 | stop("multi-columns target must be a probability matrix") | |
75 | if (nrow(target) != nrow(data) || ncol(target) == 1) | |
76 | stop("target probability matrix does not match data size") | |
77 | } | |
78 | else if (!is.numeric(target) && !is.factor(target) && !is.character(target)) | |
79 | stop("target: numeric, factor or character vector") | |
80 | } | |
81 | ||
82 | checkTask <- function(task, target) { | |
83 | if (!is.null(task)) | |
84 | task <- match.arg(task, c("classification", "regression")) | |
17ea2f13 | 85 | ifelse(is.numeric(target), "regression", "classification") |
afa67660 BA |
86 | } |
87 | ||
88 | checkModPar <- function(gmodel, params) { | |
89 | if (is.character(gmodel)) | |
90 | gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree")) | |
91 | else if (!is.null(gmodel) && !is.function(gmodel)) | |
92 | stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y") | |
93 | if (is.numeric(params) || is.character(params)) | |
94 | params <- as.list(params) | |
95 | if (!is.list(params) && !is.null(params)) | |
96 | stop("params: numerical, character, or list (passed to model)") | |
97 | if (is.function(gmodel) && !is.list(params)) | |
98 | stop("params must be provided when using a custom model") | |
99 | if (is.list(params) && is.null(gmodel)) | |
100 | stop("model (or family) must be provided when using custom params") | |
101 | list(gmodel=gmodel, params=params) | |
102 | } |