当前位置: 代码迷 >> 综合 >> Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践
  详细解决方案

Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践

热度:81   发布时间:2023-10-20 17:15:30.0

NeurPIS2020的论文

官方代码

基础的框架采用STM,但是提出了自适应调整memory bank和一种新颖的refinement操作。

Motivation

STM是目前半监督VOS方向的SOTA论文方法,几乎后面的论文都是在STM的基础上改进。
作者分析了STM的缺点:

  • 在测试的时候,每5帧增加一个memory,如果是长序列,memory bank可能会爆显存。
  • 每隔五帧,更新一次memory bank,可能会漏过一些关键帧

作者提出Adaptive feature bank(AFB)来自适应更新bank,加入新的特征进来,同时如果有需要,则排出一些特征,将memory bank的容量控制在一个上限之下。
同时针对目标边缘的分割,目标边缘是难分类样本,也是不确定区域。作者提出uncertain-region refinement(URR)来提升边缘分割质量。

方法

Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践

从框架上看,可以看作是一个二阶段的vos方法:

  • 第一阶段,采用STM结合Adaptive Feature bank得到initial segmentation。
  • 第二阶段,采用不确定区域检测 和local refinement module得到最终细化的结果。

约定一些符号:

  • LLL:目标数目,davis17是多目标数据集,包含背景(没错,确实包含背景)
  • FBFBFB: feature bank,存有value 和key。 FB={(kl,vl)∣l=0,1,2,...L?1}FB=\{(k_l, v_l) | l = 0,1,2,...L-1\}FB={ (kl?,vl?)l=0,1,2,...L?1},也就是说每一个目标都有一对(key,value)
  • query或者Q:是当前帧
  • kQ,vQk^Q, v^QkQ,vQ: 从当前帧中提取出的key和value

Pipeline,假设对第i个目标进行分割:

  • kQk^QkQ和feature bank中的FBiFB_iFBi?做attention transform。这一步和STM一样。
  • 然后经过和STM一样的Decoder。得到initial segmentation。
  • 使用Uncertain-Regions Refinement对结果进行细化。
  • 将预测结果mask和当前帧,塞进feature bank更新,然后就可以下一帧预测了

整个过程相对于STM,有两个核心,第一个是Uncertain-Regions Refinement,第二个就是如何自适应更新feature bank。

Uncertain-Regions Refinement

这个模块分两步走:

  • 先得到confience loss,和不确定区域的mask
  • 根据不确定区域的mask,局部细化

confidence loss

第一阶段得到的输出是一个shape如(LLL, h, w)的张量,经过softmax之后,得到每个像素点属于每个目标的概率,记做M。
不确定区域的mask通过如下方式获得:
U=exp(1?M1′M2′)U = exp(1-\frac{M'_1}{M'_2})U=exp(1?M2?M1??)
M1′M'_1M1?M2′M'_2M2?是每个像素位置最大的概率值和第二大的概率值组成的两个mask。U的值范围在(0,1]。
越接近0,代表M1′M'_1M1?的值远大于M2′M'_2M2?,说明网络很确信这个像素的所属目标。
Lconf=∣∣U∣∣2L_{conf} = ||U||_2Lconf?=U2?
让U的2范数作为loss,让U的值尽量都降到0。优化方向就是让每个像素点都确信自己的类别。
Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践
上图红色圈中的白色区域,就是不确定区域。可以发现都集中在边缘。

local refinement mechanism

针对U中体现的不确定区域,使用 local refinement mechanism来refine。
作者认为,一个不确定点p的类别,应该可以从p点周围的点的类别来推断。这里说的类别就是哪个目标的意思。
首先要获得一个局部特征。将res1特征和第一阶段预测的mask做乘法。

 rough_seg = rough_seg.view(bs * obj_n, 1, h, w)  # bs*obj_n, 1, h, wr1_weighted = r1 * rough_seg  # 得到reference feature

然后在邻域7*7的范围内求平均,得到local feature

r1_local = self.local_avg(r1_weighted)  # bs*obj_n, 64, h, w # 7*7的邻域

然后希望求均值的系数不是均分的,而是能由周围的概率值决定。

r1_local = r1_local / (self.local_avg(rough_seg) + 1e-8)  # neighborhood reference

其实就是论文中这个公式
Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践
然后要学习res1和r1_local的相似度,

r1_conf = self.local_max(rough_seg)  # bs*obj_n, 1, h, w
local_match = torch.cat([r1, r1_local], dim=1)
q = self.local_ResMM(self.local_convFM(local_match))
q = r1_conf * self.local_pred2(NF.relu(q))

等价这条公式
Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践
最后,用来refine第一阶段预测的logits

 p = p + uncertainty * q # 用P+;p是没有经过softmax的特征值 不是概率值
p = NF.interpolate(p, scale_factor=2, mode='bilinear', align_corners=False)  # 原图大小
p = NF.softmax(p, dim=1)[:, 1]  # no, h, w 相当于sigmoid

之后还有一个soft aggregation 和softmax。

Adaptive feature bank(AFB)

AFB也有两个步骤:

  • 吸收新特征, Absorbing new features
  • 删除过期特征, Removing obsolete features
    Note:不同目标有不同的Bank。在第一帧的时候初始化,第一帧不做预测。并且只在测试的时候用。

吸收新特征

如果新特征和旧特征距离近,就直接merge了。省了空间,提高test效率
融合的过程也很暴力,如果两个向量的余弦距离大于0.95,则直接使用滑动平均融合。如果小于0.95,则直接添加进去。
Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践

删除过期特征

思想: 如果feature bank中的特征很久没有被read out,那我们也没有必要保留他了。
什么是read out:即被使用。如果在做matching的过程中。Feature bank的第j个特征和query的第i个特征的dot product大于1e-4,对第j个特征的计数+1。 没错,feature bank的每个特征都对应有一个counter和time span counter。前者记录被readout 的次数,后者记录这个特征跨越的时间帧数。
使用LFU评价特征重要程度,如果一个FB中的特征很久没有使用了,readout的次数很少,且时间跨度很大,则LFU值就小
Video Object Segmentation with Adaptive Feature Bank and Uncertain-Region Refinement论文解读和代码实践

# p是softmax的attention mapif self.update_bank:  # 只在测试的时候用!!!try:ones = torch.ones_like(p)zeros = torch.zeros_like(p)bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0]except RuntimeError as e:device = p.devicep = p.cpu()ones = torch.ones_like(p)zeros = torch.zeros_like(p)bank_cnt = torch.where(p > self.thres_valid, ones, zeros).sum(dim=2)[0].to(device)print('\tLine 170. GPU out of memory, use CPU', f'p size: {p.shape}')feature_bank.info[i][:, 1] += torch.log(bank_cnt + 1)

从matcher代码中,可以看到,对p中的每个位置判断和阈值thres_valid的关系,实现计数,然后把数字保存到FB中。

细节

  • FB只在eval中用,训练的时候,采6张图像,把第一张往后5张warp。
  • 400*400 。uncertain loss的权重为0.5。一次最大从序列中采3个目标。
  相关解决方案