张芷铭的个人博客

webdataset 用于高效处理大规模数据集,通过 .tar 文件流式加载,适配分布式训练场景。

核心优势

特性说明
流式加载逐条读取 .tar,无需全部加载到内存
兼容 PyTorch与 Dataset/DataLoader 无缝集成
分布式支持自动数据分片,多 GPU 并行处理
灵活预处理支持 decode、map、filter 管道

安装

1
pip install webdataset

数据组织

每个样本的文件通过扩展名关联:

  • image1.jpg + label1.txt → 同一样本
  • 打包为 .tar 文件存储
1
tar -cf dataset.tar image1.jpg label1.txt image2.jpg label2.txt

基本用法

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import webdataset as wds
from torch.utils.data import DataLoader

dataset = wds.WebDataset("dataset.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .batched(32)

dataloader = DataLoader(dataset, num_workers=4)

for images, labels in dataloader:
    print(images.shape, labels)

关键方法

  • decode("rgb"):自动解码图像为 RGB 张量
  • to_tuple("jpg", "txt"):按扩展名提取为元组
  • batched(32):分批加载

预处理管道

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import torchvision.transforms as T

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

def preprocess(sample):
    image, label = sample
    return transform(image), int(label)

dataset = wds.WebDataset("dataset.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .map(preprocess) \
    .batched(32)

分布式训练

1
2
3
4
5
6
7
8
9
import torch.distributed as dist

dist.init_process_group("nccl")

# 使用分片文件模式
dataset = wds.WebDataset("shards/{000000..000099}.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .batched(32)

shards/{000000..000099}.tar 表示 100 个分片文件。

高级操作

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# 过滤样本
dataset = wds.WebDataset("dataset.tar") \
    .select(lambda s: "positive" in s["txt"])

# 多模态数据
dataset = wds.WebDataset("dataset.tar") \
    .to_tuple("jpg", "wav", "txt")

# 数据增强
dataset = wds.WebDataset("dataset.tar") \
    .decode("rgb") \
    .to_tuple("jpg", "txt") \
    .map(lambda x: (augment(x[0]), x[1]))

最佳实践

  1. 多线程 I/O:设置 num_workers > 1
  2. 均匀分片:确保 .tar 文件大小均匀
  3. 验证完整性:训练前检查 .tar 文件
  4. 动态加载:支持 AWS S3 等云存储

Comments