我对Sklearn的类StratifiedShuffleSplit
如何工作感到困惑。
下面的代码来自Géron的书“ Hands On Machine Learning”(第2章),在那里他进行了分层抽样。
from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
strat_train_set = housing.loc[train_index]
strat_test_set = housing.loc[test_index]
特别是
split.split
在做什么?谢谢!
最佳答案
split.split()函数返回火车样本和测试样本的索引。它将遍历指定交叉验证的次数,并将每次返回训练和测试样本索引,使用该索引可以通过过滤整个数据集来创建训练和测试数据集。