欢迎访问宙启技术站
智能推送

Pytorch网络结构可视化

发布时间:2023-05-16 12:26:58

PyTorch 是一款非常强大的深度学习框架,被广泛地应用于各种深度学习任务中。PyTorch 的优点不仅仅在于其强大的自动求导功能,还在于其方便易用的网络建模功能。在 PyTorch 中,我们可以用简单的代码即可创建复杂的神经网络,并且可以方便地进行可视化与调试。

网络结构可视化是深度学习中非常重要的一个环节,它能够帮助我们更好地理解和调试深度学习模型。在 PyTorch 中,我们可以使用第三方库来轻松地对神经网络进行可视化,常用的可视化库包括 TensorBoardX 和 Graphviz。本文将介绍如何通过这两个库来实现 PyTorch 网络结构可视化。

TensorBoardX 可视化

TensorBoardX 是 TensorFlow 中的一个可视化工具,它可以将 PyTorch 模型的结构、训练过程等信息可视化在 TensorBoard 中。为了方便使用 TensorBoardX,我们需要先安装它,可以通过 pip 指令进行安装:

安装好 TensorBoardX 后,我们可以按照如下流程来可视化我们的模型:

1. 定义模型结构

首先,我们需要定义我们的模型结构。假设我们有一个简单的网络结构,包含两个卷积层和两个全连接层,可以通过如下代码来定义:

import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(in_features=32 * 32 * 32, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, kernel_size=2)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, kernel_size=2)
        x = x.view(-1, 32 * 32 * 32)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

这个网络结构包含了两个卷积层和两个全连接层,其中激活函数都是 relu 函数。在 forward 函数中,先后应用了两个卷积层和池化层,并将最后输出结果进行了 reshape 处理后接入两个全连接层。

2. 定义可视化

接下来,我们需要定义可视化。TensorBoardX 提供了两种可视化方式,一种是使用 SummaryWriter 对象进行可视化,另一种是将模型导入 TensorBoard 中。这里我们选择 种方式。我们可以先在代码开头导入 SummaryWriter:

from tensorboardX import SummaryWriter

然后,在训练过程之前,定义一个 SummaryWriter 对象:

writer = SummaryWriter('./logs')

3. 可视化模型结构

我们可以使用 add_graph 方法将模型结构可视化到 TensorBoard 中:

model = MyNet()
writer.add_graph(model, (torch.rand(1, 3, 32, 32), ))

其中, 个参数是模型,第二个参数是模型的输入,需要放到一个元组中,表示输入张量的形状。

4. 可视化参数分布

我们可以使用 add_histogram 方法将模型参数的分布可视化到 TensorBoard 中:

for name, param in model.named_parameters():
    writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step=1)

上述代码中,name 参数表示参数的名称,param 参数表示参数本身,global_step 参数表示当前迭代的次数。需要注意的是,我们需要将参数的数据从 GPU 上的张量转移到 CPU 上的 numpy 数组。

5. 可视化损失值和准确率

我们可以使用 add_scalar 方法将每次迭代的损失值和准确率可视化到 TensorBoard 中:

writer.add_scalar('Loss', loss.item(), global_step=step)
writer.add_scalar('Accuracy', accuracy, global_step=step)

其中,Loss 表示损失值,accuracy 表示准确率,global_step 表示当前迭代的次数。

6. 启动 TensorBoard

所有的代码写完之后,我们需要在命令行中输入以下命令来启动 TensorBoard:

然后在浏览器里输入以下地址查看可视化结果:

Graphviz 可视化

Graphviz 是一款流程图可视化工具,可以将我们的模型结构以图像的形式展现出来。为了方便使用 Graphviz,我们需要先安装它,可以通过以下命令来安装:

安装好 Graphviz 之后,我们可以按照如下流程来可视化我们的模型:

1. 定义模型结构

我们可以使用上一节的模型结构来演示,这里就不再赘述。

2. 安装 pydot 和 graphviz

Graphviz 可视化需要使用 pydot 库和 graphviz 软件,我们需要用 pip 安装 pydot 库,和 linux 下用 apt-get 安装 graphviz 软件。

3. 可视化模型结构

在 PyTorch 中,我们可以使用 make_dot 方法将模型结构可视化为 Graphviz 图像:

首先在 beginning 导入:

from torch.autograd import Variable
from graphviz import Digraph
import torch.nn as nn
import torch

然后定义方法 make_dot(这个不需要详细解释,可以不用太关注具体实现,主要是使用前面章节的 make_dot_from_trace 函数)

`python

def make_dot(var, params=None):

""" Produces Graphviz representation of PyTorch autograd graph

Parameters

----------

var : torch.Tensor or torch.autograd.Variable

params : dict

Dict of (name, Variable) to add names to node that require gradident

"""

if params is not None:

#assert False

param_map = {id(v): k for k, v in params.items()}

node_attr = dict(style='filled',

shape='box',

align='left',

fontsize='12',

ranksep='0.1',

height='0.2')

dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"), format='png')

seen = set()

def size_to_str(size):

return '('+(', ').join(map(str, size))+')'

def add_nodes(var):

if var not in seen:

if torch.is_tensor(var):

dot.node(str(id(var)), size_to_str(var.size()), fillcolor='lightblue')

elif hasattr(var, 'variable'):

u = var.variable

name = param_map[id(u)] if params is not None else ''

node_name = '{}

{}'.format(name, size_to_str(u.size()))

dot.node(str(id(var)), node_name, fillcolor='lightblue')

elif var._is_view():

dot.node(str(id(var)), str(type(var).__name__), fillcolor='orange')

else:

dot.node(str(id(var)), str(type(var).__name__))

seen.add(var)

if hasattr(var, 'next_functions'):

for u in var.next_functions:

if u[0] is not None:

dot.edge(str(id(u[0])), str(id(var)))

add_nodes(u[0])

if hasattr(var, 'saved_tensors'):

for t in var.saved_tensors: