张芷铭的个人博客

WebDataset 通过流式处理和顺序读取替代随机访问,解决大规模深度学习训练中的数据 I/O 瓶颈。

什么是 WebDataset

WebDataset 是基于 TAR 归档格式的深度学习数据加载库,核心思想是将大量小文件打包成 TAR 文件,通过顺序读取提升 I/O 效率。本质上,wds 格式就是遵循额外约定的 tar 文件,一般不压缩以支持流式读取

特性传统文件系统WebDataset
访问模式随机访问,高延迟顺序读取,高吞吐
存储效率元数据开销大TAR 容器减少元数据
分布式支持需复杂协调天然支持分片
网络传输小文件效率低大文件流式传输

核心原理

顺序读取优势

机械硬盘随机读取速度仅为顺序读取的 1/100,WebDataset 将随机 I/O 转换为顺序 I/O,充分利用存储吞吐能力。

分片机制

大数据集分割为多个 TAR 分片,每个分片包含数千样本:

  • 并行加载:不同分片由不同进程并行读取
  • 分布式训练:每个节点处理不同分片子集
  • 容错性:单个分片损坏不影响整体

样本组织规范

同一样本的所有文件共享相同前缀 key,通过扩展名区分数据类型。

1
2
3
images17/image194.left.jpg
images17/image194.right.jpg
images17/image194.json

读取后得到字典:

1
{"__key__": "images17/image194", "left.jpg": b"...", "right.jpg": b"...", "json": b"..."}

创建 WebDataset

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import webdataset as wds
import json

with wds.TarWriter("output.tar") as sink:
    for i, (image_data, label, metadata) in enumerate(samples):
        sink.write({
            "__key__": f"sample{i:06d}",
            "jpg": image_data,
            "cls": str(label).encode(),
            "json": json.dumps(metadata).encode()
        })

读取数据

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import webdataset as wds
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
])

dataset = (wds.WebDataset("dataset-{000000..000099}.tar")
    .shuffle(1000)
    .decode("pil")
    .to_tuple("jpg", "cls")
    .map_tuple(preprocess, lambda x: int(x))
    .batched(32)
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=4)

分布式训练

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
dataset = (wds.WebDataset("dataset-{000000..012345}.tar",
        resampled=True,
        nodesplitter=wds.split_by_node)
    .shuffle(10000)
    .decode("torchrgb")
    .to_tuple("image", "label")
    .batched(64)
)

loader = wds.WebLoader(dataset, batch_size=None, num_workers=8)

性能优化

分片大小建议

存储类型推荐分片大小
本地硬盘256MB-1GB
网络存储1-4GB
云对象存储4-16GB

缓存策略

1
2
3
4
5
6
dataset = (wds.WebDataset("https://example.com/dataset-{000000..000999}.tar")
    .cache_dir("./cache")
    .cache_size(10 * 1024 ** 3)  # 10GB
    .shuffle(10000)
    .decode("pil")
)

随机读取

虽然 wds 为流式设计,但可通过 wids 库实现随机读取。不过如果已知样本所在 tar 路径和 key,直接基于 webdataset 接口读取更快。

总结

WebDataset 核心优势:

  • 性能:顺序读取比随机访问快 3-10 倍
  • 分布式友好:天然支持多节点、多 GPU 训练
  • 灵活性:支持任意数据类型和多模态场景
  • 易用性:与 PyTorch 无缝集成

Comments