Python中的SubsetRandomSampler()函数及其在数据处理中的作用
发布时间:2024-01-11 23:00:19
SubsetRandomSampler()函数是PyTorch库中的一个采样器函数,用于在数据处理中生成随机无重复的子集采样器。
在机器学习和深度学习中,数据分为训练集和验证集,SubsetRandomSampler()函数用于生成训练集和验证集中的子集,以用于模型训练和验证。
SubsetRandomSampler(indices)函数接受一个indices参数,该参数是一个列表或NumPy数组,包含了需要创建子集的索引。函数会根据这些索引来生成相应的子集。
下面是一个使用SubsetRandomSampler()函数的例子:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# 自定义数据集
class CustomDataset(Dataset):
def __init__(self):
self.data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 创建数据集对象
dataset = CustomDataset()
# 创建索引列表
indices = list(range(len(dataset)))
# 设置训练集和验证集的分割比例
split = int(len(dataset) * 0.8)
# 随机打乱索引列表
np.random.shuffle(indices)
# 根据分割比例生成训练集和验证集的索引
train_indices, val_indices = indices[:split], indices[split:]
# 创建训练集和验证集的采样器
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
# 创建数据加载器
train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler)
# 遍历训练集
for batch in train_loader:
print(batch)
# 遍历验证集
for batch in val_loader:
print(batch)
在上面的例子中,我创建了一个自定义数据集CustomDataset,定义了数据和数据长度。然后,我创建了一个包含全部索引的列表indices,并根据分割比例split将索引随机打乱。
接着,我使用SubsetRandomSampler()函数根据训练集和验证集的索引创建了两个采样器train_sampler和val_sampler。
最后,我使用DataLoader创建训练集和验证集的数据加载器train_loader和val_loader,并遍历它们以查看生成的子集。
