当前位置: 代码迷 >> 综合 >> pytorch 实现 GRL Gradient Reversal Layer
  详细解决方案

pytorch 实现 GRL Gradient Reversal Layer

热度:75   发布时间:2024-02-08 03:57:20.0

在GRL中,要实现的目标是:在前向传导的时候,运算结果不变化,在梯度传导的时候,传递给前面的叶子节点的梯度变为原来的相反方向。举个例子最好说明了:

import torch
from torch.autograd  import  Functionx = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()print(s)
s.backward()
print(x)
print(x.grad)

这个程序的运行结果是:

tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])

这个运算过程对于tensor中的每个维度上的运算为:

那么对于x的导数为:

所以当输入x=[1,2,3]时,对应的梯度为:[18,30,42]

因此这个是正常的梯度求导过程,但是如何进行梯度翻转呢?很简单,看下方的代码:

import torch
from torch.autograd  import  Functionx = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + yclass GRL(Function):def forward(self,input):return inputdef backward(self,grad_output):grad_input = grad_output.neg()return grad_inputGrl = GRL()s =6* f.sum()
s = Grl(s)print(s)
s.backward()
print(x)
print(x.grad)

运行结果为:

tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])

这个程序相对于上一个程序,只是差在加了一个梯度翻转层:

class GRL(Function):def forward(self,input):return inputdef backward(self,grad_output):grad_input = grad_output.neg()return grad_input

这个部分的forward没有进行任何操作,backward里面做了.neg()操作,相当于进行了梯度的翻转。在torch.autograd 中的FUnction 的backward部分,在不做任何操作的情况下,这里的grad_output的默认值是1.

  相关解决方案