欢迎访问宙启技术站
智能推送

Python中WeightedRandomSampler()函数的详细介绍及用法示例

发布时间:2023-12-29 11:08:15

WeightedRandomSampler()是PyTorch中的一个采样器(sampler)类,用于实现根据样本权重进行随机采样。

在机器学习中,样本的权重是指样本被选中的概率,通常根据样本的重要性和分布情况来确定。当数据集中某些类别的样本数量较少时,可以使用权重采样来平衡各个类别的样本。

WeightedRandomSampler()的用法如下:

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

参数说明:

- weights: 一个列表或张量,大小为样本总数,表示每个样本的权重。

- num_samples: 要采样的样本数量。

- replacement: 是否允许重复采样,默认为True,即允许重复采样。

WeightedRandomSampler()的返回值为一个采样器对象,可以作为DataLoader的参数传入,用于实现根据样本权重进行随机采样。

下面是一个使用WeightedRandomSampler()的示例:

import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

# 定义一个简单的数据集
class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return x, y
    
    def __len__(self):
        return len(self.data)

# 创建数据集和样本权重
data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
targets = [0, 1, 0, 1]
weights = [0.1, 0.2, 0.3, 0.4]

# 创建WeightedRandomSampler对象
sampler = WeightedRandomSampler(weights, num_samples=len(data), replacement=True)

# 创建DataLoader对象
dataset = CustomDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)

# 遍历数据集
for batch_data, batch_targets in dataloader:
    print(batch_data, batch_targets)

在上面的示例中,定义了一个简单的数据集CustomDataset,包含4个样本。其中weights列表表示每个样本的权重,num_samples参数设为数据集的样本数量,replacement参数设为True。创建WeightedRandomSampler对象后,将其作为DataLoader的参数之一传入,从而实现根据样本权重进行随机采样。

需要注意的是,由于使用了WeightedRandomSampler进行采样,每次迭代得到的样本数量可能会少于batch_size的值。