TorchExplorer是一款创新的人工智能工具,专为使用非常规神经网络架构的研究人员设计。可以在本地或者wandb中生成交互式Vega自定义图表,提供网络结构的模块级可视化。在左边的面板可以模块级方式展现神经网络架构,帮助研究人员导航网络结构。在右边的图中节点表示输入/输出占位符或在转发过程中调用的特定子模块,可以深入检查模块,直方图可视化数据分布。
节点之间的边缘表示数据处理流向,并且提供对输入/输出张量,梯度规范和参数梯度的信息。最主要的是它擅长处理非标准网络架构,这样我们看代码就方便多了,以下是官网的一个演示gif
TorchExplorer需要graphviz,所以先安装graphviz
sudo apt-get install libgraphviz-dev graphviz
pip install torchexplorer
然后就可以使用了:
import torch
import torchvision
import torchexplorer
model = torchvision.models.resnet18(pretrained=False)
dummy_X = torch.randn(5, 3, 32, 32)
# Only log input/output and parameter histograms, if you don't want even these set log=[].
torchexplorer.watch(model, log_freq=1, log=['io', 'params'], backend='standalone')
# Do one forwards and backwards pass
model(dummy_X).sum().backward()
# Your model will be available at http://localhost:5000
这里需要注意的是,需要一个完整的前向和反向传播的过程,这样他才可以得到需要的信息,结果如下:
我实验了一下,这对我们看模型代码来说是一个非常好的工具,它可以让我们更深入的了解模型的架构和工作方式,推荐大家试一试,项目地址: