问题描述
我想提取随机森林R实现的终端节点.据我所知,随机森林中有一系列正交树.当您预测一个新的观测值(回归)时,它会进入所有这些树,然后平均每个树的预测值.如果我不希望取平均,而是可能对这些相应的观察值进行线性回归,则需要一个与该新观察值相关"的观察值列表.我已经看完了源代码,但还没有想出一种方法来获得它.谁能帮我吗?
I would like to extract the terminal nodes of the random forest R implementation. As I have understood random forest, you have a sequence of orthogonal trees. When you predict a new observation (In regression), it enters all these trees and then you average the prediction of each individual tree. If I wanted to not average but maybe do a linear regression with these corresponding observations I would need, say, a list of the observations that are "associated" with this new observation. I have gone through the source code but havent come up with a way to obtain this. Can anyone help me?
推荐答案
必须有更好的方法,但这是一种解决方法:
There must be a better way to do this, but here's a workaround:
library(randomForest)
set.seed(713)
## data
my.df <- data.frame(x = rnorm(100), y = rnorm(100))
## forest
rf <- randomForest(y ~ x, data = my.df, ntree = 10, keep.inbag = TRUE)
keep.inbag = TRUE
保存该示例中用于容纳10棵树中的每一个的袋内观察值
keep.inbag = TRUE
saves the inbag observations that are used to fit each of the 10 trees in this example
predList <- lapply(seq_len(rf$ntree), function(z)
predict(rf, newdata = my.df[rf$inbag[, z] == 1, ], nodes = TRUE))
nodes = TRUE
跟踪每个观察结束的终端节点.
nodes = TRUE
tracks the terminal nodes each observation ends in.
node.list <- lapply(seq_len(rf$ntree), function(z)
split(x = my.df[rf$inbag[, z] == 1, "x"],
f = attr(predList[[z]], "nodes")[, z]))
第一棵树的前三个终端节点:
First three terminal nodes of the first tree:
node.list[[1]][1:3]
$`3`
[1] 2.028358 2.071939
$`7`
[1] 0.8306559
$`9`
[1] 1.660134 1.621299
这篇关于提取与新观测值关联的每棵树的终端节点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!