欢迎访问宙启技术站
智能推送

使用torch.utils.checkpoint减少模型内存占用

发布时间:2024-01-05 01:16:21

torch.utils.checkpoint是一个用于减少模型内存占用的工具函数,可以在模型训练或推理过程中对计算进行checkpoint,从而减少显存的使用。通过将计算过程分成多个小块,每个小块只保留当前所需的中间结果,其他中间结果可以被释放掉,从而节省内存。

下面是一个使用torch.utils.checkpoint的示例,使用一个简单的神经网络对MNIST数据集进行训练:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# 定义简单的神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        # 添加checkpoint
        x = torch.utils.checkpoint.checkpoint(self.conv2, x)
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载MNIST数据集
train_dataset = MNIST('./data', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 创建模型和优化器
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(5):
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = nn.functional.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/5], Loss: {loss.item()}')

在上面的代码中,我们定义了一个简单的神经网络模型Net,并在模型中的某些关键层之间加入了checkpoint。在forward函数中,我们首先使用nn.functional.relu和nn.functional.max_pool2d进行了一系列的操作,然后使用torch.utils.checkpoint.checkpoint函数对conv2层进行了checkpoint,即在该层进行计算时保存中间结果。

在训练过程中,我们使用train_loader进行数据加载和循环,将输入数据和标签数据移动到合适的设备上。然后,我们使用optimizer.zero_grad()将模型参数的梯度置为0,调用model(inputs)得到模型的输出结果,计算loss,并调用loss.backward()进行反向传播和梯度更新。

在每个epoch的末尾,我们打印出当前epoch的loss值。

这样,通过checkpoint,我们可以使用较少的显存来训练模型,特别是当模型较大或显存较小的情况下,checkpoint可以帮助我们更高效地利用有限的显存资源。