Yunfeng's Simple Blog

Yunfeng's Simple Blog

马上订阅 Yunfeng's Simple Blog RSS 更新: https://vra.github.io/atom.xml

repetition_penality的作用与实现

2025年6月2日 15:49

1. 原理说明

在跑LLM推理的时候,有时候会出现模型不断复读的现象,也就是模型一直输出同一个token或者token序列,不结束输出。transformers库中有一个参数repetition_penality专门针对此现象进行优化,通过将其设置为大于1.0的一个浮点数(如1.05, 1.1, 1.2等),有些情况下能缓解重复问题。 这个优化思路是在2019年的论文CTRL中提出的。

那这个参数是怎么解决重复问题的呢?其实实现原理很简单:对于之前出现过的token,在其logits(没有经过softmax的raw score)上作用一个repetition_penality 系数,使得它的logits数值降低,进而减少被选做下一个token的概率。

原理上,可以设置repetition_penality 为一个小于1.0的浮点数,使得模型增加前面token重复输出的概率,构造一个复读机,虽然好像实际没什么作用。

这个功能在transformers库中的核心代码如下(完整代码参见RepetitionPenaltyLogitsProcessor类的实现):

1
2
3
4
5
6
if self.prompt_ignore_length:
input_ids = input_ids[:, self.prompt_ignore_length :]
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_processed = scores.scatter(1, input_ids, score)

代码解释如下:

  1. 1-2行:如果设置了prompt_ignore_length(一般是用户的原始input的长度),则忽略 原始input,也就是不对问题token作用惩罚系数,注意这里原始的input_ids既包含输入又包含之前预测tokens。
  2. 3行:获取所有的scores (logits)中input_ids 对应的score
  3. 4行:如果score <0,则乘以惩罚系数,使得logits变得更小(例如-0.5*1.1->-0.55),如果score>0,则除以惩罚系数,使得logits变得更小(例如0.5/1.1->0.454)
  4. 5行:将经过惩罚系数作用后的score写入到大的scores中
    可以看到这个功能的实现是比较简单直接的,没有太多弯弯绕绕的东西。

2. 效果实测

利用下面代码可以明显地看到这个参数对输出的影响,输入I love coding. I love,预测下一个token:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor

def print...

剩余内容已隐藏

查看完整文章以阅读更多