projects
/
agghoo.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Fix gmodel == tree for regression
[agghoo.git]
/
R
/
R6_Model.R
diff --git
a/R/R6_Model.R
b/R/R6_Model.R
index
96f892d
..
3c84812
100644
(file)
--- a/
R/R6_Model.R
+++ b/
R/R6_Model.R
@@
-40,7
+40,7
@@
Model <- R6::R6Class("Model",
if (is.null(params))
# Here, gmodel is a string (= its family),
# because a custom model must be given with its parameters.
if (is.null(params))
# Here, gmodel is a string (= its family),
# because a custom model must be given with its parameters.
- params <- as.list(private$getParams(gmodel, data, target))
+ params <- as.list(private$getParams(gmodel, data, target
, task
))
private$params <- params
if (is.character(gmodel))
gmodel <- private$getGmodel(gmodel, task)
private$params <- params
if (is.character(gmodel))
gmodel <- private$getGmodel(gmodel, task)
@@
-115,7
+115,7
@@
Model <- R6::R6Class("Model",
}
},
# Return a default list of parameters, given a gmodel family
}
},
# Return a default list of parameters, given a gmodel family
- getParams = function(family, data, target) {
+ getParams = function(family, data, target
, task
) {
if (family == "tree") {
# Run rpart once to obtain a CV grid for parameter cp
require(rpart)
if (family == "tree") {
# Run rpart once to obtain a CV grid for parameter cp
require(rpart)
@@
-125,13
+125,13
@@
Model <- R6::R6Class("Model",
minsplit = 2,
minbucket = 1,
xval = 0)
minsplit = 2,
minbucket = 1,
xval = 0)
- r <- rpart(target ~ ., df, method="class", control=ctrl)
+ method <- ifelse(task == "classification", "class", "anova")
+ r <- rpart(target ~ ., df, method=method, control=ctrl)
cps <- r$cptable[-1,1]
cps <- r$cptable[-1,1]
- if (length(cps) <= 1
1) {
-
if (length(cps == 0)
)
-
stop("No cross-validation possible: select another model"
)
+ if (length(cps) <= 1
)
+
stop("No cross-validation possible: select another model"
)
+
if (length(cps) <= 11
)
return (cps)
return (cps)
- }
step <- (length(cps) - 1) / 10
cps[unique(round(seq(1, length(cps), step)))]
}
step <- (length(cps) - 1) / 10
cps[unique(round(seq(1, length(cps), step)))]
}