我正在尝试扩充MNIST数据集。这就是我尝试过的。无法获得任何成功。
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
X = mnist.train.images
y = mnist.train.labels
def flip_images(X_imgs):
X_flip = []
tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (28, 28, 1))
input_d = tf.reshape(X_imgs, [-1, 28, 28, 1])
tf_img1 = tf.image.flip_left_right(X)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for img in input_d:
flipped_imgs = sess.run([tf_img1], feed_dict = {X: img})
X_flip.extend(flipped_imgs)
X_flip = np.array(X_flip, dtype = np.float32)
return X_flip
flip = flip_images(X)
我究竟做错了什么?我似乎不知道。
错误:
Line: for img in input_d:
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable
最佳答案
首先,请注意您的tf.reshape将类型从ndarray更改为张量。将需要一个.eval()调用来将其恢复原状。在该for循环中,您尝试遍历张量(而不是列表或真正的可迭代),请考虑按数字索引,如下所示:
X = mnist.train.images
y = mnist.train.labels
def flip_images(X_imgs):
X_flip = []
tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (28, 28, 1))
input_d = tf.reshape(X_imgs, [-1, 28, 28, 1])
tf_img1 = tf.image.flip_left_right(X)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for img_ind in range(input_d.shape[0]):
img = input_d[img_ind].eval()
flipped_imgs = sess.run([tf_img1], feed_dict={X: img})
X_flip.extend(flipped_imgs)
X_flip = np.array(X_flip, dtype = np.float32)
return X_flip
flip = flip_images(X)
让我知道这是否可以解决您的问题!可能希望将范围设置为较小的常数以进行测试,如果您周围没有GPU,则可能需要一段时间。