Yunfeng's Simple Blog
Yunfeng's Simple Blog
马上订阅 Yunfeng's Simple Blog RSS 更新: https://vra.github.io/atom.xml
repetition_penality的作用与实现
1. 原理说明
在跑LLM推理的时候,有时候会出现模型不断复读的现象,也就是模型一直输出同一个token或者token序列,不结束输出。transformers库中有一个参数repetition_penality专门针对此现象进行优化,通过将其设置为大于1.0的一个浮点数(如1.05, 1.1, 1.2等),有些情况下能缓解重复问题。 这个优化思路是在2019年的论文CTRL中提出的。
那这个参数是怎么解决重复问题的呢?其实实现原理很简单:对于之前出现过的token,在其logits(没有经过softmax的raw score)上作用一个repetition_penality 系数,使得它的logits数值降低,进而减少被选做下一个token的概率。
原理上,可以设置repetition_penality 为一个小于1.0的浮点数,使得模型增加前面token重复输出的概率,构造一个复读机,虽然好像实际没什么作用。
这个功能在transformers库中的核心代码如下(完整代码参见RepetitionPenaltyLogitsProcessor类的实现):
1 | if self.prompt_ignore_length: |
代码解释如下:
- 1-2行:如果设置了
prompt_ignore_length(一般是用户的原始input的长度),则忽略 原始input,也就是不对问题token作用惩罚系数,注意这里原始的input_ids既包含输入又包含之前预测tokens。 - 3行:获取所有的
scores (logits)中input_ids 对应的score - 4行:如果
score <0,则乘以惩罚系数,使得logits变得更小(例如-0.5*1.1->-0.55),如果score>0,则除以惩罚系数,使得logits变得更小(例如0.5/1.1->0.454) - 5行:将经过惩罚系数作用后的score写入到大的scores中
可以看到这个功能的实现是比较简单直接的,没有太多弯弯绕绕的东西。
2. 效果实测
利用下面代码可以明显地看到这个参数对输出的影响,输入I love coding. I love,预测下一个token:
1 | import torch |