Llama2 是 Meta 最近推出的大语言模型,它的训练数据集达到了两万亿个token。与前代产品 Llama 的 2048 的上下文长度相比,Llama2 的上下文长度扩展至 4096,使其能够理解和生成更长篇幅的文本。该模型有三种规模版本——7B、13B 和 70B,它们在各种基准测试集上均展现出卓越的性能。最重要的是,Llama2 支持用于研究和商业目的,具有高度的可用性和实用价值。

下面分RMSNorm、RoPE、Attention、FFN等部分来对LLaMA2的模型代码进行详细解释。

LLaMA2模型源代码链接:https://github.com/facebookresearch/llama/blob/main/llama/model.py

1. RMSNorm

\frac {\mathbf x}{\mathrm{RMS}(\mathbf x)} \cdot \gamma, \text{~~~} \mathrm{RMS}(\mathbf x)=\sqrt {\frac 1 d \sum_{i=1}^d x_i^2}

这段代码比较直观,容易理解,相较于公式其中多了eps变量,这是防止出现除0错误。

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps #防止出现除0错误
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2. 旋转式位置编码(Rotary Position Embedding,RoPE)

细节在这里:《大模型基础之旋转式位置编码(Rotary Position Embedding,RoPE)》

RoPE这段代码比较晦涩难懂,现在详细介绍一下。

precompute_freqs_cis中的cis是 "cosine" 和 "sine" 的缩写,它经常在数学中使用来表示复数的极坐标形式。具体来说,给定一个角度\theta,其对应的复数可以表示为:

\text{cis}(\theta) = \cos(\theta) + i \sin(\theta)

其中i是虚数单位,满足i^2 = -1。"cis" 表示的是一个复数,其实部是角度\theta的余弦值,而虚部是角度\theta的正弦值。这种表示方法在复数分析、信号处理等领域中非常有用。

因此,故名思义,该函数的目的是预计算一个复数频率张量。该函数有两个入参,dim和end。dim就是每个attention_head中的维度,在这里就是4096/32=128。end是self.params.max_seq_len * 2,也就是4096,这也是Llama2最大的token处理数量。计算过程解释见注释:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    # dim = 128
    # end = 4096
    # torch.arange(0, dim, 2) [0, 2, 4, 6, 8, 10,..., 124, 126] 共64个
    # torch.arange(0, dim, 2)[: (dim // 2)] 保证是64个

    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # freqs = [1/10000.0^(0/128), 1/10000.0^(2/128), 1/10000.0^(4/128), ..., 1/10000.0^(126/128)]

    t = torch.arange(end, device=freqs.device)  # type: ignore
    # t = [0, 1, 2, ..., 4095]
    freqs = torch.outer(t, freqs).float()  # type: ignore
    # freqs 得到 freqs和t的笛卡尔积,维度为(4096,64)
    # freqs = [[0, 0, 0,..., 0],
    #          [1/10000.0^(0/128), 1/10000.0^(2/128), 1/10000.0^(4/128), ..., 1/10000.0^(126/128)],
    #          [2/10000.0^(0/128), 2/10000.0^(2/128), 2/10000.0^(4/128), ..., 2/10000.0^(126/128)],
    #          ...,
    #          [4095/10000.0^(0/128), 4095/10000.0^(2/128), 4095/10000.0^(4/128), ..., 4095/10000.0^(126/128)]]

    # 在PyTorch中,torch.polar用于通过极坐标(magnitude和angle)来创建一个复数张量。
    # 这个函数接受两个张量作为输入:一个张量包含复数的模(magnitude,也就是复数的长度),
    # 另一个张量包含复数的角度(angle,也就是复数的相位角),然后返回一个相应的复数张量。
    # 下面就是创建模为1的,有不同相位角的复数张量。
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    return freqs_cis

apply_rotary_emb的作用是将复数张量freqs_cis与q和k相乘,当然了,不能直接相乘,中间要通过一些变形。

注意freqs_cis的维度并不是(4096,64),而是截取了seqlen的一部分,freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]。

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    # freqs_cis.shape = [1024, 64]
    # x.shape = [2, 1024, 32, 64]
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]) #保证
    # 将freqs_cis.shape变为[1, 1024, 1, 64]
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
        xq: torch.Tensor,
        xk: torch.Tensor,
        freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

    # 将xq和xk的最后一个维度进行复数运算,得到新的xq和xk
    # 为了进行复数运算,需要将xq和xk的最后一个维度展开为2维
    # 例如,xq的形状为[2, seq_len, 32, 128], reshape后为[2, seq_len, 32 , 64, 2]
    # view_as_complex函数可以将张量中的最后一维的两个元素作为实部和虚部合成一个复数
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # 将freqs_cis广播到xq和xk的最后一个维度
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    # freqs_cis.shape = [1, 1024, 1, 64]
    # view_as_real和view_as_complex相反,可以将张量中最后一维的复数拆出实部和虚部
    # (xq_ * freqs_cis).shape = [2, seq_len, 32 , 64]
    # torch.view_as_real(xq_ * freqs_cis).shape = [2, seq_len, 32 , 64, 2]
    # flatten(3)将张量展平为[2, seq_len, 32 , 128],3代表从的第3个维度开始展平
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

3. Attention

大模型一般是分布式训练,这里涉及到几个概念。n_heads是注意力头的总个数,由于并行机制,每个进程会有n_local_heads个注意力头。由于计算当前位置的Attention Score依赖于之前所有的kv,因此需要将kv缓存下来。为了减少空间复杂度,可以对kv的头个数n_kv_heads进行调整,这个值一般小于等于n_heads,n_heads是n_kv_heads的整数倍,这个倍数也就是n_rep。相应的,每个进程会有n_local_kv_heads个注意力头。每个头的维度为head_dim=dim//n_heads。

例如:n_heads=32,model_parallel_size(并行数量)= 4,n_kv_heads = 8,n_local_heads = 32/4, n_local_kv_heads = 8/4,n_rep = 32/8。

其它细节请看注释。

class Attention(nn.Module):
    """Multi-head attention module."""

    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        # ColumnParallelLinear是一个在大规模并行训练中使用的术语,特别是在训练大型的深度学习模型,
        # 如Transformer模型时。在模型并行训练中,一个大型的矩阵(例如神经网络的权重矩阵)会被分割成不同的列,
        # 并分散到不同的计算设备(如GPU)上。
        #
        # 在ColumnParallelLinear的情况下,每个计算设备存储权重矩阵的一部分列,而不是整个矩阵。
        # 每个设备计算它自己的前向传播部分,并将结果发送给其他设备以进行进一步的处理或合并结果。
        # 对于反向传播和梯度计算,每个设备计算其自己列的梯度,并可能需要与其他设备交换信息以更新权重。
        #
        # 这种方式可以显著减少每个设备上的内存需求,并允许训练更大的模型,因为模型的不同部分可以分布在多个设备上。
        # ColumnParallelLinear和RowParallelLinear(另一种将权重矩阵按行划分的方法)是实现模型并行的两种常见策略。

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        # kv_cache是缓存键值对,在训练过程中,我们只保存最近n个键值对
        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

    def forward(
            self,
            x: torch.Tensor,
            start_pos: int,
            freqs_cis: torch.Tensor,
            mask: Optional[torch.Tensor],
    ):
        # 假设当前x为(1, 1, dim),也就是上一个预测的token
        bsz, seqlen, _ = x.shape

        # 计算当前token的qkv
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # 对当前token的qkv增加位置编码
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # 缓存当前token的kv
        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)
        self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv

        # 取出前seqlen个token的kv缓存
        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # 将kv重复填充,使kv和q的头数个数相同
        keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        # 计算当前token的attention score
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

4. FNN

这一部分,需要注意就是激活函数及其位置。

激活函数采用了SiLU (Sigmoid-Weighted Linear Unit) ,这也被称为 Swish 函数。这个激活函数是由谷歌的研究者在 2017 年提出的,并已被证明在某些情况下比传统的 ReLU 激活函数更有效。

SiLU 函数的数学表达式是:

\text{SiLU}(x) = x \cdot \sigma(x)

其中x是输入,\sigma(x)是 sigmoid 函数,其表达式为:

\sigma(x) = \frac{1}{1 + e^{-x}}

因此,SiLU 函数将输入值乘以 sigmoid 函数的输出,其效果是在正值上非饱和,负值上平滑并接近于零。与 ReLU 函数类似,SiLU 函数也能够创建非线性决策边界,但它允许一些信息(即使是负值)传递,而不是像 ReLU 那样将所有负值置为零,这种特性可以帮助减轻梯度消失问题。

class FeedForward(nn.Module):
    def __init__(
            self,
            dim: int,
            hidden_dim: int,
            multiple_of: int,
            ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

本文对LLaMA2大模型的结构代码进行了详细的介绍,代码虽然不多很多细节值得反复推敲理解。