PyTorch 神经网络模型可视化(Netron)
Netron 是一个用于可视化深度学习模型的工具,可以帮助我们更好地理解模型的结构和参数。
支持以下格式的模型存储文件:
GitHub 链接:https://github.com/lutzroeder/netron
ONNX
(1)在 PyTorch 中,可以使用 torch.onnx.export
函数将模型导出为 ONNX 格式:
import torch import netron # 定义 PyTorch 模型 class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.bn = torch.nn.BatchNorm2d(64) self.relu = torch.nn.ReLU(inplace=True) self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) self.fc = torch.nn.Linear(64 * 8 * 8, 10) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.pool(x) x = x.view(-1, 64 * 8 * 8) x = self.fc(x) return x # 创建模型实例并加载预训练权重 model = MyModel() # 设置示例输入 input = torch.randn(1, 3, 32, 32) # 将模型导出为 ONNX 格式 torch.onnx.export(model, input, './model/Test/onnx_model.onnx') # 导出后 netron.start(path) 打开
(2)再使用 Netron 的 netron.start
指令打开导出的 ONNX 模型文件:
import netron # 打开导出的 ONNX 模型文件 netron.start('./model/Test/onnx_model.onnx')
Serving './model/Test/onnx_model.onnx' at http://localhost:8080
将在浏览器中自动启动 Netron 工具,并对该模型文件进行可视化。
注意:
当模型被导出为 ONNX 格式,会在指定目录生成以 .onnx
为后缀的文件,只需将其上传至 Netron 官网 也可实现可视化:
在 Netron 中,可以查看模型的结构、参数和输入输出等信息。可以通过缩放、旋转和平移等操作来调整模型的可视化效果,以更好地理解模型的结构和参数。
torch.save
当使用 torch.save
对保存的模型进行可视化时:
# 保存模型 torch.save(model.state_dict(), './model/Test/saved_model.pt') # 可视化 netron.start('./model/Test/saved_model.pt')
如下图,这种方式并不能显示该模型的详细信息:
所以: Netron 不支持 PyTorch 通过 torch.save
方式导出的模型文件。
torch.jit.script
使用 torch.jit.script
先将模型转换为脚本,再使用 torch.jit.save
保存模型,最后进行可视化:
# TorchScript:script scripted_model = torch.jit.script(model) # 保存模型 torch.jit.save(scripted_model, './model/Test/scripted_model.pth') # 可视化 netron.start('./model/Test/scripted_model.pth')
torch.jit.trace
使用 torch.jit.trace
先将模型转换为跟踪模型执行的工具,再使用 torch.jit.save
保存模型,最后进行可视化:
# TorchScript:trace traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32)) # 保存模型 torch.jit.save(traced_model, './model/Test/traced_model.pth') # 可视化 netron.start('./model/Test/traced_model.pth')