Reorganize code - unfinished: some functions not exported yet
[agghoo.git] / R / compareTo.R
CommitLineData
afa67660 1standardCV_core <- function(data, target, task, gmodel, params, loss, CV) {
43a6578d
BA
2 n <- nrow(data)
3 shuffle_inds <- NULL
4 if (CV$type == "vfold" && CV$shuffle)
5 shuffle_inds <- sample(n, n)
43a6578d
BA
6 list_testinds <- list()
7 for (v in seq_len(CV$V))
afa67660 8 list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds)
43a6578d
BA
9 gmodel <- agghoo::Model$new(data, target, task, gmodel, params)
10 best_error <- Inf
11 best_model <- NULL
12 for (p in seq_len(gmodel$nmodels)) {
afa67660 13 error <- Reduce('+', lapply(seq_len(CV$V), function(v) {
43a6578d 14 testIdx <- list_testinds[[v]]
afa67660
BA
15 d <- splitTrainTest(data, target, testIdx)
16 model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p)
17 prediction <- model_pred(d$dataTest)
18 loss(prediction, d$targetTest)
19 }) )
43a6578d 20 if (error <= best_error) {
afa67660
BA
21 newModel <- list(model=gmodel$get(data, target, p),
22 param=gmodel$getParam(p))
43a6578d
BA
23 if (error == best_error)
24 best_model[[length(best_model)+1]] <- newModel
25 else {
26 best_model <- list(newModel)
27 best_error <- error
28 }
29 }
30 }
afa67660 31#browser()
43a6578d
BA
32 best_model[[ sample(length(best_model), 1) ]]
33}
34
35standardCV_run <- function(
afa67660 36 dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ...
43a6578d 37) {
afa67660
BA
38 args <- list(...)
39 task <- checkTask(args$task, targetTrain)
40 modPar <- checkModPar(args$gmodel, args$params)
41 loss <- checkLoss(args$loss, task)
42 s <- standardCV_core(
43 dataTrain, targetTrain, task, modPar$gmodel, modPar$params, loss, CV)
43a6578d
BA
44 if (verbose)
45 print(paste( "Parameter:", s$param ))
afa67660
BA
46 p <- s$model(dataTest)
47 err <- floss(p, targetTest)
43a6578d 48 if (verbose)
afa67660
BA
49 print(paste("error CV:", err))
50 invisible(err)
43a6578d
BA
51}
52
53agghoo_run <- function(
afa67660 54 dataTrain, dataTest, targetTrain, targetTest, CV, floss, verbose, ...
43a6578d
BA
55) {
56 a <- agghoo(dataTrain, targetTrain, ...)
57 a$fit(CV)
58 if (verbose) {
59 print("Parameters:")
60 print(unlist(a$getParams()))
61 }
62 pa <- a$predict(dataTest)
63 err <- floss(pa, targetTest)
64 if (verbose)
65 print(paste("error agghoo:", err))
afa67660 66 invisible(err)
43a6578d
BA
67}
68
afa67660 69# ... arguments passed to method_s (agghoo, standard CV or else)
43a6578d 70compareTo <- function(
afa67660 71 data, target, method_s, rseed=-1, floss=NULL, verbose=TRUE, ...
43a6578d
BA
72) {
73 if (rseed >= 0)
74 set.seed(rseed)
75 n <- nrow(data)
76 test_indices <- sample( n, round(n / ifelse(n >= 500, 10, 5)) )
afa67660
BA
77 d <- splitTrainTest(data, target, test_indices)
78 CV <- checkCV(list(...)$CV)
43a6578d
BA
79
80 # Set error function to be used on model outputs (not in core method)
afa67660 81 task <- checkTask(list(...)$task, target)
43a6578d
BA
82 if (is.null(floss)) {
83 floss <- function(y1, y2) {
84 ifelse(task == "classification", mean(y1 != y2), mean(abs(y1 - y2)))
85 }
86 }
87
88 # Run (and compare) all methods:
89 runOne <- function(o) {
afa67660
BA
90 o(d$dataTrain, d$dataTest, d$targetTrain, d$targetTest,
91 CV, floss, verbose, ...)
43a6578d 92 }
afa67660 93 errors <- c()
43a6578d
BA
94 if (is.list(method_s))
95 errors <- sapply(method_s, runOne)
96 else if (is.function(method_s))
97 errors <- runOne(method_s)
43a6578d
BA
98 invisible(errors)
99}
100
101# Run compareTo N times in parallel
afa67660 102# ... : additional args to be passed to method_s
43a6578d 103compareMulti <- function(
afa67660 104 data, target, method_s, N=100, nc=NA, floss=NULL, ...
43a6578d 105) {
afa67660 106 require(parallel)
43a6578d
BA
107 if (is.na(nc))
108 nc <- parallel::detectCores()
afa67660
BA
109
110 # "One" comparison for each method in method_s (list)
43a6578d
BA
111 compareOne <- function(n) {
112 print(n)
afa67660 113 compareTo(data, target, method_s, n, floss, verbose=FALSE, ...)
43a6578d 114 }
afa67660 115
43a6578d 116 errors <- if (nc >= 2) {
43a6578d
BA
117 parallel::mclapply(1:N, compareOne, mc.cores = nc)
118 } else {
119 lapply(1:N, compareOne)
120 }
121 print("Errors:")
122 Reduce('+', errors) / N
123}