我有以下遵循目录结构
data/
train/
Cat 1/ ### 5000 pictures
dog001.jpg
...
cat 2/ ### 3000 pictures
cat001.jpg
Cat 3/ ### 50000 pictures
Unicorn.jpg
...
Cat 4/ ### 10000 pictures
Angels.jpg
我正在使用以下代码加载我的图像
datagen = ImageDataGenerator(rescale=1./255)
# automagically retrieve images and their classes for train and validation sets
train_generator = datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode="categorical")
由于我的数据不是均匀分布的,所以我的模型不太适合,因此会偏向
Cat 3
,那么我如何加载所有四个类别都一致的火车数据呢? 最佳答案
您有两种方法:
从cat3
删除一些数据,以便可以将数据统一改组
将数据添加到其他类
1非常简单,要添加数据,您可以从其他不经常使用的类中复制数据,或者更好的方法是从现有数据生成新数据
通过处理图像,您可以将一行/行设置为空白,可以旋转图像或移动图像,我使用了像这样的方法来实现28x28图像的效果
import numpy as np
from scipy.ndimage.interpolation import rotate, shift
def rand_jitter(temp, prob=0.5):
np.random.seed(1337) # for reproducibility
if np.random.random() > prob:
temp[np.random.randint(0,28,1), :] = 0
if np.random.random() > prob:
temp[:, np.random.randint(0,28,1)] = 0
if np.random.random() > prob:
temp = shift(temp, shift=(np.random.randint(-3,4,2)))
if np.random.random() > prob:
temp = rotate(temp, angle = np.random.randint(-20,21,1), reshape=False)
return temp
这样一来,您可以为您的网络训练更多的数据,并对其进行概括并使其预测最可靠
关于python - 如何统一分配火车,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49964049/