LRM 浪潮下,训练样本的长度不断增加,序列并行策略已成为训练中的常用选择。
然而,无论是常规训练还是序列并行训练,以“样本数量”定义批次的传统做法,都存在不容忽视的问题: 一方面,样本长度参差不齐,按经验设置的固定 batch size 很容易因超长样本出现 OOM(内存溢出),只能反复手动调小 batch size 进行尝试;另一方面,这种方式无法根据样本实际长度灵活适配计算资源,导致 TFLOPs(每秒万亿次浮点运算)利用率难以最大化,造成算力浪费。
为此,本次基于 VERL 的 TokenAwareBatchPack 实践给出了创新解法——用“Token 数量”而非“样本数量”来定义批次。
这一策略通过以固定 token 总量作为批次划分标准,既能避免超长样本引发的 OOM 问题,又能更精准地匹配计算资源,充分榨干序列并行的算力潜力。
由于 TokenAwareBatchPack 策略需要提前获取每个样本的 token 长度,我们可以顺势将样本的 tokenize 过程离线完成。这不仅为批次划分提供了数据基础,还能减少在线训练时 tokenize 操作带来的耗时,进一步提升训练效率。
训练数据格式
{
"input_ids": [100, 101, 102, 103, 104], // 每个样本分词后的ids
"loss_mask": [0, 0, 1, 1, 1], // input_ids计算损失的位置为1
"category_ids": [0, 0, 12, 12, 12], // input_ids对应位置的 token 类型,方便后续统计 channel loss
"length": 5, // ids长度
}
核心代码
import torch
import datasets
import numpy as np
from tqdm import tqdm
class SFTDataSet:
def __init__(self, train_dir, max_tokens, padding_value):
self.dataframe = datasets.load_from_disk(train_dir).shuffle(seed=42)
self.max_tokens = max_tokens
self.padding_value = padding_value
self.token_aware_batch_pack()
def token_aware_batch_pack(self):
self.batch_idx_list = []
batch, batch_tokens = [], 0
for idx, tok_len in tqdm(enumerate(self.dataframe['length']), total=len(self.dataframe), desc="token aware batch packing"):
if batch and batch_tokens + tok_len > self.max_tokens:
self.batch_idx_list.append(batch)
batch, batch_tokens = [], 0
if tok_len <= self.max_tokens:
batch.append(idx)
batch_tokens += tok_len
if batch and batch_tokens <= self.max_tokens:
self.batch_idx_list.append(batch)
def __len__(self):
return len(self.batch_idx_list)
def __getitem__(self, item):
return self.batch_idx_list[item]
def collate_fn(self, batch_idx):
batch = [self.dataframe[idx] for idx in batch_idx[0]]
packed_input_ids = np.full(self.max_tokens, self.padding_value, dtype=np.int64)
packed_position_ids = np.full(self.max_tokens, 0, dtype=np.int64)
packed_loss_mask = np.full(self.max_tokens, 0, dtype=np.int64)
packed_category_ids = np.zeros(self.max_tokens, dtype=np.int64)
attention_mask = np.zeros(self.max_tokens, dtype=np.int64)
current_pos = 0
for record in batch:
end_pos = current_pos + record['length']
packed_input_ids[current_pos:end_pos] = record['input_ids']
packed_position_ids[current_pos:end_pos] = [i for i in range(record['length'])]
packed_loss_mask[current_pos:end_pos] = record['loss_mask']
packed_category_ids[current_pos:end_pos] = record['category_ids']
attention_mask[current_pos:end_pos] = 1
current_pos = end_pos
packed_position_ids[end_pos:] = [i for i in range(self.max_tokens-end_pos)]
return {
'input_ids': torch.from_numpy(packed_input_ids).unsqueeze(0),
'position_ids': torch.from_numpy(packed_position_ids).unsqueeze(0),
'loss_mask': torch.from_numpy(packed_loss_mask).unsqueeze(0),
'category_ids': torch.from_numpy(packed_category_ids).unsqueeze(0),
'attention_mask': torch.from_numpy(attention_mask).unsqueeze(0),
}
verl fsdp_sft_trainer.py 改动代码
self.train_dataset = dataset_cls(train_dir=config.data.train_dir, max_tokens=config.data.max_tokens, padding_value=self.tokenizer.pad_token_id)
self.train_dataloader = DataLoader(
dataset=self.train_dataset,
batch_size=config.data.train_batch_size,
sampler=self.train_sampler,
num_workers=8,
pin_memory=True,
drop_last=True,
collate_fn=self.train_dataset.collate_fn
)
训练脚本
PYTHONUNBUFFERED=1 torchrun --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$NODE_RANK --master_addr=$master_ip --master_port=$master_port \
fsdp_sft_trainer.py \
+data.train_dir=$train_files \
+data.max_tokens=131072 \
optim.lr=3e-5 \
data.train_batch_size=$nnodes \
data.micro_batch_size=1 \
data.micro_batch_size_per_gpu=1 \
data.custom_cls.path=sft_data_packer.py \
data.custom_cls.name=SFTDataSet \
model.partial_pretrain=$model_path \
model.use_liger=True \
trainer.default_local_dir=$save_path \
trainer.project_name=$task_name \
trainer.experiment_name=$task_name \
trainer.logger=['console'] \
ulysses_sequence_parallel_size=8 \
use_remove_padding=true
在计算 loss 后添加,方便打印 channel loss
self.idx2name = ... // category_id到name的映射
if self.device_mesh.get_rank() == 0:
idx2mean_loss = pd.DataFrame({
'category_id': category_ids.cpu().ravel(),
'token_loss': loss.clone().detach().cpu().numpy().ravel()
}).groupby('category_id')['token_loss'].mean().to_dict()
for idx, mean_loss in idx2mean_loss.items():
if idx == 0 or idx not in self.idx2name:
continue
self.loss_log[f"{self.idx2name[idx]}_loss"] = mean_loss
