最近在工作中遇到python 打印可视化3D图。需求是根据以下CSV文件黄色高亮的三列打印3D立体网格图,尝试过用matplotlib打印出来的效果不是很好。
发现了非常强大的可视化包plotly。但是plotly没有打印出四边形网格的函数,只有三角形网格trisurf,所以四边形网格需要自己去画。
附上plotly 官方文档链接 https://plot.ly/python/
首先安装plotly包:
pip install plotly
基本思路如下:
# -*- coding:utf8 -*-
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from collections import defaultdict
import plotly.offline as py # 设置离线画图
"""
打开原始CSV文件数据,必须保证csv 文件只有一行列名,并且检查输入查询的列名是否存在于文件中,存在则返回正确的列名用于坐标轴标签
若存在,则提取这些列的数据,返回新的dataframe,
若不存在,则报错不执行程序
返回值: new_Df, correct_list
"""
def check_file_column():
input_file = input('请输入csv文件路径:')
filename = input_file.strip()
with open(filename, 'rb') as f:
csv_data = pd.read_csv(f, encoding='gb2312')
if len(csv_data.columns) != 3:
print('文件列名不等于3列,请修改文件!')
exit()
correct_list = list(csv_data.columns.values)
# 删除空的数据行
new_Df = csv_data.dropna()
# 将导入的各个类型变成正确的类型
new_Df[new_Df.columns.values[0]] = new_Df[new_Df.columns.values[0]].apply(int)
new_Df[new_Df.columns.values[2]] = new_Df[new_Df.columns.values[2]].apply(float)
new_Df[new_Df.columns.values[1]] = new_Df[new_Df.columns.values[1]].apply(float)
# 将new_Df进行排序
new_Df.sort_values(new_Df.columns.values[0], inplace=True)
return new_Df, correct_list
"""
使用递归对input dataframe 进行处理,每次返回删除指定数据后的dataframe和需要删除的dataframe列表
"""
def get_blockdata(csv_data_copy,sorted_pd_list):
while csv_data_copy.size != 0:
min_value = csv_data_copy.iloc[:, 0].min()
# 筛选转速Speed的误差值在正负10之内,并组成新的dataframe
diff_value_1 = min_value - 10
diff_value_2 = min_value + 10
pd_block = csv_data_copy[(csv_data_copy[csv_data_copy.columns.values[0]] > diff_value_1) & (csv_data_copy[csv_data_copy.columns.values[0]] < diff_value_2)]
# 将pd_block按照Torque排序
pd_block.sort_values(pd_block.columns.values[1], inplace=True)
pd_block.reset_index(drop=True, inplace=True)
# 将每个pd_block添加到一个sorted_pd list中
sorted_pd_list.append(pd_block)
# 从csv_data中删除pd_block的数据
csv_data_copy.drop(
csv_data_copy[(csv_data_copy[csv_data_copy.columns.values[0]] > diff_value_1) & (csv_data_copy[csv_data_copy.columns.values[0]] < diff_value_2)].index,
inplace=True)
return csv_data_copy, sorted_pd_list
"""
提取CSV_data的数据,分别画网格线和3D平面图:
x1,y1,z1 用于画3D mesh 图
x2,y2,z2 用于画同一个x轴坐标上的不同y轴坐标点在三维空间上的线性图
x3,y3,z3 用于画同一个y轴坐标上的不同x轴坐标点在三维空间上的线性图
"""
def handle_data_plot_3D(csv_data, correct_list):
# x1,y1,z1 用于画3D mesh 图
x1 = np.array(list(csv_data.iloc[:, 0].values))
y1 = np.array(list(csv_data.iloc[:, 1].values))
z1 = np.array(list(csv_data.iloc[:, 2].values))
# 将每个相同X值得数据提取放入一个pandas中
csv_data_copy = csv_data
# 用于存出每个数据块pandas数据
sorted_pd_list = []
csv_data_copy, sorted_pd_list = get_blockdata(csv_data_copy, sorted_pd_list)
lines = []
line_marker = dict(color='black', width=0.8)
# 画同一个x轴坐标上的不同y轴坐标点在三维空间上的线性图
for each in sorted_pd_list:
x2 = np.array(list(each.iloc[:, 0].values))
y2 = np.array(list(each.iloc[:, 1].values))
z2 = np.array(list(each.iloc[:, 2].values))
lines.append(go.Scatter3d(x=x2, y=y2, z=z2, mode='lines', line=line_marker, showlegend=False,hoverinfo='none'))
# 用于存储每个数据块中相同index的数据字典
x_y_z_row = defaultdict(list)
for _ in sorted_pd_list:
for index, row in _.iterrows():
x_y_z_row[index].append([row[0], row[1], row[2]])
# 提取每一数据块的x, y,z坐标字典
each_index_x = defaultdict(list)
each_index_y = defaultdict(list)
each_index_z = defaultdict(list)
for key, value in x_y_z_row.items():
for each in value:
each_index_x[key].append(each[0])
each_index_y[key].append(each[1])
each_index_z[key].append(each[2])
# 遍历每个数据块的x,y,z坐标字典,每个块进行描点和画图
for key, value in each_index_x.items():
if key in each_index_y.keys() and key in each_index_z:
x3 = value
y3 = each_index_y[key]
z3 = each_index_z[key]
lines.append(go.Scatter3d(x=x3, y=y3, z=z3, mode='lines', line=line_marker, showlegend=False, hoverinfo='none'))
fig = go.Figure(data=lines)
fig.add_trace(go.Mesh3d(x=x1, y=y1, z=z1, opacity=0.80,intensity=z1,
colorbar={"title": correct_list[2], "len": 0.6, },
hoverinfo='x+y+z',
colorscale='Portland',
))
fig.update_layout(title='3D_Mesh_Grid plotly', autosize=False,
width=900, height=900,
margin=dict(l=65, r=50, b=65, t=90),
scene=dict(
xaxis_title=correct_list[0],
yaxis_title=correct_list[1],
zaxis_title=correct_list[2],
)
)
fig.show()
py.plot(fig, filename='pictures/3Dfigure.html', # 会生成一个网页文件
image='png', ) # 设置保存的文件类型,不会在本地有个png的文件,需要在生成的网页打开另存为png的文件
def main():
# check columns and check files
csv_data, correct_list = check_file_column()
# 处理数据画网格图以及3Dmesh 图
handle_data_plot_3D(csv_data, correct_list)
if __name__ == '__main__':
main()