欢迎访问宙启技术站
智能推送

实现类别平衡的数据采样方法:torch.utils.data.sampler.WeightedRandomSampler

发布时间:2023-12-24 08:42:05

类别平衡的数据采样是在不平衡的数据集中通过一定方式调整样本的权重,以达到每个类别的样本数量相对平衡的效果。在PyTorch中,可以使用torch.utils.data.sampler.WeightedRandomSampler类来实现类别平衡的数据采样。

WeightedRandomSampler继承自torch.utils.data.sampler.Sampler类,它根据每个样本的权重来进行采样。权重可以自定义,一般来说,类别样本数量较多的样本的权重会较小,类别样本数量较少的样本的权重会较大。

下面是一个使用WeightedRandomSampler进行类别平衡采样的例子:

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import WeightedRandomSampler

# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return x, y
    
    def __len__(self):
        return len(self.data)

# 假设数据集中有3个类别,每个类别有不同数量的样本
data = [...]  # 数据集
targets = [...]  # 标签
num_classes = 3

# 统计每个类别的样本数量
class_counts = torch.bincount(targets)
class_weights = 1.0 / class_counts  # 计算每个类别的权重,样本数量越多,权重越小

# 构建sampler
weights = class_weights[targets]  # 根据样本的标签获取对应的权重
sampler = WeightedRandomSampler(weights, len(weights), replacement=True)  # replacement=True表示可以有重复样本

# 构建dataloader
dataset = CustomDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

# 使用dataloader进行训练
for inputs, labels in dataloader:
    # 进行训练
    ...

在上面的例子中,我们首先统计每个类别的样本数量,然后计算每个类别的权重。接下来,根据样本的标签获取对应的权重,并使用这些权重构建WeightedRandomSampler对象。最后,我们使用该sampler作为参数来构建数据加载器,并使用数据加载器进行模型的训练。

通过使用WeightedRandomSampler类,我们可以实现对不平衡数据集的类别平衡采样,以提高模型训练的效果。