在使用keras的机器学习教程中,训练机器学习模型的代码就是这种典型的一类代码。

model.fit(X_train,
          Y_train,
          nb_epoch=5,
          batch_size = 128,
          verbose=1,
          validation_split=0.1)


当训练数据X_trainY_train很小时,这似乎很容易。 X_trainY_train是numpy ndarrays。在实际情况下,训练数据可能会达到千兆字节,这可能太大,甚至无法安装到计算机的RAM中。

当训练数据太大时,如何将数据发送到model.fit()

最佳答案

在Keras中有一个简单的解决方案。您可以简单地使用python生成器,将数据延迟加载。如果您有图像,也可以使用ImageDataGenerator。

def generate_data(x, y, batch_size):
    while True:
        batch = []
        for b in range(batch_size):
           batch.append(myDataSlice)

        yield np.array(batch )

model.fit_generator(
generator=generate_data(x, y, batch_size),
steps_per_epoch=num_batches,
validation_data=list_batch_generator(x_val, y_val, batch_size),
validation_steps=num_batches_test)

07-24 09:52
查看更多