欢迎访问宙启技术站
智能推送

如何使用torch.utils.data.sampler在PyTorch中进行数据采样

发布时间:2023-12-24 08:39:20

在PyTorch中,可以使用torch.utils.data.sampler来进行数据采样,以控制数据加载和训练过程中的样本顺序。

torch.utils.data.sampler提供了多种采样方法,包括随机采样、有序采样、不平衡类别采样等。

下面是一些常用的数据采样方法及其使用例子:

1. 随机采样(RandomSampler):

随机采样器按照数据集的索引进行随机选择。在每个epoch开始时,打乱数据集,并按照打乱后的顺序加载数据。

使用例子:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler

dataset = torch.utils.data.TensorDataset(torch.randn(10, 5))
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)

for batch in dataloader:
    print(batch)

2. 有序采样(SequentialSampler):

有序采样器按照数据集的索引顺序进行选择。数据集按照索引从小到大的顺序加载。

使用例子:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler

dataset = torch.utils.data.TensorDataset(torch.randn(10, 5))
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)

for batch in dataloader:
    print(batch)

3. 不平衡类别采样(WeightedRandomSampler):

不平衡类别采样器根据每个样本的权重进行采样,用于处理数据集中类别不平衡的情况。

使用例子:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

# 假设数据集有100个样本,其中10个为类别1,90个为类别0
labels = torch.cat([torch.zeros(90), torch.ones(10)], dim=0)
weights = 1 / (labels.bincount().float())
sample_weights = weights[labels.long()]

dataset = torch.utils.data.TensorDataset(torch.randn(100, 5), labels)
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)

for batch, label in dataloader:
    print(batch, label)

除了上述几种采样方法外,torch.utils.data.sampler还提供了其他一些采样方法,如SubsetRandomSampler、BatchSampler等。

使用torch.utils.data.sampler可以灵活地对数据集进行采样,控制数据的加载和训练过程,以适应不同的需求和场景。