建议使用tensorflow数据集作为输入管道,可以如下设置:

# Specify dataset
dataset  = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset  = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset  = dataset.batch(128)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Get next batch
next_batch = iterator.get_next()

我应该能够获得批处理大小(无论是从数据集本身还是从它创建的迭代器,即iteratornext_batch)。也许有人想知道数据集或其迭代器中有多少批。还是调用了多少个批处理,并且迭代器中还剩下多少个批处理?可能还需要一次获取特定元素,甚至是整个数据集。

我无法在tensorflow文档中找到任何东西。这可能吗?如果没有,有人知道这是否在Tensorflow GitHub上被要求作为问题吗?

最佳答案

试试这个

import tensorflow as tf
import numpy as np

features=np.array([[3.0, 0.0], [1.0, 2.0], [0.0, 0.0]], dtype="float32")
labels=np.array([[0], [0], [1]], dtype="float32")
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

batch_size = 2
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(np.shape(sess.run(batch_data)[0])[0])

你会看到
tensorflow - 如何从 tensorflow 数据集中获取批量大小?-LMLPHP

关于tensorflow - 如何从 tensorflow 数据集中获取批量大小?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49912441/

10-14 17:54
查看更多