我想问一下是否可以在scikit learn中执行“Startified GroupShuffleSplit”,换句话说,它是GroupShuffleSplit和StratifiedShuffleSplit的组合
下面是我正在使用的代码示例:
cv=GroupShuffleSplit(n_splits=n_splits,test_size=test_size,\
train_size=train_size,random_state=random_state).split(\
allr_sets_nor[:,:2],allr_labels,groups=allr_groups)
opt=GridSearchCV(SVC(decision_function_shape=dfs,tol=tol),\
param_grid=param_grid,scoring=scoring,n_jobs=n_jobs,cv=cv,verbose=verbose)
opt.fit(allr_sets_nor[:,:2],allr_labels)
在这里我应用了
GroupShuffleSplit
但是我仍然想根据allr_labels
添加startification 最佳答案
我解决这个问题的方法是:对组应用StratifiedShuffleSplit,然后手动查找训练集和测试集索引,因为它们与组索引相关联(在我的例子中,每个组包含从6*index
到6*index+5
的6个连续集)
如下所示:
sss=StratifiedShuffleSplit(n_splits=n_splits,test_size=test_size,
train_size=train_size,random_state=random_state).split(all_groups,all_labels)
# startified splitting for groups only
i=0
train_is = [np.array([],dtype=int)]*n_splits
test_is = [np.array([],dtype=int)]*n_splits
for train_index,test_index in sss :
# finding the corresponding indices of reflected training and testing sets
train_is[i]=np.hstack((train_is[i],np.concatenate([train_index*6+i for i in range(6)])))
test_is[i]=np.hstack((test_is[i],np.concatenate([test_index*6+i for i in range(6)])))
i=i+1
cv=[(train_is[i],test_is[i]) for i in range(n_splits)]
# constructing the final cross-validation iterable: list of 'n_splits' tuples;
# each tuple contains two numpy arrays for training and testing indices respectively
opt=GridSearchCV(SVC(decision_function_shape=dfs,tol=tol),param_grid=param_grid,
scoring=scoring,n_jobs=n_jobs,cv=cv,verbose=verbose)
opt.fit(allr_sets_nor[:,:2],allr_labels)