论文收藏夹
🗒️Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads
00 分钟
2023-10-9
2024-1-20
type
status
date
slug
summary
tags
category
icon
password

一 概述

一个更简单、更友好的 LLM 生成加速框架。Medusa 没有使用像投机采样那样的额外草稿模型,而只是引入了几个额外的解码头,沿用了 [Stern et al. 2018] 的思路,并加入了一些其他成分。尽管设计简单,但 Medusa 可以将 LLM 的生成效率提高约 2 倍。
下面我们将探讨 LLM 生成的基本瓶颈和推测解码的一些限制,然后展示 Medusa 如何设法解决这些瓶颈并实现加速。
The implementation is available at this repo.
notion image

二 为什么LLM生成效率低下?

从系统角度看,LLM 生成遵循内存约束计算模式,主要延迟瓶颈来自内存读/写,而非算术计算。这一问题的根源在于自动回归解码过程本身的顺序性。每次前向传递都需要将整个模型的参数从高带宽内存(HBM)传输到加速器的计算单元。这种操作虽然只为每个样本生成一个 token,却无法充分利用现代加速器的算术计算能力,导致效率低下。
在 LLM 兴起之前,缓解这种低效的常见方法是简单地增加批量大小,从而并行生产更多令牌。但有了 LLM,情况就复杂得多。在这种情况下,增加批量大小不仅会带来更高的延迟,还会大大增加 Transformer 模型键值缓存的内存需求。在这种权衡下,对于许多以低延迟为关键要求的应用来说,使用大批次是不切实际的。
此外,这种低效率还体现在成本结构上。截至 2023 年 9 月,与仅仅处理提示相比,GPT-4 的生成成本大约高出 2 倍,Claude 2 大约高出 3 倍。我们注意到,本博客的主要重点是改进 LLM 生成的延迟,而我们认为这里的技术也可以应用于 LLM 服务,因为后者需要平衡延迟和吞吐量。
notion image

三 投机采样(Speculative Decoding)是最终解决方案吗?

鉴于上述挑战,加速文本生成的一个吸引人的策略是更有效地利用计算资源,具体来说,就是并行处理更多 token。这正是投机采样发挥作用的地方。该方法采用简化的 "草稿 "模型,每一步都能快速生成一批候选 token。然后,这些候选 token 将通过原始的完整语言模型进行验证,以确定最合理的文本连续性。其基本逻辑依赖于一个耐人寻味的假设:草稿模型虽然较小,但应该足够精通,能够产生原始模型认为可以接受的序列。
如果这一假设成立,草稿模型就能快速生成 token 序列,而原始模型则能高效地并行审核多个 token,从而最大限度地提高计算吞吐量。最近的研究表明,如果有一个调整良好的草稿模型,投机采样可以将延迟时间缩短 2.5 倍,令人印象深刻。
然而,这种方法并非没有挑战:
  1. 寻找理想的模型草案: 找到一个 "小而强大 "的模型草案,与原始模型保持一致,说起来容易做起来难。
  1. 系统复杂性:在一个系统中托管两个不同的模型会带来计算和操作上的复杂性,尤其是在分布式环境中。
  1. 采样效率低: 在使用投机采样进行采样时,需要使用重要性采样方案。这会带来额外的生成开销,尤其是在较高的采样温度下。
这些复杂性和权衡限制了推测解码技术的广泛应用。因此,投机解码虽然前景广阔,但并未被广泛采用。
在这里,我们用投机采样来指那些需要独立草稿模型的方法。从广义上讲,我们的方法也可以看作是投机采样,而草稿模型与原始模型是纠缠在一起的。

四 Medusa:简单与高效的结合

投机采样固有的复杂性抵消了它的前景。我们认识到需要一种更方便用户但更有效的解决方案,因此自豪地推出了 Medusa。这一创新框架不仅加速了 LLM 的生成,还使 LLM 技术为更多人所了解和使用。
我们的方法重温了《深度自回归模型的分块并行解码》(Blockwise Parallel Decoding for Deep Autoregressive Models)[Stern et al. 2018] 一文中被低估的精华,回到 Transformer 模型的发明:与其调入一个全新的草稿模型来预测后续 token,为什么不简单地扩展原始模型本身呢?这就是 "美杜莎头 "的用武之地。这些额外的解码头与原始模型无缝集成,在每个生成关口产生 token 块。
与草稿模型不同,"美杜莎头 "可以与原始模型一起训练,而原始模型在训练过程中保持冻结。这种方法可以在单个 GPU 上对大型模型进行微调,充分利用强大的基础模型所学到的表征。此外,由于新的头部仅由单层组成,类似于原始语言模型头部,因此 Medusa 不会增加服务系统设计的复杂性,而且对分布式设置也很友好。
就其本身而言,Medusa 头并没有达到将处理速度提高一倍的目标。但这里有一个转折点:当我们将其与基于树的关注机制搭配使用时,我们可以并行验证由 Medusa head 生成的多个候选语料。这样一来,美杜莎头的预测能力就真正发挥出来了,速度提高了 2 到 3 倍。
我们也没有止步于此。我们摒弃了传统的重要性采样方案,创造了一种高效、高质量的替代方案,专门用于与美杜莎磁头一起生成数据。这种新方法完全避免了采样开销,甚至为美杜莎已经加速的步骤增加了额外的动力。
简而言之,我们用一个简单的系统解决了投机采样的难题:
  1. 没有单独的模型: 我们不引入新的草稿模型,而是在同一模型上训练多个解码头。
  1. 与现有系统集成简单: 训练具有参数效率高的特点,因此即使 GPU 性能较差也能完成训练。由于没有额外的模型,因此无需调整分布式计算设置。
  1. 将采样视为一种放松: 放宽与原始模型分布相匹配的要求,使非贪婪生成比贪婪解码更快。
下图提供了 Medusa 管道的可视化分解:
notion image

4.1 Medusa 概述

Medusa 在 LLM 的最后一个隐藏状态之上引入多个头部,从而能够并行预测多个后续标记。
在用 Medusa 头增强模型时,原始模型在训练过程中会被冻结,只有 Medusa 头会进行微调。这种方法使得在单个 GPU 上对大型模型进行微调变得可行。
在推理过程中,每个头部都会为其指定位置生成多个最高预测值。这些预测被组合成候选模型,并使用基于树的关注机制进行并行处理。最后一步是利用典型的接受方案来选择合理的前缀,最长的接受候选前缀将用于下一个解码阶段。
同时接受更多前缀可以提高解码过程的效率,从而减少所需的解码步骤。
让我们深入了解一下 Medusa 的三个组成部分: 美杜莎头、树状注意力和典型接受方案。

4.2 Medusa 头

美杜莎头类似于原始架构中的语言模型头(因果转换器模型的最后一层),但有一个转折点:它们可以预测多个即将出现的 token,而不仅仅是下一个 token。我们从分块并行解码(Blockwise Parallel Decoding)方法中汲取灵感,将每个美杜莎头作为单层前馈网络来实现,并增加了一个残差连接。
这些头部的训练非常简单。可以使用训练原始模型的相同语料库,也可以使用模型本身生成新的语料库。重要的是,在这一训练阶段,原始模型保持不变;只有美杜莎头进行了微调。这种有针对性的训练带来了参数效率极高的过程,能迅速达到收敛,尤其是与投机采样方法中训练单独的草稿模型所带来的计算负担相比。
美杜莎头的功效令人印象深刻。在我们测试的 Vicuna 模型中,Medusa head 预测下下个 token 的最高准确率约为 60%。但仍有改进的余地。

4.3 树形注意力

在测试过程中,我们发现了一些惊人的指标:虽然预测下下个 token 的 top-1 准确率徘徊在 60% 左右,但 top-5 的准确率却飙升至 80% 以上。这一大幅提升表明,如果我们能战略性地利用美杜莎头做出的多个排名靠前的预测,就能显著提高每个解码步骤生成的 token 数量。
为了实现这一目标,我们首先通过提取每个美杜莎头的最高预测值的笛卡尔乘积来制作一组候选词。然后,我们按照图神经网络的思路将依赖图编码到注意力中,这样就可以并行处理多个候选词。
例如,让我们考虑这样一种情况:我们使用来自第一个美杜莎头的前 2 位预测和来自第二个美杜莎头的前 3 位预测,如下图所示。在这种情况下,第一个头部的任何预测都可以与第二个头部的任何预测配对,最终形成一个多层次的树状结构。这棵树的每一层都对应着一个美杜莎头的预测。在这个树形结构中,我们实现了一个注意力掩码,它将注意力限制在一个 token 的前一个 token 上,从而保留了历史语境的概念。通过这样做,并相应地为位置编码设置位置索引,我们就可以同时处理大量候选信息,而无需扩大批处理量。
我们还想说的是,一些独立的研究也采用了非常类似的树形关注思路。与它们相比,我们的方法偏向于一种更简单的树注意形式,即在推理过程中树的模式是规则和固定的,这使得树注意掩码的预处理能够进一步提高效率。
notion image
上图演示了如何使用树状注意力同时处理多个候选项。例如,第一个美杜莎头的 top-2 预测结果和第二个头的 top-3 预测结果产生了 2*3=6 个候选结果。每个候选项都对应树形结构中的一个不同分支。为了保证每个标记只访问其前置标记,我们设计了一个注意力掩码,只允许注意力从当前标记流回其前置标记。位置编码的位置索引就是根据这一结构调整的。

4.4 经典接受度

在早期关于投机采样的研究中,重要度采样技术被用来产生与原始模型预测密切相关的各种输出结果。然而,后来的研究表明,随着采样温度的升高,这种方法的效率往往会降低。
简单地说,如果你的草稿模型和原始模型一样好,那么你最好接受它的所有输出结果,从而使整个过程变得超级高效。然而,重要性取样很可能会在中间环节拒绝这种解决方案。
在现实世界中,我们经常调整采样温度,只是为了控制模型的创造性,而不一定是为了与原始模型的分布相匹配。那么,为什么不只关注接受可信的候选方案呢?
接下来,我们将介绍典型的接受方案。从现有的截断抽样工作中汲取灵感,我们的目标是根据原始模型挑选出足够可能的候选者。我们根据原始模型的预测概率设定一个阈值,如果候选者超过了这个阈值,就会被接受。
用专业术语来说,我们取硬阈值和与熵相关的阈值的最小值来决定是否接受候选者,就像截断采样一样。这可以确保在解码过程中选择有意义的 token 和合理的连续词。我们总是使用贪婪解码法接受第一个 token,确保每一步至少生成一个 token。最后的输出结果就是通过验收测试的最长序列。
这种方法的最大优点是适应性强。如果将采样温度设为零,它就会恢复到最有效的形式贪心解码。当温度升高时,我们的方法会变得更加高效,可以接受更长的序列,我们已经通过严格的测试证实了这一点。
因此,从本质上讲,我们的典型接受方案提供了一种更高效的方法,来生成 LLM 的创造性输出。

(五)Llama 吞吐有多快?

我们使用 Vicuna 模型对 Medusa 进行了测试,Vicuna 模型是专门针对聊天应用进行微调的专用 Llama 模型。这些模型大小不一,参数数分别为 7B、13B 和 33B。我们的目标是衡量 Medusa 如何在真实世界的聊天机器人环境中加速这些模型。
在训练 Medusa 头部时,我们选择了一种简单的方法。我们使用了公开的 ShareGPT 数据集,该数据集是最初用于 Vicuna 模型的训练数据的一个子集,并且只训练了一次。
最重要的是,根据模型的大小,整个训练过程只需要几小时到一天的时间,这一切都可以在单个 A100-80G GPU 上完成。值得注意的是,Medusa 可以轻松地与量化基础模型相结合,以减少内存需求。我们利用这一优势,在训练 33B 模型时使用了 8 位量化。
为了模拟真实世界的环境,我们使用 MT 标准进行评估。结果令人鼓舞: 凭借其简单的设计,Medusa 在各种使用情况下都能持续实现约 2 倍的挂壁时间加速。值得注意的是,经过 Medusa 的优化,33B 参数的 Vicuna 模型可以像 13B 模型一样快速运行。
notion image
notion image

(六)消融研究

6.1 Medusa 头选择

notion image
在利用美杜莎头像的预测能力时,我们可以灵活选择每个头像应考虑多少顶级候选人。例如,我们可以选择第一个头的前 3 个预测结果,第二个头的前 2 个预测结果。当我们取这些顶级候选者的笛卡尔乘积时,就会生成一组六个连续项供模型评估。
这种程度的可配置性是需要权衡的。一方面,选择更多顶级预测会增加模型接受生成标记的可能性。另一方面,它也增加了每个解码步骤的计算开销。为了找到最佳平衡点,我们尝试了各种配置,并确定了最有效的设置,如附图所示。

6.2 典型接受度阈值的选择

notion image
在典型的接受方案中,一个关键的超参数(称为 "阈值")可以帮助我们根据模型自身的预测来确定生成的标记是否可信。阈值越高,接受的标准就越严格,这反过来又会影响通过这种方法获得的整体速度。
我们通过对 MT 工作台上的两个以创造力为导向的任务进行实验,探索了质量和速度之间的权衡。结果如图所示,与贪婪解码方法相比,典型的验收速度提高了 10%。这种提速效果明显优于采用随机抽样的推测解码法,后者实际上比贪婪解码法更慢。