torch.utils.data.sampler中的RandomSampler与SequentialSampler的区别
torch.utils.data.sampler中的RandomSampler和SequentialSampler是用于数据集采样的两个常用类。
RandomSampler是随机采样器,它可以在数据集中随机选取样本。使用RandomSampler时,每次迭代都会随机选择一个样本,因此可以用于在训练过程中对数据进行随机打乱。例如,假设有一个包含100个样本的数据集,我们可以使用RandomSampler对其进行随机采样,并将其用于训练一个分类模型。
下面是一个使用RandomSampler的例子:
import torch
from torch.utils.data import DataLoader, RandomSampler, Dataset
# 创建一个示例数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 创建示例数据集
data = list(range(100))
dataset = MyDataset(data)
# 使用RandomSampler对数据集进行随机采样
random_sampler = RandomSampler(dataset)
# 创建数据加载器
dataloader = DataLoader(dataset, sampler=random_sampler, batch_size=10)
# 遍历数据加载器
for batch in dataloader:
print(batch)
在上面的例子中,我们首先定义了一个自定义数据集类MyDataset,然后创建了一个包含数字0到99的数据集。我们使用RandomSampler对数据集进行随机采样,并创建了一个数据加载器。在遍历数据加载器时,每个batch中的样本都是随机选择的。
SequentialSampler是顺序采样器,它按照数据集的顺序依次选取样本。使用SequentialSampler时,每次迭代都会从数据集的 个样本开始选择,直到遍历完所有样本。SequentialSampler可以用于在测试阶段对数据进行顺序采样,以便按照固定顺序对模型进行评估。
下面是一个使用SequentialSampler的例子:
import torch
from torch.utils.data import DataLoader, SequentialSampler, Dataset
# 创建一个示例数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 创建示例数据集
data = list(range(100))
dataset = MyDataset(data)
# 使用SequentialSampler对数据集进行顺序采样
sequential_sampler = SequentialSampler(dataset)
# 创建数据加载器
dataloader = DataLoader(dataset, sampler=sequential_sampler, batch_size=10)
# 遍历数据加载器
for batch in dataloader:
print(batch)
在上面的例子中,我们使用SequentialSampler对数据集进行顺序采样,并创建了一个数据加载器。在遍历数据加载器时,每个batch中的样本都是按照数据集的顺序依次选择的。
综上所述,RandomSampler和SequentialSampler是两个常用的数据采样器,RandomSampler用于随机采样,SequentialSampler用于顺序采样。这两个采样器在数据集的遍历方式上有所不同,具体使用哪种采样器取决于任务需求。
