本文使用的环境是jupyter notebook,目的是画实战中的散点图,该例子来自于机器学习网站MachineLearning Plus上的博文:Python可视化50图
# 导入需要的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 让jupyter notebook能显示图像
%matplotlib inline
matplotlib.pyplot的开发文档
# 定义数据
d1 = np.random.randn(10)
d2 = d1**2 + 5
# 设置画布
plt.figure(figsize=(8,4))
# 画图
plt.scatter(d1,d2
,s=50 # 点的大小
,c='blue' # 点的颜色
,label = 'dom' # 图例名称)
# 显示图例
plt.legend()
# 显示图像
plt.show()
上例代码中d1为横坐标,d2为纵坐标,上图显示的点如下:
[*zip(d1,d2)]
[(-2.161729597367027, 9.67307485213261),
(0.5673234146211847, 5.32185585677744),
(-0.11933596881252635, 5.014241073452424),
(0.5532153528916103, 5.3060472266749885),
(-0.5561178534755685, 5.309267066954274),
(1.7353038997420718, 8.011279624460043),
(2.004534385021128, 9.01815810073203),
(-0.3234661826094298, 5.1046303712919165),
(0.41636007718081874, 5.173355713870017),
(0.7261378153079194, 5.5272761268201585)]
当散点图中需要显示两种及以上时,为了区分需要设置不同的颜色和图例。
注意:区分d1,d2和上述出现的值不一样
d1 = np.random.randn(10,2) # 10行,2列的ndarray,一行代表一个点
d2 = np.array([0,1,0,1,1,1,0,1,0,0]) # 0和1是用来区分种类,达到显示不同的颜色
plt.figure(figsize=(8,4))
plt.scatter(d1[:,0],d1[:,1]
,s=50
,c=d2);
d1是10X2的矩阵,d1.shape是(10,2)元组
colrs = ['red', 'blue']
labels = ['red', 'blue']
# 循环的目的是区分,循环一次就画一种类的所有点
# d2 = np.array([0,1,0,1,1,1,0,1,0,0])
# d2==0时,点的颜色是red,图例是red
# d2==1时,点的颜色是blue,图例是blue
for i in range(d1.shape[1]): # shape[1]的值是2,即便利两次
plt.scatter(d1[d2==i,0]
,d1[d2==i,1]
,s=50
,c=colrs[i]
,label = labels[i])
plt.legend()
plt.show()
机器学习网站MachineLearning Plus上的博文:Python可视化50图-scatter plot
传送门:https://www.machinelearningplus.com/plots/top-50-matplotlib-visualizations-the-> master-plots-python/
# 使用的数据集
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
midwest.head() # 显示前五条数据
PID | county | state | area | poptotal | popdensity | ... | category | dot_size | |
---|---|---|---|---|---|---|---|---|---|
0 | 561 | ADAMS | IL | 0.052 | 66090 | 1270.961540 | ... | AAR | 250.944411 |
1 | 562 | ALEXANDER | IL | 0.014 | 10626 | 759.000000 | ... | LHR | 185.781260 |
2 | 563 | BOND | IL | 0.022 | 14991 | 681.409091 | ... | AAR | 175.905385 |
3 | 564 | BOONE | IL | 0.017 | 30806 | 1812.117650 | ... | ALU | 319.823487 |
4 | 565 | BROWN | IL | 0.018 | 5836 | 324.222222 | ... | AAR | 130.442161 |
5 rows × 29 columns
# 查看所有的属性
midwest.columns
Index(['PID', 'county', 'state', 'area', 'poptotal', 'popdensity', 'popwhite',
'popblack', 'popamerindian', 'popasian', 'popother', 'percwhite',
'percblack', 'percamerindan', 'percasian', 'percother', 'popadults',
'perchsd', 'percollege', 'percprof', 'poppovertyknown',
'percpovertyknown', 'percbelowpoverty', 'percchildbelowpovert',
'percadultpoverty', 'percelderlypoverty', 'inmetro', 'category',
'dot_size'],
dtype='object')
图中需要使用的数据有:area,poptotal,category。其中area是地区为横坐标,poptotal是总人口为纵坐标。 下面是博文的源码。
# 导入数据
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
# category去重,结果就是下图中的图例
categories = np.unique(midwest['category'])
# 为每个category创建一个颜色,使用plt.cm.tab10生成
colors = [[plt.cm.tab10(i/float(len(categories)-1))] for i in range(len(categories))]
# 根据每个Category,画图
plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')
for i, category in enumerate(categories):
plt.scatter('area', 'poptotal',
data=midwest.loc[midwest.category==category, :],
s=20, c=colors[i], label=str(category))
# 图像装饰
plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
xlabel='Area', ylabel='Population')
# 设置刻度
plt.xticks(fontsize=12); plt.yticks(fontsize=12)
# 设置标题
plt.title("Scatterplot of Midwest Area vs Population", fontsize=22)
plt.legend(fontsize=12)
plt.show()