- function(X) predict(model, X)
- }
- }
- else if (family == "rf") {
- function(dataHO, targetHO, param) {
- require(randomForest)
- if (task == "classification" && !is.factor(targetHO))
- targetHO <- as.factor(targetHO)
- model <- randomForest::randomForest(dataHO, targetHO, mtry=param)
- function(X) predict(model, X)
+ if (task == "regression")
+ type <- "vector"
+ else {
+ if (is.null(dim(targetHO)))
+ type <- "class"
+ else
+ type <- "prob"
+ }
+ function(X) {
+ if (is.null(colnames(X)))
+ colnames(X) <- paste0("V", 1:ncol(X))
+ predict(model, as.data.frame(X), type=type)
+ }