我有连续数据..
X的输入暗淡都相同
X的序列长度在X中不同
我正在使用LSTM,因此我想为每个x
数据(x1
和x2
)调用reset_states。x1
和x2
是独立的数据,因此在x2
之后测试x1
时,我必须重置LSTM的历史记录。
我的代码在这里。我应该使用stateful
选项吗?
# input dimension is two
# but data length is differenct between x1 and y1
x1 = [[1,2],[3,3],[2,1],[2,4]] # x1 length == 4
y1 = [2,3,2,1]
x2 = [[3,2], [2,1]] # x2 length == 2
y2 = [2,4]
input_dim = 2
max_len = 4 # max(len(x1), len(x2)
max_y = 4 # y -> (1,2,3,4)
trainX = [x1, x2]
trainY = [y1, y2]
m = Sequential()
m.add(LSTM(128,
input_shape=(max_len, input_dim),
activation='tanh',
return_sequences=True))
m.add(TimeDistributed(Dense(max_y, activation='softmax')))
m.compile(...)
m.fit(trainX, trainY, nb_epoch=10)
已编辑
我找到了一个有状态的LSTM示例。但是它在每个时期都调用reset_states()。我想做的是调用每个
x
。https://github.com/fchollet/keras/blob/aff40d800891799dc9ed765617fcbfa665349d0d/examples/stateful_lstm.py
最佳答案
您引用的链接将fit
函数与epochs = 1
一起使用。我认为您可以在train_on_batch()
回调中使用fit()
或使用on_batch_end()
函数。这样,您可以在每个x
之后重置状态(通过设置适当的批处理大小)。
关于python - 在keras训练期间如何调用reset_states()?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/42826456/