LRM 浪潮下,训练样本的长度不断增加,序列并行策略已成为训练中的常用选择。

然而,无论是常规训练还是序列并行训练,以“样本数量”定义批次的传统做法,都存在不容忽视的问题: 一方面,样本长度参差不齐,按经验设置的固定 batch size 很容易因超长样本出现 OOM(内存溢出),只能反复手动调小 batch size 进行尝试;另一方面,这种方式无法根据样本实际长度灵活适配计算资源,导致 TFLOPs(每秒万亿次浮点运算)利用率难以最大化,造成算力浪费。

为此,本次基于 VERL 的 TokenAwareBatchPack 实践给出了创新解法——用“Token 数量”而非“样本数量”来定义批次

这一策略通过以固定 token 总量作为批次划分标准,既能避免超长样本引发的 OOM 问题,又能更精准地匹配计算资源,充分榨干序列并行的算力潜力。

由于 TokenAwareBatchPack 策略需要提前获取每个样本的 token 长度,我们可以顺势将样本的 tokenize 过程离线完成。这不仅为批次划分提供了数据基础,还能减少在线训练时 tokenize 操作带来的耗时,进一步提升训练效率。

训练数据格式

{
    "input_ids": [100, 101, 102, 103, 104], // 每个样本分词后的ids
    "loss_mask": [0, 0, 1, 1, 1], // input_ids计算损失的位置为1
    "category_ids": [0, 0, 12, 12, 12], // input_ids对应位置的 token 类型,方便后续统计 channel loss
    "length": 5, // ids长度
}

核心代码

import torch
import datasets
import numpy as np
from tqdm import tqdm


class SFTDataSet:
    def __init__(self, train_dir, max_tokens, padding_value):
        self.dataframe = datasets.load_from_disk(train_dir).shuffle(seed=42)
        self.max_tokens = max_tokens
        self.padding_value = padding_value
        self.token_aware_batch_pack()

    def token_aware_batch_pack(self):
        self.batch_idx_list = []
        batch, batch_tokens = [], 0
        for idx, tok_len in tqdm(enumerate(self.dataframe['length']), total=len(self.dataframe), desc="token aware batch packing"):
            if batch and batch_tokens + tok_len > self.max_tokens:
                self.batch_idx_list.append(batch)
                batch, batch_tokens = [], 0
            if tok_len <= self.max_tokens:
                batch.append(idx)
                batch_tokens += tok_len

        if batch and batch_tokens <= self.max_tokens:
            self.batch_idx_list.append(batch)

    def __len__(self):
        return len(self.batch_idx_list)

    def __getitem__(self, item):
        return self.batch_idx_list[item]


    def collate_fn(self, batch_idx):
        batch = [self.dataframe[idx] for idx in batch_idx[0]]
        packed_input_ids = np.full(self.max_tokens, self.padding_value, dtype=np.int64)
        packed_position_ids = np.full(self.max_tokens, 0, dtype=np.int64)
        packed_loss_mask = np.full(self.max_tokens, 0, dtype=np.int64)
        packed_category_ids = np.zeros(self.max_tokens, dtype=np.int64)
        attention_mask = np.zeros(self.max_tokens, dtype=np.int64)
        current_pos = 0
        for record in batch:
            end_pos = current_pos + record['length']
            packed_input_ids[current_pos:end_pos] = record['input_ids']
            packed_position_ids[current_pos:end_pos] = [i for i in range(record['length'])]
            packed_loss_mask[current_pos:end_pos] = record['loss_mask']
            packed_category_ids[current_pos:end_pos] = record['category_ids']
            attention_mask[current_pos:end_pos] = 1
            current_pos = end_pos
        packed_position_ids[end_pos:] = [i for i in range(self.max_tokens-end_pos)]
        return {
            'input_ids': torch.from_numpy(packed_input_ids).unsqueeze(0),
            'position_ids': torch.from_numpy(packed_position_ids).unsqueeze(0),
            'loss_mask': torch.from_numpy(packed_loss_mask).unsqueeze(0),
            'category_ids': torch.from_numpy(packed_category_ids).unsqueeze(0),
            'attention_mask': torch.from_numpy(attention_mask).unsqueeze(0),
        }

verl fsdp_sft_trainer.py 改动代码

self.train_dataset = dataset_cls(train_dir=config.data.train_dir, max_tokens=config.data.max_tokens, padding_value=self.tokenizer.pad_token_id)

self.train_dataloader = DataLoader(
    dataset=self.train_dataset,
    batch_size=config.data.train_batch_size,
    sampler=self.train_sampler,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
    collate_fn=self.train_dataset.collate_fn
)

训练脚本

PYTHONUNBUFFERED=1 torchrun --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$NODE_RANK --master_addr=$master_ip --master_port=$master_port  \
    fsdp_sft_trainer.py \
    +data.train_dir=$train_files \
    +data.max_tokens=131072 \
    optim.lr=3e-5 \
    data.train_batch_size=$nnodes \
    data.micro_batch_size=1 \
    data.micro_batch_size_per_gpu=1 \
    data.custom_cls.path=sft_data_packer.py \
    data.custom_cls.name=SFTDataSet \
    model.partial_pretrain=$model_path \
    model.use_liger=True \
    trainer.default_local_dir=$save_path \
    trainer.project_name=$task_name \
    trainer.experiment_name=$task_name \
    trainer.logger=['console'] \
    ulysses_sequence_parallel_size=8 \
    use_remove_padding=true

在计算 loss 后添加,方便打印 channel loss

self.idx2name = ... // category_id到name的映射

if self.device_mesh.get_rank() == 0:    
    idx2mean_loss = pd.DataFrame({
        'category_id': category_ids.cpu().ravel(),
        'token_loss': loss.clone().detach().cpu().numpy().ravel()
    }).groupby('category_id')['token_loss'].mean().to_dict()

    for idx, mean_loss in idx2mean_loss.items():
        if idx == 0 or idx not in self.idx2name:
            continue
        self.loss_log[f"{self.idx2name[idx]}_loss"] = mean_loss