论文题目:PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training
论文链接:
代码链接:
一、研究简介
大型语言模型(LLMs)通常有一个预定义的上下文窗口大小,这限制了它们在长输入的场景中的使用。为了使 LLMs 适应更长的输入,通常需要用目标长度的样本对其进行微调(全长微调),由此导致训练成本十分昂贵。
举例来说,在 Positional Interpolation[1]这份工作中,将 LLaMA 的上下文窗口从 2048 拓展到 8192 使用了 32 张 A100,对于更大的上下文窗口则使用了 128 张 A100。
为了将训练长度与目标长度解耦合,以实现高效的上下文窗口扩展,我们提出了一种称为位置跳跃式训练(Positional Skip-wisE training, PoSE)的方法,在原始的上下文窗口中模拟更长的训练样本。
如下图所示,我们将原始的上下文窗口分成几块,然后引入不同的 bias 项来调整每个块的位置编码。对于每一条训练样本,这些 bias 项和块的长度都会发生变化,因此通过大量的训练,模型能适应目标长度内的所有位置。
实验结果表明,PoSE 有以下三方面的优势:
二、技术背景
旋转位置编码 RoPE:RoPE 是当下主流的位置编码方式,被 LLaMA、GPT-J 等大语言模型所采用。给定一个 ,RoPE 通过如下方式编码位置信息:
其中 。此前的绝对位置编码多是直接作用在输入向量 上,与之不同的是,RoPE 是在作用在每一层的 query 和 key 向量上。RoPE 可以看作是一种相对位置编码,给定位置
上下文窗口扩展:给定一个以 为原始上下文窗口长度的大语言模型,我们的目标是其支持的上下文长度拓展到 个输入内能较好地保持原有的性能。
位置插值(PI):为了将 LLM 的上下文窗口从
然而,实践表明[1] [2], 这部分的位置在前向传播时会产生灾难性的离群值,从而导致训练无法达到预期的效果。这主要是因为模型在预训练时只见过
这些位置,无法很好的泛化到外推出去的这部分位置。
为了解决这个问题,Position Interpolation[1]这份工作首先提出用“内插”代替“外推”,设定缩放因子 ,并将上述注意力公式修改为 (也就是将位置编码线性修改为
这种方式可以减少离群值的出现,将上下文窗口拓展到了 32k。在此基础上,NTK 提出通过修改 来进行位置插值,取得了更好的效果。YaRN 则根据不同的维度,对上述线性插值和 NTK 进行了整合。
三、方法描述
尽管上述 Linear / NTK / YaRN 等插值方式能一定程度上解决位置外推的问题,他们仍然需要用目标长度的训练样本来训练模型(即全长微调)。
随着目标长度的增加,平方级别的计算复杂度带来的开销依旧是难以承受的。因此,在插值技术的基础上,我们提出调整原始的上下文窗口中的位置编码,来模拟更长的训练样本,从而实现高效的上下文扩展。
位置编码的调整主要有两个考量:
第一步:我们将原上下文窗口 ,则这个块的位置编码如下:
第二步:我们从离散均匀分布
为了避免块之间位置编码的重合,我们施加了 这一限制。值得注意的是,对于每条数据,我们会重新采样每个块的大小和跳跃偏置项。直观上来说,通过这种方式,我们扩大了原上下文窗口能覆盖的相对位置范围,并且位置编码的不连续只发生在块之间,因此尽可能地保持了预训练阶段的位置编码结构。
第三步:选定每个块内的内容。给定输入文本 ,我们用类似的方法来抽取每个块内的填充的内容:
我们也尝试了其它 ,此时块间的内容也是连续的;或如 ,此时调整后的位置编码恰好对应训练数据在原始文本中的位置。实验结果表明,这几种赋值方式并没有明显的差别。
第四步:位置插值及超参初始化。我们使用位置插值来使训练更稳定。
四、实验分析
1. 实验设置
训练过程:我们主要使用 LLaMA-7B 作为基模型,对于所有设定都只训练 1000 步,训练时长度为 2k,batch size 为 64。我们使用 8 张 V100 进行训练,1 张 A100 进行推理。对于我们的方法和各个 baseline,我们都默认采用线性插值来使训练更稳定。
2. 主要结果
语言模型:
我们使用滑动窗口的方式来计算困惑度 PPL。在 GovReport 和 Proof-Pile 两个数据集上,PoSE 的性能和 Full-length 十分接近,远超未做窗口扩展的版本(Original)和随机位置的版本(RandPos)。且随着窗口长度从 2k 增加到 32k,PPL 呈下降趋势,说明拓展后的模型能充分利用更长的上下文信息。
密码检索:
在密码检索任务上,利用 PoSE 拓展到 16k 和 32k 的模型能分别在 16k 和 32k 的上下文内取得接近 100% 的密码检索准确率,说明模型能关注到目标长度内的每个位置。
时空效率:
在时空效率方面,全长微调的训练时长和内存消耗随目标长度的增加而迅速增长,相比之下,PoSE 需要的训练时间和内存较为稳定。并且在每个时间步上,性能和全长微调都很接近。
兼容性:
兼容性方面,PoSE 可以适配 LLaMA、LLaMA2、GPT-J、Baichuan2 等各种基于 RoPE 的基础模型,以及 Linear、NTK、YaRN 等各种插值策略,展现出较好的普适性。其中 NTK 在最后阶段会有一个 PPL 的突增,这主要是因为给定缩放因子,NTK 实际实现的缩放倍数会略小于[3]。YaRN 解决了这个缺陷,取得了三者中最好的效果。
超长上下文拓展的潜力:
只使用 2k 的训练长度和 1000 步的训练步数,我们尝试了将 LLaMA 模型拓展到 128k。实验表明,在使用 YaRN 的情况下,模型在 128k 的窗口下仍然能保持较低的 PPL。
原窗口内的语言能力:
最后,我们分析了经由 PoSE 训练过后的模型在原窗口内的语言能力。可以看出,和全长微调以及原始模型相比,PoSE 模型能力的损失非常微小,这说明 PoSE 在拓展上下文窗口的同时较好地保持了模型的基础能力。
五、总结与讨论
本文提出了一种位置跳跃式训练(PoSE)来高效的拓展大语言模型的上下文窗口。通过调整位置编码,PoSE 在原始的上下文窗口中模拟更长的训练样本,以达到解耦合训练长度和目标长度的目的。
实验结果表明 PoSE 在和全长微调保持同等性能的情况下,大大缩小了训练所需的时空开销,并表现出良好的普适性和超长上下文扩展的潜力。我们相信 PoSE 将大大降低上下文窗口拓展的成本,使更多人可以参与到相关的研究中来,从而推动长上下文建模领域的快速发展。
PoSE 完成于 2023 年 9 月,我们相信这种位置跳跃的思路是 Long Context 的有效解决方案。结合近几个月来 Long Context 相关研究的进展,我们认为 PoSE 可能有以下一些方面值得进一步探究:
原文链接: