欢迎访问宙启技术站
智能推送

使用Python的SummaryWriter()在训练过程中生成结果摘要

发布时间:2023-12-19 06:31:42

SummaryWriter 是 PyTorch 中的一个工具类,可以用来生成训练过程中的摘要和可视化结果。它提供了多种方法来记录和保存训练过程中的标量、图像、网络结构等信息,方便用户对训练过程进行可视化和分析。下面是一个使用 SummaryWriter 的例子,以说明其使用方法和功能。

首先,我们需要导入 torch 和 tensorboardX 库,tensorboardX 是一个与 TensorFlow 兼容的库,用于将 SummaryWriter 生成的摘要保存为 TensorBoard 可以读取的格式。

import torch
import tensorboardX
from tensorboardX import SummaryWriter

然后,我们定义一个简单的深度神经网络,用于分类 MNIST 数据集。这里我们使用一个二层的神经网络来进行演示。

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

接下来,我们定义训练过程。在每个迭代中,我们将输入数据传入网络进行前向传播和反向传播,并更新网络的权重。同时,使用 SummaryWriter 记录训练过程中的标量和图像信息。

def train():
    writer = SummaryWriter()
    net = Net()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

    for epoch in range(10):
        for i, data in enumerate(trainloader):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 记录每个迭代的损失值和准确率
            writer.add_scalar('data/loss', loss.item(), epoch * len(trainloader) + i)

            if i % 100 == 0:
                # 保存输入图像和网络输出图像
                writer.add_image('Image', inputs[0], epoch * len(trainloader) + i)
                writer.add_graph(net, inputs)

    # 关闭 writer
    writer.close()

在上面的代码中,我们使用 writer.add_scalar() 记录了每个迭代的损失值,并使用 writer.add_image() 记录了训练集中的 张图像。另外,我们还使用 writer.add_graph() 记录了网络的结构,方便可视化。

最后,我们可以使用 TensorBoard 来查看记录的结果。首先,需要在命令行中切换到代码所在的目录,然后执行以下命令:

tensorboard --logdir=./runs

执行完成后,会显示一个链接,通常为 http://localhost:6006/,在浏览器中打开该链接,即可看到训练过程中记录的摘要。在 Scalars 标签页中,可以查看损失值随着迭代次数的变化情况;在 Images 标签页中,可以查看记录的图像;在 Graph 标签页中,可以查看网络的结构。

使用 SummaryWriter 可以方便地对训练过程进行可视化和分析,并通过 TensorBoard 在浏览器中查看记录的结果。这样可以更好地理解模型的训练进展,并提供反馈或改进模型的性能。