问题描述
我的问题与pytorch register_hook
的语法有关.
My question is concerning the syntax of pytorch register_hook
.
x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y
x.register_hook(print)
y.register_hook(print)
z.backward()
输出:
tensor([2.])
tensor([4.])
此代码段仅分别打印 z
w.r.t x
和 y
的渐变.
this snippet simply prints the gradient of z
w.r.t x
and y
, respectively.
现在,我的(最可能是琐碎的)问题是如何返回中间渐变(而不是仅打印)?
Now my (most likely trivial) question is how to return the intermediate gradients (rather than only printing)?
更新:
看来,调用 retain_grad()
可以解决叶节点的问题.前任. y.retain_grad()
.
It appears that calling retain_grad()
solves the issue for leaf nodes. ex. y.retain_grad()
.
但是,对于非叶子节点, retain_grad
似乎无法解决.有什么建议吗?
However, retain_grad
does not seem to solve it for non-leaf nodes. Any suggestions?
推荐答案
我认为您可以使用这些挂钩将梯度存储在全局变量中:
I think you can use those hooks to store the gradients in a global variable:
grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()
但是您很可能还需要记住计算这些梯度所对应的张量.在这种情况下,我们使用 dict
而不是 list
:
But you most likely also need to remember the corresponding tensor these gradients were computed for. In that case, we slightly extend above using a dict
instead of list
:
grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
print(grad,parent)
grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()
例如,现在您只需使用 grads [y]
Now you can, for example, access tensor y
's grad simply using grads[y]
这篇关于如何在pytorch中返回中间梯度(对于非叶子节点)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!