• 论文标题:Attention Residuals
  • 论文链接:https://arxiv.org/pdf/2603.15031

TL;DR

今天解读 Kimi 团队提出的一种替代标准残差连接的新架构设计——注意力残差(Attention Residuals, AttnRes)。该方法将网络深度方向上固定的残差累加替换为学习到的、基于输入的 Softmax 注意力机制,允许每一层选择性地聚合之前的特征表示。

为了解决大规模模型训练中的显存和通信开销,研究人员进一步提出了块级注意力残差(Block AttnRes),将层划分为多个块,在保持全注意力残差大部分收益的同时,将内存和通信复杂度从 降低至 。实验表明,该机制在保持计算效率的同时,缓解了 PreNorm 带来的特征稀释问题,并在多个下游任务中取得了性能提升。


1. 背景

在现代大型语言模型(LLMs)的设计中,标准残差连接是核心的构建模块。残差更新公式通常表示为:

其中 表示进入第 层的隐藏状态, 表示第 层的变换函数。将该递归公式展开,可以发现第 层的隐藏状态等于输入嵌入与之前所有层输出的等权重和:

这种设计在反向传播时提供了一条梯度高速公路。对中间隐藏状态求导时,包含一个单位矩阵 的项,确保梯度可以直接从损失函数传播到任意深度的层,从而实现深层网络的稳定训练。

然而,这种机制也存在尚未被充分讨论的局限性。在序列混合和专家路由等模块已经普遍采用可学习的、输入依赖的权重分配机制的背景下,深度方向上的信息聚合依然采用固定的单位权重。这意味着网络没有机制去选择性地强调或抑制特定层的贡献。

在实际应用中,PreNorm 架构虽然占据主导地位,但其无权重的累加会导致隐藏状态的幅度随着深度的增加以 的速度增长,进而逐渐稀释每一层的相对贡献。早期层的信息被深层信息掩盖,无法被选择性地检索。为了影响累加后的残差流,更深的网络层被迫输出越来越大的值,这会影响训练的稳定性。

尽管 Highway networks 等方法通过引入可学习的逐元素门控机制 放宽了固定权重的限制,但每一层仍然只能访问其直接前驱输入 。这个单一的压缩状态混合了所有早期层的输出,导致不同类型的层(如自注意力和 MLP)接收相同的聚合状态,且聚合过程中丢失的信息无法在更深的层中恢复。

2. 时间与深度的对偶性

上述深度方向上的信息压缩问题,与序列建模中循环神经网络(RNNs)面临的时间维度上的瓶颈具有相似性。RNN 在时间维度上将历史信息压缩到一个单一的隐藏状态中,而残差连接在深度维度上将所有先前的层级信息压缩到 中。

在序列建模领域,Transformer 通过引入注意力机制替代了递归结构,允许每个位置选择性地访问之前的所有位置。基于这种时间与深度的对偶性,研究人员提出将相同的方法应用于深度维度 :

其中 是特定层的注意力权重,满足 。由于网络深度 通常小于 1000,计算深度方向上的 注意力在计算上是可行的。这种机制被称为注意力残差(Attention Residuals, AttnRes)。

3. 注意力残差机制详述

3.1 全注意力残差 (Full Attention Residuals)

在 Full AttnRes 中,注意力权重定义为 。研究人员选择了带有 RMSNorm 的核函数 ,从而在深度方向上实现 Softmax 注意力。

具体而言,对于每一层

  • 查询 (Query): ,这是一个特定于该层的、维度为 的可学习向量。
  • 键和值 (Key/Value): 当 时,(即输入嵌入);当 时,
  • 输出计算:

函数中引入 RMSNorm 是为了防止输出幅度过大的层主导注意力权重的分配。

对于每个 Token,Full AttnRes 需要 的算术运算和 的内存来存储层输出。在标准的非并行训练中,这部分 的内存开销与反向传播所需的激活值缓存完全重叠,因此基本没有额外的内存负担。但是,在大规模模型训练中广泛采用激活重计算和流水线并行(Pipeline Parallelism)时,这些层输出必须显式保留并在流水线阶段之间传输,导致通信和内存开销随深度呈 增长。

图 1 注意力残差概览
图 1 注意力残差概览

3.2 块注意力残差 (Block Attention Residuals)

为了解决大规模分布式训练中的扩展性问题,研究提出了 Block AttnRes 变体。该方法将 层划分为 个块(Blocks),在块内通过累加将层输出缩减为单一表示,而在块间仅对这 个块级表示和 Token 嵌入应用全注意力。这使得内存和通信开销从 降低至

假设 可被 整除,每个块包含 层。设 为块 中的层索引集合。块的完整表示通过对该块内所有层输出求和得到:

同时定义 为块 中前 层的部分和(Partial Sum)。

在跨块注意力中,第 层的输入不再关注之前所有的独立输出,而是关注块级表示。定义 确保 Token 嵌入始终作为源参与计算。对于块 中的第 层,其值矩阵 构造如下:

  • (块 的第一层)时,
  • (后续层)时,,即包含之前的块表示以及当前块的局部部分和。

通过将每层关注的源数量从 减少到 ,内存和计算复杂度分别降至 。当 时退化为 Full AttnRes,当 时退化为标准残差连接。实验表明,设置 即可在不同模型规模下恢复全注意力残差的大部分收益。

图 2 Block Attention Residuals 伪代码
图 2 Block Attention Residuals 伪代码

4. 面向大规模扩展的基础设施设计

Block AttnRes 虽然减少了理论复杂度,但在实际的流水线并行训练和长上下文推理中仍面临系统工程挑战。

4.1 训练阶段的优化:跨阶段缓存 (Cross-stage Caching)

在标准残差连接中,流水线并行在相邻阶段之间传输固定大小的隐藏状态,与流水线深度无关。但 Block AttnRes 要求每个阶段获得所有已累加的块表示,如果每次转换都天真地传输完整历史,会导致冗余通信。

考虑具有 个物理阶段和每个物理阶段 个虚拟阶段的交错流水线调度(Interleaved Pipeline Schedule)。假设每个物理阶段平均产生 个块表示,如果采用天真的全量传输,每个 Token 的通信成本为 ,其中

为了消除这种冗余,研究人员利用了交错流水线的特性:每个物理阶段连续处理多个虚拟阶段。通过在本地内存中缓存早期虚拟阶段接收到的块,后续过渡只需传输增量块。优化后的通信成本降至:

第一项对应第一个虚拟阶段的开销,第二项对应后续虚拟阶段的开销。缓存机制将峰值单次传输成本从 降低至 ,使通信能够与稳定状态的计算完全重叠。实际测试表明,在开启流水线并行的端到端训练中,Block AttnRes 的开销占比不到 4%。

图 3 基于缓存的流水线并行通信示例
图 3 基于缓存的流水线并行通信示例

4.2 推理阶段的优化:两阶段计算策略

对于推理过程,自回归解码中逐层重复访问累加的块表示会增加内存 I/O,影响延迟。由于伪查询向量 独立于前向计算过程,研究设计了包含两个阶段的计算策略 :

  1. Phase 1 (并行跨块注意力) :同一个块内的 个层可以将其查询打包到一个矩阵乘法中,统一对缓存的跨块表示进行注意力计算。这使得内存读取操作从 次分摊为每块 1 次。
  2. Phase 2 (串行块内注意力与在线 Softmax) :顺序计算每个层的块内部分和依赖项,并使用在线 Softmax (Online Softmax) 技术与 Phase 1 的并行结果进行合并。该阶段可以与周围的算子(如 RMSNorm)进行内核融合(Kernel Fusion),进一步降低 I/O 开销。

在典型设置下(如 ),Block AttnRes 平均每层的 I/O 访问成本仅为 ,低于 mHC 等多流方法的 。这使得推理延迟开销在常规负载下控制在 2% 以内。

表 1 各种残差机制的内存访问成本对比
表 1 各种残差机制的内存访问成本对比

针对长上下文(例如 128K Token)的预填充(Prefilling)阶段,缓存块表示会占用大量显存。通过沿序列维度在张量并行(Tensor Parallelism)设备上分片(Sharding)表示矩阵,Phase 1 可以在本地独立执行。随后,Phase 2 的在线 Softmax 合并可集成到标准的张量并行通信路径中,从而显著降低单卡显存占用。

5. 实验

模型的底层架构基于 Kimi Linear,包含混合的 Kimi Delta Attention (KDA) 和 Multi-Head Latent Attention (MLA) 层,以及 MoE 前馈层。所有注意力残差的伪查询向量均初始化为零,以确保训练初期的注意力权重是均匀的,等同于标准的等权重平均。

5.1 缩放定律 (Scaling Laws) 实验

研究人员在五种模型规模(激活参数范围 194M 到 528M)上训练了 Baseline(PreNorm)、Full AttnRes 和 Block AttnRes()变体。拟合的计算量-损失幂律曲线 表明:

  • Baseline:
  • Block AttnRes:
  • Full AttnRes:

三种变体的斜率相似,但 AttnRes 在整个算力区间内均实现了更低的损失。在 5.6 PFLOP/s-days 的算力下,Block AttnRes 达到的验证损失与使用 1.25 倍算力的 Baseline 相当。随着规模扩大,Full 和 Block AttnRes 之间的差距逐渐缩小。

图 4 注意力残差的 Scaling Law 曲线
图 4 注意力残差的 Scaling Law 曲线
表 2 Baseline、Block AttnRes、Full AttnRes 及 mHC 的对比
表 2 Baseline、Block AttnRes、Full AttnRes 及 mHC 的对比

5.2 大模型下游性能评估

最终的评估在包含 48B 总参数(3B 激活参数)的模型上进行,使用 1.4T Token 进行两阶段预训练。

训练动态分析揭示了 AttnRes 的重要优势 :

  • 缓解输出幅度增长:Baseline 受到 PreNorm 稀释问题的影响,输出幅度随深度单调增长。Block AttnRes 在块边界通过选择性聚合重置了累加,使得每个 transformer 块的输出幅度呈现有界的周期性模式。
  • 梯度分布更均匀:由于单位权重,Baseline 的早期层梯度不成比例地大。AttnRes 引入的可学习 Softmax 权重使得不同源之间竞争概率质量,从而在深层间实现更为均匀的梯度分布。
图 5 Baseline 和 Block AttnRes 的训练动态
图 5 Baseline 和 Block AttnRes 的训练动态

下游任务上,Block AttnRes 在所有测试的基准中均优于 Baseline。特别是在需要多步组合推理的任务中提升显著,例如 GPQA-Diamond 提升了 7.5 个百分点,Minerva Math 提升了 3.6 个百分点,HumanEval 代码生成提升了 3.1 个百分点。这种提升模式验证了深度方向上更好的信息流有利于组合任务的假设,深层网络可以有选择性地检索并构建在早期的特征表示之上。

表 3 预训练后的下游任务性能对比
表 3 预训练后的下游任务性能对比

5.3 消融实验 (Ablation Study)

在 16 层的模型上进行的消融实验进一步验证了机制设计的合理性 :

  • 相比以往方法:具有输入无关权重的 DenseFormer 未能优于基线(1.767 vs 1.766),而基于并行流的 mHC 改善至 1.747。Block AttnRes(1.746)和 Full AttnRes(1.737)均表现更优,说明输入依赖的 Softmax 注意力选择机制的重要性。
  • 跨层访问粒度:滑动窗口聚合(SWA)性能较差,表明选择性访问远距离层级特征比仅关注近期层更有效。随着块大小 增加,性能平稳衰减,在 时均能保持较好性能。
  • 组件设计

    • 移除 RMSNorm 导致性能下降,说明归一化对于防止幅度差异主导注意力权重至关重要,特别是在累加效应明显的块级表示中。
    • 将 Softmax 替换为 Sigmoid 导致退化,因为 Softmax 的竞争性归一化能强制模型在特征源之间做出更清晰的选择。
    • 引入多头深度注意力并没有带来提升,表明特定层输出的相关性在不同通道间趋于一致(即某个层的信息若相关,则是作为一个整体相关)。
表 4 AttnRes 关键组件消融实验|图 6 块大小对验证损失的影响
表 4 AttnRes 关键组件消融实验|图 6 块大小对验证损失的影响

6. 深入探讨与分析

6.1 架构资源重分配:深度与宽度的权衡

为了理解 AttnRes 对最优架构扩展律的影响,研究在固定的计算量和激活参数预算下,对不同的深度-宽度比例()进行了网格搜索实验。

结果表明,虽然 Baseline 和 AttnRes 均在 附近取得最优 ,但在所有 25 种配置中 AttnRes 均保持损失优势。值得注意的是,Baseline 的最低损失出现在 ,而 AttnRes 将最优值转移到了 。在固定参数预算下,更低的 意味着更深、更窄的网络结构。这暗示 AttnRes 通过优化深层信息流,使得模型能够更有效地利用额外的深度。

图 7 固定算力下的架构搜索
图 7 固定算力下的架构搜索

6.2 注意力模式可视化

通过可视化平均 Token 上的深度注意力权重 ,观察到以下模式 :

  1. 局部性保持:各层最强烈地关注其直接前驱,但对角线之外的权重集中点表明网络学习到了跳过标准残差路径的跨层连接。
  2. 层的专业化:Token 嵌入 在整个深度中保持不可忽视的权重,特别是在预注意力(Pre-Attention)层。预 MLP 层的输入表现出对近期表示的强依赖,符合注意力层在全局路由信息而 MLP 执行局部计算的直觉。
  3. 块结构的一致性:Block AttnRes 保持了对角主导、嵌入持久性等核心结构,表明块状压缩起到了隐式正则化的作用,同时保留了重要的信息通路。
图 8 深度方向的注意力权重分布
图 8 深度方向的注意力权重分布

6.3 结构化矩阵视角

从更宏观的理论视角,各种残差变体都可以表示为一个深度混合矩阵 ,其中输入层 的聚合表示为

  • 标准残差:对应一个全 1 的下三角矩阵。
  • Highway Networks:权重基于累积乘积,保持 1-半可分(1-semiseparable)秩,但引入了输入依赖。
  • mHC:多流机制相当于在深度方向上将循环状态扩展为 ,使得矩阵 变为 -半可分秩。
  • Full AttnRes:生成一个密集的、秩为 的矩阵。
  • Block AttnRes:矩阵秩介于 之间,在标准残差和全注意力残差的表达能力之间进行插值。
图 9 四种残差变体的深度混合矩阵 M
图 9 四种残差变体的深度混合矩阵 M
表 5 残差更新机制对比
表 5 残差更新机制对比

这一视角进一步将现有的残差变体联系起来:包含矩阵状态的 (m)HC 本质上执行的是具有矩阵值状态的深度线性注意力(Depth-wise Linear Attention),而 AttnRes 则执行深度 Softmax 注意力(Depth-wise Softmax Attention)。

7. 相关工作对比

  • 归一化与深度稳定性:PostNorm 虽然限制了幅度增长,但多次归一化会导致梯度消失。PreNorm 保证了单位梯度路径,但引发了隐藏状态幅度的无界增长(PreNorm 稀释)。AttnRes 通过对早期输出的选择性聚合,避免了 PreNorm 的累积幅度增长,同时也绕过了 PostNorm 的尺度收缩问题。
  • 多状态循环 (Multi-State Recurrence) :诸如 Hyper-Connections (HC)、mHC、DDL 以及 SiameseNorm 等方法通过维护多个并行流或矩阵状态来缓解单状态的信息压缩瓶颈。但它们仍然依赖于前驱状态进行演化。AttnRes 是正交的研究方向,提供了对历史独立层输出的直接选择性访问。
  • 跨层连接 (Cross-Layer Connectivity) :如 DenseNet、DenseFormer 提供跨层访问但采用静态权重 ;MUDDFormer 生成位置依赖权重;MRLA 执行基于 Sigmoid 的门控机制。AttnRes 结合了基于 Softmax 归一化的输入依赖权重与块级系统工程设计,实现了在大规模训练中的可用性。

8. 结论

AttnRes 的核心创新在于认识到深度和时间维度之间深刻的对称性。正如 Transformer 用自注意力取代了 RNN 的序列循环,AttnRes 用学习到的、依赖输入的深度注意力取代了标准的、固定的残差累加循环。

在全连接注意力面临 扩展瓶颈时,研究团队在系统架构上做出了务实的妥协与创新。Block AttnRes 通过块级分区,不仅将复杂度降低到 ,并且证明仅需约 8 个块就能挽回 Full AttnRes 的绝大部分收益。搭配跨阶段通信缓存和两阶段推理计算机制,Block AttnRes 以微不足道的训练与推理额外开销,成为标准残差连接一个极具吸引力的直接替代方案。未来,随着底层硬件内存容量及互连带宽的演进,更细粒度的分块或直接采用全注意力残差机制将成为进一步发掘模型深度潜能的自然路径。

更多细节请阅读原文。


往期文章: