欢迎访问宙启技术站
智能推送

使用torch.optim.lr_scheduler调整学习率的步骤

发布时间:2024-01-19 21:28:25

torch.optim.lr_scheduler是PyTorch中用于调整学习率的工具。它提供了多种学习率调整策略,包括StepLR、MultiStepLR、ExponentialLR、CosineAnnealingLR等。在本文中,我将为你介绍使用torch.optim.lr_scheduler调整学习率的步骤,并提供一个具体的例子来说明其用法。

使用torch.optim.lr_scheduler调整学习率的步骤如下:

1. 导入必要的包

   import torch
   import torch.optim as optim
   import torch.optim.lr_scheduler as lr_scheduler
   

2. 初始化优化器和学习率调度器

   optimizer = optim.SGD(model.parameters(), lr=0.1)
   scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
   

这里使用了SGD优化器并设置了初始学习率为0.1,同时使用了StepLR学习率调度器,其中step_size表示每经过多少个epoch将学习率调整为初始学习率的gamma倍。

3. 在训练循环中调用学习率调度器的step()方法

   for epoch in range(num_epochs):
       # 训练过程
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       
       # 调整学习率
       scheduler.step()
   

在训练过程中,必须在每个epoch结束后调用学习率调度器的step()方法,用于根据设定的调整策略更新学习率。

下面是使用torch.optim.lr_scheduler的一个具体例子,以说明其用法:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

# 定义一个简单的网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.fc(x)

# 初始化数据和模型
data = torch.randn(100, 10)
target = torch.randn(100, 1)
model = Net()

# 初始化优化器和学习率调度器
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    # 前向传播
    output = model(data)
    loss = nn.MSELoss()(output, target)
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 调整学习率
    scheduler.step()
    
    # 打印学习率和损失
    print('Epoch [{}/{}], Learning rate: {:.4f}, Loss: {:.4f}'.format(epoch+1, num_epochs, scheduler.get_lr()[0], loss.item()))

在上述例子中,我们定义了一个简单的网络模型,使用了MSELoss作为损失函数,优化器选择了SGD,并将学习率初始化为0.1,学习率调度器使用了StepLR策略,每经过20个epoch学习率减小为原来的0.1倍。在训练循环中,我们先进行正向传播和反向传播,然后调用optimizer.step()来更新模型参数,最后调用scheduler.step()来更新学习率。同时,我们还打印了每个epoch的学习率和损失值。

通过使用torch.optim.lr_scheduler,我们可以方便地实现对学习率的动态调整,从而提高模型的训练效果。