欢迎访问宙启技术站
智能推送

使用torch.utils.data.dataloader进行数据随机采样的示例代码

发布时间:2023-12-27 18:08:52

torch.utils.data.dataloader是PyTorch中用于加载数据的一个工具类,可以方便地将数据加载到模型中进行训练。其中的随机采样功能可以用于每个epoch从数据集中随机选择一定数量的样本进行训练。

下面是使用torch.utils.data.dataloader进行数据随机采样的示例代码:

import torch
from torch.utils.data import Dataset, DataLoader

# 自定义数据集类,继承自torch.utils.data.Dataset
class MyDataset(Dataset):
    def __init__(self):
        # 数据集初始化
        self.data = [i for i in range(1000)]
    
    def __len__(self):
        # 返回数据集的大小
        return len(self.data)
    
    def __getitem__(self, index):
        # 将数据集的第index个样本返回
        return self.data[index]

# 创建自定义数据集对象
dataset = MyDataset()

# 创建数据加载器对象,设置batch_size为32,shuffle为True进行随机采样
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 遍历每个epoch的数据
for epoch in range(10):
    # 遍历每个batch的数据
    for data in dataloader:
        # 在这里进行模型训练的操作
        # data是一个batch的数据,可以直接作为模型的输入进行训练
        pass

在以上示例代码中,首先定义了一个自定义数据集类MyDataset,继承自torch.utils.data.Dataset。其中,在__init__方法中初始化了一个包含1000个元素的数据集,__len__方法返回数据集的大小,__getitem__方法用于通过索引从数据集中获取样本。

然后通过创建DataLoader对象,将自定义数据集dataset传入其中。设置batch_size为32,表示每次从数据集中选择32个样本作为一个batch。设置shuffle=True表示每个epoch开始时都对数据进行随机重排,以实现随机采样。

最后通过嵌套的for循环进行模型的训练。外层的for循环表示遍历每个epoch的数据,内层的for循环表示遍历每个batch的数据。在内层循环中,data变量会被赋值为一个batch的数据,可以直接作为模型的输入进行训练。

使用torch.utils.data.dataloader进行数据随机采样的示例代码如上所示,可以根据实际需求进行修改和扩展。