我想使用Keras训练2维回归的神经网络。

我的输入是一个数字,而我的输出则有两个数字:

model = Sequential()
model.add(Dense(16, input_shape=(1,), kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
model.add(Activation('relu'))
model.add(Dense(16, input_shape=(1,), kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
model.add(Activation('relu'))
model.add(Dense(2, kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='mean_squared_error', optimizer=adam)

然后,我创建了一些用于训练的虚拟数据:
inputs = np.zeros((10, 1), dtype=np.float32)
targets = np.zeros((10, 2), dtype=np.float32)

for i in range(10):
    inputs[i] = i / 10.0
    targets[i, 0] = 0.1
    targets[i, 1] = 0.01 * i

最后,我循环训练了迷你批次,同时对训练数据进行了测试:
while True:

    loss = model.train_on_batch(inputs, targets)

    test_outputs = model.predict(inputs)

    print test_outputs

问题是,输出的输出如下:

[0.1,0.045]
[0.1,0.045]
[0.1,0.045]
.....
.....
.....

因此,虽然第一维是正确的(0.1),但第二维是不正确的。第二维应为[0.01,0.02,0.03,.....]。因此,实际上,网络(0.45)的输出只是第二维中所有值的平均值。

我究竟做错了什么?

最佳答案

问题是,您正在用零初始化所有权重。问题是,如果所有权重都相同,则所有梯度都相同。因此,好像您的网络在每个层上都具有单个神经元。删除它,以便使用默认的随机初始化,并且可以正常工作:

model = Sequential()
model.add(Dense(16, input_shape=(1,)))
model.add(Activation('relu'))
model.add(Dense(16, input_shape=(1,)))
model.add(Activation('relu'))
model.add(Dense(2))
model.compile(loss='mean_squared_error', optimizer='Adam')

1000个纪元后的结果:
Epoch 1000/1000
10/10 [==============================] - 0s - loss: 5.2522e-08

In [59]: test_outputs
Out[59]:
array([[ 0.09983768,  0.00040025],
       [ 0.09986718,  0.010469  ],
       [ 0.09985521,  0.02051429],
       [ 0.09984323,  0.03055958],
       [ 0.09983127,  0.04060487],
       [ 0.09995781,  0.05083206],
       [ 0.09995599,  0.06089856],
       [ 0.09995417,  0.07096504],
       [ 0.09995237,  0.08103154],
       [ 0.09995055,  0.09109804]], dtype=float32)

关于tensorflow - 用Keras进行多维回归,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/43985082/

10-12 21:20