python 可视化 ploty 画3dmesh网格图

python 可视化 plotly 画3dmesh网格图

最近在工作中遇到python 打印可视化3D图。需求是根据以下CSV文件黄色高亮的三列打印3D立体网格图,尝试过用matplotlib打印出来的效果不是很好。
python 可视化 ploty 画3dmesh网格图_第1张图片
发现了非常强大的可视化包plotly。但是plotly没有打印出四边形网格的函数,只有三角形网格trisurf,所以四边形网格需要自己去画。
附上plotly 官方文档链接 https://plot.ly/python/
首先安装plotly包:
pip install plotly
基本思路如下:

  1. 首先画出同一个X轴上不同Y轴在Z轴上的点,分别连线,其中X轴是Speed,Y轴是Torque, Z轴是功率。对CSV进行排序,转变成dataframe,X的误差范围在正负10,即对于X相差在10以内都放在一个数据块dataframe中,用while循环获取每个小数据块,放入数据字典中。遍历数据字典,取出对应的x2,y2,z2,画点并对每个数据块中的点连线。
    2.其次画同一个y轴坐标上的不同x轴坐标点在三维空间上的线性图,利用上面的每一个数据块的相同index及其不同坐标点放入字典Key-value中。之后遍历该字典,对每个item画点(X3,y3,z3)和线。这个做法有一个缺点,即每个数据块的大小可能不相等,多出来的点并没有连线画图。
  2. 最后将最初文件中的dataframe中的第1,2,3列取出(x1,y1,z1)画Mesh3d,为网格画颜色平面图。
# -*- coding:utf8 -*-
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from collections import defaultdict
import plotly.offline as py  # 设置离线画图

"""
打开原始CSV文件数据,必须保证csv 文件只有一行列名,并且检查输入查询的列名是否存在于文件中,存在则返回正确的列名用于坐标轴标签
    若存在,则提取这些列的数据,返回新的dataframe,
    若不存在,则报错不执行程序
    返回值: new_Df, correct_list
"""
def check_file_column():
    input_file = input('请输入csv文件路径:')
    filename = input_file.strip()
    with open(filename, 'rb') as f:
        csv_data = pd.read_csv(f, encoding='gb2312')
    if len(csv_data.columns) != 3:
        print('文件列名不等于3列,请修改文件!')
        exit()
    correct_list = list(csv_data.columns.values)
    # 删除空的数据行
    new_Df = csv_data.dropna()
    # 将导入的各个类型变成正确的类型
    new_Df[new_Df.columns.values[0]] = new_Df[new_Df.columns.values[0]].apply(int)
    new_Df[new_Df.columns.values[2]] = new_Df[new_Df.columns.values[2]].apply(float)
    new_Df[new_Df.columns.values[1]] = new_Df[new_Df.columns.values[1]].apply(float)
    # 将new_Df进行排序
    new_Df.sort_values(new_Df.columns.values[0], inplace=True)
    return new_Df, correct_list


"""

使用递归对input dataframe 进行处理,每次返回删除指定数据后的dataframe和需要删除的dataframe列表

"""
def get_blockdata(csv_data_copy,sorted_pd_list):
    while csv_data_copy.size != 0:
        min_value = csv_data_copy.iloc[:, 0].min()
        # 筛选转速Speed的误差值在正负10之内,并组成新的dataframe
        diff_value_1 = min_value - 10
        diff_value_2 = min_value + 10
        pd_block = csv_data_copy[(csv_data_copy[csv_data_copy.columns.values[0]] > diff_value_1) & (csv_data_copy[csv_data_copy.columns.values[0]] < diff_value_2)]
        # 将pd_block按照Torque排序
        pd_block.sort_values(pd_block.columns.values[1], inplace=True)
        pd_block.reset_index(drop=True, inplace=True)
        # 将每个pd_block添加到一个sorted_pd list中
        sorted_pd_list.append(pd_block)
        # 从csv_data中删除pd_block的数据
        csv_data_copy.drop(
            csv_data_copy[(csv_data_copy[csv_data_copy.columns.values[0]] > diff_value_1) & (csv_data_copy[csv_data_copy.columns.values[0]] < diff_value_2)].index,
            inplace=True)
    return csv_data_copy, sorted_pd_list

"""
提取CSV_data的数据,分别画网格线和3D平面图:
x1,y1,z1 用于画3D mesh 图
x2,y2,z2 用于画同一个x轴坐标上的不同y轴坐标点在三维空间上的线性图
x3,y3,z3 用于画同一个y轴坐标上的不同x轴坐标点在三维空间上的线性图

"""
def handle_data_plot_3D(csv_data, correct_list):
    # x1,y1,z1 用于画3D mesh 图
    x1 = np.array(list(csv_data.iloc[:, 0].values))
    y1 = np.array(list(csv_data.iloc[:, 1].values))
    z1 = np.array(list(csv_data.iloc[:, 2].values))
    # 将每个相同X值得数据提取放入一个pandas中
    csv_data_copy = csv_data
    # 用于存出每个数据块pandas数据
    sorted_pd_list = []
    csv_data_copy, sorted_pd_list = get_blockdata(csv_data_copy, sorted_pd_list)
    lines = []
    line_marker = dict(color='black', width=0.8)
    # 画同一个x轴坐标上的不同y轴坐标点在三维空间上的线性图
    for each in sorted_pd_list:
        x2 = np.array(list(each.iloc[:, 0].values))
        y2 = np.array(list(each.iloc[:, 1].values))
        z2 = np.array(list(each.iloc[:, 2].values))
        lines.append(go.Scatter3d(x=x2, y=y2, z=z2, mode='lines', line=line_marker, showlegend=False,hoverinfo='none'))
    # 用于存储每个数据块中相同index的数据字典
    x_y_z_row = defaultdict(list)
    for _ in sorted_pd_list:
        for index, row in _.iterrows():
            x_y_z_row[index].append([row[0], row[1], row[2]])
    # 提取每一数据块的x, y,z坐标字典
    each_index_x = defaultdict(list)
    each_index_y = defaultdict(list)
    each_index_z = defaultdict(list)
    for key, value in x_y_z_row.items():
        for each in value:
            each_index_x[key].append(each[0])
            each_index_y[key].append(each[1])
            each_index_z[key].append(each[2])
    # 遍历每个数据块的x,y,z坐标字典,每个块进行描点和画图
    for key, value in each_index_x.items():
        if key in each_index_y.keys() and key in each_index_z:
            x3 = value
            y3 = each_index_y[key]
            z3 = each_index_z[key]
            lines.append(go.Scatter3d(x=x3, y=y3, z=z3, mode='lines', line=line_marker, showlegend=False, hoverinfo='none'))
    fig = go.Figure(data=lines)
    fig.add_trace(go.Mesh3d(x=x1, y=y1, z=z1, opacity=0.80,intensity=z1,
                            colorbar={"title": correct_list[2], "len": 0.6, },
                            hoverinfo='x+y+z',
                            colorscale='Portland',
                            ))
    fig.update_layout(title='3D_Mesh_Grid plotly', autosize=False,
                      width=900, height=900,
                      margin=dict(l=65, r=50, b=65, t=90),
                      scene=dict(
                          xaxis_title=correct_list[0],
                          yaxis_title=correct_list[1],
                          zaxis_title=correct_list[2],
                      )
                      )
    fig.show()
    py.plot(fig, filename='pictures/3Dfigure.html',  # 会生成一个网页文件
            image='png', )  # 设置保存的文件类型,不会在本地有个png的文件,需要在生成的网页打开另存为png的文件


def main():
    # check columns and check files
    csv_data, correct_list = check_file_column()

    # 处理数据画网格图以及3Dmesh 图
    handle_data_plot_3D(csv_data, correct_list)


if __name__ == '__main__':
    main()

没有画颜色填充平面的只有网格效果的如下:
python 可视化 ploty 画3dmesh网格图_第2张图片
最后整个图形效果如下:
python 可视化 ploty 画3dmesh网格图_第3张图片

你可能感兴趣的:(python,可视化,plotly,python)