欢迎访问宙启技术站
智能推送

PyTorch中的随机数据采样器(torch.utils.data.sampler)

发布时间:2023-12-24 08:39:33

PyTorch中的随机数据采样器(torch.utils.data.sampler)是一种用于数据加载器的采样策略。它决定了如何从数据集中选择样本,并定义了迭代数据集的顺序。

常见的数据采样器包括随机采样器(RandomSampler)、顺序采样器(SequentialSampler)和自定义采样器(CustomSampler)。

随机采样器(RandomSampler)在每个epoch中随机选择样本。该采样器允许重复选择相同的样本,因此可以用于训练集。以下是使用随机采样器的示例:

import torch
from torch.utils.data import DataLoader, RandomSampler

dataset = torch.arange(10)  # 创建一个0到9的Tensor数据集
sampler = RandomSampler(dataset)  # 创建随机采样器
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)  # 创建数据加载器

for batch in dataloader:
    print(batch)

# 输出:
# tensor([8, 1])
# tensor([9, 6])
# tensor([7, 0])
# tensor([4, 2])
# tensor([5, 3])

顺序采样器(SequentialSampler)按照数据集的顺序依次选择样本。该采样器不会重复选择样本,适用于验证集和测试集。以下是使用顺序采样器的示例:

import torch
from torch.utils.data import DataLoader, SequentialSampler

dataset = torch.arange(10)  # 创建一个0到9的Tensor数据集
sampler = SequentialSampler(dataset)  # 创建顺序采样器
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)  # 创建数据加载器

for batch in dataloader:
    print(batch)

# 输出:
# tensor([0, 1])
# tensor([2, 3])
# tensor([4, 5])
# tensor([6, 7])
# tensor([8, 9])

除了随机采样器和顺序采样器,您还可以实现自定义采样器(CustomSampler)。自定义采样器需要继承torch.utils.data.sampler.Sampler类,并实现__iter____len__方法。以下是一个自定义采样器的示例:

import torch
from torch.utils.data import DataLoader, Sampler

class CustomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source) - 1, -1, -1))  # 逆向选择样本

    def __len__(self):
        return len(self.data_source)

dataset = torch.arange(10)  # 创建一个0到9的Tensor数据集
sampler = CustomSampler(dataset)  # 创建自定义采样器
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)  # 创建数据加载器

for batch in dataloader:
    print(batch)

# 输出:
# tensor([9, 8])
# tensor([7, 6])
# tensor([5, 4])
# tensor([3, 2])
# tensor([1, 0])

正如示例所示,使用随机采样器、顺序采样器或自定义采样器可以根据需要灵活地控制样本选择的顺序,从而更好地适应不同的数据加载需求。