save_checkpoint()函数如何在python中用于保存训练模型参数
发布时间:2023-12-30 13:29:38
在Python中,可以使用torch.save()函数来保存训练模型的参数。torch.save()函数将模型的状态保存到磁盘上,它会自动选择 的方式来保存模型参数,包括支持GPU/CPU的保存。
下面是一个保存训练模型参数的示例:
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 3)
def forward(self, x):
x = self.fc(x)
return x
# 初始化模型和优化器
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练模型,更新参数
for epoch in range(10):
# 前向传播和计算损失
output = model(torch.randn(1, 10))
loss = output.sum()
# 反向传播和更新参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存训练模型的参数
torch.save(model.state_dict(), 'model.pth')
在上面的示例中,我们首先定义了一个简单的模型 Net,该模型包含一个全连接层。然后通过使用 optim.SGD 定义了一个优化器 optimizer。在训练循环中,我们使用随机生成的数据进行前向传播、计算损失、反向传播并更新模型的参数。最后,使用 torch.save() 函数将模型的参数保存到 model.pth 文件中。
当需要加载之前保存的模型参数时,可以使用torch.load()函数来加载模型参数。下面是一个加载模型参数的示例:
import torch
import torch.nn as nn
# 定义一个简单的模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 3)
def forward(self, x):
x = self.fc(x)
return x
# 初始化模型
model = Net()
# 加载模型的参数
model.load_state_dict(torch.load('model.pth'))
在上面的示例中,我们首先定义了一个与之前保存的模型一样的模型 Net。然后使用 torch.load() 函数加载之前保存的模型参数,并使用 model.load_state_dict() 方法将加载的参数加载到模型中。
通过保存和加载模型参数,可以轻松地保存和恢复训练后的模型状态,并在需要的时候重新使用它们。这对于在不同时间点保存和加载模型参数,进行模型迁移、调整和继续训练都非常有用。
