如何获得每行rpart
模型的终端节点的ID(或名称)? predict.rpart
只能返回分类树的预测类(数字或因子)或类概率或某种组合(使用type="matrix"
)。
我想做类似的事情:
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit) # there are 5 terminal nodes
predict(fit, type = "node_id") # should return IDs of terminal nodes (e.g. 1-5) (does not work)
最佳答案
对于该模型,有4个拆分,产生5个“终端节点”或rpart中使用的术语:<leaf>
。我不明白为什么对任何事情都要有5个预测。这些预测是针对特定情况的,叶子是用于进行这些预测的可变数量拆分的结果。原始数据集中出现在叶子中的行数可能是您想要的,在这种情况下,这些是获取这些数字的方法:
# Row-wise predicted class
fit$where
# counts of cases in leaves of prediction rules
table(fit$where)
3 5 7 8 9
29 12 14 7 19
为了组装适用于特定叶子的
labels(fit)
,您将需要遍历规则树并累积用于生成特定叶子的所有拆分的所有标签。您可能想看一下:?print.rpart
?rpart.object
?text.rpart
?labels.rpart