我正在尝试使用cforest function(R,party package)。

这是我构建森林的工作:

library("party")
set.seed(42)
readingSkills.cf <- cforest(score ~ ., data = readingSkills,
                         control = cforest_unbiased(mtry = 2, ntree = 50))


然后我要打印第一棵树,然后执行

party:::prettytree(readingSkills.cf@ensemble[[1]],names(readingSkills.cf@data@get("input")))


结果看起来像这样

     1) shoeSize <= 28.29018; criterion = 1, statistic = 89.711
       2) age <= 6; criterion = 1, statistic = 48.324
    3) age <= 5; criterion = 0.997, statistic = 8.917
      4)*  weights = 0
    3) age > 5
      5)*  weights = 0
  2) age > 6
    6) age <= 7; criterion = 1, statistic = 13.387
      7) shoeSize <= 26.66743; criterion = 0.214, statistic = 0.073
        8)*  weights = 0
      7) shoeSize > 26.66743
        9)*  weights = 0
    6) age > 7
      10)*  weights = 0
1) shoeSize > 28.29018
  11) age <= 9; criterion = 1, statistic = 36.836
    12) nativeSpeaker == {}; criterion = 0.998, statistic = 9.347
      13)*  weights = 0
    12) nativeSpeaker == {}
      14)*  weights = 0
  11) age > 9
    15) nativeSpeaker == {}; criterion = 1, statistic = 19.124
      16) age <= 10; criterion = 1, statistic = 18.441
        17)*  weights = 0
      16) age > 10
        18)*  weights = 0
    15) nativeSpeaker == {}
      19)*  weights = 0


为什么它是空的(每个节点的权重等于零)?

最佳答案

简短答案:每个节点中的案例权重weightsNULL,即不存储。 prettytree函数输出weights = 0,因为sum(NULL)在R中等于0。



考虑下面的ctree示例:

library("party")
x <- ctree(Species ~ ., data=iris)
plot(x, type="simple")




对于生成的对象x(类BinaryTree),案例权重存储在每个节点中:

R> sum(x@tree$left$weights)
[1] 50
R> sum(x@tree$right$weights)
[1] 100
R> sum(x@tree$right$left$weights)
[1] 54
R> sum(x@tree$right$right$weights)
[1] 46


现在,让我们仔细看看cforest

y <- cforest(Species ~ ., data=iris, control=cforest_control(mtry=2))
tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
plot(new("BinaryTree", tree=tr, data=y@data, responses=y@responses))




案例权重未存储在树集合中,可以通过以下方式查看:

fixInNamespace("print.TerminalNode", "party")


print方法更改为

function (x, n = 1, ...)·
{
    print(names(x))
    print(x$weights)
    cat(paste(paste(rep(" ", n - 1), collapse = ""), x$nodeID,·
        ")* ", sep = "", collapse = ""), "weights =", sum(x$weights),·
        "\n")
}


现在我们可以观察到在每个节点中weightsNULL

R> tr
1) Petal.Width <= 0.4; criterion = 10.641, statistic = 10.641
 [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"
 [6] "ssplits"    "prediction" "left"       "right"      NA
NULL
  2)*  weights = 0
1) Petal.Width > 0.4
  3) Petal.Width <= 1.6; criterion = 8.629, statistic = 8.629
 [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"
 [6] "ssplits"    "prediction" "left"       "right"      NA
NULL
    4)*  weights = 0
  3) Petal.Width > 1.6
 [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"
 [6] "ssplits"    "prediction" "left"       "right"      NA
NULL
    5)*  weights = 0


更新这是一个hack,以显示案例权重的总和:

update_tree <- function(x) {
  if(!x$terminal) {
    x$left <- update_tree(x$left)
    x$right <- update_tree(x$right)
  } else {
    x$weights <- x[[9]]
    x$weights_ <- x[[9]]
  }
  x
}
tr_weights <- update_tree(tr)
plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))

07-26 06:54