我的情况是这样的:
我有一个训练 tensorflow 模型的脚本。在此脚本中,我实例化了一个提供培训数据的类。该类的初始化又实例化了另一个称为“图像”的类,以进行各种数据增强操作以及不进行数据增强的操作。
main script -> instantiates data_feed class -> instantiates image class
我的问题是我正在尝试使用tensorflow通过传递 session 本身或图形来在此图像类中执行一些操作。但是我收效甚微。
可行的方法(但太慢了)
我现在所拥有的,但是工作很慢,却是这样的(简化的):
class image(object):
def __init__(self, im):
self.im = im
def augment(self):
aux_im = tf.image.random_saturation(self.im, 0.6)
sess = tf.Session(graph=aux_im.graph)
self.im = sess.run(aux_im)
class data_feed(object):
def __init__(self, data_dir):
self.images = load_data(data_dir)
def process_data(self):
for im in self.images:
image = image(im)
image.augment()
if __name__ == "__main__":
# initialize everything tensorflow related here, including model
sess = tf.Session()
# next load the data
data_feed = data_feed(TRAIN_DATA_DIR)
train_data = data_feed.process_data()
此方法有效,但会为每个图像创建一个新的 session :
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
etc ...
无效的方法(应该已经快了很多)
例如,什么是行不通的,而我不知道为什么,是通过主脚本传递图或 session ,如下所示:
class image(object):
def __init__(self, im):
self.im = im
def augment(self, tf_sess):
with tf_sess.as_default():
aux_im = tf.image.random_saturation(self.im, 0.6)
self.im = tf_sess.run(aux_im)
class data_feed(object):
def __init__(self, data_dir, tf_sess):
self.images = load_data(data_dir)
self.tf_sess = tf_sess
def process_data(self):
for im in self.images:
image = image(im)
image.augment(self.tf_sess)
if __name__ == "__main__":
# initialize everything tensorflow related here, including model
sess = tf.Session()
# next load the data
data_feed = data_feed(TRAIN_DATA_DIR, sess)
train_data = data_feed.process_data()
这是我得到的错误:
Traceback (most recent call last):
File "/usr/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/usr/lib/python2.7/threading.py", line 754, in run
self.__target(*self.__args, **self.__kwargs)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 409, in data_generator_task
generator_output = next(generator)
File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 198, in generate
yield self.next_batch()
File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 192, in next_batch
X, y, l = self.process_image(json_im, X, y, l)
File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 131, in process_image
im.augment_with_tf(self.tf_sess)
File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 85, in augment_with_tf
self.im = sess.run(saturation, {im_placeholder: np.asarray(self.im)})
File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 766, in run
run_metadata_ptr)
File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 921, in _run
+ e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(96, 96, 3), dtype=float32) is not an element of this graph.
任何帮助将非常感激!
最佳答案
如何创建Image
类,而不是使用ImageAugmenter
类,该类在初始化时会进行 session ,然后使用Tensorflow处理图像?您可以执行以下操作:
import tensorflow as tf
import numpy as np
class ImageAugmenter(object):
def __init__(self, sess):
self.sess = sess
self.im_placeholder = tf.placeholder(tf.float32, shape=[1,784,3])
def augment(self, image):
augment_op = tf.image.random_saturation(self.im_placeholder, 0.6, 0.8)
return self.sess.run(augment_op, {self.im_placeholder: image})
class DataFeed(object):
def __init__(self, data_dir, sess):
self.images = load_data(data_dir)
self.augmenter = ImageAugmenter(sess)
def process_data(self):
processed_images = []
for im in self.images:
processed_images.append(self.augmenter.augment(im))
return processed_images
def load_data(data_dir):
# True method would read images from disk
# This is just a mockup
images = []
images.append(np.random.random([1,784,3]))
images.append(np.random.random([1,784,3]))
return images
if __name__ == "__main__":
TRAIN_DATA_DIR = '/some/dir/'
sess = tf.Session()
data_feed = DataFeed(TRAIN_DATA_DIR, sess)
train_data = data_feed.process_data()
print(train_data)
有了这个,您将不会为每个图像创建一个新的 session ,它将为您提供所需的内容。
注意
sess.run()
的调用方式。我传递给它的提要dict的关键是上面定义的占位符张量。根据您的错误跟踪,您可能正在尝试从未定义sess.run()
或已将其定义为im_placeholder
之外的代码的一部分调用tf.placeholder
。另外,您可以通过更改
ImageAugmenter.augment()
方法以接收上下参数作为tf.image.random_saturation()
方法的输入来进一步改进代码,或者可以使用特定形状初始化ImageAugmenter
而不是对其进行硬编码。关于python - Tensorflow:在方法中使用 session /图形,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/42438170/