PlotNeuralNet 是 github 上做神经网络可视化的一个工具,利用 python 将 .py 文件中定义的网络结构转换成 .tex 文件,最后通过 TeXworks 等工具可以将其转换为 .pdf 等形式来显示网络结构。
点击 http://www.tug.org/texlive/ 进入下载链接,选择 on DVD。
选择 downloading the TeX Live ISO image and burning your own DVD。
选择 downloading from a nearby CTAN mirror。
下载得到一个 iso 映像文件,双击打开,再双击 install-tl-windows.bat 文件进入安装向导,注意修改安装路径,勾选安装 TeXworks 前端选项,然后就是漫长的等待安装过程。
github 链接:https://github.com/HarisIqbal88/PlotNeuralNet
用 pycharm 或者别的 python 编辑器打开 PlotNeuralNet 工程,进入到 pyexamples 路径下,打开 unet.py 文件直接运行。此时会在当前路径下生成一个 unet.tex 文件。
打开 TeXworks editor 软件,选择文件打开刚开路径下的 unet.tex 文件。
点击绿色箭头运行,会直接生成并打开一个 unet.pdf 文件,如果显示不完整可以点击适应页面或者将窗口放大,结果如下。
可以看出输出的网络图非常好看,每一层的 H,W,C 也都是自己可以定义的。
PlotNeuralNet 的一个弊端就是它是用户利用代码来直接生成并定义的,操作难度较高,其实也不大实用,但在制作论文插图的时候是比较好的。下来通过一个自定义的测试代码来简单介绍使用方法(类似 U-Net):
import sys
sys.path.append('../')
from pycore.tikzeng import *
# 自定义一个网络结构
arch = [
to_head('..'),
to_cor(),
to_begin(),
# 写入网络结构,其他代码不要改动
to_Conv("conv1", 64, 3, offset="(0,0,0)", to="(0,0,0)", height=128, depth=128, width=3),
# 参数列表:名字;左下角数字大小和厚度;相对于上一幅图的 x,y,z 坐标;对应放在 conv1 的 东(右边)边;高度;宽度;厚度
to_Conv("conv2", 64, 32, offset="(5,0,0)", to="(conv1-east)", height=64, depth=64, width=32),
to_Conv("conv3", 64, 32, offset="(5,0,0)", to="(conv2-east)", height=32, depth=32, width=64),
to_Conv("conv4", 64, 32, offset="(5,0,0)", to="(conv3-east)", height=16, depth=16, width=128),
to_Conv("conv5", 64, 32, offset="(5,0,0)", to="(conv4-east)", height=16, depth=16, width=128),
to_Conv("conv6", 64, 32, offset="(5,0,0)", to="(conv5-east)", height=32, depth=32, width=64),
to_Conv("conv7", 64, 32, offset="(5,0,0)", to="(conv6-east)", height=64, depth=64, width=32),
to_Conv("conv8", 64, 3, offset="(5,0,0)", to="(conv7-east)", height=128, depth=128, width=3),
# 箭头连接
to_connection("conv1", "conv2"),
to_connection("conv2", "conv3"),
to_connection("conv3", "conv4"),
to_connection("conv4", "conv5"),
to_connection("conv5", "conv6"),
to_connection("conv6", "conv7"),
to_connection("conv7", "conv8"),
# 跳跃连接
to_skip("conv2", "conv7"),
to_end()
]
def main():
namefile = str(sys.argv[0]).split('.')[0]
to_generate(arch, namefile + '.tex')
if __name__ == '__main__':
main()
运行结果: