Python中WeightedRandomSampler()函数的详细介绍及用法示例
发布时间:2023-12-29 11:08:15
WeightedRandomSampler()是PyTorch中的一个采样器(sampler)类,用于实现根据样本权重进行随机采样。
在机器学习中,样本的权重是指样本被选中的概率,通常根据样本的重要性和分布情况来确定。当数据集中某些类别的样本数量较少时,可以使用权重采样来平衡各个类别的样本。
WeightedRandomSampler()的用法如下:
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
参数说明:
- weights: 一个列表或张量,大小为样本总数,表示每个样本的权重。
- num_samples: 要采样的样本数量。
- replacement: 是否允许重复采样,默认为True,即允许重复采样。
WeightedRandomSampler()的返回值为一个采样器对象,可以作为DataLoader的参数传入,用于实现根据样本权重进行随机采样。
下面是一个使用WeightedRandomSampler()的示例:
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
# 定义一个简单的数据集
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
def __len__(self):
return len(self.data)
# 创建数据集和样本权重
data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
targets = [0, 1, 0, 1]
weights = [0.1, 0.2, 0.3, 0.4]
# 创建WeightedRandomSampler对象
sampler = WeightedRandomSampler(weights, num_samples=len(data), replacement=True)
# 创建DataLoader对象
dataset = CustomDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)
# 遍历数据集
for batch_data, batch_targets in dataloader:
print(batch_data, batch_targets)
在上面的示例中,定义了一个简单的数据集CustomDataset,包含4个样本。其中weights列表表示每个样本的权重,num_samples参数设为数据集的样本数量,replacement参数设为True。创建WeightedRandomSampler对象后,将其作为DataLoader的参数之一传入,从而实现根据样本权重进行随机采样。
需要注意的是,由于使用了WeightedRandomSampler进行采样,每次迭代得到的样本数量可能会少于batch_size的值。
