欢迎访问宙启技术站
智能推送

torch.utils.data.sampler中的RandomSampler与SequentialSampler的区别

发布时间:2023-12-24 08:41:30

torch.utils.data.sampler中的RandomSampler和SequentialSampler是用于数据集采样的两个常用类。

RandomSampler是随机采样器,它可以在数据集中随机选取样本。使用RandomSampler时,每次迭代都会随机选择一个样本,因此可以用于在训练过程中对数据进行随机打乱。例如,假设有一个包含100个样本的数据集,我们可以使用RandomSampler对其进行随机采样,并将其用于训练一个分类模型。

下面是一个使用RandomSampler的例子:

import torch
from torch.utils.data import DataLoader, RandomSampler, Dataset

# 创建一个示例数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

# 创建示例数据集
data = list(range(100))
dataset = MyDataset(data)

# 使用RandomSampler对数据集进行随机采样
random_sampler = RandomSampler(dataset)

# 创建数据加载器
dataloader = DataLoader(dataset, sampler=random_sampler, batch_size=10)

# 遍历数据加载器
for batch in dataloader:
    print(batch)

在上面的例子中,我们首先定义了一个自定义数据集类MyDataset,然后创建了一个包含数字0到99的数据集。我们使用RandomSampler对数据集进行随机采样,并创建了一个数据加载器。在遍历数据加载器时,每个batch中的样本都是随机选择的。

SequentialSampler是顺序采样器,它按照数据集的顺序依次选取样本。使用SequentialSampler时,每次迭代都会从数据集的 个样本开始选择,直到遍历完所有样本。SequentialSampler可以用于在测试阶段对数据进行顺序采样,以便按照固定顺序对模型进行评估。

下面是一个使用SequentialSampler的例子:

import torch
from torch.utils.data import DataLoader, SequentialSampler, Dataset

# 创建一个示例数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

# 创建示例数据集
data = list(range(100))
dataset = MyDataset(data)

# 使用SequentialSampler对数据集进行顺序采样
sequential_sampler = SequentialSampler(dataset)

# 创建数据加载器
dataloader = DataLoader(dataset, sampler=sequential_sampler, batch_size=10)

# 遍历数据加载器
for batch in dataloader:
    print(batch)

在上面的例子中,我们使用SequentialSampler对数据集进行顺序采样,并创建了一个数据加载器。在遍历数据加载器时,每个batch中的样本都是按照数据集的顺序依次选择的。

综上所述,RandomSampler和SequentialSampler是两个常用的数据采样器,RandomSampler用于随机采样,SequentialSampler用于顺序采样。这两个采样器在数据集的遍历方式上有所不同,具体使用哪种采样器取决于任务需求。