在 PyTorch 中,Dataset 和 IterableDataset 都是 torch.utils.data 中的数据集基类,但它们的工作方式不同,主要区别如下:
1. Dataset(索引式数据集,Map-Style Dataset)
特点
• 适用于 索引访问 的数据集,即数据可以通过索引 (dataset[i]) 访问。
• 数据是静态的,可以随机访问。
• 通常用于 完整加载到内存的数据(如图像、文本等)。
• 支持 len(dataset)(可以知道数据集的大小)。
• 可以用于 DataLoader 进行多进程 (num_workers>0) 预加载,因为数据索引是确定的。
使用示例
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
self.data = list(range(100)) # 假设数据是 0-99 的列表
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index] # 通过索引访问数据
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2)
for batch in dataloader:
print(batch)
✅ 适合小数据集,支持随机索引访问、批量加载、shuffle、并行处理。
2. IterableDataset(迭代式数据集,Stream-Style Dataset)
特点
• 适用于 流式数据(Streaming Data),如 WebDataset、Kafka、网络数据流、大型日志等。
• 不支持 dataset[i],必须通过 for sample in dataset 迭代访问。
• 不支持 len(dataset),因为数据集可能是无限的(如数据流)。
• 不支持 shuffle=True,但可以手动实现 shuffle。
• 适合大规模数据,比如 .tar 文件、日志流、实时数据等。
使用示例
from torch.utils.data import IterableDataset, DataLoader
class MyIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
for i in range(self.start, self.end):
yield i # 逐个生成数据
dataset = MyIterableDataset(0, 100)
dataloader = DataLoader(dataset, batch_size=10, num_workers=0) # 不能 shuffle!
for batch in dataloader:
print(batch)
✅ 适合大规模数据流式处理,不占用太多内存,但不支持索引访问和 shuffle。
3. Dataset vs. IterableDataset 总结对比
| 特性 | Dataset**(索引式)** | IterableDataset**(流式)** |
|---|---|---|
| 索引访问 (dataset[i]) | ✅ 支持 | ❌ 不支持 |
| len(dataset) | ✅ 支持 | ❌ 不支持 |
| 随机访问 | ✅ 支持 | ❌ 不支持 |
| shuffle=True | ✅ 可用 | ❌ 不能自动 shuffle |
| 多进程 (num_workers) | ✅ 支持 | ⚠️ 可能需要手动 sharding |
| 适用场景 | 小数据集、内存加载(如图像、文本) | 大规模数据流(如 .tar 文件、日志流、Kafka) |
4. 什么时候用 IterableDataset?
✅ 使用 IterableDataset 的情况:
处理 WebDataset(如 .tar 文件)。
流式数据(实时传输,如日志、消息队列)。
超大数据集(无法全部加载到内存)。
数据顺序很重要,不能打乱。
❌ 避免使用 IterableDataset 的情况:
数据可索引、随机访问(推荐 Dataset)。
需要 shuffle(IterableDataset 默认不支持)。
5. IterableDataset 结合 DataLoader 的 num_workers
• IterableDataset 不支持索引,所以 num_workers > 0 时,多个 worker 可能会重复处理相同数据。
• 解决方法:使用 torch.utils.data.get_worker_info() 进行数据 切片(sharding)。
示例:
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
class MyIterableDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None: # 单进程
start, end = self.start, self.end
else: # 多进程切片
total = self.end - self.start
num_workers = worker_info.num_workers
worker_id = worker_info.id
per_worker = total // num_workers
start = self.start + worker_id * per_worker
end = start + per_worker if worker_id != num_workers - 1 else self.end
for i in range(start, end):
yield i
dataset = MyIterableDataset(0, 100)
dataloader = DataLoader(dataset, batch_size=10, num_workers=4)
for batch in dataloader:
print(batch)
✅ 确保多个 worker 处理的数据不重复,提高效率。
总结
Dataset(索引式)适用于小数据集,支持随机访问和 shuffle。
IterableDataset(流式)适用于超大规模数据流,但不支持索引访问。
IterableDataset 结合 get_worker_info() 可以支持多进程并行处理。
选择哪种取决于你的数据规模和加载方式 🚀。
张芷铭的个人博客
💬 评论