本文介绍了如何在pytorch中返回中间梯度(对于非叶子节点)?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的问题与pytorch register_hook 的语法有关.

My question is concerning the syntax of pytorch register_hook.

x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y

x.register_hook(print)
y.register_hook(print)

z.backward()

输出:

tensor([2.])
tensor([4.])

此代码段仅分别打印 z w.r.t x y 的渐变.

this snippet simply prints the gradient of z w.r.t x and y, respectively.

现在,我的(最可能是琐碎的)问题是如何返回中间渐变(而不是仅打印)?

Now my (most likely trivial) question is how to return the intermediate gradients (rather than only printing)?

更新:

看来,调用 retain_grad()可以解决叶节点的问题.前任. y.retain_grad().

It appears that calling retain_grad() solves the issue for leaf nodes. ex. y.retain_grad().

但是,对于非叶子节点, retain_grad 似乎无法解决.有什么建议吗?

However, retain_grad does not seem to solve it for non-leaf nodes. Any suggestions?

推荐答案

我认为您可以使用这些挂钩将梯度存储在全局变量中:

I think you can use those hooks to store the gradients in a global variable:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y

x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))

z.backward()

但是您很可能还需要记住计算这些梯度所对应的张量.在这种情况下,我们使用 dict 而不是 list :

But you most likely also need to remember the corresponding tensor these gradients were computed for. In that case, we slightly extend above using a dict instead of list:

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y

def store(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()

x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))

z.sum().backward()

例如,现在您只需使用 grads [y]

Now you can, for example, access tensor y's grad simply using grads[y]

这篇关于如何在pytorch中返回中间梯度(对于非叶子节点)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

07-23 03:50