Some fixes + refactoring
[agghoo.git] / R / checks.R
CommitLineData
17ea2f13
BA
1# Internal usage: check and fill arguments with default values.
2
afa67660
BA
3defaultLoss_classif <- function(y1, y2) {
4 if (is.null(dim(y1)))
5 # Standard case: "hard" classification
6 mean(y1 != y2)
7 else {
8 # "Soft" classification: predict() outputs a probability matrix
9 # In this case "target" could be in matrix form.
10 if (!is.null(dim(y2)))
11 mean(rowSums(abs(y1 - y2)))
12 else {
13 # Or not: y2 is a "factor".
14 y2 <- as.character(y2)
15 # NOTE: the user should provide target in matrix form because
16 # matching y2 with columns is rather inefficient!
17 names <- colnames(y1)
18 positions <- list()
19 for (idx in seq_along(names))
20 positions[[ names[idx] ]] <- idx
21 mean(vapply(
22 seq_along(y2),
23 function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])),
24 0))
25 }
26 }
27}
28
29defaultLoss_regress <- function(y1, y2) {
30 mean(abs(y1 - y2))
31}
32
33# TODO: allow strings like "MSE", "abs" etc
34checkLoss <- function(loss, task) {
35 if (!is.null(loss) && !is.function(loss))
36 stop("loss: function(y1, y2) --> Real")
37 if (is.null(loss)) {
38 loss <- if (task == "classification") {
39 defaultLoss_classif
40 } else {
41 defaultLoss_regress
42 }
43 }
44 loss
45}
46
47checkCV <- function(CV) {
48 if (is.null(CV))
49 CV <- list(type="MC", V=10, test_size=0.2, shuffle=TRUE)
50 else {
51 if (!is.list(CV))
52 stop("CV: list of type('MC'|'vfold'), V(integer, [test_size, shuffle]")
53 if (is.null(CV$type)) {
54 warning("CV$type not provided: set to MC")
55 CV$type <- "MC"
56 }
57 if (is.null(CV$V)) {
58 warning("CV$V not provided: set to 10")
59 CV$V <- 10
60 }
61 if (CV$type == "MC" && is.null(CV$test_size))
62 CV$test_size <- 0.2
63 if (CV$type == "vfold" && is.null(CV$shuffle))
64 CV$shuffle <- TRUE
65 }
66 CV
67}
68
69checkDaTa <- function(data, target) {
70 if (!is.data.frame(data) && !is.matrix(data))
71 stop("data: data.frame or matrix")
72 if (is.data.frame(target) || is.matrix(target)) {
73 if (!is.numeric(target))
74 stop("multi-columns target must be a probability matrix")
75 if (nrow(target) != nrow(data) || ncol(target) == 1)
76 stop("target probability matrix does not match data size")
77 }
78 else if (!is.numeric(target) && !is.factor(target) && !is.character(target))
79 stop("target: numeric, factor or character vector")
80}
81
82checkTask <- function(task, target) {
83 if (!is.null(task))
84 task <- match.arg(task, c("classification", "regression"))
17ea2f13 85 ifelse(is.numeric(target), "regression", "classification")
afa67660
BA
86}
87
88checkModPar <- function(gmodel, params) {
89 if (is.character(gmodel))
90 gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree"))
91 else if (!is.null(gmodel) && !is.function(gmodel))
92 stop("gmodel: function(dataHO, targetHO, param) --> function(X) --> y")
93 if (is.numeric(params) || is.character(params))
94 params <- as.list(params)
95 if (!is.list(params) && !is.null(params))
96 stop("params: numerical, character, or list (passed to model)")
97 if (is.function(gmodel) && !is.list(params))
98 stop("params must be provided when using a custom model")
99 if (is.list(params) && is.null(gmodel))
100 stop("model (or family) must be provided when using custom params")
101 list(gmodel=gmodel, params=params)
102}