webdataset 用于高效处理大规模数据集,通过 .tar 文件流式加载,适配分布式训练场景。
核心优势
| 特性 | 说明 |
|---|
| 流式加载 | 逐条读取 .tar,无需全部加载到内存 |
| 兼容 PyTorch | 与 Dataset/DataLoader 无缝集成 |
| 分布式支持 | 自动数据分片,多 GPU 并行处理 |
| 灵活预处理 | 支持 decode、map、filter 管道 |
安装
数据组织
每个样本的文件通过扩展名关联:
image1.jpg + label1.txt → 同一样本- 打包为
.tar 文件存储
1
| tar -cf dataset.tar image1.jpg label1.txt image2.jpg label2.txt
|
基本用法
1
2
3
4
5
6
7
8
9
10
11
12
| import webdataset as wds
from torch.utils.data import DataLoader
dataset = wds.WebDataset("dataset.tar") \
.decode("rgb") \
.to_tuple("jpg", "txt") \
.batched(32)
dataloader = DataLoader(dataset, num_workers=4)
for images, labels in dataloader:
print(images.shape, labels)
|
关键方法:
decode("rgb"):自动解码图像为 RGB 张量to_tuple("jpg", "txt"):按扩展名提取为元组batched(32):分批加载
预处理管道
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| import torchvision.transforms as T
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
])
def preprocess(sample):
image, label = sample
return transform(image), int(label)
dataset = wds.WebDataset("dataset.tar") \
.decode("rgb") \
.to_tuple("jpg", "txt") \
.map(preprocess) \
.batched(32)
|
分布式训练
1
2
3
4
5
6
7
8
9
| import torch.distributed as dist
dist.init_process_group("nccl")
# 使用分片文件模式
dataset = wds.WebDataset("shards/{000000..000099}.tar") \
.decode("rgb") \
.to_tuple("jpg", "txt") \
.batched(32)
|
shards/{000000..000099}.tar 表示 100 个分片文件。
高级操作
1
2
3
4
5
6
7
8
9
10
11
12
13
| # 过滤样本
dataset = wds.WebDataset("dataset.tar") \
.select(lambda s: "positive" in s["txt"])
# 多模态数据
dataset = wds.WebDataset("dataset.tar") \
.to_tuple("jpg", "wav", "txt")
# 数据增强
dataset = wds.WebDataset("dataset.tar") \
.decode("rgb") \
.to_tuple("jpg", "txt") \
.map(lambda x: (augment(x[0]), x[1]))
|
最佳实践
- 多线程 I/O:设置
num_workers > 1 - 均匀分片:确保 .tar 文件大小均匀
- 验证完整性:训练前检查 .tar 文件
- 动态加载:支持 AWS S3 等云存储
Comments