我想看看我是否能用tensorflow而不是pymc3解决this问题。实验的想法是,我要定义一个包含开关点的问题系统。我可以用抽样作为推断的方法,但我开始想为什么我不能用梯度下降来代替。
我决定在tensorflow中进行梯度搜索,但是当涉及到tf.where时,tensorflow似乎很难执行梯度搜索。
你可以在下面找到代码。

import tensorflow as tf
import numpy as np

x1 = np.random.randn(50)+1
x2 = np.random.randn(50)*2 + 5
x_all = np.hstack([x1, x2])
len_x = len(x_all)
time_all = np.arange(1, len_x + 1)

mu1 = tf.Variable(0, name="mu1", dtype=tf.float32)
mu2 = tf.Variable(5, name = "mu2", dtype=tf.float32)
sigma1 = tf.Variable(2, name = "sigma1", dtype=tf.float32)
sigma2 = tf.Variable(2, name = "sigma2", dtype=tf.float32)
tau = tf.Variable(10, name = "tau", dtype=tf.float32)

mu = tf.where(time_all < tau,
              tf.ones(shape=(len_x,), dtype=tf.float32) * mu1,
              tf.ones(shape=(len_x,), dtype=tf.float32) * mu2)
sigma = tf.where(time_all < tau,
              tf.ones(shape=(len_x,), dtype=tf.float32) * sigma1,
              tf.ones(shape=(len_x,), dtype=tf.float32) * sigma2)

likelihood_arr = tf.log(tf.sqrt(1/(2*np.pi*tf.pow(sigma, 2)))) -tf.pow(x_all - mu, 2)/(2*tf.pow(sigma, 2))
total_likelihood = tf.reduce_sum(likelihood_arr, name="total_likelihood")

optimizer = tf.train.RMSPropOptimizer(0.01)
opt_task = optimizer.minimize(-total_likelihood)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print("these variables should be trainable: {}".format([_.name for _ in tf.trainable_variables()]))
    for step in range(10000):
        _lik, _ = sess.run([total_likelihood, opt_task])
        if step % 1000 == 0:
            variables = {_.name:_.eval() for _ in [mu1, mu2, sigma1, sigma2, tau]}
            print("step: {}, values: {}".format(str(step).zfill(4), variables))

你会注意到tau参数并没有改变,即使tensorflow似乎知道这个变量和它的梯度。有什么不对劲的线索吗?这是可以用tensorflow计算的,还是需要一个不同的模式?

最佳答案

tau仅用于condition参数中的where:(tf.where(time_all < tau, ...),它是布尔张量。由于计算梯度只对连续值有意义,因此相对于tau的输出梯度将为零。
即使忽略tf.where,您也在表达式tau中使用了time_all < tau,它几乎处处都是常数,因此梯度为零。
由于梯度为零,因此无法使用梯度下降法学习tau
根据您的问题,您可以使用加权和来代替p*val1 + (1-p)*val2,而不是在两个值之间进行硬切换,其中p以连续方式依赖于tau

关于python - tf.where使 tensorflow 中的优化器失败,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/43544007/

10-10 17:16