全体团队:Simran*,Sabri*,Michael*,Aman,Silas,Dylan,James,Atri,Chris
Arxiv:arxiv.org/abs/2402.18668
代码:github.com/HazyResearch/based
在ICLR 论文(以及博客文章)中,我们在去年年底分享了一个发现,许多高效的架构(例如Mamba,RWKV,Hyena,RetNet)在召回方面表现不及 Transformer,召回是将生成内容与上下文中看到的信息联系起来的能力,这对于上下文学习和复制至关重要。我们利用这一分析设计了一个名为 Based 的新架构(在这篇博客文章中预览)。我们很高兴分享这一工作线的���新进展。
我们最近的工作深入探讨了召回挑战。我们首先阐明了模型的召回能力与其在生成过程中的内存消耗之间的基本权衡。这一分析为 Based 的设计提供了指导,Based 是一个简单的循环架构,在真实世界的召回密集任务(信息提取、阅读理解)和上下文学习方面优于先前的次二次模型。同时,Based 提供快速的生成速度:与 FlashAttention-2 和 Mamba 相比,Based 在处理提示时分别快 56%和 44%。Based 的文本生成吞吐量比 FlashAttention-2 高 24 倍。
我们特别对 Based 的简单性感到兴奋。仅使用两个众所周知的、类似注意力的构建块,滑动窗口注意力(具有微小的窗口大小)和线性注意力(具有exp(QKT)\exp(QK^T)exp(QKT)的泰勒级数近似),我们可以在语言建模上胜过最强的次二次架构,并实现比优化的 Transformer 更快的速度提升!
本博文概述了我们对次二次架构中召回的分析,这导致了 Based 设计,以及我们如何让 Based 快速生成!
激励分析:召回-记忆权衡。 驱动我们探索的主要问题是:我们是否可以大幅提高语言模型的现实世界速度和内存消耗,而不影响召回和上下文学习能力?
要开始回答这个问题,我们首先必须考虑什么会减慢架构的速度。高效的架构(例如 Mamba)在推理时比 Transformer 快得多(例如吞吐量高 5 倍),这在很大程度上是因为它们具有较小的内存占用。较小的内存占用意味着更大的批处理大小和更少的 I/O。然而,直观地减少内存占用太多可能会损害模型回忆先前序列中看到的信息的能力。这对我们来说看起来像是一个经典的“没有免费午餐”情况,因此我们采取了一些流行架构,变化了影响内存占用的超参数,并在具有挑战性的合成联想召回任务上评估性能。
召回-记忆权衡。 我们发现所有架构都遵循一个基本权衡:模型在推理时消耗的内存越少,它在联想召回上的表现就越差。我们关注循环状态大小,即在逐个生成标记时用于表示先前看到的标记的字节数。
在注意力中,状态通常被称为 KV 缓存,并随着序列长度增长。在图 1 的右上角,我们可以看到注意力完美地执行召回,尽管以巨大的循环状态为代价。滑动窗口注意力提供了一种限制 KV 缓存大小的方法,但我们发现(并不奇怪)随着我们减少循环状态的大小(例如从 1MB 循环状态到 65KB 循环状态),召回性能会迅速下降(图 1,浅蓝色)。
令人兴奋的是,我们发现 Mamba 扩展了召回-记忆权衡曲线的帕累托前沿,超越了滑动窗口注意力。这意味着它比滑动窗口注意力更好地利用了有限的循环状态大小。
自然的问题是:是否有其他,也许更简单的模型可以扩展帕累托前沿?
Based:帕累托前沿上的简单模型。 为了回答这个问题,我们开始研究为什么最简单的替代 softmax 注意力的方法未能达到有利的权衡。作为进一步的设计原则,我们寻找了可以在当前和未来硬件上良好扩展的基元。例如,如果我们的基元能够利用 GPU Tensor Cores,这将是很好的,现代 GPU 上的专用硬件可以比默认的(CUDA 核心)更快地执行矩阵乘法(GEMMs)16x16 矩阵!
在我们的ICLR 论文中,我们深入探讨了为什么任何具有卷积视图(例如 H3 或 Hyena)的模型在召回方面会遇到困难。接下来,我们考虑了两种最简单的高效注意力技术:(1)滑动 窗口 注意力和(2)线性注意力(即没有 softmax 的注意力)。
我们在现实世界的语言建模(高达 14 亿参数)和合成联想召回的实验中发现,单独的基元都无法足以导航帕累托前沿。
-
我们发现纯线性注意力模型在执行精确的本地标记移位和标记比较方面存在困难,这些技能在召回(Fu 等,2023 年;Arora 等,2023a 年)以及密集注意力中很重要。在我们的发现上进行扩展,我们发现我们的纯线性注意力模型在早期的次二次架构上有所改进。专注于 Pile 测试集中召回密集的部分(即强制模型使用先前上下文而非记忆知识进行下一个标记预测),355M 的纯线性注意力模型在 ppl 上优于 RWKV-v5 0.1 个单位,优于 H3 2.6 个单位(表 1,论文)。在这个召回部分上,纯线性注意力甚至与 Mamba 架构相媲美 - Mamba 为 2.21 个单位,纯线性注意力为 2.29 个单位!然而,我们观察到与 Transformer 之间存在较大差距,在召回部分达到 1.87 个单位。
-
在滑动窗口注意力中,模型只能召回滑动窗口内的标记(图 2,中间)。随着窗口大小的增加,循环状态呈线性增长,并对并行训练和推理速度产生非线性影响(图 2,左侧)。
然而,我们发现这两个基元是互补的 - 线性注意力用于建模长距离标记交互,滑动窗口用于序列中的本地标记交互。我们将它们结合成一个称为 Based 的单一架构(图 2,右侧)。
-
滑动窗口注意力可以执行关联召回所需的精确局部移位。我们在实验中使用小窗口大小(例如 64),与Mistral-7B和最近提出的Griffin等架构中较大的窗口大小形成对比。直观上,更多的注意力(更大的窗口大小)从质量的角度来看是好的,但我们希望在质量和挂钟速度之间取得平衡。为了平衡这些目标,让我们看看上图中的左图。观察到 16x16 与 64x64 矩阵的矩阵乘法延迟大致相等,超过 64 后,延迟随窗口大小呈非线性增长。请注意,16x16 和 64x64 之间的粗略相似性是因为后者保持了 GPU 张量核心的占用率足够高,以使其饱和!
-
线性注意力使全局令牌交互成为可能,同时保持固定大小的循环状态。与 softmax 注意力不同,线性注意力的循环状态大小是超参数(例如特征映射选择)的函数,而不是序列长度。这使我们能够平稳地遍历权衡空间。我们使用指数函数的泰勒近似作为特征映射,这在我们之前关于线性注意力的工作中首次使用!
在 Based 中,与注意力不同,循环状态大小不随序列长度增长。相反,它由线性注意力特征维度和窗口大小确定。通过调整这些超参数,我们可以在图 1 中的帕累托前沿中权衡召回和吞吐量。
尽管它很简单,在真实的语言建模实验中(至少达到 13 亿参数),Based 在整体 Pile 困惑度和来自LM eval harness的标准零-shot 基准方面与 Mamba 竞争力相当(显示在问答-常见中)。
这些常用的零-shot 基准仅限于极短的文本,因此它们无法测试模型的召回能力。为了解决这个缺点,我们精心策划了一套真实世界召回密集基准,需要从长文档(例如从FDA 文件和原始 HTML中提取信息,以及阅读理解)。__Based 是这些任务中最强大的次二次架构,平均比 Mamba 高出 6.22 个准确度点。然而,Based 和 Mamba 仍然表现不佳,有时差距很大,与我们上面的“没有免费午餐”观察一致。
值得注意的是,我们认为 Based 并不是唯一可以在权衡曲线上运行的架构。例如,我们在论文中展示,我们可以用短卷积(滤波器大小为 3)替换滑动窗口注意力,并在 0.1 困惑度点内达到类似的性能。我们怀疑还有许多其他架构也可以匹配这个帕累托前沿,我们希望还有其他架构甚至可以超越它!
我们如何使用固定大小的循环状态也很重要!
有许多可能具有相同隐藏状态大小的循环架构,但我们的工作突显了特征化(例如线性注意力特征图、状态更新机制)的重要性。我们在 Based 中的映射选择出人意料地简单(只需高中微积分即可):它用泰勒级数逼近指数函数。我们计算ϕ\phiϕ,使得ϕ(q)ϕ(k)T≈exp(qkT)\phi(q) \phi(k)^T \approx \exp (q k^T)ϕ(q)ϕ(k)T≈exp(qkT)。我们仅使用二阶泰勒级数,与我们之前的工作相同,其中exp^(x)=1+x+x2/2\hat{\exp}(x) = 1 + x + x² / 2exp^(x)=1+x+x2/2!。请注意,如果xxx 的维度为d’d’d’,那么x2x²x2 项的维度将为d’2d’²d’2。键-值外积的结果(上述步骤 1)在d’d’d’中迅速增长,扩展了 Based 的状态大小。
我们选择特征化与扩展状态大小对 Based 质量的影响有多大? 模型有效使用状态的能力至关重要。在准确度与循环状态大小权衡曲线中,几种替代泰勒映射的模型都落在帕累托前沿以下。下面我们将其与使用学习投影扩展状态大小然后应用文献中的流行特征映射(Performer、CosFormer、PosELU)的模型进行比较。我们在MQAR 合成测试上训练这些模型,为图中所示的所有点扫描超参数(学习率),发现泰勒映射最有效。这一趋势延续到 Pile 语言建模语料库的真实世界实验中(更多内容请参阅我们的论文)。
IO 和数据流感知实现。
下一个关键问题是如何使 Based 在挂钟效率上具有竞争力。线性注意力在序列长度的函数上理论上比标准注意力更有效。然而,现有的线性注意力方法实现通常比优化良好的注意力实现(如FlashAttention)慢。
在 Based 中,我们使用了二次泰勒近似,这扩展了键的维度,导致了大的状态大小和大的内存消耗O(Nd’2d)O(Nd’²d)O(Nd’2d),在序列长度NNN,键维度d’d’d’和值维度ddd(如上所述)。由于大量的键值状态,朴素的 Taylor 线性注意力实现速度相当慢。
首先让我们重新审视一下硬件工作原理的一些背景。GPU 具有少量快速访问的内存(线程特定寄存器,使用 SRAM 的 warp/32 个线程级别的共享内存)和大量慢速访问的内存(HBM)。减少慢速 HBM 和 SRAM 之间以及 SRAM 和寄存器之间的读写次数非常重要,以提高效率。我们提出了新的 IO 感知算法,用于 Taylor 线性注意力前向传递和推断,将 HBM 到 SRAM 的数据移动减少了O(Nd’2)O(Nd’²)O(Nd’2)字节,将 SRAM 到寄存器的数据移动减少了O(Nd’2d)O(Nd’²d)O(Nd’2d)字节。我们的算法允许在特征维度d’d’d’ = 16 时将 KV 状态保持在线程寄存器中,这在实验中得到了应用。
下面我们将比较朴素的 Taylor 注意力前向传递,一个利用来自Fast Transformers的流行线性注意力内核的实现,以及我们的自定义内核在批处理大小(序列长度 1024)上的表现。
然后,我们使用我们的 IO 感知算法比较 FlashAttention-2、Mamba 和 Based 360M 和 1.3Bn 参数模型的端到端生成速度。我们将批处理大小保持为 2 以进行预填充,并为下一个标记预测生成 1024 个标记。令人惊讶的是,Based 的吞吐量比 FlashAttention-2 高出多达 24 倍!
敬请关注! 这些算法是在我们实验室开发的一种令人兴奋的新 CUDA DSL ThunderKittens 中实现的。敬请关注更多相关内容 - 我们希望这种 DSL 能提高 CUDA 开发的可访问性!与像 Triton 这样的框架不同,后者对用户可以执行的操作范围做出了明确的决定,我们的 DSL 是嵌入在 C++中的。我们非常期待分享它并获得您的反馈!在接下来的几周里,我们将继续制作更多模型工件,动机是:硬件需要哪些模型?
您可以在Hugging Face上玩我们的检查点和评估,并在此代码库中查看:github.com/HazyResearch/based
!非常感谢Together AI,Stanford HAI和Stanford CRFM支持这项工作!请将您的反馈和问���发送至:Simran(simarora@stanford.edu),Sabri(eyuboglu@stanford.edu)和 Michael(mzhang@cs.stanford.edu)。