我正在使用MXnet制作简单的NN,但是step()方法有一些问题

x1.shape=(64, 1, 1000)
y1.shape=(64, 1, 10)


net =nm.Sequential()
net.add(nn.Dense(H,activation='relu'),nn.Dense(90,activation='relu'),nn.Dense(D_out))


for t in range(500):
    #y_pred = net(x1)

    #loss = loss_fn(y_pred, y)
    #for i in range(len(x1)):

    with autograd.record():
        output=net(x1)
        loss =loss_fn(output,y1)
    loss.backward()
    trainer.step(64)
    if t % 100 == 99:
        print(t, loss)
        #optimizer.zero_grad()



  用户警告:上下文cpu(0)上参数dense30_weight的渐变
  自上一个step以来未进行过向后更新。这可能意味着
  模型中的错误,使其仅使用一部分参数
  (块)进行此迭代。如果您有意仅使用
  子集,使用ignore_stale_grad = True调用步骤以禁止显示此警告
  并跳过具有陈旧渐变的参数更新

最佳答案

该错误表明您正在训练器中传递不在计算图中的参数。
您需要初始化模型的参数并定义训练器。与Pytorch不同,您无需在MXNet中调用zero_grad,因为默认情况下会写入新渐变而不进行累积。以下代码显示了使用MXNet的Gluon API实现的简单神经网络:

# Define model
net = gluon.nn.Dense(1)
net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=model_ctx)
square_loss = gluon.loss.L2Loss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.0001})

# Create random input and labels
def real_fn(X):
    return 2 * X[:, 0] - 3.4 * X[:, 1] + 4.2

X = nd.random_normal(shape=(num_examples, num_inputs))
noise = 0.01 * nd.random_normal(shape=(num_examples,))
y = real_fn(X) + noise

# Define Dataloader
batch_size = 4
train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(X, y), batch_size=batch_size, shuffle=True)
num_batches = num_examples / batch_size

for e in range(10):

    # Iterate over training batches
    for i, (data, label) in enumerate(train_data):

    # Load data on the CPU
        data = data.as_in_context(mx.cpu())
        label = label.as_in_context(mx.cpu())

        with autograd.record():
            output = net(data)
            loss = square_loss(output, label)

    # Backpropagation
        loss.backward()
        trainer.step(batch_size)

        cumulative_loss += nd.mean(loss).asscalar()

    print("Epoch %s, loss: %s" % (e, cumulative_loss / num_examples))

关于machine-learning - 在mxnet错误中定义简单的神经网络,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/57720590/

10-12 18:13