当前位置: 代码迷 >> 综合 >> pytorch MultiheadAttention 出现NaN
  详细解决方案

pytorch MultiheadAttention 出现NaN

热度:35   发布时间:2023-11-21 02:42:01.0
    if attn_mask is not None:if attn_mask.dtype == torch.bool:attn_output_weights.masked_fill_(attn_mask, float('-inf'))else:attn_output_weights += attn_mask

使用MultiheadAttention做self-attention时因为batch内序列长度不一致,难免需要使用mask

以pytorch自带的torch.nn.TransformerEncoder方法为例,其forward函数如下

forward(src, mask=None, src_key_padding_mask=None)

这里的mask会送到torch.nn.TransformerEncoderLayer的forward函数:

def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:src2 = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]

的参数src_mask,然后送到MultiheadAttention的forward函数的attn_mask参数,而这里做的是一个self attention

使用attn_mask一定要小心,千万不要出现一整行都是True的情况,如下是源码中实现mask的方法:

    if attn_mask is not None:if attn_mask.dtype == torch.bool:attn_output_weights.masked_fill_(attn_mask, float('-inf'))else:attn_output_weights += attn_mask

把权重矩阵中需要mask的位置置为负无穷,然后再按行做softmax,问题就在这里,把一个元素全是是负无穷的tensor送给softmax,就会得到一个全是NaN的tensor。然后NaN和任何数运算都是NaN,NaN会传染,再经过一轮self attention,输出的tensor就全成NaN了。