如何使用Python中的WeightedRandomSampler()函数对样本进行加权采样
发布时间:2023-12-29 11:12:13
在Python中,可以使用torch.utils.data中的WeightedRandomSampler()函数对样本进行加权采样。该函数用于创建一个采样器,根据每个样本的权重来确定其被选中的概率。
WeightedRandomSampler()函数的签名如下:
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
其中,weights是每个样本的权重列表,num_samples是要采样的样本数量,replacement指定是否可以重复采样。
下面是一个使用WeightedRandomSampler()函数对样本进行加权采样的例子:
首先,导入必要的库:
import torch import torch.utils.data as data
接下来,定义一个自定义的数据集类,示例中为了方便起见,直接使用了torchvision.datasets包中的CIFAR10数据集:
from torchvision.datasets import CIFAR10
from torchvision import transforms
class CustomDataset(data.Dataset):
def __init__(self, train=True, transform=None):
self.dataset = CIFAR10(root='./data', train=train, transform=transform, download=True)
self.weights = self.calculate_weights()
def calculate_weights(self):
# 假设按类别计算权重
class_counts = [0] * len(self.dataset.classes)
for _, target in self.dataset:
class_counts[target] += 1
total_samples = sum(class_counts)
# 计算每个样本的权重
weights = [1.0 / class_counts[target] for _, target in self.dataset]
return weights
def __getitem__(self, index):
sample, target = self.dataset[index]
weight = self.weights[index]
return sample, target, weight
def __len__(self):
return len(self.dataset)
然后,创建数据预处理函数:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
接下来,创建自定义数据集的实例以及一个WeightedRandomSampler采样器:
dataset = CustomDataset(train=True, transform=transform) sampler = torch.utils.data.WeightedRandomSampler(dataset.weights, num_samples=len(dataset), replacement=True)
最后,使用torch.utils.data.DataLoader将sampler应用于自定义数据集,实现加权采样:
batch_size = 32
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler)
for data, target, weight in dataloader:
# 进行训练等操作
pass
在上述代码中,自定义数据集考虑了样本的权重,计算每个样本的权重时,可以根据实际需求进行灵活的设计。WeightedRandomSampler函数会根据每个样本的权重分布进行采样,从而实现对样本的加权采样。
