我的输入看起来像这样:
[
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
...]
形状为
(1, num_samples, num_features)
,标签如下所示:[
[0, 1]
[1, 0]
[1, 0]
...]
形状为
(1, num_samples, 2)
。但是,当我尝试运行以下Keras代码时,出现此错误:
ValueError: Error when checking model target: expected dense_1 to have 2 dimensions, but got array with shape (1, 8038, 2)
。从我阅读的内容来看,这似乎源于我的标签是2D而不是简单的整数的事实。这是正确的吗?如果是的话,如何在Keras中使用一键式标签?这是代码:
num_features = 463
trX = np.random(8038, num_features)
trY = # one-hot array of shape (8038, 2) as described above
def keras_builder(): #generator to build the inputs
while(1):
x = np.reshape(trX, (1,) + np.shape(trX))
y = np.reshape(trY, (1,) + np.shape(trY))
print(np.shape(x)) # (1, 8038, 463)
print(np.shape(y)) # (1, 8038, 2)
yield x, y
model = Sequential()
model.add(LSTM(100, input_dim = num_features))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit_generator(keras_builder(), samples_per_epoch = 1, nb_epoch=3, verbose = 2, nb_worker = 1)
会立即引发上面的错误:
Traceback (most recent call last):
File "file.py", line 35, in <module>
model.fit_generator(keras_builder(), samples_per_epoch = 1, nb_epoch=3, verbose = 2, nb_worker = 1)
...
ValueError: Error when checking model target: expected dense_1 to have 2 dimensions, but got array with shape (1, 8038, 2)
谢谢!
最佳答案
很多事情没有加在一起。
我假设您正在尝试解决顺序分类任务,即您的数据形状为(<batch size>, <sequence length>, <feature length>)
。
在批处理生成器中,您将创建一个批处理,该批处理由每个序列元素的一个长度为8038和463个特征的序列组成。您将创建一个匹配的Y批次进行比较,该批次由一个包含8038个元素的序列组成,每个序列的大小为2。
您的问题是Y
与最后一层的输出不匹配。您的Y
是3维的,而模型的输出仅为2维:Y.shape = (1, 8038, 2)
与dense_1.shape = (1,1)
不匹配。这说明了您收到的错误消息。
解决方案:您需要在LSTM层中启用return_sequences=True
以返回序列,而不是仅返回最后一个元素(有效地删除时维)。这将在LSTM层给出(1, 8038, 100)
的输出形状。由于Dense
层无法处理顺序数据,因此需要将其分别应用于每个序列元素,方法是将其包装在 TimeDistributed
包装器中。然后,这会为您的模型提供输出形状(1, 8038, 1)
。
您的模型应如下所示:
from keras.layers.wrappers import TimeDistributed
model = Sequential()
model.add(LSTM(100, input_dim=num_features, return_sequences=True))
model.add(TimeDistributed(Dense(1, activation='sigmoid')))
在检查模型摘要时,可以很容易地发现这一点:
print(model.summary())