webdataset 用于高效处理大规模数据集,通过 .tar 文件流式加载,适配分布式训练场景。

核心优势

特性说明
流式加载逐条读取 .tar,无需全部加载到内存
兼容 PyTorch与 Dataset/DataLoader 无缝集成
分布式支持自动数据分片,多 GPU 并行处理
灵活预处理支持 decode、map、filter 管道

安装

pip install webdataset

数据组织

每个样本的文件通过扩展名关联:

  • image1.jpg + label1.txt → 同一样本
  • 打包为 .tar 文件存储
tar -cf dataset.tar image1.jpg label1.txt image2.jpg label2.txt

基本用法

import webdataset as wds
from torch.utils.data import DataLoader
 
dataset = wds.WebDataset("dataset.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .batched(32)
 
dataloader = DataLoader(dataset, num_workers=4)
 
for images, labels in dataloader:
    print(images.shape, labels)

关键方法

  • decode("rgb"):自动解码图像为 RGB 张量
  • to_tuple("jpg", "txt"):按扩展名提取为元组
  • batched(32):分批加载

预处理管道

import torchvision.transforms as T
 
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])
 
def preprocess(sample):
    image, label = sample
    return transform(image), int(label)
 
dataset = wds.WebDataset("dataset.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .map(preprocess) \
    .batched(32)

分布式训练

import torch.distributed as dist
 
dist.init_process_group("nccl")
 
# 使用分片文件模式
dataset = wds.WebDataset("shards/{000000..000099}.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .batched(32)

shards/{000000..000099}.tar 表示 100 个分片文件。

高级操作

# 过滤样本
dataset = wds.WebDataset("dataset.tar") \
    .select(lambda s: "positive" in s["txt"])
 
# 多模态数据
dataset = wds.WebDataset("dataset.tar") \
    .to_tuple("jpg", "wav", "txt")
 
# 数据增强
dataset = wds.WebDataset("dataset.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .map(lambda x: (augment(x[0]), x[1]))

最佳实践

  1. 多线程 I/O:设置 num_workers > 1
  2. 均匀分片:确保 .tar 文件大小均匀
  3. 验证完整性:训练前检查 .tar 文件
  4. 动态加载:支持 AWS S3 等云存储