我创建了一个DataGenerator(Sequence)
类:
class DataGenerator(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_x = np.array(imread(file_name) for file_name in batch_x)
batch_x = batch_x * 1./255
batch_y = self.y[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_y = np.array(batch_y)
return batch_x, batch_y
此DataGenerator
应该从两个列表中获取批量数据。 x_set
是图像的文件路径的列表。 y_set
是此图像数据的相应标签的列表。 batch_x
是一批x_set
,可读取并除以255。batch_y
是相应的y_set
批次。之后,我使用此生成器拟合模型:
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
steps_per_epoch = num_train_samples // 128,
validation_steps = num_val_samples // 128,
epochs = 5)
并得到此错误:---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-77-e56bae3d1c4b> in <module>()
3 steps_per_epoch = num_train_samples // 128,
4 validation_steps = num_val_samples // 128,
----> 5 epochs = 5)
8 frames
<ipython-input-75-6e4037882cc3> in __getitem__(self, idx)
11 batch_x = self.x[idx*self.batch_size : (idx + 1)*self.batch_size]
12 batch_x = np.array(imread(file_name) for file_name in batch_x)
---> 13 batch_x = batch_x * 1./255
14 batch_y = self.y[idx*self.batch_size : (idx + 1)*self.batch_size]
15 batch_y = np.array(batch_y)
TypeError: unsupported operand type(s) for *: 'generator' and 'float'
我该如何修改batch_x = batch_x * 1./255
行?谢谢!
最佳答案
试试这个
def __getitem__(self, idx):
batch_x = self.x[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_x = [imread(file_name) for file_name in batch_x]
batch_x = np.array(batch_x)
batch_x = batch_x * 1./255
batch_y = self.y[idx*self.batch_size : (idx + 1)*self.batch_size]
batch_y = np.array(batch_y)
return batch_x, batch_y
关于python - DataGenerator TypeError : unsupported operand type(s) for *: 'generator' and 'float' ,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/62712080/