使用WeightedRandomSampler()实现类别不平衡数据集的采样
WeightedRandomSampler()是PyTorch中用于实现类别不平衡数据集采样的一个采样器。在处理类别不平衡数据集时,由于某些类别的样本数量较少,直接使用随机采样可能会导致模型对于少数类别的预测效果不佳。这时可以使用WeightedRandomSampler()来调整样本权重,从而平衡数据集。
下面举一个例子来说明如何使用WeightedRandomSampler()。
假设我们有一个二分类任务的数据集,总共有100个样本,其中正例有10个,负例有90个。由于正例的样本数量较少,直接采用随机采样可能导致模型对正例的预测效果不佳。为了解决这个问题,我们可以使用WeightedRandomSampler()来进行采样。
首先,需要计算每个样本的权重。我们可以使用样本的类别比例的倒数作为样本的权重。在本例中,正例的权重为总样本数90除以正例数10,即9;负例的权重为总样本数10除以负例数90,即0.11。
接下来,我们可以使用WeightedRandomSampler()来进行采样。首先,需要导入相关的库:
import torch from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler
接下来,我们可以定义一个自定义的数据集类,例如名为CustomDataset:
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# 定义数据集,例如使用torch.Tensor来表示样本和标签
self.data = torch.randn(100, 10) # 样本数据,假设有100个样本,每个样本有10个特征
self.labels = torch.cat([torch.zeros(90), torch.ones(10)]) # 标签数据,其中前90个为负例,后10个为正例
def __getitem__(self, index):
x = self.data[index] # 获取样本
y = self.labels[index] # 获取标签
return x, y
def __len__(self):
return len(self.data) # 返回总样本数
然后,我们可以定义一个WeightedRandomSampler(),并将其传递给DataLoader来实现采样。在定义WeightedRandomSampler()时,需要指定样本权重(weights)和样本数量(num_samples)。
dataset = CustomDataset() # 创建数据集实例
# 计算样本权重
weights = []
for label in dataset.labels:
if label == 0:
weights.append(0.11) # 负例的权重为0.11
else:
weights.append(9) # 正例的权重为9
sampler = WeightedRandomSampler(weights=weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
在上述代码中,sampler被传递给了DataLoader,这样在每个epoch中,WeightedRandomSampler会根据样本权重进行采样,从而平衡数据集。
最后,我们可以通过遍历dataloader来获取采样后的数据:
for batch in dataloader:
x, y = batch
# 进行训练或者预测操作
通过使用WeightedRandomSampler,我们可以在处理类别不平衡数据集时,提高模型对于少数类别的预测效果。
