我正在参加APTOS 2019 kaggle比赛,并试图进行5折合奏,但是我在正确实施StratifiedKFold时遇到问题。

我已经尝试过搜索Fastai讨论,但没有任何解决方案。
我正在使用fastai库,并且有一个预先训练的模型。

def get_df():
    base_image_dir = os.path.join('..', 'input/aptos2019-blindness-
    detection/')
    train_dir = os.path.join(base_image_dir,'train_images/')
    df = pd.read_csv(os.path.join(base_image_dir, 'train.csv'))
    df['path'] = df['id_code'].map(lambda x:
    os.path.join(train_dir,'{}.png'.format(x)))
    df = df.drop(columns=['id_code'])
    df = df.sample(frac=1).reset_index(drop=True) #shuffle dataframe
    test_df = pd.read_csv('../input/aptos2019-blindness-
    detection/sample_submission.csv')
    return df, test_df

df, test_df = get_df()

random_state = np.random.seed(2019)
skf = StratifiedKFold(n_splits=5, random_state=random_state, shuffle=True)

X = df['path']
y = df['diagnosis']

#getting the splits
for train_index, test_index in skf.split(X, y):
   print('##')
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]
   train = X_train, y_train
   test = X_test, y_test
   train_list = [list(x) for x in train]
   test_list  = [list(x) for x in test]


data = (ImageList.from_df(df=df,path='./',cols='path')
    .split_by_rand_pct(0.2)
    .label_from_df(cols='diagnosis',label_cls=FloatList)
    .transform(tfms,size=sz,resize_method=ResizeMethod.SQUISH,padding_mode='zeros')
    .databunch(bs=bs,num_workers=4)
    .normalize(imagenet_stats)
   )

learn = Learner(data,
            md_ef,
            metrics = [qk],
            model_dir="models").to_fp16()
learn.data.add_test(ImageList.from_df(test_df,
                             '../input/aptos2019-blindness-detection',
                                  folder='test_images',
                                  suffix='.png'))


我想使用从skf.split获得的折叠来训练我的模型,但是我不确定该怎么做。

最佳答案

有两种方法可以做到这一点。


将'split_by_idxs'与索引一起使用


    data = (ImageList.from_df(df=df,path='./',cols='path')
        .split_by_idxs(train_idx=train_index, valid_idx=test_index)
        .label_from_df(cols='diagnosis',label_cls=FloatList)
        .transform(tfms,size=sz,resize_method=ResizeMethod.SQUISH,padding_mode='zeros')
        .databunch(bs=bs,num_workers=4)
        .normalize(imagenet_stats)
       )



使用“ split_by_list”


   il = ImageList.from_df(df=df,path='./',cols='path')

   data = (il.split_by_list(train=il[train_index], valid=il[test_index])
       .label_from_df(cols='diagnosis',label_cls=FloatList)
       .transform(tfms,size=sz,resize_method=ResizeMethod.SQUISH,padding_mode='zeros')
       .databunch(bs=bs,num_workers=4)
       .normalize(imagenet_stats)
      )

09-27 08:58