所以,我明白归一化对于训练神经网络很重要。

我也明白我必须使用训练集中的参数对验证集和测试集进行标准化(参见例如这个讨论: https://stats.stackexchange.com/questions/77350/perform-feature-normalization-before-or-within-model-validation )

我的问题是:我如何在 Keras 中做到这一点?

我目前正在做的是:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping

def Normalize(data):
    mean_data = np.mean(data)
    std_data = np.std(data)
    norm_data = (data-mean_data)/std_data
    return norm_data

input_data, targets = np.loadtxt(fname='data', delimiter=';')
norm_input = Normalize(input_data)

model = Sequential()
model.add(Dense(25, input_dim=20, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

early_stopping = EarlyStopping(monitor='val_acc', patience=50)
model.fit(norm_input, targets, validation_split=0.2, batch_size=15, callbacks=[early_stopping], verbose=1)

但在这里,我 首先 规范化数据 w.r.t。整个数据集和 然后 拆分验证集,根据上述讨论这是错误的。

从训练集(training_mean 和 training_std)中保存均值和标准差并不是什么大问题,但是我如何分别在验证集上应用 training_mean 和 training_std 的归一化?

最佳答案

在使用 sklearn.model_selection.train_test_split 拟合模型之前,您可以手动将数据拆分为训练和测试数据集。然后,根据训练数据的均值和标准差对训练和测试数据进行归一化。最后,使用 model.fit 参数调用 validation_data

代码示例

import numpy as np
from sklearn.model_selection import train_test_split

data = np.random.randint(0,100,200).reshape(20,10)
target = np.random.randint(0,1,20)

X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

def Normalize(data, mean_data =None, std_data =None):
    if not mean_data:
        mean_data = np.mean(data)
    if not std_data:
        std_data = np.std(data)
    norm_data = (data-mean_data)/std_data
    return norm_data, mean_data, std_data

X_train, mean_data, std_data = Normalize(X_train)
X_test, _, _ = Normalize(X_test, mean_data, std_data)

model.fit(X_train, y_train, validation_data=(X_test,y_test), batch_size=15, callbacks=[early_stopping], verbose=1)

关于python - 在 Keras 中标准化神经网络的验证集,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/45301648/

10-12 19:34