欢迎访问宙启技术站
智能推送

Python中的SubsetRandomSampler()函数及其在数据预处理中的应用

发布时间:2024-01-11 23:05:58

SubsetRandomSampler()函数是PyTorch库中的一个类,用于数据集的随机采样。

在数据预处理中,常常需要将数据集划分为训练集、验证集和测试集。SubsetRandomSampler()函数通过随机抽样的方式从数据集中选择子集,并将该子集用于训练、验证或测试。

下面是SubsetRandomSampler()函数的使用示例:

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler

# 定义一个自定义的数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

# 创建一个虚拟的数据集
data = np.arange(100)

# 创建数据集对象
dataset = MyDataset(data)

# 设置训练集、验证集和测试集的样本比例
train_ratio = 0.7
valid_ratio = 0.2
test_ratio = 0.1

# 计算对应的样本数量
train_size = int(train_ratio * len(dataset))
valid_size = int(valid_ratio * len(dataset))
test_size = len(dataset) - train_size - valid_size

# 创建索引列表
indices = np.arange(len(dataset))

# 随机打乱索引列表
np.random.shuffle(indices)

# 根据样本数量划分索引列表为训练集、验证集和测试集的索引
train_indices = indices[:train_size]
valid_indices = indices[train_size:train_size+valid_size]
test_indices = indices[train_size+valid_size:]

# 创建采样器对象
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
test_sampler = SubsetRandomSampler(test_indices)

# 创建数据加载器
train_loader = DataLoader(dataset, batch_size=10, sampler=train_sampler)
valid_loader = DataLoader(dataset, batch_size=10, sampler=valid_sampler)
test_loader = DataLoader(dataset, batch_size=10, sampler=test_sampler)

# 输出训练集、验证集和测试集的batch数据
for batch in train_loader:
    print("Training batch:", batch)

for batch in valid_loader:
    print("Validation batch:", batch)

for batch in test_loader:
    print("Testing batch:", batch)

在上述示例中,我们首先定义了一个自定义的数据集类MyDataset,然后创建了一个虚拟的数据集data,并基于虚拟数据集创建了数据集对象dataset

接下来,我们设置了训练集、验证集和测试集的样本比例,并计算了对应的样本数量。

然后,我们创建了一个索引列表indices,并对其进行随机打乱操作。

最后,根据样本数量划分索引列表为训练集、验证集和测试集的索引,并利用这些索引创建了对应的采样器对象train_samplervalid_samplertest_sampler

最后,我们创建了数据加载器对象train_loadervalid_loadertest_loader,并通过遍历来输出训练集、验证集和测试集的batch数据。

通过使用SubsetRandomSampler()函数,我们可以方便地实现对数据集的随机采样,并将其应用于数据预处理的过程中,例如划分训练集、验证集和测试集等操作。