张芷铭的个人博客

Pytorch中的Dataset

在 PyTorch 中,Dataset 和 IterableDataset 都是 torch.utils.data 中的数据集基类,但它们的工作方式不同,主要区别如下:


1. Dataset(索引式数据集,Map-Style Dataset)

特点

• 适用于 索引访问 的数据集,即数据可以通过索引 (dataset[i]) 访问。

数据是静态的,可以随机访问

• 通常用于 完整加载到内存的数据(如图像、文本等)。

支持 len(dataset)(可以知道数据集的大小)。

可以用于 DataLoader 进行多进程 (num_workers>0) 预加载,因为数据索引是确定的。

使用示例

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.data = list(range(100))  # 假设数据是 0-99 的列表

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]  # 通过索引访问数据

dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2)

for batch in dataloader:
    print(batch)

适合小数据集,支持随机索引访问、批量加载、shuffle、并行处理


2. IterableDataset(迭代式数据集,Stream-Style Dataset)

特点

• 适用于 流式数据(Streaming Data),如 WebDataset、Kafka、网络数据流、大型日志等。

不支持 dataset[i],必须通过 for sample in dataset 迭代访问。

不支持 len(dataset),因为数据集可能是无限的(如数据流)。

不支持 shuffle=True,但可以手动实现 shuffle。

适合大规模数据,比如 .tar 文件、日志流、实时数据等。

使用示例

from torch.utils.data import IterableDataset, DataLoader

class MyIterableDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __iter__(self):
        for i in range(self.start, self.end):
            yield i  # 逐个生成数据

dataset = MyIterableDataset(0, 100)
dataloader = DataLoader(dataset, batch_size=10, num_workers=0)  # 不能 shuffle!

for batch in dataloader:
    print(batch)

适合大规模数据流式处理,不占用太多内存,但不支持索引访问和 shuffle


3. Dataset vs. IterableDataset 总结对比

特性Dataset**(索引式)**IterableDataset**(流式)**
索引访问 (dataset[i])✅ 支持❌ 不支持
len(dataset)✅ 支持❌ 不支持
随机访问✅ 支持❌ 不支持
shuffle=True✅ 可用❌ 不能自动 shuffle
多进程 (num_workers)✅ 支持⚠️ 可能需要手动 sharding
适用场景小数据集、内存加载(如图像、文本)大规模数据流(如 .tar 文件、日志流、Kafka)

4. 什么时候用 IterableDataset?

使用 IterableDataset 的情况:

  1. 处理 WebDataset(如 .tar 文件)。

  2. 流式数据(实时传输,如日志、消息队列)。

  3. 超大数据集(无法全部加载到内存)。

  4. 数据顺序很重要,不能打乱

避免使用 IterableDataset 的情况:

  1. 数据可索引、随机访问(推荐 Dataset)。

  2. 需要 shuffle(IterableDataset 默认不支持)。


5. IterableDataset 结合 DataLoader 的 num_workers

• IterableDataset 不支持索引,所以 num_workers > 0 时,多个 worker 可能会重复处理相同数据。

• 解决方法:使用 torch.utils.data.get_worker_info() 进行数据 切片(sharding)。

示例:

from torch.utils.data import IterableDataset, DataLoader, get_worker_info

class MyIterableDataset(IterableDataset):
    def __init__(self, start, end):
        self.start = start
        self.end = end

    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is None:  # 单进程
            start, end = self.start, self.end
        else:  # 多进程切片
            total = self.end - self.start
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            per_worker = total // num_workers
            start = self.start + worker_id * per_worker
            end = start + per_worker if worker_id != num_workers - 1 else self.end
        
        for i in range(start, end):
            yield i

dataset = MyIterableDataset(0, 100)
dataloader = DataLoader(dataset, batch_size=10, num_workers=4)

for batch in dataloader:
    print(batch)

确保多个 worker 处理的数据不重复,提高效率


总结

  1. Dataset(索引式)适用于小数据集,支持随机访问和 shuffle。

  2. IterableDataset(流式)适用于超大规模数据流,但不支持索引访问。

  3. IterableDataset 结合 get_worker_info() 可以支持多进程并行处理

选择哪种取决于你的数据规模和加载方式 🚀。

💬 评论