利用SummaryWriter()函数生成TensorBoard可视化结果
发布时间:2023-12-24 23:56:04
SummaryWriter()函数是PyTorch中的一个工具类,用于生成TensorBoard可视化结果。TensorBoard是TensorFlow提供的一个功能强大的可视化工具,可以用于可视化模型的图形结构、训练过程中的曲线图、计算图以及其他统计信息。在PyTorch中,我们可以使用SummaryWriter()函数将模型的训练过程以及其他统计信息保存为TensorBoard可视化结果。
下面是一个使用SummaryWriter()函数生成TensorBoard可视化结果的示例代码:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
# 定义一个简单的神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(784, 10)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc(x)
return x
# 加载MNIST数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
])
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=100, shuffle=True)
# 初始化模型和优化器
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 创建SummaryWriter对象,指定保存路径
writer = SummaryWriter('./logs')
# 训练模型
total_step = len(train_loader)
for epoch in range(5):
for i, (images, labels) in enumerate(train_loader):
# 前向传播和计算损失
outputs = net(images)
loss = nn.functional.cross_entropy(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 输出训练信息
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, 5, i + 1, total_step, loss.item()))
# 写入训练信息到TensorBoard
writer.add_scalar('Loss/train', loss.item(), epoch * total_step + i)
# 关闭SummaryWriter对象
writer.close()
在上述代码中,我们首先定义了一个简单的神经网络模型(包含一个全连接层),然后加载MNIST数据集,并将其通过DataLoader进行批处理。接下来,我们初始化模型和优化器,并创建一个SummaryWriter对象,指定保存路径为'./logs'。在训练过程中,我们使用for循环遍历数据集,进行前向传播、损失计算、反向传播和优化。每100个批次,我们将训练损失(loss)写入SummaryWriter对象中,使用writer.add_scalar()函数实现。最后,我们关闭SummaryWriter对象。
在运行上述代码后,可以在指定的保存路径下找到生成的TensorBoard可视化结果。使用命令'python -m tensorboard.main --logdir=./logs'打开TensorBoard服务,然后在浏览器中访问本地地址'http://localhost:6006'即可查看可视化结果。在TensorBoard中,可以看到训练损失的曲线图,以及其他统计信息。
