欢迎访问宙启技术站
智能推送

PyTorch中torch.utils.data.sampler的无放回和有放回采样实现方法

发布时间:2023-12-19 05:24:24

在PyTorch中,可以使用torch.utils.data.sampler模块来进行数据采样,包括无放回采样和有放回采样。下面将分别介绍两种采样方法以及它们的实现方式,并附上相应的使用例子。

1. 无放回采样(无重复采样):

无放回采样是指每次采样时都会从数据集中移除已被采样的样本,确保每个样本只会被采样一次。torch.utils.data.sampler模块中提供了RandomSampler和SequentialSampler两个类来实现无放回采样。

- RandomSampler:随机采样器,每次采样时都会随机选择一个样本。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 假设有10个样本
dataset = MyDataset(data)
sampler = RandomSampler(dataset)  # 创建随机采样器

dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)  # 使用随机采样器创建数据加载器

for batch in dataloader:
    print(batch)

- SequentialSampler:顺序采样器,每次采样时按照顺序选择样本。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 假设有10个样本
dataset = MyDataset(data)
sampler = SequentialSampler(dataset)  # 创建顺序采样器

dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)  # 使用顺序采样器创建数据加载器

for batch in dataloader:
    print(batch)

2. 有放回采样(重复采样):

有放回采样是指每次采样时都会将样本放回数据集,从而允许一个样本多次被采样。torch.utils.data.sampler模块中的RandomSampler类可以实现有放回采样。

- RandomSampler:随机采样器,每次采样时都会随机选择一个样本,但可能会选择到已经被采样的样本。

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 假设有10个样本
dataset = MyDataset(data)
sampler = RandomSampler(dataset, replacement=True)  # 创建有放回采样器

dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)  # 使用有放回采样器创建数据加载器

for batch in dataloader:
    print(batch)

通过以上例子可以看到,无放回采样和有放回采样在创建Dataset和DataLoader时均通过sampler参数传入相应的采样器对象。无放回采样的RandomSampler和SequentialSampler分别适用于需要无重复采样的情况,而有放回采样的RandomSampler适用于需要允许重复采样的情况。根据具体需求可以选择适当的采样方法处理数据集。