276749ba6b0a0d384b27f88cdc72eccdf3b789c6
[agghoo.git] / test / compareToCV.R
1 library(agghoo)
2
3 standardCV <- function(data, target, task = NULL, gmodel = NULL, params = NULL,
4 loss = NULL, CV = list(type = "MC", V = 10, test_size = 0.2, shuffle = TRUE)
5 ) {
6 if (!is.null(task))
7 task = match.arg(task, c("classification", "regression"))
8 if (is.character(gmodel))
9 gmodel <- match.arg(gmodel, c("knn", "ppr", "rf", "tree"))
10 if (is.numeric(params) || is.character(params))
11 params <- as.list(params)
12 if (is.null(task)) {
13 if (is.numeric(target))
14 task = "regression"
15 else
16 task = "classification"
17 }
18
19 if (is.null(loss)) {
20 loss <- function(y1, y2) {
21 if (task == "classification") {
22 if (is.null(dim(y1)))
23 mean(y1 != y2)
24 else {
25 if (!is.null(dim(y2)))
26 mean(rowSums(abs(y1 - y2)))
27 else {
28 y2 <- as.character(y2)
29 names <- colnames(y1)
30 positions <- list()
31 for (idx in seq_along(names))
32 positions[[ names[idx] ]] <- idx
33 mean(vapply(
34 seq_along(y2),
35 function(idx) sum(abs(y1[idx,] - positions[[ y2[idx] ]])),
36 0))
37 }
38 }
39 }
40 else
41 mean(abs(y1 - y2))
42 }
43 }
44
45 n <- nrow(data)
46 shuffle_inds <- NULL
47 if (CV$type == "vfold" && CV$shuffle)
48 shuffle_inds <- sample(n, n)
49 get_testIndices <- function(v, shuffle_inds) {
50 if (CV$type == "vfold") {
51 first_index = round((v-1) * n / CV$V) + 1
52 last_index = round(v * n / CV$V)
53 test_indices = first_index:last_index
54 if (!is.null(shuffle_inds))
55 test_indices <- shuffle_inds[test_indices]
56 }
57 else
58 test_indices = sample(n, round(n * CV$test_size))
59 test_indices
60 }
61 list_testinds <- list()
62 for (v in seq_len(CV$V))
63 list_testinds[[v]] <- get_testIndices(v, shuffle_inds)
64
65 gmodel <- agghoo::Model$new(data, target, task, gmodel, params)
66 best_error <- Inf
67 best_model <- NULL
68 for (p in seq_len(gmodel$nmodels)) {
69 error <- 0
70 for (v in seq_len(CV$V)) {
71 testIdx <- list_testinds[[v]]
72 dataHO <- data[-testIdx,]
73 testX <- data[testIdx,]
74 targetHO <- target[-testIdx]
75 testY <- target[testIdx]
76 if (!is.matrix(dataHO) && !is.data.frame(dataHO))
77 dataHO <- as.matrix(dataHO)
78 if (!is.matrix(testX) && !is.data.frame(testX))
79 testX <- as.matrix(testX)
80 model_pred <- gmodel$get(dataHO, targetHO, p)
81 prediction <- model_pred(testX)
82 error <- error + loss(prediction, testY)
83 }
84 if (error <= best_error) {
85 newModel <- list(model=model_pred, param=gmodel$getParam(p))
86 if (error == best_error)
87 best_model[[length(best_model)+1]] <- newModel
88 else {
89 best_model <- list(newModel)
90 best_error <- error
91 }
92 }
93 }
94 best_model[[ sample(length(best_model), 1) ]]
95 }
96
97 compareToCV <- function(df, t_idx, task=NULL, rseed=-1, verbose=TRUE, ...) {
98 if (rseed >= 0)
99 set.seed(rseed)
100 if (is.null(task))
101 task <- ifelse(is.numeric(df[,t_idx]), "regression", "classification")
102 n <- nrow(df)
103 test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
104 data <- as.matrix(df[-test_indices,-t_idx])
105 target <- df[-test_indices,t_idx]
106 test <- as.matrix(df[test_indices,-t_idx])
107 a <- agghoo(data, target, task, ...)
108 a$fit()
109 if (verbose) {
110 print("Parameters:")
111 print(unlist(a$getParams()))
112 }
113 pa <- a$predict(test)
114 err_a <- ifelse(task == "classification",
115 mean(pa != df[test_indices,t_idx]),
116 mean(abs(pa - df[test_indices,t_idx])))
117 if (verbose)
118 print(paste("error agghoo:", err_a))
119 # Compare with standard cross-validation:
120 s <- standardCV(data, target, task, ...)
121 if (verbose)
122 print(paste( "Parameter:", s$param ))
123 ps <- s$model(test)
124 err_s <- ifelse(task == "classification",
125 mean(ps != df[test_indices,t_idx]),
126 mean(abs(ps - df[test_indices,t_idx])))
127 if (verbose)
128 print(paste("error CV:", err_s))
129 invisible(c(err_a, err_s))
130 }
131
132 library(parallel)
133 compareMulti <- function(df, t_idx, task = NULL, N = 100, nc = NA, ...) {
134 if (is.na(nc))
135 nc <- detectCores()
136 compareOne <- function(n) {
137 print(n)
138 compareToCV(df, t_idx, task, n, verbose=FALSE, ...)
139 }
140 errors <- if (nc >= 2) {
141 mclapply(1:N, compareOne, mc.cores = nc)
142 } else {
143 lapply(1:N, compareOne)
144 }
145 print("error agghoo vs. cross-validation:")
146 Reduce('+', errors) / N
147 }