我使用tf.keras构建了一个完全连接的ANN“ my_model”。然后,我试图使用来自TensorFlow的Adam优化器来最小化函数f(x) = my_model.predict(x) - 0.5 + g(x)
。我尝试了以下代码:
x = tf.get_variable('x', initializer = np.array([1.5, 2.6]))
f = my_model.predict(x) - 0.5 + g(x)
optimizer = tf.train.AdamOptimizer(learning_rate=.001).minimize(f)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(50):
print(sess.run([x,f]))
sess.run(optimizer)
但是,执行
my_model.predict(x)
时出现以下错误:如果您的数据采用符号张量形式,则应指定
steps
参数(而不是batch_size
参数)我知道错误是什么,但是我无法弄清楚在符号张量存在的情况下如何使
my_model.predict(x)
工作。如果从功能my_model.predict(x)
中删除了f(x)
,则代码将正常运行。我检查了以下link,link,其中使用TensorFlow优化器来最小化任意函数,但我认为我的问题是使用底层keras的
model.predict()
函数。感谢您的帮助。提前致谢! 最佳答案
我找到了答案!
基本上,我正在尝试优化一个函数,该函数涉及经过训练的ANN,而没有输入到ANN的输入变量。因此,我只想知道如何调用my_model
并将其放在f(x)
中。在这里深入Keras文档:https://keras.io/getting-started/functional-api-guide/,我发现所有Keras模型都可以调用,就像模型的层一样!引用链接中的信息,
..您可以将任何模型都视为图层,方法是在
张量。请注意,通过调用模型,您不仅可以重复使用
模型的体系结构,您也在重用它的权重。
同时,model.predict(x)
部分期望x
是numpy数组或评估张量,并且不将tensorflow变量作为输入(https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict)。
因此,以下代码起作用了:
## initializations
sess = tf.InteractiveSession()
x_init_value = np.array([1.5, 2.6])
x_placeholder = tf.placeholder(tf.float32)
x_var = tf.Variable(x_init_value, dtype=tf.float32)
# Check calling my_model
assign_step = tf.assign(x_var, x_placeholder)
sess.run(assign_step, feed_dict={x_placeholder: x_init_value})
model_output = my_model(x_var) # This simple step is all I wanted!
sess.run(model_output) # This outputs my_model's predicted value for input x_init_value
# Now, define the objective function that has to be minimized
f = my_model(x_var) - 0.5 + g(x_var) # g(x_var) is some function of x_var
# Define the optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=.001).minimize(f)
# Run the optimization steps
for i in range(50): # for 50 steps
_,loss = optimizer.minimize(f, var_list=[x_var])
print("step: ", i+1, ", loss: ", loss, ", X: ", x_var.eval()))
关于python - 使用TensorFlow优化器优化涉及tf.keras的“model.predict()”的函数?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/51515253/