我有以下遵循目录结构

data/
    train/
        Cat 1/ ### 5000 pictures
            dog001.jpg

            ...
        cat 2/ ### 3000 pictures
            cat001.jpg

       Cat 3/ ### 50000 pictures
            Unicorn.jpg

            ...
        Cat 4/ ### 10000 pictures
            Angels.jpg


我正在使用以下代码加载我的图像

datagen = ImageDataGenerator(rescale=1./255)

# automagically retrieve images and their classes for train and validation sets
train_generator = datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_width, img_height),
        batch_size=batch_size,
        class_mode="categorical")


由于我的数据不是均匀分布的,所以我的模型不太适合,因此会偏向Cat 3,那么我如何加载所有四个类别都一致的火车数据呢?

最佳答案

您有两种方法:


cat3删除一些数据,以便可以将数据统一改组
将数据添加到其他类


1非常简单,要添加数据,您可以从其他不经常使用的类中复制数据,或者更好的方法是从现有数据生成新数据

通过处理图像,您可以将一行/行设置为空白,可以旋转图像或移动图像,我使用了像这样的方法来实现28x28图像的效果

import numpy as np
from scipy.ndimage.interpolation import rotate, shift

def rand_jitter(temp, prob=0.5):
    np.random.seed(1337)  # for reproducibility
    if np.random.random() > prob:
        temp[np.random.randint(0,28,1), :] = 0
    if np.random.random() > prob:
        temp[:, np.random.randint(0,28,1)] = 0
    if np.random.random() > prob:
        temp = shift(temp, shift=(np.random.randint(-3,4,2)))
    if np.random.random() > prob:
        temp = rotate(temp, angle = np.random.randint(-20,21,1), reshape=False)
    return temp


这样一来,您可以为您的网络训练更多的数据,并对其进行概括并使其预测最可靠

关于python - 如何统一分配火车,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49964049/

10-16 03:14