python 可视化 plotly 画3dmesh网格图
最近在工作中遇到python 打印可视化3D图。需求是需要根据根据之前用matplotlib打印出来的效果不是很好。发现了非常强大的可视化包plotly。附上官方文档链接 https://plot.ly/python/
首先安装plotly包:
pip install plotly
# -*- coding:utf8 -*-
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from collections import defaultdict
import plotly.offline as py # 设置离线画图
"""
打开原始CSV文件数据,必须保证csv 文件只有一行列名,并且检查输入查询的列名是否存在于文件中,存在则返回正确的列名用于坐标轴标签
若存在,则提取这些列的数据,返回新的dataframe,
若不存在,则报错不执行程序
返回值: new_Df, correct_list
"""
def check_file_column():
input_file = input('请输入csv文件路径:')
filename = input_file.strip()
with open(filename, 'rb') as f:
csv_data = pd.read_csv(f, encoding='gb2312')
if len(csv_data.columns) != 3:
print('文件列名不等于3列,请修改文件!')
exit()
correct_list = list(csv_data.columns.values)
# 删除空的数据行
new_Df = csv_data.dropna()
# 将导入的各个类型变成正确的类型
new_Df[new_Df.columns.values[0]] = new_Df[new_Df.columns.values[0]].apply(int)
new_Df[new_Df.columns.values[2]] = new_Df[new_Df.columns.values[2]].apply(float)
new_Df[new_Df.columns.values[1]] = new_Df[new_Df.columns.values[1]].apply(float)
# 将new_Df进行排序
new_Df.sort_values(new_Df.columns.values[0], inplace=True)
return new_Df, correct_list
"""
使用递归对input dataframe 进行处理,每次返回删除指定数据后的dataframe和需要删除的dataframe列表
"""
def get_blockdata(csv_data_copy,sorted_pd_list):
while csv_data_copy.size != 0:
min_value = csv_data_copy.iloc[:, 0].min()
# 筛选转速Speed的误差值在正负10之内,并组成新的dataframe
diff_value_1 = min_value - 10
diff_value_2 = min_value + 10
pd_block = csv_data_copy[(csv_data_copy[csv_data_copy.columns.values[0]] > diff_value_1) & (csv_data_copy[csv_data_copy.columns.values[0]] < diff_value_2)]
# 将pd_block按照Torque排序
pd_block.sort_values(pd_block.columns.values[1], inplace=True)
pd_block.reset_index(drop=True, inplace=True)
# 将每个pd_block添加到一个sorted_pd list中
sorted_pd_list.append(pd_block)
# 从csv_data中删除pd_block的数据
csv_data_copy.drop(
csv_data_copy[(csv_data_copy[csv_data_copy.columns.values[0]] > diff_value_1) & (csv_data_copy[csv_data_copy.columns.values[0]] < diff_value_2)].index,
inplace=True)
return csv_data_copy, sorted_pd_list
"""
提取CSV_data的数据,分别画网格线和3D平面图:
x1,y1,z1 用于画3D mesh 图
x2,y2,z2 用于画同一个x轴坐标上的不同y轴坐标点在三维空间上的线性图
x3,y3,z3 用于画同一个y轴坐标上的不同x轴坐标点在三维空间上的线性图
"""
def handle_data_plot_3D(csv_data, correct_list):
# x1,y1,z1 用于画3D mesh 图
x1 = np.array(list(csv_data.iloc[:, 0].values))
y1 = np.array(list(csv_data.iloc[:, 1].values))
z1 = np.array(list(csv_data.iloc[:, 2].values))
# 将每个相同X值得数据提取放入一个pandas中
csv_data_copy = csv_data
# 用于存出每个数据块pandas数据
sorted_pd_list = []
csv_data_copy, sorted_pd_list = get_blockdata(csv_data_copy, sorted_pd_list)
lines = []
line_marker = dict(color='black', width=0.8)
# 画同一个x轴坐标上的不同y轴坐标点在三维空间上的线性图
for each in sorted_pd_list:
x2 = np.array(list(each.iloc[:, 0].values))
y2 = np.array(list(each.iloc[:, 1].values))
z2 = np.array(list(each.iloc[:, 2].values))
lines.append(go.Scatter3d(x=x2, y=y2, z=z2, mode='lines', line=line_marker, showlegend=False,hoverinfo='none'))
# 用于存储每个数据块中相同index的数据字典
x_y_z_row = defaultdict(list)
for _ in sorted_pd_list:
for index, row in _.iterrows():
x_y_z_row[index].append([row[0], row[1], row[2]])
# 提取每一数据块的x, y,z坐标字典
each_index_x = defaultdict(list)
each_index_y = defaultdict(list)
each_index_z = defaultdict(list)
for key, value in x_y_z_row.items():
for each in value:
each_index_x[key].append(each[0])
each_index_y[key].append(each[1])
each_index_z[key].append(each[2])
# 遍历每个数据块的x,y,z坐标字典,每个块进行描点和画图
for key, value in each_index_x.items():
if key in each_index_y.keys() and key in each_index_z:
x3 = value
y3 = each_index_y[key]
z3 = each_index_z[key]
lines.append(go.Scatter3d(x=x3, y=y3, z=z3, mode='lines', line=line_marker, showlegend=False, hoverinfo='none'))
fig = go.Figure(data=lines)
fig.add_trace(go.Mesh3d(x=x1, y=y1, z=z1, opacity=0.80,intensity=z1,
colorbar={"title": correct_list[2], "len": 0.6, },
hoverinfo='x+y+z',
colorscale='Portland',
))
fig.update_layout(title='3D_Mesh_Grid plotly', autosize=False,
width=900, height=900,
margin=dict(l=65, r=50, b=65, t=90),
scene=dict(
xaxis_title=correct_list[0],
yaxis_title=correct_list[1],
zaxis_title=correct_list[2],
)
)
fig.show()
py.plot(fig, filename='pictures/3Dfigure.html', # 会生成一个网页文件
image='png', ) # 设置保存的文件类型,不会在本地有个png的文件,需要在生成的网页打开另存为png的文件
def main():
# check columns and check files
csv_data, correct_list = check_file_column()
# 处理数据画网格图以及3Dmesh 图
handle_data_plot_3D(csv_data, correct_list)
if __name__ == '__main__':
main()1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132