我对Sklearn的类StratifiedShuffleSplit如何工作感到困惑。

下面的代码来自Géron的书“ Hands On Machine Learning”(第2章),在那里他进行了分层抽样。

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]


特别是split.split在做什么?

谢谢!

最佳答案

split.split()函数返回火车样本和测试样本的索引。它将遍历指定交叉验证的次数,并将每次返回训练和测试样本索引,使用该索引可以通过过滤整个数据集来创建训练和测试数据集。

08-24 14:03