LRM 浪潮下,训练样本的长度不断增加,序列并行策略已成为训练中的常用选择。
然而,无论是常规训练还是序列并行训练,以“样本数量”定义批次的传统做法,都存在不容忽视的问题: 一方面,样本长度参差不齐,按经验设置的固定 batch size 很容易因超长样本出现 OOM(内存溢出),只能反复手动调小 batch size 进行尝试;另一方面,这种方式无法根据样本实际长度灵活适配计算资源,导致 TFLOPs(每秒万亿次浮点运算)利用率难以最大化,造成算力浪费。
为此,本次基于 VERL 的 TokenAwareBatchPack 实践给出了创新解法——用“Token 数量”而非“样本数量”来定义批次。
这一策略通过以固定 token 总量作为批次划分标准,既能避免超长样本引发的 OOM 问题,又能更精准地匹配计算资源,充分榨干序列并行的算力潜力。
由于 TokenAwareBatchPack 策略需要提前获取每个样本的 token 长度,我们可以顺势将样本的 tokenize 过程离线完成。这不仅为批次划分提供了数据基础,还能减少在线训练时 tokenize 操作带来的耗时,进一步提升训练效率。
训练数据格式
{ "input_ids": [100, 101, 102, 103, 104], // 每个样本分词后的ids "loss_mask": [0, 0, 1, 1, 1], // input_ids计算损失的位置为1 "category_ids": [0, 0, 12, 12, 12], // input_ids对应位置的 token 类型,方便后续统计 channel loss "length": 5, // ids长度 }
核心代码
import torch import datasets import numpy as np from tqdm import tqdm class SFTDataSet: def __init__(self, train_dir, max_tokens, padding_value): self.dataframe = datasets.load_from_disk(train_dir).shuffle(seed=42) self.max_tokens = max_tokens self.padding_value = padding_value self.token_aware_batch_pack() def token_aware_batch_pack(self): self.batch_idx_list = [] batch, batch_tokens = [], 0 for idx, tok_len in tqdm(enumerate(self.dataframe['length']), total=len(self.dataframe), desc="token aware batch packing"): if batch and batch_tokens + tok_len > self.max_tokens: self.batch_idx_list.append(batch) batch, batch_tokens = [], 0 if tok_len <= self.max_tokens: batch.append(idx) batch_tokens += tok_len if batch and batch_tokens <= self.max_tokens: self.batch_idx_list.append(batch) def __len__(self): return len(self.batch_idx_list) def __getitem__(self, item): return self.batch_idx_list[item] def collate_fn(self, batch_idx): batch = [self.dataframe[idx] for idx in batch_idx[0]] packed_input_ids = np.full(self.max_tokens, self.padding_value, dtype=np.int64) packed_position_ids = np.full(self.max_tokens, 0, dtype=np.int64) packed_loss_mask = np.full(self.max_tokens, 0, dtype=np.int64) packed_category_ids = np.zeros(self.max_tokens, dtype=np.int64) attention_mask = np.zeros(self.max_tokens, dtype=np.int64) current_pos = 0 for record in batch: end_pos = current_pos + record['length'] packed_input_ids[current_pos:end_pos] = record['input_ids'] packed_position_ids[current_pos:end_pos] = [i for i in range(record['length'])] packed_loss_mask[current_pos:end_pos] = record['loss_mask'] packed_category_ids[current_pos:end_pos] = record['category_ids'] attention_mask[current_pos:end_pos] = 1 current_pos = end_pos packed_position_ids[end_pos:] = [i for i in range(self.max_tokens-end_pos)] return { 'input_ids': torch.from_numpy(packed_input_ids).unsqueeze(0), 'position_ids': torch.from_numpy(packed_position_ids).unsqueeze(0), 'loss_mask': torch.from_numpy(packed_loss_mask).unsqueeze(0), 'category_ids': torch.from_numpy(packed_category_ids).unsqueeze(0), 'attention_mask': torch.from_numpy(attention_mask).unsqueeze(0), }
verl fsdp_sft_trainer.py 改动代码
self.train_dataset = dataset_cls(train_dir=config.data.train_dir, max_tokens=config.data.max_tokens, padding_value=self.tokenizer.pad_token_id) self.train_dataloader = DataLoader( dataset=self.train_dataset, batch_size=config.data.train_batch_size, sampler=self.train_sampler, num_workers=8, pin_memory=True, drop_last=True, collate_fn=self.train_dataset.collate_fn )
训练脚本
PYTHONUNBUFFERED=1 torchrun --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$NODE_RANK --master_addr=$master_ip --master_port=$master_port \ fsdp_sft_trainer.py \ +data.train_dir=$train_files \ +data.max_tokens=131072 \ optim.lr=3e-5 \ data.train_batch_size=$nnodes \ data.micro_batch_size=1 \ data.micro_batch_size_per_gpu=1 \ data.custom_cls.path=sft_data_packer.py \ data.custom_cls.name=SFTDataSet \ model.partial_pretrain=$model_path \ model.use_liger=True \ trainer.default_local_dir=$save_path \ trainer.project_name=$task_name \ trainer.experiment_name=$task_name \ trainer.logger=['console'] \ ulysses_sequence_parallel_size=8 \ use_remove_padding=true
在计算 loss 后添加,方便打印 channel loss
self.idx2name = ... // category_id到name的映射 if self.device_mesh.get_rank() == 0: idx2mean_loss = pd.DataFrame({ 'category_id': category_ids.cpu().ravel(), 'token_loss': loss.clone().detach().cpu().numpy().ravel() }).groupby('category_id')['token_loss'].mean().to_dict() for idx, mean_loss in idx2mean_loss.items(): if idx == 0 or idx not in self.idx2name: continue self.loss_log[f"{self.idx2name[idx]}_loss"] = mean_loss