自学内容网 自学内容网

ChatGPT背后的创新之源:InstructGPT的详细解读~

Training language models to follow instructions with human feedback

Note:InstructGPT作为ChatGPT的前身,他们的模型结构,训练方式都完全一致,即都是用了instrcut learning和RLHF指导模型学习。区别可能就是微调的元模型不同(InstructGPT是在GPT3基础上,而ChatGPT是在GPT3.5)

本篇用自己通俗易懂的方式讲解自己对InstructGPT的理解~

原文链接: https://arxiv.org/pdf/2203.02155.pdf

1.Abstract

大语言模型在生成答案时,可能会产生有毒的、不真实的、对用户没有帮助的(胡编乱造)的输出。例如GPT3虽然能力很强大,但是它的训练数据中来自互联网中大量没有筛选过的内容,其中可能存在各种偏见、歧视性言论等不适当的内容。InstructGPT旨在通过提供更加细粒度的指导和控制,来解决GPT3存在的一些缺陷:

1.1.InstructGPT对标GPT3中的缺陷:

  1. 增强上下文理解:InstructGPT使用prompt对输入的训练数据进行重新的定义和引导,帮助模型更好的理解当下的语境和任务,从而避免误解或忽略特定的上下文信息。

  2. 排除推广偏见和不当内容:InstructGPT通过人工干预,指导和约束尽量减少模型生成的偏见性言论或不适当内容,提升生成文本的准确性和中立性。

1.2.InstructGPT训练流程:

  1. 收集人工标注的演示数据集并微调GPT3:首先,需要创建一个人工标注的演示数据集,其中包含了任务示例文本或指令以及对应的期望输出(这些示例可以是从专家或众包平台收集的)。然后,将收集到的演示数据集输入到 GPT3模型中进行微调。微调的目标是让模型学习在特定任务中生成符合期望的输出。

  2. 生成多个输出-进行排序,以训练奖励模型:首先,使用微调后的GPT3模型,将演示文本输入模型并生成多个候选输出。这些候选输出可以通过模型的自动推理生成。然后,对生成的多个候选输出进行人工干预和正确性排序。最后,使用人工干预排序的数据,训练一个奖励模型。

  3. 强化学习微调GPT3:将奖励模型用作强化学习的优化目标,进一步微调GPT3。

(这里的描述只是说一下模型的流程,后续会细节性描述)

2.DataSet

InstructGPT的训练分为3个步骤,每个步骤对应一个专属的训练数据集:

2.1.SFT数据集(step 1):

SFT数据集是用来训练step 1的GPT3模型,即按照GPT3的训练方式对GPT3进行微调。因为GPT3是一个自回归基于提示学习的生成模型,因此SFT数据集也是由提示-答复对组成的样本。

SFT数据一部分来自使用OpenAI的PlayGround的用户,另一部分来自OpenAI雇佣的40名标注工(labeler),在SFT中,标注工作是根据内容自己编写指示,并且要求编写的指示满足下面三点:

简单任务:labeler给出任意一个简单的任务,同时要确保任务的多样性;

Few-shot任务:labeler给出一个指示,以及该指示的多个查询-响应对;

用户相关的:从接口中获取用例,然后让labeler根据这些用例编写指示。

(SFT数据集包含13k个训练提示)

2.1.1.指示学习(Instruct Learning)和提示(Prompt Learning)学习

指示学习和提示学习的目的都是去挖掘语言模型本身具备的知识。不同的是Prompt是激发语言模型的补全能力,例如根据上半句生成下半句,或是完形填空等。Instruct是激发语言模型的理解能力,它通过给出更明显的指令,让模型去做出正确的行动。

提示学习:今天发了工资,我感觉我要____了!

指示学习:这句话的情感是非常正向的:今天发了工资,我感觉我要发财了!

Instruct Learning的优点是它经过多任务的微调后,也能够在其他任务上做zero-shot,而Prompt Learning都是针对一个任务的。泛化能力不如指示学习。

2.2.RM数据集

RM数据集用来训练step 2的奖励模型,为InstructGPT的训练设置一个奖励目标,要尽可能全面且真实的对齐需要模型生成的内容。很自然的,可以通过人工标注的方式来提供这个奖励,通过人工对可以给那些涉及偏见的生成内容更低的分从而鼓励模型不去生成这些人类不喜欢的内容。InstructGPT的做法是先让模型生成一批候选文本,让后通过labeler根据生成数据的质量对这些生成内容进行排序。

(RM 数据集有 33k 个训练提示)

2.3.PPO数据集

PPO数据集用来训练强化模型,即InstructGPT。InstructGPT的PPO数据没有进行标注,它均来自GPT-3的API的用户。既又不同用户提供的不同种类的生成任务

(PPO 数据集有 31k 个训练提示)

img
img

InstructGPT中数据集的分布以及其他详细信息

3.InstructGPT原理解读

img

图2.InstructGPT的三个步骤

LLMs模型能够通过提示的方式把任务作为输入,但是这些模型也经常会输出一些不好的回复,比如说捏造事实,生成有偏见的、有害的或者是没有按照想要的方式来,这是因为整个语言模型训练的目标函数有问题。LLMs模型通过预测下一个词的方式进行训练,其目标函数是最大化给定语言序列的条件概率,而不是“有帮助且安全地遵循用户的指示”。

InstructGPT是如何实现上述目标的呢?

主要是使用来自人类反馈的强化学习(利用人类的偏好作为奖励信号,让模型仿照人来生成答案),对GPT-3进行微调。具体实现步骤如下(如图2):

  1. 收集示范数据,进行有监督微调SFT

    • 标注数据:根据prompts(提示,这里就是写的各种各样的问题),人类会撰写一系列demonstrations(演示)作为模型的期望输出。

    • 模型微调:将prompts和人类标注的答案拼在一起,作为人工标注的数据集,然后使用这部分数据集对预训练的GPT-3进行监督微调,得到第一个模型SFT。

    Note:因为问题和答案是拼在一起的,所以在 GPT 眼中都是一样的,都是给定一段话然后预测下一个词,所以在微调上跟之前的在别的地方做微调或者是做预训练没有任何区别。

  2. 收集比较数据,训练奖励模型RM

    • 标注数据:生成式标注是很贵的一件事,所以第二步是进行排序式/判别式标注。用上一步得到的SFT模型生成各种问题的答案,标注者(labelers)会对这些输出进行比较和排序(由好到坏,比如图2 D>C>A=B)。

    • 训练模型:基于这个数据集,训练一个RM(reward model)。训练好了之后这个RM模型就可以对生成的答案进行打分,且打出的分数能够满足人工排序的关系。

  3. 使用强化学习的机制,优化SFT模型,得到最终的RL模型(InstructGPT)

    • 微调模型:将新的标注数据输入到SFT模型得到输出,并将输出输入RM进行打分,通过强化学习来优化SFT模型的参数。具体使用 PPO 针对奖励模型优化策略,使用 RM 的输出作为标量奖励,使用 PPO 算法微调监督策略以优化此奖励。

步骤2和步骤3可以不断迭代;收集当前最佳策略的更多比较数据,用于训练新的 RM,然后训练新的策略。

3.1.step 1 有监督微调(微调SFT)

与训练GPT3的过程一致,而且作者发现让模型适当过拟合有助于后面两步的训练:根据验证集上的RM分数,选择最终的SFT模型。作者发现,训练更多的epochs尽管会产生过拟合,但有助于提高后续步骤的RM分数。

3.2.step 2 奖励模型(RM)

由上述可知,训练RM的数据是labeler根据SFT输出的结果进行排序的形式,为的是求出每个排序结果的得分,因此RM可以看作一个回归模型。

RM的结构:RM结构是将SFT训练后的模型的最后的嵌入层去掉后的模型。它的输入是prompt和Response,输出是该response对应的score(奖励值)。(将SFT模型最后的softmax层去掉,换成一个线性层来投影,将所有词的输出投影到一个值上面,也就是说输出的是一个标量)

具体的讲,每个prompt,InstructGPT会随机生成 K个输出,然后它们向每个labeler成对的展示输出结果,也就是每个prompt共展示 C k 2 C_k^2 Ck2个结果,然后用户从中选择效果更好的输出。在训练时,InstructGPT将每个prompt的 C k 2 C_k^2 Ck2个响应对作为一个batch,这种按prompt为batch的训练方式要比传统的按样本为batch的方式更不容易过拟合,因为这种方式每个prompt会且仅会输入到模型中一次。

损失函数:这里使用的是排序中常见的pairwise ranking loss。这是因为人工标注的是答案的顺序,而不是分数,所以中间需要转换一下。这个损失函数的目标是最大化labeler更喜欢的响应和不喜欢的响应之间的差值。

img
其中, y w , y l y_w,y_l yw,yl :SFT在表示prompt x下生成的结果;

r θ ( x , y ) r_{\theta}(x,y) rθ(x,y):表示prompt x和结果y在参数 θ \theta θ下RM的输出值,即奖励值

D D D:是训练数据集

K 2 \frac{K}{2} 2K:对于每个prompt,InstructGPT会随机生成K个输出,每个prompt的输出可以产生 C k 2 C_k^2 Ck2 对,这里就表达将loss除以 C k 2 C_k^2 Ck2

img

RM损失函数细节

Note:

已经有了人工标注的数据集,直接训练一个模型就行,为什么还要另外训练一个参数为 的RM模型呢? 这是因为RM模型标注的仅仅是排序,而非真正的分数scores。这样RL模型更新之后,又生成新的数据,需要新的标注。在强化学习中,叫做在线学习。在线学习在训练时,需要人工一直不断的反馈(标注),非常的贵。这里通过学习一个 ,代替人工排序,从而给模型实时的反馈,这就是为什么这里需要训练两个模型。

2.3.step 3 强化学习模型(PPO)

之前不少科研工作者说强化学习并不是一个非常适合应用到预训练模型中,因为很难通过模型的输出内容建立奖励机制。InstructGPT做到了这点,它通过结合人工标注,将强化学习引入到预训练语言模型是这个算法最大的创新点。

在强化学习中,模型用policy (策略)表示。所以文中的 RL policy ,其实就是step1中的SFT模型。当policy做了一些action之后(输出Y),环境会发生变化。

该模型的流程如上述,将PPO数据输入到step 1中的SFT模型中,生成K个输出,将该输出送入RM模型进行打分,使用打分后的结果进一步优化SFT,即RL在损失函数层面改进:

img
由三部分组成:打分损失+KL损失+GPT3预训练损失,其中

x x x:表示PPO数据集的prompt,即问题;

π ϕ R L \pi_\phi^{RL} πϕRL :表示待学习的RL策略,即对于每个prompt,其是RL模型的输出 y y y

π S F T \pi^{SFT} πSFT:表示step1中的SFT模型,注意,强化学习中,模型叫做Policy,通过不断的更新参数, π ϕ R L \pi_\phi^{RL} πϕRL就是最终的InstructGPT模型,并且其由最开始 π S F T \pi^{SFT} πSFT初始化而来,也就是说最开始的时候这来两个是一样的。

r θ ( x , y ) r_{\theta}(x,y) rθ(x,y) :表示输出得分,即把输出 y y y输入到step2训练好的RM模型中得到的结果;(损失函数希望这个得分最大化,即说明RL模型输出的答案总是人类排序中最优的)

β l o g ( π ϕ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) \beta log(\pi_\phi^{RL}(y|x)/\pi^{SFT}(y|x)) βlog(πϕRL(yx)/πSFT(yx)) :是一个正则项,即PPO的主要思想

随着模型的更新,RL产生的输出y和原始的SFT模型输出的y会逐渐不一样,即数据分布 ( y / x ) (y/x) (y/x)的差异会越来越大,RL的输出可能会不准。所以在loss里加入了一个KL散度(评估两个概率分布的差异),希望RL在SFT模型的基础上优化一些就行,但是不要偏太远,即相当于加入了一个正则项。

因为需要最大化 o b j e c t i v e ( ϕ ) objective(\phi) objective(ϕ),所以β前面加了一个负号,表示希望KL散度比较小(两个概率分布一样时,相除结果为1,取对数后结果为0)。

r θ ( x , y ) r_{\theta}(x,y) rθ(x,y) :RM模型的输出

将PPO数据集中的问题x输入到 π ϕ R L \pi_\phi^{RL} πϕRL模型得到答案y,然后把数据 ( x , y ) (x,y) (x,y) 输入到RM模型中得到 r θ ( x , y ) r_{\theta}(x,y) rθ(x,y) ,这个分数越高说明生成的答案越好,越符合人类预期,确保回答的安全性。

γ E x ∼ D p r e t r a i n [ l o g ( π ϕ R L ( x ) ) ] \gamma E_x \sim D_{pretrain}[log(\pi_\phi^{RL}(x))] γExDpretrain[log(πϕRL(x))] :GPT3本身的损失函数

如果只使用上述两项进行训练,会导致该模型仅仅对人类的排序结果较好,而在通用NLP任务上,性能可能会大幅下降,文章通过在loss中加入了GPT-3预训练模型的目标函数来规避这一问题。

D p r e t r a i n D_{pretrain} Dpretrain表示从训练GPT3的预训练数据中采样x,然后输入RL模型中得到输出概率。这样使得前面两个部分在新的数据集上做拟合,同时保证原始的数据也不要丢,主要是保证NLU的能力。

综合起来,整个RL模型(InstructGPT)简单来说就是一个PPO的目标函数(在新的标注数据集上做微调)加上一个GPT3的目标函数(原始的预训练数据)结合在一起。

img

RL损失函数具体细节

(PPO算法属于强化学习,RL领域的知识后续在补充)

4.Conclusion

LLMs模型其实就是用大量的训练数据和大规模的硬件堆造出来的,并且当探究其中的原理后,发现它并没有业内宣传的那么恐怖。InstructGPT的亮点主要分为两个:1.高质量的训练数据集构建;2.将强化学习机制引入到预训练语言模型中,构造奖励模型来引导RL模型的优化。

作者在一开始提到了三个目标:想要语言模型更加具有帮助性、真实性和无害性。实际上这篇文章主要还是在讲帮助性,包括在人工标注时,也更多的是在考虑帮助性,但在模型评估时,更考虑真实性和无害性。所以从所以从创新性和完成度的角度,这篇文章一般,没有考虑另外两个方面如何显著的优化。

另外最后的RL模型可能也是没有必要做的。我们只需要在第一步多标一些数据(比如10万条),这样直接在GPT-3上进行微调就行,是不是会更好一些呢?

img

GPT3与InstructGPT在同prompt下输出区别

InstructGPT与GPT3相比:

1.InstructGPT/ChatGPT的效果比GPT-3更加真实

2.InstructGPT/ChatGPT在模型的无害性上比GPT-3效果要有些许提升

3.InstructGPT/ChatGPT具有很强的Coding能力

缺点:

1.InstructGPT会降低模型在通用NLP任务上的效果

2.InstructGPT对指示非常敏感


原文地址:https://blog.csdn.net/weixin_44362044/article/details/136232561

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!