我有一个具有以下设置的数据框:
import numpy as np
X = np.random.rand(100, 20, 3)
这里有100个时间片,20个观测值和每个观测值3个属性。
我试图弄清楚如何将上面的数据传递给以下Keras序列:
from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Dropout, Activation
import keras
# config
stateful = False
look_back = 3
lstm_cells = 1024
dropout_rate = 0.5
n_features = int(X.shape[1]*3)
input_shape = (look_back, n_features, 3)
output_shape = n_features
def loss(y_true, y_pred):
return keras.losses.mean_squared_error(y_true, y_pred)
model = Sequential()
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True, input_shape=input_shape))
model.add(Dense(output_shape, activation='relu'))
model.compile(loss=loss, optimizer='sgd')
运行此抛出:
ValueError:输入0与层lstm_23不兼容:预期
ndim = 3,找到的ndim = 4
有谁知道我如何重塑
X
并将其传递到模型中?任何的意见都将会有帮助! 最佳答案
这似乎使事情发生了变化:
from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Dropout, Activation
import keras
# config
stateful = False
look_back = 3
lstm_cells = 1024
dropout_rate = 0.5
n_features = int(X.shape[1]) * 3
input_shape = (look_back, n_features)
output_shape = n_features
def loss(y_true, y_pred):
return keras.losses.mean_squared_error(y_true, y_pred)
model = Sequential()
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True, input_shape=input_shape))
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True))
model.add(LSTM(lstm_cells, stateful=stateful))
model.add(Dense(output_shape, activation='relu'))
model.compile(loss=loss, optimizer='sgd')
然后,可以按以下方式划分训练数据:
# build training data
train_x = []
train_y = []
n_time = int(X.shape[0])
n_obs = int(X.shape[1])
n_attrs = int(X.shape[2])
# note we flatten the last dimension
for i in range(look_back, n_time-1, 1):
train_x.append( X[i-look_back:i].reshape(look_back, n_obs * n_attrs ) )
train_y.append( X[i+1].ravel() )
train_x = np.array(train_x)
train_y = np.array(train_y)
然后可以训练玩具模型:
model.fit(train_x, train_y, epochs=10, batch_size=10)
关于python - Keras序列:使用三个参数指定输入形状,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52770595/