我需要遍历大量图像文件并将数据提供给 tensorflow.我通过生成器函数创建了一个 Dataset
,该函数将文件路径名生成为字符串,然后使用 map
I need to iterate through large number of image files and feed the data to tensorflow. I created a Dataset
back by a generator function that produces the file path names as strings and then transform the string path to image data using map
. But it failed as generating string values won't work, as shown below. Is there a fix or work around for this?
2017-12-07 15:29:05.820708: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
producing data/miniImagenet/val/n01855672/n0185567200001000.jpg
2017-12-07 15:29:06.009141: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
2017-12-07 15:29:06.009215: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
Traceback (most recent call last):
File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
return fn(*args)
File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
status, run_metadata)
File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type str
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,21168]], output_types=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
测试代码如下所示.它可以与 from_tensor_slices
一起正常工作,或者首先将文件名列表放入张量中.但是,任何一种解决方法都会耗尽 GPU 内存.
The test codes are shown below. It can work correctly with from_tensor_slices
or by first putting the the file name list in a tensor. however, either work around would exhaust GPU memory.
import tensorflow as tf
if __name__ == "__main__":
file_names = ['data/miniImagenet/val/n01855672/n0185567200001000.jpg',
# note: converting the file list to tensor and returning an index from generator works
# path_to_indexes = {p: i for i, p in enumerate(file_names)}
# file_names_tensor = tf.convert_to_tensor(file_names)
def dataset_producer():
for s in file_names:
print('producing', s)
yield s
dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
# note: this would also work
# dataset = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(file_names))
def read_image(filename):
# filename = file_names_tensor[filename_index]
image_file = tf.read_file(filename, name='read_file')
image = tf.image.decode_jpeg(image_file, channels=3)
image = tf.reshape(image, [21168])
image = tf.cast(image, tf.float32) / 255.0
return image
dataset = dataset.map(read_image)
dataset = dataset.batch(2)
data_iterator = dataset.make_one_shot_iterator()
images = data_iterator.get_next()
print('images', images)
max_value = tf.argmax(images)
with tf.Session() as session:
result = session.run(max_value)
这是一个影响 Python 3.x 的错误,它是 已修复 TensorFlow 1.4 版本后.TensorFlow 1.5 及以后的所有版本都包含此修复程序.
This is a bug affecting Python 3.x that was fixed after the TensorFlow 1.4 release. All releases of TensorFlow from 1.5 onwards contain the fix.
如果您只是使用早期版本,解决方法是在从生成器返回字符串之前将字符串转换为 bytes
If you just use an earlier version, the workaround is to convert the strings to bytes
before returning them from the generator. The following code should work:
def dataset_producer():
for s in file_names:
print('producing', s)
yield s.encode('utf-8') # Convert `s` to `bytes`.
dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
这篇关于TensorFlow:`tf.data.Dataset.from_generator()` 不适用于 Python 3.x 上的字符串的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!