欢迎访问宙启技术站
智能推送

PyTorch中torch.utils.data.sampler模块的使用介绍

发布时间:2023-12-19 05:21:38

PyTorch中的torch.utils.data.sampler模块提供了用于数据加载和训练的采样器类。采样器类定义了从数据集中提取样本的策略,可以帮助在数据加载和训练过程中控制样本的顺序、分布和重复。

PyTorch中常用的采样器包括顺序采样器(SequentialSampler)、随机采样器(RandomSampler)和子集采样器(SubsetRandomSampler)。下面我们将介绍这些采样器的用法,并给出使用示例。

1. 顺序采样器(SequentialSampler)

顺序采样器按顺序提取数据集中的样本。它适用于希望按照固定顺序遍历数据集的情况。

使用示例:

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler

class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5]
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        return self.data[index]

dataset = MyDataset()
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)

for batch in dataloader:
    print(batch)

2. 随机采样器(RandomSampler)

随机采样器会随机从数据集中提取样本。它可以洗牌数据集,使得每个batch的样本顺序都是随机的。

使用示例:

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5]
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        return self.data[index]

dataset = MyDataset()
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)

for batch in dataloader:
    print(batch)

3. 子集采样器(SubsetRandomSampler)

子集采样器可以根据给定的索引列表来提取子集。它常用于将数据集划分为训练集和验证集,使两者的样本不重复。

使用示例:

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5]
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        return self.data[index]

dataset = MyDataset()
train_indices = [0, 1, 2]
valid_indices = [3, 4]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
train_dataloader = DataLoader(dataset, sampler=train_sampler, batch_size=2)
valid_dataloader = DataLoader(dataset, sampler=valid_sampler, batch_size=2)

for batch in train_dataloader:
    print(batch)

for batch in valid_dataloader:
    print(batch)

以上是PyTorch中torch.utils.data.sampler模块的使用介绍及示例。通过采样器类的灵活组合,我们可以灵活地控制数据集样本的顺序、分布和重复,满足不同的训练需求。