streamlit+ndraw进行可视化训练深度学习模型

简介

如果你喜欢web可视化的方式训练深度学习模型,那么streamlit是一个不可错过的选择!

优点:

  1. 提供丰富的web组件支持
  2. 嵌入python中,简单易用
  3. 轻松构建一个web页面,按钮控制训练过程

本文使用streamlit进行web可视化渲染,并使用ndraw进行模型可视化,做到了:

  1. 训练过程可视化
  2. 模型输入输出shape一目了然

构建环境

首先安装必要的依赖,tensorflow、streamlit和ndraw为必须依赖,其他依赖根据自己的情况进行安装

pip install streamlit
pip install tensorflow
pip install ndraw

其他的环境自行安装,不过多赘述

然后引入模块:

import ndraw
import streamlit as st
import tensorflow as tf
import streamlit.components.v1 as components

编写代码

以mnist数据集为例

1.获取数据

书写数据加载方法,如果你的数据集没有改动的话,可以使用@st.cache装饰器,其作用是缓存数据,不用每次训练都重新加载数据

@st.cache(allow_output_mutation=True)
def get_data(is_onehot = False):
    # 根据自己的训练数据进行加载
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train/255.0
    x_test = x_test/255.0
    if is_onehot:
        y_train = tf.one_hot(y_train,10)
        y_test = tf.one_hot(y_test,10)
    return (x_train, y_train), (x_test, y_test)

2.构建模型

简单构建一个模型:如果是较为复杂模型,可以使用ndraw进行维度的查看

def build_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

3.构建逻辑

使用streamlit构建模型的逻辑:

  1. 首先设置一个web页面的标题
  2. 在左侧设置一个导航栏:开始和结束
  3. 点击开始的时候开始训练
  4. 添加一个模型扩展位置,点击的时候可以查看模型
if __name__ == '__main__':
    #设置网页标题
    st.title("训练xx模型")
    #点击开始后进行数据加载和训练
    if st.sidebar.button('开始'):
        (x_train, y_train), (x_test, y_test) = get_data(is_onehot=True)

        st.text("train size: {} {}".format(x_train.shape, y_train.shape))
        st.text("test size: {} {}".format(x_test.shape, y_test.shape))

        model = build_model()
        #点击查看模型后可以可视化模型
        with st.expander("查看模型"):
            components.html(ndraw.render(model,init_x=200, flow=ndraw.VERTICAL), height=1000, scrolling=True)
        model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(lr=0.001),metrics=["accuracy"])
        model.fit(x_train, y_train, batch_size=128, validation_data=(x_test, y_test), epochs=10, verbose=1,callbacks=[TrainCallback(x_test, y_test)])
        st.success('训练结束')

    if st.sidebar.button('停止'):
        st.stop()


4.自定义指标可视化

tf提供了丰富的自定义功能,包括模型自定义,指标自定义,loss自定义、训练过程自定义等等,此处自定义一个训练过程自定义的Callback,主要用于在训练过程中获取相关的loss和acc进行绘图

class TrainCallback(tf.keras.callbacks.Callback):
    def __init__(self, test_x, test_y):
        super(TrainCallback, self).__init__()
        self.test_x = test_x
        self.test_y = test_y

    def on_train_begin(self, logs=None):
        st.header("训练汇总")
        self.summary_line = st.area_chart()

        st.subheader("训练进度")
        self.process_text = st.text("0/{}".format(self.params['epochs']))
        self.process_bar = st.progress(0)

        st.subheader('loss曲线')
        self.loss_line = st.line_chart()

        st.subheader('accuracy曲线')
        self.acc_line = st.line_chart()

    def on_epoch_end(self, epoch, logs=None):
        self.loss_line.add_rows({'train_loss': [logs['loss']], 'val_loss': [logs['val_loss']]})
        self.acc_line.add_rows({'train_acc': [logs['accuracy']], 'val_accuracy': [logs['val_accuracy']]})
        self.process_bar.progress(epoch / self.params['epochs'])
        self.process_text.empty()
        self.process_text.text("{}/{}".format(epoch, self.params['epochs']))

    def on_batch_end(self, epoch, logs=None):
        if epoch % 10 == 0 or epoch == self.params['epochs']:
            self.summary_line.add_rows({'loss': [logs['loss']], 'accuracy': [logs['accuracy']]})

展示

streamlit+ndraw进行可视化训练深度学习模型_第1张图片
streamlit+ndraw进行可视化训练深度学习模型_第2张图片

总结

以上就是整个训练过程,不同的模型只需要更改一下加载数据和构建模型的函数即可,其他内容不变或者根据自己的需求添加

完整外码已放到gitee自取 visualneu

你可能感兴趣的:(深度学习,深度学习,python,tensorflow,模型可视化)