我正在与Bach chorales dataset合作。每个合唱的长度约为100-500个时间步长,每个时间步长包含4个整数(例如:[74、70、65、58]),其中每个整数对应于钢琴上的音符索引。
我正在尝试训练可以预测下一个步骤的模型(4
注意),并从合唱开始按时间顺序排列。
问题是什么:对于与模型训练相同大小的输入,我得到正确的输出,但对于不同大小的输入却得到错误的输出。
到目前为止,我已经完成了什么:我使用了Keras的TimeseriesGenerator来生成输入和相应输出的序列:
generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
print(generator[0])
输出:
(array([[[74, 70, 65, 58],
[74, 70, 65, 58],
[74, 70, 65, 58]]]), array([[75, 70, 58, 55]]))
然后,我训练了一个LSTM模型。我在input_shape中使用
None
来允许大小可变的输入。n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')
# fit model
model.fit_generator(generator, epochs=500, validation_data=validation_generator)
我预测大小为3的输入似乎有效(因为它是针对长度为3的输入进行训练的):
# demonstrate prediction
x_input = dataX[5:8]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[8])
[[[75 70 58 55]
[75 70 60 55]
[75 70 60 55]]]
[[76.25768 68.525444 59.745518 53.799873]]
expected: [77 69 62 50]
现在,我尝试预测输入长度为5的不同大小的输入,这不起作用。
测试样品的输出:
# demonstrate prediction
x_input = dataX[1:6]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[6])
[[[74 70 65 58]
[74 70 65 58]
[74 70 65 58]
[75 70 58 55]
[75 70 58 55]]]
[[227.16667 217.89767 213.62988 148.44817]]
expected: [75 70 60 55]
该预测是完全错误的,似乎正在做一些总结。任何关于为什么会发生这种情况以及如何解决它的输入/帮助都将受到高度赞赏。
最佳答案
我可以为您提供三个模型无法学习的可能原因。
最后一层model.add(Dense(n_features))
这可能是您模型中的主要罪魁祸首(但我建议全部解决)。分类模型的最后一层需要是softmax
层。所以只需将其更改为model.add(Dense(n_features, activation='softmax`))
损失函数
通常,对于分类问题,crossentropy
比mse
更有效。所以尝试一下model.compile(optimizer='adam', loss='categorical_crossentropy')
在LSTM中激活
LSTM使用tanh
作为激活。除非您有充分的理由将其更改为relu
,否则不要这样做,因为当激活函数发生变化时,LSTM不会输出与常规前馈层相同的行为。
关于python - 错误的LSTM时间序列预测的输入大小与训练的输入大小不同,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/59601739/