Dataset 支持索引访问,IterableDataset 支持流式迭代,适用场景不同。
核心对比
| 特性 | Dataset | IterableDataset |
|---|
| 索引访问 | ✅ dataset[i] | ❌ |
| len() | ✅ 支持 | ❌ 不支持 |
| shuffle | ✅ 自动支持 | ❌ 需手动实现 |
| 多进程 | ✅ 自动支持 | ⚠️ 需手动 sharding |
| 适用场景 | 小数据集、内存加载 | 大规模数据流 |
Dataset(索引式)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(100))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2)
|
适用:小数据集、内存加载、需要 shuffle。
IterableDataset(流式)
1
2
3
4
5
6
7
8
9
10
11
12
13
| from torch.utils.data import IterableDataset, DataLoader
class MyIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
for i in range(self.start, self.end):
yield i
dataset = MyIterableDataset(0, 100)
dataloader = DataLoader(dataset, batch_size=10) # 不能 shuffle
|
适用:WebDataset、日志流、超大数据集。
多进程 Sharding
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| from torch.utils.data import get_worker_info
class MyIterableDataset(IterableDataset):
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
start, end = self.start, self.end
else:
per_worker = (self.end - self.start) // worker_info.num_workers
start = self.start + worker_info.id * per_worker
end = start + per_worker
for i in range(start, end):
yield i
|
选择指南
- 用 Dataset:数据可索引、需要 shuffle、内存可容纳
- 用 IterableDataset:超大数据、流式数据、顺序重要
Comments