张芷铭的个人博客

Dataset 支持索引访问,IterableDataset 支持流式迭代,适用场景不同。

核心对比

特性DatasetIterableDataset
索引访问dataset[i]
len()✅ 支持❌ 不支持
shuffle✅ 自动支持❌ 需手动实现
多进程✅ 自动支持⚠️ 需手动 sharding
适用场景小数据集、内存加载大规模数据流

Dataset(索引式)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.data = list(range(100))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2)

适用:小数据集、内存加载、需要 shuffle。

IterableDataset(流式)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
from torch.utils.data import IterableDataset, DataLoader

class MyIterableDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __iter__(self):
        for i in range(self.start, self.end):
            yield i

dataset = MyIterableDataset(0, 100)
dataloader = DataLoader(dataset, batch_size=10)  # 不能 shuffle

适用:WebDataset、日志流、超大数据集。

多进程 Sharding

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from torch.utils.data import get_worker_info

class MyIterableDataset(IterableDataset):
    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is None:
            start, end = self.start, self.end
        else:
            per_worker = (self.end - self.start) // worker_info.num_workers
            start = self.start + worker_info.id * per_worker
            end = start + per_worker

        for i in range(start, end):
            yield i

选择指南

  • 用 Dataset:数据可索引、需要 shuffle、内存可容纳
  • 用 IterableDataset:超大数据、流式数据、顺序重要

Comments