当前位置: 代码迷 >> 综合 >> detach的简易用法
  详细解决方案

detach的简易用法

热度:90   发布时间:2023-09-22 06:55:36.0
import torch 
a = torch.tensor([1, 2, 3.], requires_grad=True)
b = torch.tensor([2, 3, 4.], requires_grad=True)
n = a*2
n2 = n.detach()
f = n2 + 3*a/b
#detach 用法 阻断梯度传播 比如此时n2就没有梯度 但是a有 如果把对应a改成b 则b也有
f.sum().backward()
print(a.grad)对于经常出现的round函数 本身没有梯度 可以采用
w_1 = round(w)-w
w_2 = w_1.detach()
w_3 = w_2 + w
这种方式 绕过对round求梯度 采取用w的梯度代替代码运行结果:
tensor([3., 3., 3.])