问题描述
我需要遍历大量图像文件并将数据提供给 tensorflow.我通过生成器函数创建了一个 Dataset
,该函数将文件路径名生成为字符串,然后使用 map
将字符串路径转换为图像数据.但它失败了,因为生成字符串值不起作用,如下所示.是否有解决方法或解决方法?
I need to iterate through large number of image files and feed the data to tensorflow. I created a Dataset
back by a generator function that produces the file path names as strings and then transform the string path to image data using map
. But it failed as generating string values won't work, as shown below. Is there a fix or work around for this?
2017-12-07 15:29:05.820708: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
producing data/miniImagenet/val/n01855672/n0185567200001000.jpg
2017-12-07 15:29:06.009141: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
2017-12-07 15:29:06.009215: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
Traceback (most recent call last):
File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
return fn(*args)
File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
status, run_metadata)
File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type str
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,21168]], output_types=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
测试代码如下所示.它可以与 from_tensor_slices
一起正常工作,或者首先将文件名列表放入张量中.但是,任何一种解决方法都会耗尽 GPU 内存.
The test codes are shown below. It can work correctly with from_tensor_slices
or by first putting the the file name list in a tensor. however, either work around would exhaust GPU memory.
import tensorflow as tf
if __name__ == "__main__":
file_names = ['data/miniImagenet/val/n01855672/n0185567200001000.jpg',
'data/miniImagenet/val/n01855672/n0185567200001005.jpg']
# note: converting the file list to tensor and returning an index from generator works
# path_to_indexes = {p: i for i, p in enumerate(file_names)}
# file_names_tensor = tf.convert_to_tensor(file_names)
def dataset_producer():
for s in file_names:
print('producing', s)
yield s
dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
output_shapes=(tf.TensorShape([])))
# note: this would also work
# dataset = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(file_names))
def read_image(filename):
# filename = file_names_tensor[filename_index]
image_file = tf.read_file(filename, name='read_file')
image = tf.image.decode_jpeg(image_file, channels=3)
image.set_shape((84,84,3))
image = tf.reshape(image, [21168])
image = tf.cast(image, tf.float32) / 255.0
return image
dataset = dataset.map(read_image)
dataset = dataset.batch(2)
data_iterator = dataset.make_one_shot_iterator()
images = data_iterator.get_next()
print('images', images)
max_value = tf.argmax(images)
with tf.Session() as session:
result = session.run(max_value)
print(result)
推荐答案
这是一个影响 Python 3.x 的错误,它是 已修复 TensorFlow 1.4 版本后.TensorFlow 1.5 及以后的所有版本都包含此修复程序.
This is a bug affecting Python 3.x that was fixed after the TensorFlow 1.4 release. All releases of TensorFlow from 1.5 onwards contain the fix.
如果您只是使用早期版本,解决方法是在从生成器返回字符串之前将字符串转换为 bytes
.以下代码应该可以工作:
If you just use an earlier version, the workaround is to convert the strings to bytes
before returning them from the generator. The following code should work:
def dataset_producer():
for s in file_names:
print('producing', s)
yield s.encode('utf-8') # Convert `s` to `bytes`.
dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
output_shapes=(tf.TensorShape([])))
这篇关于TensorFlow:`tf.data.Dataset.from_generator()` 不适用于 Python 3.x 上的字符串的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!