X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=R%2FcompareTo.R;h=e6bf2b2f44f45d7d9f623e0b82891fcf5448edaf;hp=eb372dc0e43c57da9336c778d42295239b6d6bd6;hb=a7ec4f8a3987ee66daef8471ed1a7a609a987914;hpb=a78bd1c09d34868df1cf11664bfc5cef4a9384d6 diff --git a/R/compareTo.R b/R/compareTo.R index eb372dc..e6bf2b2 100644 --- a/R/compareTo.R +++ b/R/compareTo.R @@ -1,3 +1,42 @@ +#' CVvoting_core +#' +#' "voting" cross-validation method, added here as an example. +#' Parameters are described in ?agghoo and ?AgghooCV +CVvoting_core <- function(data, target, task, gmodel, params, loss, CV) { + CV <- checkCV(CV) + n <- nrow(data) + shuffle_inds <- NULL + if (CV$type == "vfold" && CV$shuffle) + shuffle_inds <- sample(n, n) + bestP <- rep(0, gmodel$nmodels) + gmodel <- agghoo::Model$new(data, target, task, gmodel, params) + for (v in seq_len(CV$V)) { + test_indices <- get_testIndices(n, CV, v, shuffle_inds) + d <- splitTrainTest(data, target, test_indices) + best_p <- NULL + best_error <- Inf + for (p in seq_len(gmodel$nmodels)) { + model_pred <- gmodel$get(d$dataTrain, d$targetTrain, p) + prediction <- model_pred(d$dataTest) + error <- loss(prediction, d$targetTest) + if (error <= best_error) { + if (error == best_error) + best_p[[length(best_p)+1]] <- p + else { + best_p <- list(p) + best_error <- error + } + } + } + for (p in best_p) + bestP[p] <- bestP[p] + 1 + } + # Choose a param at random in case of ex-aequos: + maxP <- max(bestP) + chosenP <- sample(which(bestP == maxP), 1) + list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) +} + #' standardCV_core #' #' Cross-validation method, added here as an example. @@ -12,7 +51,7 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { list_testinds[[v]] <- get_testIndices(n, CV, v, shuffle_inds) gmodel <- agghoo::Model$new(data, target, task, gmodel, params) best_error <- Inf - best_model <- NULL + best_p <- NULL for (p in seq_len(gmodel$nmodels)) { error <- Reduce('+', lapply(seq_len(CV$V), function(v) { testIdx <- list_testinds[[v]] @@ -22,17 +61,16 @@ standardCV_core <- function(data, target, task, gmodel, params, loss, CV) { loss(prediction, d$targetTest) }) ) if (error <= best_error) { - newModel <- list(model=gmodel$get(data, target, p), - param=gmodel$getParam(p)) if (error == best_error) - best_model[[length(best_model)+1]] <- newModel + best_p[[length(best_p)+1]] <- p else { - best_model <- list(newModel) + best_p <- list(p) best_error <- error } } } - best_model[[ sample(length(best_model), 1) ]] + chosenP <- best_p[[ sample(length(best_p), 1) ]] + list(model=gmodel$get(data, target, chosenP), param=gmodel$getParam(chosenP)) } #' standardCV_run