Reorganize code - unfinished: some functions not exported yet
[agghoo.git] / R / utils.R
CommitLineData
afa67660
BA
1get_testIndices <- function(n, CV, v, shuffle_inds) {
2 if (CV$type == "vfold") {
3 # Slice indices (optionnally shuffled)
4 first_index = round((v-1) * n / CV$V) + 1
5 last_index = round(v * n / CV$V)
6 test_indices = first_index:last_index
7 if (!is.null(shuffle_inds))
8 test_indices <- shuffle_inds[test_indices]
9 }
10 else
11 # Monte-Carlo cross-validation
12 test_indices = sample(n, round(n * CV$test_size))
13 test_indices
14}
15
16splitTrainTest <- function(data, target, testIdx) {
17 dataTrain <- data[-testIdx,]
18 targetTrain <- target[-testIdx]
19 dataTest <- data[testIdx,]
20 targetTest <- target[testIdx]
21 # [HACK] R will cast 1-dim matrices into vectors:
22 if (!is.matrix(dataTrain) && !is.data.frame(dataTrain))
23 dataTrain <- as.matrix(dataTrain)
24 if (!is.matrix(dataTest) && !is.data.frame(dataTest))
25 dataTest <- as.matrix(dataTest)
26 list(dataTrain=dataTrain, targetTrain=targetTrain,
27 dataTest=dataTest, targetTest=targetTest)
28}