X-Git-Url: https://git.auder.net/?p=agghoo.git;a=blobdiff_plain;f=R%2FR6_Model.R;h=05cb7d8dd4bd52cb261a4110b56103799a83c1f8;hp=3c84812e57d5703064b5f02dedd5438c18913613;hb=17ea2f13e0c32c107db20677750bd7a98bb7e0f8;hpb=afa676609daba103e43d6d4654560ca4c1c9b38b diff --git a/R/R6_Model.R b/R/R6_Model.R index 3c84812..05cb7d8 100644 --- a/R/R6_Model.R +++ b/R/R6_Model.R @@ -77,10 +77,18 @@ Model <- R6::R6Class("Model", colnames(dataHO) <- paste0("V", 1:ncol(dataHO)) df <- data.frame(cbind(dataHO, target=targetHO)) model <- rpart::rpart(target ~ ., df, method=method, control=list(cp=param)) + if (task == "regression") + type <- "vector" + else { + if (is.null(dim(targetHO))) + type <- "class" + else + type <- "prob" + } function(X) { if (is.null(colnames(X))) colnames(X) <- paste0("V", 1:ncol(X)) - predict(model, as.data.frame(X)) + predict(model, as.data.frame(X), type=type) } } } @@ -139,7 +147,7 @@ Model <- R6::R6Class("Model", p <- ncol(data) # Use caret package to obtain the CV grid of mtry values require(caret) - caret::var_seq(p, classification = (task == "classificaton"), + caret::var_seq(p, classification = (task == "classification"), len = min(10, p-1)) } else if (family == "ppr")