问题描述
正如标题中所说,我正在尝试使用 tensorflow 概率包创建多元正态分布的混合.
As said in the title, I am trying to create a mixture of multivariate normal distributions using tensorflow probability package.
在我的原始项目中,我从神经网络的输出中输入分类、位置和方差的权重.但是,在创建图形时,出现以下错误:
In my original project, am feeding the weights of the categorical, the loc and the variance from the output of a neural network. However when creating the graph, I get the following error:
components[0] 批量形状必须与猫形状和其他组件批量形状兼容
我使用占位符重新创建了同样的问题:
I recreated the same problem using placeholders:
import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions
tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()
l1 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_2')
log_std = tf.compat.v1.get_variable('log_std', [1, 2], dtype=tf.float32,
initializer=tf.constant_initializer(1.0),
trainable=True)
mix = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,1], name='weights')
cat = tfp.distributions.Categorical(probs=[mix, 1.-mix])
components = [
tfp.distributions.MultivariateNormalDiag(loc=l1, scale_diag=tf.exp(log_std)),
tfp.distributions.MultivariateNormalDiag(loc=l2, scale_diag=tf.exp(log_std)),
]
bimix_gauss = tfp.distributions.Mixture(
cat=cat,
components=components)
那么,我的问题是,我做错了什么?我查看了错误,似乎 tensorshape_util.is_compatible_with
是引发错误的原因,但我不明白为什么.
So, my question is, what am I doing wrong? I looked into the error and it seems tensorshape_util.is_compatible_with
is what raises the error but I don't see why.
谢谢!
推荐答案
您似乎向 tfp.distributions.Categorical
提供了错误形状的输入.它的 probs
参数应该是 [batch_size, cat_size]
的形状,而你提供的是 [cat_size, batch_size, 1]
.所以也许可以尝试使用 tf.concat([mix, 1-mix], 1)
参数化 probs
.
It seems you provided a mis-shaped input to tfp.distributions.Categorical
. It's probs
parameter should be of shape [batch_size, cat_size]
while the one you provide is rather [cat_size, batch_size, 1]
. So maybe try to parametrize probs
with tf.concat([mix, 1-mix], 1)
.
您的log_std
也可能存在问题,它与l1
和l2
的形状不同.如果 MultivariateNormalDiag
没有正确广播它,请尝试将其形状指定为 (None, 2)
或平铺它,使其第一维对应于您所在位置的维度参数.
There may also be a problem with yourlog_std
which doesn't have the same shape as l1
and l2
. In case MultivariateNormalDiag
doesn't properly broadcast it, try to specify it's shape as (None, 2)
or to tile it so that it's first dimension corresponds to that of your location parameters.
这篇关于多元高斯分布张量流概率的混合的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!