X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=R%2Futils.R;fp=R%2Futils.R;h=fa3a9dfc43e3e547ee2281841294e80fee1093c3;hp=0000000000000000000000000000000000000000;hb=afa676609daba103e43d6d4654560ca4c1c9b38b;hpb=43a6578d444f388d72755e74c7eed74f3af638ec diff --git a/R/utils.R b/R/utils.R new file mode 100644 index 0000000..fa3a9df --- /dev/null +++ b/R/utils.R @@ -0,0 +1,28 @@ +get_testIndices <- function(n, CV, v, shuffle_inds) { + if (CV$type == "vfold") { + # Slice indices (optionnally shuffled) + first_index = round((v-1) * n / CV$V) + 1 + last_index = round(v * n / CV$V) + test_indices = first_index:last_index + if (!is.null(shuffle_inds)) + test_indices <- shuffle_inds[test_indices] + } + else + # Monte-Carlo cross-validation + test_indices = sample(n, round(n * CV$test_size)) + test_indices +} + +splitTrainTest <- function(data, target, testIdx) { + dataTrain <- data[-testIdx,] + targetTrain <- target[-testIdx] + dataTest <- data[testIdx,] + targetTest <- target[testIdx] + # [HACK] R will cast 1-dim matrices into vectors: + if (!is.matrix(dataTrain) && !is.data.frame(dataTrain)) + dataTrain <- as.matrix(dataTrain) + if (!is.matrix(dataTest) && !is.data.frame(dataTest)) + dataTest <- as.matrix(dataTest) + list(dataTrain=dataTrain, targetTrain=targetTrain, + dataTest=dataTest, targetTest=targetTest) +}