我正在Keras中使用imageDataGenerator(),我想获取整个测试数据的标签。
目前,我正在使用以下代码来完成此任务:
test_batches = ImageDataGenerator().flow_from_directory(...)
test_labels = []
for i in range(0,3):
test_labels.extend(np.array(test_batches[i][1]))
不过,这段代码只起作用,因为我知道我总共有150个图像,而我的批大小被定义为50。
此外,使用:
imgs, labels = next(test_batches)
正如在这个主题的类似文章中所建议的,只返回一个批次的标签,而不是整个数据集的标签因此,我想知道是否有比我上面使用的方法更有效的方法来做这件事。
最佳答案
当您知道batch_size
时,您可以从flow_from_directory
对象获取图像的数量:
test_batches = ImageDataGenerator().flow_from_directory(.., batch_size=n)
number_of_examples = len(test_batches.filenames)
number_of_generator_calls = math.ceil(number_of_examples / (1.0 * n))
# 1.0 above is to skip integer division
test_labels = []
for i in range(0,int(number_of_generator_calls)):
test_labels.extend(np.array(test_batches[i][1]))