我正在尝试使用Tensorflow实现一些自定义GRU单元。我需要堆叠这些单元格,我想从tensorflow.keras.layers.GRU继承。但是,在查看源代码时,我注意到您只能将units参数传递给__init__GRU,而RNN具有作为RNNcell列表的参数,并利用它堆叠那些调用StackedRNNCells的单元格。同时,GRU仅创建一个GRUCell

对于我要实施的论文,我实际上需要堆叠GRUCell。为什么RNNGRU的实现不同?

最佳答案

在搜索这些类的文档以添加链接时,我注意到可能会绊倒您:TensorFlow中有两个(当前是在正式TF 2.0发行版之前)!有一个GRUCelltf.nn.rnn_cell.GRUCell。似乎已弃用tf.keras.layers.GRUCell中的那个,而Keras之一是您应该使用的那个。

据我所知,tf.nn.rnn_cell具有与GRUCell__call__()相同的tf.keras.layers.LSTMCell方法签名,并且它们都继承自tf.keras.layers.SimpleRNNCellLayer文档对传递给其RNN参数的对象的__call__()方法必须执行的操作提出了一些要求,但是我猜测这三个对象都应满足这些要求。您应该能够只使用相同的cell框架,并向其传递RNN对象列表,而不是GRUCellLSTMCell

我现在无法对此进行测试,因此不确定是否将SimpleRNNCell对象列表或仅GRUCell对象传递给GRU,但是我认为其中一个应该起作用。

07-26 02:08