使用torch.utils.serialization保存和加载训练好的PyTorch模型
发布时间:2024-01-10 08:00:19
在PyTorch中,可以使用torch.utils.serialization来保存和加载训练好的模型。torch.utils.serialization提供了两个主要函数:torch.save和torch.load,分别用于保存和加载模型。
下面是一个完整的例子来保存和加载一个训练好的模型:
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
x = self.fc(x)
return x
# 创建一个模型实例
model = Net()
# 实例化优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 训练模型
for epoch in range(10):
# 模拟输入和目标张量
input_data = torch.randn(64, 10)
target = torch.randn(64, 2)
# 前向传播
output = model(input_data)
# 计算损失
loss = criterion(output, target)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存模型
torch.save(model.state_dict(), 'model.pth')
# 加载模型
model = Net()
model.load_state_dict(torch.load('model.pth'))
在上述例子中,我们首先定义了一个简单的神经网络模型,其中包含一个全连接层。然后,我们创建了一个模型实例以及一个优化器和损失函数。接下来,我们使用随机数据训练了模型,并保存了模型的状态字典(model.state_dict())到名为'model.pth'的文件中。
最后,我们使用torch.load加载了保存的模型,并将其赋值给另一个模型实例。这样我们就成功地加载了训练好的模型。
需要注意的是,加载模型时,要确保模型结构和保存的模型状态字典的键匹配。因此,在加载模型之前, 先定义一个与原始模型结构相同的模型实例。在加载模型时,可以使用load_state_dict方法将保存的状态字典加载到新模型实例中。
通过使用torch.utils.serialization保存和加载模型,我们可以方便地在PyTorch中保存和重用训练好的模型。这对于避免每次重新训练模型以及在不同任务上迁移学习都非常有用。
