Commit | Line | Data |
---|---|---|
17ea2f13 | 1 | # Helper for cross-validation: return the next test indices. |
afa67660 BA |
2 | get_testIndices <- function(n, CV, v, shuffle_inds) { |
3 | if (CV$type == "vfold") { | |
4 | # Slice indices (optionnally shuffled) | |
5 | first_index = round((v-1) * n / CV$V) + 1 | |
6 | last_index = round(v * n / CV$V) | |
7 | test_indices = first_index:last_index | |
8 | if (!is.null(shuffle_inds)) | |
9 | test_indices <- shuffle_inds[test_indices] | |
10 | } | |
11 | else | |
12 | # Monte-Carlo cross-validation | |
13 | test_indices = sample(n, round(n * CV$test_size)) | |
14 | test_indices | |
15 | } | |
16 | ||
17ea2f13 | 17 | # Helper which split data into training and testing parts. |
afa67660 BA |
18 | splitTrainTest <- function(data, target, testIdx) { |
19 | dataTrain <- data[-testIdx,] | |
20 | targetTrain <- target[-testIdx] | |
21 | dataTest <- data[testIdx,] | |
22 | targetTest <- target[testIdx] | |
23 | # [HACK] R will cast 1-dim matrices into vectors: | |
24 | if (!is.matrix(dataTrain) && !is.data.frame(dataTrain)) | |
25 | dataTrain <- as.matrix(dataTrain) | |
26 | if (!is.matrix(dataTest) && !is.data.frame(dataTest)) | |
27 | dataTest <- as.matrix(dataTest) | |
28 | list(dataTrain=dataTrain, targetTrain=targetTrain, | |
29 | dataTest=dataTest, targetTest=targetTest) | |
30 | } |