引言
Beam Search 是一种受限的宽度优先搜索方法,经常用在各种 NLP 生成类任务中,例如机器翻译、对话系统、文本摘要。本文首先介绍 Beam Search 的基本思想,然后再介绍一些beam search的优化方法,最后附上自己的代码实现。
1. Beam Search的基础版本
在生成文本的时候,通常需要进行解码操作,贪心搜索 (Greedy Search) 是比较简单的解码。Beam Search 对贪心搜索进行了改进,扩大了搜索空间,更容易得到全局最优解。Beam Search 包含一个参数 beam size k,表示每一时刻均保留得分最高的 k 个序列,然后下一时刻用这 k 个序列继续生成。示意图如下所示:
假设我们生成词表中有三个单词{我,爱,你}。我们设K=2K=2K=2。那么我们在第一时刻确定两个候选输出是{我,你}。紧接着我们要考虑第二个输出,具体步骤如下:
- 确定单词“我”为第一时刻输出,并将其作为第二时刻输入,在已知p(x,我)p(x,我)p(x,我)的情况下,各个单词的输出概率为3种情况,每个组合的概率为 P(我∣x)P(y2∣x,我)P(我|x)P(y_2|x,我)P(我∣x)P(y2?∣x,我)。
- 同样我们把“你”也作为第二时刻输入,同样也有三种组合。
- 最后我们在六种组合中选择概率最大的三个组合。
接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:
以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。
2. Beam Search的优化
Beam search算法也有许多改进的地方。
2.1 Length normalization:惩罚短句
根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数,这样比较的就是每个单词的平均概率大小。此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。谷歌给的公式如下:
其中α∈[0,1],谷歌建议取值为[0.6,0.7]之间,α用于length normalization。
2.2 Coverage normalization:惩罚重复
另外我们在序列到序列任务中经常会发现一个问题,2016 年, 华为诺亚方舟实验室的论文提到,机器翻译的时候会存在over translation or undertranslation due to attention coverage。 作者提出coverage-based atttention机制来解决coverage 问题。 Google machine system 利用了如下的方式进行了length normalization 和 coverage penalty。
还是上述公式,β用于控制coverage penalty
coverage penalty 主要用于使用 Attention 的场合,通过 coverage penalty 可以让 Decoder 均匀地关注于输入序列 xxx 的每一个 token,防止一些 token 获得过多的 Attention。
2.3 End of sentence normalization:抑制长句
有的时候我们发现生成的序列一直生成下去不会停止,有的时候我们可以显式的设置最大生成长度进行控制,这里我们可以采用下式来进行约束:
其中∣X∣|X|∣X∣是source的长度,∣Y∣|Y|∣Y∣是当前target的长度,那么由上式可知,target长度越长的话,上述得分越低,这样就会防止出现生成一直不停止的情况。
3. Beam Search的代码实现
总的来说,beam search不保证全局最优,但是比greedy search搜索空间更大,一般结果比greedy search要好。下面附上一些代码实现:
首先,首先定义一个 Beam
类,作为一个存放候选序列的容器,属性需维护当前序列中的 token 以及对应的对数概率,同时还需维护跟当前 timestep 的 Decoder 相关的一些变量。此外,还需要给 Beam 类实现两个函数:一个 extend
函数用以扩展当前的序列(即添加新的 time step的 token 及相关变量);一个 score
函数用来计算当前序列的分数(在Beam类下的seq_score函数中有Length normalization以及Coverage normalization)。
class Beam(object):def __init__(self,tokens,log_probs,decoder_states,coverage_vector):self.tokens = tokensself.log_probs = log_probsself.decoder_states = decoder_statesself.coverage_vector = coverage_vectordef extend(self,token,log_prob,decoder_states,coverage_vector):return Beam(tokens=self.tokens + [token],log_probs=self.log_probs + [log_prob],decoder_states=decoder_states,coverage_vector=coverage_vector)def seq_score(self):"""This function calculate the score of the current sequence."""len_Y = len(self.tokens)# Lenth normalizationln = (5+len_Y)**config.alpha / (5+1)**config.alphacn = config.beta * torch.sum( # Coverage normalizationtorch.log(config.eps +torch.where(self.coverage_vector < 1.0,self.coverage_vector,torch.ones((1, self.coverage_vector.shape[1])).to(torch.device(config.DEVICE)))))score = sum(self.log_probs) / ln + cnreturn scoredef __lt__(self, other):return self.seq_score() < other.seq_score()def __le__(self, other):return self.seq_score() <= other.seq_score()
接着我们需要实现一个 best_k
函数,作用是将一个 Beam 容器中当前 time step 的变量传入 Decoder 中,计算出新一轮的词表概率分布,并从中选出概率最大的 k 个 token 来扩展当前序列(其中加入了End of sentence normalization),得到 k 个新的候选序列。
def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs):"""Get best k tokens to extend the current sequence at the current time step."""# use decoder to generate vocab distribution for the next tokenx_t = torch.tensor(beam.tokens[-1]).reshape(1, 1)x_t = x_t.to(self.DEVICE)# Get context vector from attention network.context_vector, attention_weights, coverage_vector = \self.model.attention(beam.decoder_states,encoder_output,x_padding_masks,beam.coverage_vector)# Replace the indexes of OOV words with the index of OOV token# to prevent index-out-of-bound error in the decoder.p_vocab, decoder_states, p_gen = \self.model.decoder(replace_oovs(x_t, self.vocab),beam.decoder_states,context_vector)final_dist = self.model.get_final_distribution(x,p_gen,p_vocab,attention_weights,torch.max(len_oovs))# Calculate log probabilities.log_probs = torch.log(final_dist.squeeze())# Filter forbidden tokens.# EOS token penalty. Follow the definition in# https://opennmt.net/OpenNMT/translation/beam_search/.log_probs[self.vocab.EOS] *= \config.gamma * x.size()[1] / len(beam.tokens)log_probs[self.vocab.UNK] = -float('inf')# Get top k tokens and the corresponding logprob.topk_probs, topk_idx = torch.topk(log_probs, k)# Extend the current hypo with top k tokens, resulting k new hypos.best_k = [beam.extend(x,log_probs[x],decoder_states,coverage_vector) for x in topk_idx.tolist()]return best_k
最后我们实现主函数 beam_search
。初始化encoder、attention和decoder的输?,然后对于每?个decodestep,对于现有的k个beam,我们分别利?best_k
函数来得到各?最佳的k个extended beam,也就是每个decode step我们会得到k*k个新的beam,然后只保留分数最?的k个,作为下?轮需要扩展的k个beam。为了只保留分数最?的k个beam,我们可以??个堆(heap)来实现,堆的中只保存k个节点,根结点保存分数最低的beam。
def beam_search(self,x,max_sum_len,beam_width,len_oovs,x_padding_masks):"""Using beam search to generate summary."""# run body_sequence input through encoderencoder_output, encoder_states = self.model.encoder(replace_oovs(x, self.vocab))coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE)# initialize decoder states with encoder forward statesdecoder_states = self.model.reduce_state(encoder_states)# initialize the hypothesis with a class Beam instance.init_beam = Beam([self.vocab.SOS],[0],decoder_states,coverage_vector)# get the beam size and create a list for stroing current candidates# and a list for completed hypothesisk = beam_widthcurr, completed = [init_beam], []# use beam search for max_sum_len (maximum length) stepsfor _ in range(max_sum_len):# get k best hypothesis when adding a new tokentopk = []for beam in curr:# When an EOS token is generated, add the hypo to the completed# list and decrease beam size.if beam.tokens[-1] == self.vocab.EOS:completed.append(beam)k -= 1continuefor can in self.best_k(beam,k,encoder_output,x_padding_masks,x,torch.max(len_oovs)):# Using topk as a heap to keep track of top k candidates.# Using the sequence scores of the hypos to campare# and object ids to break ties.add2heap(topk, (can.seq_score(), id(can), can), k)curr = [items[2] for items in topk]# stop when there are enough completed hypothesisif len(completed) == beam_width:break# When there are not engouh completed hypotheses,# take whatever when have in current best k as the final candidates.completed += curr# sort the hypothesis by normalized probability and choose the best oneresult = sorted(completed,key=lambda x: x.seq_score(),reverse=True)[0].tokensreturn result