如何使用torch.utils.data.sampler在PyTorch中进行数据采样
发布时间:2023-12-24 08:39:20
在PyTorch中,可以使用torch.utils.data.sampler来进行数据采样,以控制数据加载和训练过程中的样本顺序。
torch.utils.data.sampler提供了多种采样方法,包括随机采样、有序采样、不平衡类别采样等。
下面是一些常用的数据采样方法及其使用例子:
1. 随机采样(RandomSampler):
随机采样器按照数据集的索引进行随机选择。在每个epoch开始时,打乱数据集,并按照打乱后的顺序加载数据。
使用例子:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
dataset = torch.utils.data.TensorDataset(torch.randn(10, 5))
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)
for batch in dataloader:
print(batch)
2. 有序采样(SequentialSampler):
有序采样器按照数据集的索引顺序进行选择。数据集按照索引从小到大的顺序加载。
使用例子:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
dataset = torch.utils.data.TensorDataset(torch.randn(10, 5))
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)
for batch in dataloader:
print(batch)
3. 不平衡类别采样(WeightedRandomSampler):
不平衡类别采样器根据每个样本的权重进行采样,用于处理数据集中类别不平衡的情况。
使用例子:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
# 假设数据集有100个样本,其中10个为类别1,90个为类别0
labels = torch.cat([torch.zeros(90), torch.ones(10)], dim=0)
weights = 1 / (labels.bincount().float())
sample_weights = weights[labels.long()]
dataset = torch.utils.data.TensorDataset(torch.randn(100, 5), labels)
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)
for batch, label in dataloader:
print(batch, label)
除了上述几种采样方法外,torch.utils.data.sampler还提供了其他一些采样方法,如SubsetRandomSampler、BatchSampler等。
使用torch.utils.data.sampler可以灵活地对数据集进行采样,控制数据的加载和训练过程,以适应不同的需求和场景。
