我正在尝试使用Tensorflow实现一些自定义GRU单元。我需要堆叠这些单元格,我想从tensorflow.keras.layers.GRU
继承。但是,在查看源代码时,我注意到您只能将units
参数传递给__init__
的GRU
,而RNN
具有作为RNNcell
列表的参数,并利用它堆叠那些调用StackedRNNCells
的单元格。同时,GRU
仅创建一个GRUCell
。
对于我要实施的论文,我实际上需要堆叠GRUCell
。为什么RNN
和GRU
的实现不同?
最佳答案
在搜索这些类的文档以添加链接时,我注意到可能会绊倒您:TensorFlow中有两个(当前是在正式TF 2.0发行版之前)!有一个GRUCell
和tf.nn.rnn_cell.GRUCell
。似乎已弃用tf.keras.layers.GRUCell
中的那个,而Keras之一是您应该使用的那个。
据我所知,tf.nn.rnn_cell
具有与GRUCell
和__call__()
相同的tf.keras.layers.LSTMCell
方法签名,并且它们都继承自tf.keras.layers.SimpleRNNCell
。 Layer
文档对传递给其RNN
参数的对象的__call__()
方法必须执行的操作提出了一些要求,但是我猜测这三个对象都应满足这些要求。您应该能够只使用相同的cell
框架,并向其传递RNN
对象列表,而不是GRUCell
或LSTMCell
。
我现在无法对此进行测试,因此不确定是否将SimpleRNNCell
对象列表或仅GRUCell
对象传递给GRU
,但是我认为其中一个应该起作用。