欢迎访问宙启技术站
智能推送

如何使用Python中的WeightedRandomSampler()函数对样本进行加权采样

发布时间:2023-12-29 11:12:13

在Python中,可以使用torch.utils.data中的WeightedRandomSampler()函数对样本进行加权采样。该函数用于创建一个采样器,根据每个样本的权重来确定其被选中的概率。

WeightedRandomSampler()函数的签名如下:

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

其中,weights是每个样本的权重列表,num_samples是要采样的样本数量,replacement指定是否可以重复采样。

下面是一个使用WeightedRandomSampler()函数对样本进行加权采样的例子:

首先,导入必要的库:

import torch
import torch.utils.data as data

接下来,定义一个自定义的数据集类,示例中为了方便起见,直接使用了torchvision.datasets包中的CIFAR10数据集:

from torchvision.datasets import CIFAR10
from torchvision import transforms

class CustomDataset(data.Dataset):
    def __init__(self, train=True, transform=None):
        self.dataset = CIFAR10(root='./data', train=train, transform=transform, download=True)
        self.weights = self.calculate_weights()

    def calculate_weights(self):
        # 假设按类别计算权重
        class_counts = [0] * len(self.dataset.classes)
        for _, target in self.dataset:
            class_counts[target] += 1
        total_samples = sum(class_counts)
        # 计算每个样本的权重
        weights = [1.0 / class_counts[target] for _, target in self.dataset]
        return weights

    def __getitem__(self, index):
        sample, target = self.dataset[index]
        weight = self.weights[index]
        return sample, target, weight

    def __len__(self):
        return len(self.dataset)

然后,创建数据预处理函数:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

接下来,创建自定义数据集的实例以及一个WeightedRandomSampler采样器:

dataset = CustomDataset(train=True, transform=transform)

sampler = torch.utils.data.WeightedRandomSampler(dataset.weights, num_samples=len(dataset), replacement=True)

最后,使用torch.utils.data.DataLoader将sampler应用于自定义数据集,实现加权采样:

batch_size = 32
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler)

for data, target, weight in dataloader:
    # 进行训练等操作
    pass

在上述代码中,自定义数据集考虑了样本的权重,计算每个样本的权重时,可以根据实际需求进行灵活的设计。WeightedRandomSampler函数会根据每个样本的权重分布进行采样,从而实现对样本的加权采样。