Reorganize code - unfinished: some functions not exported yet
[agghoo.git] / R / compareTo.R
1 standardCV_core <- function(data, target, task, gmodel, params, loss, CV) {
2 n <- nrow(data)
3 shuffle_inds <- NULL
4 if (CV$type == "vfold" && CV$shuffle)
5 shuffle_inds <- sample(n, n)
6 list_testinds <- list()
7 for (v in seq_len(CV$V))
8 list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds)
9 gmodel <- agghoo::Model$new(data, target, task, gmodel, params)
10 best_error <- Inf
11 best_model <- NULL
12 for (p in seq_len(gmodel$nmodels)) {
13 error <- Reduce('+', lapply(seq_len(CV$V), function(v) {
14 testIdx <- list_testinds[[v]]
15 d <- splitTrainTest(data, target, testIdx)
16 model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p)
17 prediction <- model_pred(d$dataTest)
18 loss(prediction, d$targetTest)
19 }) )
20 if (error <= best_error) {
21 newModel <- list(model=gmodel$get(data, target, p),
22 param=gmodel$getParam(p))
23 if (error == best_error)
24 best_model[[length(best_model)+1]] <- newModel
25 else {
26 best_model <- list(newModel)
27 best_error <- error
28 }
29 }
30 }
31 #browser()
32 best_model[[ sample(length(best_model), 1) ]]
33 }
34
35 standardCV_run <- function(
36 dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ...
37 ) {
38 args <- list(...)
39 task <- checkTask(args$task, targetTrain)
40 modPar <- checkModPar(args$gmodel, args$params)
41 loss <- checkLoss(args$loss, task)
42 s <- standardCV_core(
43 dataTrain, targetTrain, task, modPar$gmodel, modPar$params, loss, CV)
44 if (verbose)
45 print(paste( "Parameter:", s$param ))
46 p <- s$model(dataTest)
47 err <- floss(p, targetTest)
48 if (verbose)
49 print(paste("error CV:", err))
50 invisible(err)
51 }
52
53 agghoo_run <- function(
54 dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ...
55 ) {
56 a <- agghoo(dataTrain, targetTrain, ...)
57 a$fit(CV)
58 if (verbose) {
59 print("Parameters:")
60 print(unlist(a$getParams()))
61 }
62 pa <- a$predict(dataTest)
63 err <- floss(pa, targetTest)
64 if (verbose)
65 print(paste("error agghoo:", err))
66 invisible(err)
67 }
68
69 # ... arguments passed to method_s (agghoo, standard CV or else)
70 compareTo <- function(
71 data, target, method_s, rseed=-1, floss=NULL, verbose=TRUE, ...
72 ) {
73 if (rseed >= 0)
74 set.seed(rseed)
75 n <- nrow(data)
76 test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
77 d <- splitTrainTest(data, target, test_indices)
78 CV <- checkCV(list(...)$CV)
79
80 # Set error function to be used on model outputs (not in core method)
81 task <- checkTask(list(...)$task, target)
82 if (is.null(floss)) {
83 floss <- function(y1, y2) {
84 ifelse(task == "classification", mean(y1 != y2), mean(abs(y1 - y2)))
85 }
86 }
87
88 # Run (and compare) all methods:
89 runOne <- function(o) {
90 o(d$dataTrain, d$dataTest, d$targetTrain, d$targetTest,
91 CV, floss, verbose, ...)
92 }
93 errors <- c()
94 if (is.list(method_s))
95 errors <- sapply(method_s, runOne)
96 else if (is.function(method_s))
97 errors <- runOne(method_s)
98 invisible(errors)
99 }
100
101 # Run compareTo N times in parallel
102 # ... : additional args to be passed to method_s
103 compareMulti <- function(
104 data, target, method_s, N=100, nc=NA, floss=NULL, ...
105 ) {
106 require(parallel)
107 if (is.na(nc))
108 nc <- parallel::detectCores()
109
110 # "One" comparison for each method in method_s (list)
111 compareOne <- function(n) {
112 print(n)
113 compareTo(data, target, method_s, n, floss, verbose=FALSE, ...)
114 }
115
116 errors <- if (nc >= 2) {
117 parallel::mclapply(1:N, compareOne, mc.cores = nc)
118 } else {
119 lapply(1:N, compareOne)
120 }
121 print("Errors:")
122 Reduce('+', errors) / N
123 }