使用torch.utils.checkpoint减少模型内存占用
发布时间:2024-01-05 01:16:21
torch.utils.checkpoint是一个用于减少模型内存占用的工具函数,可以在模型训练或推理过程中对计算进行checkpoint,从而减少显存的使用。通过将计算过程分成多个小块,每个小块只保留当前所需的中间结果,其他中间结果可以被释放掉,从而节省内存。
下面是一个使用torch.utils.checkpoint的示例,使用一个简单的神经网络对MNIST数据集进行训练:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
# 定义简单的神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2)
# 添加checkpoint
x = torch.utils.checkpoint.checkpoint(self.conv2, x)
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 64 * 7 * 7)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载MNIST数据集
train_dataset = MNIST('./data', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 创建模型和优化器
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(5):
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.functional.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/5], Loss: {loss.item()}')
在上面的代码中,我们定义了一个简单的神经网络模型Net,并在模型中的某些关键层之间加入了checkpoint。在forward函数中,我们首先使用nn.functional.relu和nn.functional.max_pool2d进行了一系列的操作,然后使用torch.utils.checkpoint.checkpoint函数对conv2层进行了checkpoint,即在该层进行计算时保存中间结果。
在训练过程中,我们使用train_loader进行数据加载和循环,将输入数据和标签数据移动到合适的设备上。然后,我们使用optimizer.zero_grad()将模型参数的梯度置为0,调用model(inputs)得到模型的输出结果,计算loss,并调用loss.backward()进行反向传播和梯度更新。
在每个epoch的末尾,我们打印出当前epoch的loss值。
这样,通过checkpoint,我们可以使用较少的显存来训练模型,特别是当模型较大或显存较小的情况下,checkpoint可以帮助我们更高效地利用有限的显存资源。
