scikit-learn的DecisionTreeClassifier
支持通过predict_proba()
函数预测每个类的概率。 DecisionTreeRegressor
中不存在:
AttributeError:'DecisionTreeRegressor'对象没有属性'predict_proba'
我的理解是,决策树分类器和回归器之间的底层机制非常相似,主要区别在于,将回归器的预测计算为潜在叶子的手段。因此,我希望有可能提取每个值的概率。
有没有另一种方法可以模拟这种情况,例如通过处理tree structure? DecisionTreeClassifier
的predict_proba
的code不可直接转让。
最佳答案
您可以从树结构中获取该数据:
import sklearn
import numpy as np
import graphviz
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.datasets import make_regression
# Generate a simple dataset
X, y = make_regression(n_features=2, n_informative=2, random_state=0)
clf = DecisionTreeRegressor(random_state=0, max_depth=2)
clf.fit(X, y)
# Visualize the tree
graphviz.Source(sklearn.tree.export_graphviz(clf)).view()
>>> clf.predict(X[:5])
0 184.005667
1 53.017289
2 184.005667
3 -20.603498
4 -97.414461
如果调用
clf.apply(X)
,您将获得实例所属的节点ID:array([6, 5, 6, 3, 2, 5, 5, 3, 6, ... 5, 5, 6, 3, 2, 2, 5, 2, 2], dtype=int64)
将其与目标变量合并:
df = pd.DataFrame(np.vstack([y, clf.apply(X)]), index=['y','node_id']).T
y node_id
0 190.370562 6.0
1 13.339570 5.0
2 141.772669 6.0
3 -3.069627 3.0
4 -26.062465 2.0
5 54.922541 5.0
6 25.952881 5.0
...
现在,如果您在
node_id
上进行分组,则表示您将获得与clf.predict(X)
相同的值>>> df.groupby('node_id').mean()
y
node_id
2.0 -97.414461
3.0 -20.603498
5.0 53.017289
6.0 184.005667
我们的树中叶子的
value
是哪些:>>> clf.tree_.value[6]
array([[184.00566679]])
要获取新数据集的节点ID,您需要调用
clf.decision_path(X[:5]).toarray()
它显示了这样的数组
array([[1, 0, 0, 0, 1, 0, 1],
[1, 0, 0, 0, 1, 1, 0],
[1, 0, 0, 0, 1, 0, 1],
[1, 1, 0, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0]], dtype=int64)
您需要获取最后一个非零元素(即叶子)的地方
>>> pd.DataFrame(clf.decision_path(X[:5]).toarray()).apply(lambda x:x.nonzero()[0].max(), axis=1)
0 6
1 5
2 6
3 3
4 2
dtype: int64
因此,如果不是预测平均值,而是想要预测中位数,
>>> pd.DataFrame(clf.decision_path(X[:5]).toarray()).apply(lambda x: x.nonzero()[0].max(
), axis=1).to_frame(name='node_id').join(df.groupby('node_id').median(), on='node_id')['y']
0 181.381106
1 54.053170
2 181.381106
3 -28.591188
4 -93.891889
关于python - 对于DecisionTreeRegressor而言,等效于Forecast_proba,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53586860/