欢迎访问宙启技术站
智能推送

PyTorch数据加载器和数据采样器的使用

发布时间:2024-01-16 02:03:07

PyTorch提供了数据加载器(DataLoader)和数据采样器(Sampler)来帮助我们更方便地处理数据。

数据加载器是一个可迭代对象,它可以将数据集划分为小批量进行训练。数据加载器还提供了一些功能,比如随机打乱数据集、并行加载数据等。

数据采样器用于定义从数据集中获取样本的策略。主要有两种类型的采样器:随机采样器(RandomSampler)和顺序采样器(SequentialSampler)。随机采样器可以用于随机从数据集中获取样本,而顺序采样器则可以按照顺序获取样本。

下面是一个使用数据加载器和数据采样器的示例:

import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn((100, 10))

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# 创建数据集对象
dataset = MyDataset()

# 定义数据采样器
sampler = RandomSampler(dataset)

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)

# 使用数据加载器迭代数据集
for batch in dataloader:
    print(batch)

以上示例中,首先定义了一个自定义的数据集类MyDataset。该数据集包含100个样本,每个样本有10个特征。

然后,我们创建了一个RandomSampler作为数据采样器,用于随机从数据集中获取样本。

最后,使用数据加载器DataLoader将数据集划分为大小为4的小批量。在迭代数据加载器时,每次返回一个小批量的数据,用于模型的训练或推理。

通过使用数据加载器和数据采样器,我们可以更方便地处理数据集,加快模型的训练速度,并提高模型的性能。