我有一个具有以下设置的数据框:

import numpy as np

X = np.random.rand(100, 20, 3)


这里有100个时间片,20个观测值和每个观测值3个属性。

我试图弄清楚如何将上面的数据传递给以下Keras序列:

from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Dropout, Activation
import keras

# config
stateful = False
look_back = 3
lstm_cells = 1024
dropout_rate = 0.5
n_features = int(X.shape[1]*3)
input_shape = (look_back, n_features, 3)
output_shape = n_features

def loss(y_true, y_pred):
  return keras.losses.mean_squared_error(y_true, y_pred)

model = Sequential()
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True, input_shape=input_shape))
model.add(Dense(output_shape, activation='relu'))
model.compile(loss=loss, optimizer='sgd')


运行此抛出:


  ValueError:输入0与层lstm_23不兼容:预期
  ndim = 3,找到的ndim = 4


有谁知道我如何重塑X并将其传递到模型中?任何的意见都将会有帮助!

最佳答案

这似乎使事情发生了变化:

from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Dropout, Activation
import keras

# config
stateful = False
look_back = 3
lstm_cells = 1024
dropout_rate = 0.5
n_features = int(X.shape[1]) * 3
input_shape = (look_back, n_features)
output_shape = n_features

def loss(y_true, y_pred):
  return keras.losses.mean_squared_error(y_true, y_pred)

model = Sequential()
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True, input_shape=input_shape))
model.add(LSTM(lstm_cells, stateful=stateful, return_sequences=True))
model.add(LSTM(lstm_cells, stateful=stateful))
model.add(Dense(output_shape, activation='relu'))
model.compile(loss=loss, optimizer='sgd')


然后,可以按以下方式划分训练数据:

# build training data
train_x = []
train_y = []
n_time = int(X.shape[0])
n_obs = int(X.shape[1])
n_attrs = int(X.shape[2])

# note we flatten the last dimension
for i in range(look_back, n_time-1, 1):
  train_x.append( X[i-look_back:i].reshape(look_back, n_obs * n_attrs ) )
  train_y.append( X[i+1].ravel() )

train_x = np.array(train_x)
train_y = np.array(train_y)


然后可以训练玩具模型:

model.fit(train_x, train_y, epochs=10, batch_size=10)

关于python - Keras序列:使用三个参数指定输入形状,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52770595/

10-14 05:45