我最近开始使用scikit Learn sklearn.ensemble.RandomForestClassifier在Python中使用随机森林实现。我在Kaggle上找到了一个示例脚本,用于使用随机森林(见下文)对土地覆被进行分类,以用于磨练自己的技能。我对评估随机森林分类的​​结果感兴趣。

例如,如果我要在R中使用randomForest进行分析,则可以使用varImpPlot()包中的randomForest评估变量的重要性:

require(randomForests)
...
myrf = randomForests(predictors, response)
varImpPlot(myrf)


为了了解错误率和分类的错误矩阵的即用型估计,我只需在解释器中键入“ myrf”即可。

如何使用Python以编程方式评估这些错误指标?

请注意,尽管我不确定如何实际应用这些属性,但我知道文档中有多个潜在有用的属性(例如feature_importances_oob_score_oob_decision_function_)。



样本RF脚本

import pandas as pd
from sklearn import ensemble

if __name__ == "__main__":
  loc_train = "kaggle_forest\\train.csv"
  loc_test = "kaggle_forest\\test.csv"
  loc_submission = "kaggle_forest\\kaggle.forest.submission.csv"

  df_train = pd.read_csv(loc_train)
  df_test = pd.read_csv(loc_test)

  feature_cols = [col for col in df_train.columns if col not in ['Cover_Type','Id']]

  X_train = df_train[feature_cols]
  X_test = df_test[feature_cols]
  y = df_train['Cover_Type']
  test_ids = df_test['Id']

  clf = ensemble.RandomForestClassifier(n_estimators = 500, n_jobs = -1)

  clf.fit(X_train, y)

  with open(loc_submission, "wb") as outfile:
    outfile.write("Id,Cover_Type\n")
    for e, val in enumerate(list(clf.predict(X_test))):
      outfile.write("%s,%s\n"%(test_ids[e],val))

最佳答案

训练后,如果您有测试数据和标签,则可以通过以下方法检查准确性并生成ROC图/ AUC分数:

from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

# overall accuracy
acc = clf.score(X_test,Y_test)

# get roc/auc info
Y_score = clf.predict_proba(X_test)[:,1]
fpr = dict()
tpr = dict()
fpr, tpr, _ = roc_curve(Y_test, Y_score)

roc_auc = dict()
roc_auc = auc(fpr, tpr)

# make the plot
plt.figure(figsize=(10,10))
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([-0.05, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.grid(True)
plt.plot(fpr, tpr, label='AUC = {0}'.format(roc_auc))
plt.legend(loc="lower right", shadow=True, fancybox =True)
plt.show()

关于python - 如何评估随机森林分类器的性能?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/29148355/

10-09 03:06