建议使用tensorflow数据集作为输入管道,可以如下设置:
# Specify dataset
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# Suffle
dataset = dataset.shuffle(buffer_size=1e5)
# Specify batch size
dataset = dataset.batch(128)
# Create an iterator
iterator = dataset.make_one_shot_iterator()
# Get next batch
next_batch = iterator.get_next()
我应该能够获得批处理大小(无论是从数据集本身还是从它创建的迭代器,即
iterator
和next_batch
)。也许有人想知道数据集或其迭代器中有多少批。还是调用了多少个批处理,并且迭代器中还剩下多少个批处理?可能还需要一次获取特定元素,甚至是整个数据集。我无法在tensorflow文档中找到任何东西。这可能吗?如果没有,有人知道这是否在Tensorflow GitHub上被要求作为问题吗?
最佳答案
试试这个
import tensorflow as tf
import numpy as np
features=np.array([[3.0, 0.0], [1.0, 2.0], [0.0, 0.0]], dtype="float32")
labels=np.array([[0], [0], [1]], dtype="float32")
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
batch_size = 2
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch_data = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer)
print(np.shape(sess.run(batch_data)[0])[0])
你会看到
关于tensorflow - 如何从 tensorflow 数据集中获取批量大小?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49912441/