欢迎访问宙启技术站
智能推送

使用WeightedRandomSampler()实现类别不平衡数据集的采样

发布时间:2023-12-29 11:03:35

WeightedRandomSampler()是PyTorch中用于实现类别不平衡数据集采样的一个采样器。在处理类别不平衡数据集时,由于某些类别的样本数量较少,直接使用随机采样可能会导致模型对于少数类别的预测效果不佳。这时可以使用WeightedRandomSampler()来调整样本权重,从而平衡数据集。

下面举一个例子来说明如何使用WeightedRandomSampler()。

假设我们有一个二分类任务的数据集,总共有100个样本,其中正例有10个,负例有90个。由于正例的样本数量较少,直接采用随机采样可能导致模型对正例的预测效果不佳。为了解决这个问题,我们可以使用WeightedRandomSampler()来进行采样。

首先,需要计算每个样本的权重。我们可以使用样本的类别比例的倒数作为样本的权重。在本例中,正例的权重为总样本数90除以正例数10,即9;负例的权重为总样本数10除以负例数90,即0.11。

接下来,我们可以使用WeightedRandomSampler()来进行采样。首先,需要导入相关的库:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

接下来,我们可以定义一个自定义的数据集类,例如名为CustomDataset:

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # 定义数据集,例如使用torch.Tensor来表示样本和标签
        self.data = torch.randn(100, 10)  # 样本数据,假设有100个样本,每个样本有10个特征
        self.labels = torch.cat([torch.zeros(90), torch.ones(10)])  # 标签数据,其中前90个为负例,后10个为正例

    def __getitem__(self, index):
        x = self.data[index]  # 获取样本
        y = self.labels[index]  # 获取标签
        return x, y

    def __len__(self):
        return len(self.data)  # 返回总样本数

然后,我们可以定义一个WeightedRandomSampler(),并将其传递给DataLoader来实现采样。在定义WeightedRandomSampler()时,需要指定样本权重(weights)和样本数量(num_samples)。

dataset = CustomDataset()  # 创建数据集实例

# 计算样本权重
weights = []
for label in dataset.labels:
    if label == 0:
        weights.append(0.11)  # 负例的权重为0.11
    else:
        weights.append(9)  # 正例的权重为9

sampler = WeightedRandomSampler(weights=weights, num_samples=len(dataset), replacement=True)

dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)

在上述代码中,sampler被传递给了DataLoader,这样在每个epoch中,WeightedRandomSampler会根据样本权重进行采样,从而平衡数据集。

最后,我们可以通过遍历dataloader来获取采样后的数据:

for batch in dataloader:
    x, y = batch
    # 进行训练或者预测操作

通过使用WeightedRandomSampler,我们可以在处理类别不平衡数据集时,提高模型对于少数类别的预测效果。