使用tensorboard_logger可视化模型的训练曲线
发布时间:2024-01-14 07:23:05
Tensorboard是TensorFlow官方提供的一个强大的可视化工具,可以帮助开发者更好地理解和调试模型。在Tensorboard中,tensorboard_logger是一个基于Python的轻量级库,可以将训练模型的曲线数据以实时的方式显示在Tensorboard中。
首先,我们需要安装tensorboard_logger库,可以使用以下命令进行安装:
pip install tensorboard_logger
下面是一个使用tensorboard_logger可视化训练模型的例子。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tensorboard_logger import configure, log_value
# 定义网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 100)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# 加载MNIST数据集
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST('data/', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data/', train=False, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 配置tensorboard_logger,设置日志保存路径
configure("logs")
# 定义训练函数
def train(model, dataloader, criterion, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(dataloader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 记录训练过程中的损失值
log_value('loss', loss.item(), step=epoch * len(dataloader) + batch_idx)
# 创建网络模型、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练模型,并使用tensorboard_logger记录训练过程的loss值
for epoch in range(10):
train(model, train_dataloader, criterion, optimizer, epoch)
# 在Tensorboard中查看训练过程中的损失值曲线
在使用tensorboard_logger之前,首先需要使用configure函数配置Tensorboard日志保存路径。然后,在每个训练步骤中,使用log_value函数将损失值记录到Tensorboard中。在上面的例子中,我们记录了每个batch的损失值,并以step为x轴,loss为y轴绘制曲线。
最后,可以使用以下命令启动Tensorboard来查看训练过程中的损失值曲线:
tensorboard --logdir=logs
通过浏览器访问http://localhost:6006,即可在Tensorboard中查看训练过程中的损失值曲线。
除了损失值曲线,tensorboard_logger还可以记录和可视化其他的训练指标,比如准确率、学习率等。通过使用log_value函数记录这些指标,可以更直观地了解模型的训练过程。
