模型结构图的可视化,能直观展示模型的结构以及各个模块之间的关系。最近借助plotneuralnet python库(windows版)绘制了一个网络结构图,有一些经验和心得记录在这里。
这个库里贴心地给了许多的示例,用于理解代码的含义并表达自己的意思。
每一种模块都有固定的几个参数,多数不难理解,重点是理解其中的两个,s_filer和n_filer。
如果input是(512,512,4),s_filer对应两个512,代表了图形的height和depth(也是参数),n_filer对应4,代表了图形的width。
需要注意的是两者的尺度不一样,在s_filer中,如果512对应40,256对应32,128对应25等,而在n_filer中,512对应的width值是7左右,256对应的是6等等, 这样图形看起来就是例子中“树立的扁扁的方块”。
绘图部分有了着落之后,还要获得网络各个部分的输入和输出的尺寸,可以用netron这个网站和他的Python库来实现,站内也有介绍。
这样数据和绘图部分都有了着落,就可以实现绘图啦。
附:生成文中图的代码
import sys
sys.path.append('../')
from pycore.tikzeng import *
from pycore.blocks import *
'''
s_filer对应height和depth,之间的对应关系:512-40,256-32,128-25,64-16;
n_filer对应width,之间的对应关系:256-6,128-5.5,64-4.5,32-3.5,16-2.5,4-1.5
'''
arch = [
to_head('..'),
to_cor(),
to_begin(),
#input
to_input( 'demo.jpg',to='(-3,0,0)',name="input_b1" ),
#block-001
to_ConvConvRelu( name='ccr_b1', s_filer=512, n_filer=(16,16), offset="(0,0,0)", to="(0,0,0)", width=(2.5,2.5), height=40, depth=40,caption="Stem Block"),
to_Pool(name="pool_b1", offset="(0,0,0)", to="(ccr_b1-east)", width=1, height=30, depth=30, opacity=0.5),
*block_2ConvPool( name='b2', botton='pool_b1', top='pool_b2', s_filer=256, n_filer=32, offset="(1,-8,0)", size=(32,32,3.5), opacity=0.5 ),
*block_2ConvPool( name='b3', botton='pool_b2', top='pool_b3', s_filer=128, n_filer=64, offset="(1,-8,0)", size=(25,25,4.5), opacity=0.5 ),
*block_2ConvPool( name='b4', botton='pool_b3', top='pool_b4', s_filer=64, n_filer=128, offset="(1,-8,0)", size=(16,16,5.5), opacity=0.5 ),
#Bottleneck
#block-005
to_Conv( name='conv1_aspp', s_filer=64, n_filer=256, offset="(2,0,0)", to="(pool_b4-east)", width=6, height=16, depth=16 ),
to_connection( "pool_b4", "conv1_aspp"),
to_Conv( name='conv2_aspp', s_filer=64, n_filer=256, offset="(0,-1,0)", to="(conv1_aspp-east)", width=6, height=16, depth=16),
to_connection( "pool_b4", "conv2_aspp"),
to_Conv( name='conv3_aspp', s_filer=64, n_filer=256, offset="(0,-1,0)", to="(conv2_aspp-east)", width=6, height=16, depth=16 ),
to_connection( "pool_b4", "conv3_aspp"),
to_Conv( name='conv4_aspp', s_filer=64, n_filer=256, offset="(0,-1,0)", to="(conv3_aspp-east)", width=6, height=16, depth=16 ),
to_connection( "pool_b4","conv4_aspp"),
to_Sum(name='sum_aspp',offset="(2,3,0)",to="(conv4_aspp-east)",radius=3.5,opacity=0.6),
to_connection('conv1_aspp','sum_aspp'),
to_connection('conv2_aspp','sum_aspp'),
to_connection('conv3_aspp','sum_aspp'),
to_connection('conv4_aspp','sum_aspp'),
to_Conv( name='conv5_aspp', s_filer=64, n_filer=256, offset="(2,0,0)", to="(sum_aspp-east)", width=6, height=16, depth=16 ),
to_connection( "sum_aspp", "conv5_aspp"),
#Decoder
*block_Unconv( name="b6", botton="conv5_aspp", top='end_b6', s_filer=64, n_filer=256, offset="(2.1,0,0)", size=(16,16,6.0), opacity=0.5 ),
to_skip( of='ccr_b4', to='ccr_res_b6', pos=1.25),
*block_Unconv( name="b7", botton="end_b6", top='end_b7', s_filer=128, n_filer=128, offset="(2.1,8,0)", size=(25,25,5.5), opacity=0.5 ),
to_skip( of='ccr_b3', to='ccr_res_b7', pos=1.25),
*block_Unconv( name="b8", botton="end_b7", top='end_b8', s_filer=256, n_filer=64, offset="(2.1,8,0)", size=(32,32,4.5), opacity=0.5 ),
to_skip( of='ccr_b2', to='ccr_res_b8', pos=1.25),
*block_Unconv( name="b9", botton="end_b8", top='end_b9', s_filer=512, n_filer=32, offset="(2.1,8,0)", size=(40,40,3.5), opacity=0.5 ),
to_skip( of='ccr_b1', to='ccr_res_b9', pos=1.25),
#aspp
# to_ConvConvRelu( name='b10', s_filer=512, n_filer=(16,16), offset="(2,0,0)", to="(end_b9-east)", width=(2.5,2.5), height=40, depth=40, caption="ASSP" ),
to_Conv( name='conv1_aspp2', s_filer=512, n_filer=16, offset="(2,0,0)", to="(end_b9-east)", width=2.5, height=40, depth=40 ),
to_connection( "end_b9", "conv1_aspp2"),
to_Conv( name='conv2_aspp2', s_filer=512, n_filer=16, offset="(0,-1,0)", to="(conv1_aspp2-east)", width=2.5, height=40, depth=40),
to_connection( "end_b9", "conv2_aspp2"),
to_Conv( name='conv3_aspp2', s_filer=512, n_filer=16, offset="(0,-1,0)", to="(conv2_aspp2-east)", width=2.5, height=40, depth=40 ),
to_connection( "end_b9", "conv3_aspp2"),
to_Conv( name='conv4_aspp2', s_filer=512, n_filer=16, offset="(0,-1,0)", to="(conv3_aspp2-east)", width=2.5, height=40, depth=40 ),
to_connection( "end_b9","conv4_aspp2"),
to_Sum(name='sum_aspp2',offset="(2,3,0)",to="(conv4_aspp2-east)",radius=3.5,opacity=0.6),
to_connection('conv1_aspp2','sum_aspp2'),
to_connection('conv2_aspp2','sum_aspp2'),
to_connection('conv3_aspp2','sum_aspp2'),
to_connection('conv4_aspp2','sum_aspp2'),
to_Conv( name='conv5_aspp2', s_filer=512, n_filer=16, offset="(2,0,0)", to="(sum_aspp2-east)", width=2.5, height=40, depth=40 ),
to_connection( "sum_aspp2", "conv5_aspp2"),
to_Conv( name='b10', s_filer=512, n_filer=16, offset="(2,0,0)", to="(conv5_aspp2-east)", width=2.5, height=40, depth=40 ),
to_connection( "conv5_aspp2", "b10"),
to_Conv( name='b11', s_filer=512, n_filer=4, offset="(2,0,0)", to="(b10-east)", width=1.5, height=40, depth=40 ),
to_connection( "b10", "b11"),
to_ConvSoftMax( name="soft1", s_filer=512, offset="(0.75,0,0)", to="(b11-east)", width=1, height=40, depth=40, caption="SOFT" ),
to_connection( "b11", "soft1"),
to_input( 'mask.png',to='(64,0,0)',name="output" ),
to_end()
]
def main():
namefile = str(sys.argv[0]).split('.')[0]
to_generate(arch, namefile + '.tex' )
if __name__ == '__main__':
main()