用 Python 实现全连接层网络可视化

以下代码即可实现全连接层网络的可视化:

# 引用模块
from pylab import mpl #matplotlib使用中文

# 自编函数
def ANN_ksh(number_input,number_hidden,number_output):
    import numpy as np
    import networkx as nx
    import matplotlib.pyplot as plt
    
    mpl.rcParams['font.sans-serif']=['SimHei'] #matplotlib使用中文,SimHei为黑体
    
    # number_input为输入层节点个数,number_hidden为隐藏层各层节点个数,number_output为输出层节点个数
    ceng_hidden=len(number_hidden) #隐藏层层数
    G=nx.DiGraph()
    
    # 节点
    vertex_input_list=['v'+str(i) for i in range(1,number_input+1)] #输入层
    vertex_hidden_list=[]
    start=number_input+1
    end=number_input+number_hidden[0]+1
    vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层
    for j in range(1,ceng_hidden):
        start=end
        end=start+number_hidden[j]
        vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层
    vertex_output_list=['v'+str(i) for i in range(end,end+number_output)] #输出层
    vertex_list=[]
    vertex_list.extend(vertex_input_list)
    list(map(lambda i:vertex_list.extend(vertex_hidden_list[i]),range(ceng_hidden)))
    vertex_list.extend(vertex_output_list)
    G.add_nodes_from(vertex_list)
    
    # 连接
    edge_input_hidden_list=[]
    edge_input_hidden_list.extend([(i,j) for i in vertex_input_list for j in vertex_hidden_list[0]]) #输入层-隐藏层
    edge_list=[]
    edge_list.extend(edge_input_hidden_list)
    edge_hidden_hidden_list=[]
    if ceng_hidden>1:
        for k in range(ceng_hidden-1):
            edge_hidden_hidden_list.extend([(i,j) for i in vertex_hidden_list[k] for j in vertex_hidden_list[k+1]]) #隐藏层-隐藏层
        edge_list.extend(edge_hidden_hidden_list)
    edge_hidden_output_list=[]
    edge_hidden_output_list.extend([(i,j) for i in vertex_hidden_list[len(vertex_hidden_list)-1] for j in vertex_output_list]) #隐藏层-输出层
    edge_list.extend(edge_hidden_output_list)
    G.add_edges_from(edge_list)
    
    # 位置
    pos={}
    ceng_pos_x=np.linspace(-(ceng_hidden+2)/2,(ceng_hidden+2)/2,num=ceng_hidden+2)
    list(map(lambda i:pos.update({vertex_input_list[int(np.where(np.arange(
        -number_input/2*1+1/2,number_input/2*1+1/2,1)==i)[0])]:(ceng_pos_x[0],i)}),
        np.arange(-number_input/2*1+1/2,number_input/2*1+1/2,1))) #输入层
    list(map(lambda j:list(map(lambda i:pos.update({vertex_hidden_list[j][int(np.where(np.arange(
            -number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1)==i)[0])]:(ceng_pos_x[j+1],i)}),
            np.arange(-number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1))),range(ceng_hidden))) #隐藏层
    list(map(lambda i:pos.update({vertex_output_list[int(np.where(np.arange(
        -number_output/2*1+1/2,number_output/2*1+1/2,1)==i)[0])]:(ceng_pos_x[len(ceng_pos_x)-1],i)}),
        np.arange(-number_output/2*1+1/2,number_output/2*1+1/2,1))) #输出层
    
    fig=plt.figure(figsize=(8,5),dpi=300)
    plt.xlim(ceng_pos_x[0]-1,ceng_pos_x[len(ceng_pos_x)-1]+1)
    plt.ylim(-max(number_input,max(number_hidden),number_output)/2*1,
             max(number_input,max(number_hidden),number_output)/2*1+1/2)
    
    nx.draw(
            G,
            pos=pos,
            node_color='red',
            edge_color='black',
            with_labels=False,
            font_size=10,
            node_size=300,
           )
    fig.savefig('全连接层网络可视化.png')

函数参数说明:

  number_input 为输入层的节点个数,number_hidden 为隐藏层各层的节点个数,number_output 为输出层的节点个数。

调用函数示例:

ANN_ksh(8,[8,5,2],2)

结果:
用 Python 实现全连接层网络可视化_第1张图片

图 1 全连接层网络可视化

你可能感兴趣的:(全连接层,可视化,神经网络,深度学习)