我正在Keras中使用imageDataGenerator(),我想获取整个测试数据的标签。
目前,我正在使用以下代码来完成此任务:

test_batches = ImageDataGenerator().flow_from_directory(...)

test_labels = []

for i in range(0,3):
    test_labels.extend(np.array(test_batches[i][1]))

不过,这段代码只起作用,因为我知道我总共有150个图像,而我的批大小被定义为50。
此外,使用:
imgs, labels = next(test_batches)

正如在这个主题的类似文章中所建议的,只返回一个批次的标签,而不是整个数据集的标签因此,我想知道是否有比我上面使用的方法更有效的方法来做这件事。

最佳答案

当您知道batch_size时,您可以从flow_from_directory对象获取图像的数量:

test_batches = ImageDataGenerator().flow_from_directory(.., batch_size=n)
number_of_examples = len(test_batches.filenames)
number_of_generator_calls = math.ceil(number_of_examples / (1.0 * n))
# 1.0 above is to skip integer division

test_labels = []

for i in range(0,int(number_of_generator_calls)):
    test_labels.extend(np.array(test_batches[i][1]))

08-20 01:34