WebDataset 通过流式处理和顺序读取替代随机访问,解决大规模深度学习训练中的数据 I/O 瓶颈。
什么是 WebDataset
WebDataset 是基于 TAR 归档格式的深度学习数据加载库,核心思想是将大量小文件打包成 TAR 文件,通过顺序读取提升 I/O 效率。本质上,wds 格式就是遵循额外约定的 tar 文件,一般不压缩以支持流式读取。
| 特性 | 传统文件系统 | WebDataset |
|---|
| 访问模式 | 随机访问,高延迟 | 顺序读取,高吞吐 |
| 存储效率 | 元数据开销大 | TAR 容器减少元数据 |
| 分布式支持 | 需复杂协调 | 天然支持分片 |
| 网络传输 | 小文件效率低 | 大文件流式传输 |
核心原理
顺序读取优势
机械硬盘随机读取速度仅为顺序读取的 1/100,WebDataset 将随机 I/O 转换为顺序 I/O,充分利用存储吞吐能力。
分片机制
大数据集分割为多个 TAR 分片,每个分片包含数千样本:
- 并行加载:不同分片由不同进程并行读取
- 分布式训练:每个节点处理不同分片子集
- 容错性:单个分片损坏不影响整体
样本组织规范
同一样本的所有文件共享相同前缀 key,通过扩展名区分数据类型。
1
2
3
| images17/image194.left.jpg
images17/image194.right.jpg
images17/image194.json
|
读取后得到字典:
1
| {"__key__": "images17/image194", "left.jpg": b"...", "right.jpg": b"...", "json": b"..."}
|
创建 WebDataset
1
2
3
4
5
6
7
8
9
10
11
| import webdataset as wds
import json
with wds.TarWriter("output.tar") as sink:
for i, (image_data, label, metadata) in enumerate(samples):
sink.write({
"__key__": f"sample{i:06d}",
"jpg": image_data,
"cls": str(label).encode(),
"json": json.dumps(metadata).encode()
})
|
读取数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| import webdataset as wds
from torchvision import transforms
preprocess = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
])
dataset = (wds.WebDataset("dataset-{000000..000099}.tar")
.shuffle(1000)
.decode("pil")
.to_tuple("jpg", "cls")
.map_tuple(preprocess, lambda x: int(x))
.batched(32)
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=4)
|
分布式训练
1
2
3
4
5
6
7
8
9
10
| dataset = (wds.WebDataset("dataset-{000000..012345}.tar",
resampled=True,
nodesplitter=wds.split_by_node)
.shuffle(10000)
.decode("torchrgb")
.to_tuple("image", "label")
.batched(64)
)
loader = wds.WebLoader(dataset, batch_size=None, num_workers=8)
|
性能优化
分片大小建议
| 存储类型 | 推荐分片大小 |
|---|
| 本地硬盘 | 256MB-1GB |
| 网络存储 | 1-4GB |
| 云对象存储 | 4-16GB |
缓存策略
1
2
3
4
5
6
| dataset = (wds.WebDataset("https://example.com/dataset-{000000..000999}.tar")
.cache_dir("./cache")
.cache_size(10 * 1024 ** 3) # 10GB
.shuffle(10000)
.decode("pil")
)
|
随机读取
虽然 wds 为流式设计,但可通过 wids 库实现随机读取。不过如果已知样本所在 tar 路径和 key,直接基于 webdataset 接口读取更快。
总结
WebDataset 核心优势:
- 性能:顺序读取比随机访问快 3-10 倍
- 分布式友好:天然支持多节点、多 GPU 训练
- 灵活性:支持任意数据类型和多模态场景
- 易用性:与 PyTorch 无缝集成
Comments