欢迎访问宙启技术站
智能推送

在Python中使用checkpoint()函数实现断点续训

发布时间:2023-12-23 22:47:40

在Python中,可以使用checkpoint()函数来实现断点续训。checkpoint()函数的作用是将模型的参数保存在磁盘上,以便在后续训练中恢复模型的状态。

下面是一个简单的例子,演示了如何在PyTorch中使用checkpoint()函数实现断点续训:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 定义数据
inputs = torch.randn(100, 10)
targets = torch.randn(100, 1)

# 创建模型实例和优化器
model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 定义断点文件路径
checkpoint_path = 'checkpoint.pth'

# 检查是否存在断点文件
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    device = torch.device('cuda')
    model = nn.DataParallel(model)
else:
    device = torch.device('cpu')

if os.path.exists(checkpoint_path):
    # 加载断点文件
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f'Resuming training from epoch {start_epoch} with loss {loss:.4f}')
else:
    start_epoch = 0

model = model.to(device)
criterion = nn.MSELoss()

# 训练循环
for epoch in range(start_epoch, 1000):
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 打印训练进度
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{1000}], Loss: {loss.item():.4f}')

    # 保存断点文件
    if (epoch+1) % 500 == 0:
        checkpoint = {
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item()
        }
        torch.save(checkpoint, checkpoint_path)

print('Finished training')

在这个例子中,首先定义了一个简单的线性模型 MyModel,然后生成了随机输入和目标数据。接着创建了模型实例和SGD优化器。checkpoint文件的保存路径定义为 checkpoint.pth

然后检查是否存在断点文件。如果存在,则加载断点文件中的模型和优化器状态,并从上次的epoch和loss开始训练。如果不存在,则将 start_epoch 设置为0,从头开始训练。

接下来,将模型和优化器移动到指定设备上(如果可用的话),并定义了损失函数。

进入训练循环,进行前向传播、反向传播和优化。每100个epoch打印一次训练进度。每500个epoch保存一次断点文件,其中包含当前epoch、模型状态、优化器状态和损失值。

最后,当训练完成后,输出完成的消息。

这个例子演示了如何在PyTorch中使用checkpoint()函数实现断点续训。通过保存和加载模型的参数,可以在训练过程中中断和恢复,节省时间和资源。