核心要点:
优化matplotlib大数量级绘图速率和内存管理的核心要点在于,尽可能少的新建画布。
也就是尽可能一张纸,画了擦,擦了画,画完拍个照(即保存),再擦了准备画下一张,反反复复。
问题描述:
在使用matplotlib进行绘制图片时, 如果使用循环,如第一种类型,循环内新建画布、绘制、保存、关闭画布:
import matplotlib.pyplot as plt
# for loop, type 1
# 反复新建画布,关闭画布
for k in range(150):
fig = plt.figure(figsize=(6,6))
plt.plot(...)
plt.imshow(...)
......
plt.savefig(..,dpi = 600)
plt.close()
以及第二种,比如说我们要画一张迷宫,然后这张迷宫的图是没有办法用一句简单的plt.plot(x,y)就画出来的:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6,6))
# we have to plot a maze, but maze is composed of several discontinuous lines that we can
# only plot them using a loop.
axes = plt.axes()
for k in range(1440000):
if isEastBorder(i, nx = nx):
axes.plot([x+0.5, x+0.5],[y+0.5, y-0.5],'-',color = color,linewidth = linewidth)
border[0] = 1
if isSouthBorder(i, nx = nx):
axes.plot([x-0.5, x+0.5],[y-0.5, y-0.5],'-',color = color,linewidth = linewidth)
border[1] = 1
if isWestBorder(i, nx = nx):
axes.plot([x-0.5, x-0.5],[y-0.5, y+0.5],'-',color = color,linewidth = linewidth)
border[2] = 1
if isNorthBorder(i, nx = nx):
axes.plot([x-0.5, x+0.5],[y+0.5, y+0.5],'-',color = color,linewidth = linewidth)
border[3] = 1
for k in range(4):
if border[0] != 1 and loc_to_idx(x+1, y, nx = nx) not in surr:
axes.plot([x+0.5, x+0.5],[y-0.5, y+0.5],'-',color = color,
linewidth = linewidth)
if border[1] != 1 and loc_to_idx(x, y-1, nx = nx) not in surr:
axes.plot([x-0.5, x+0.5],[y-0.5, y-0.5],'-',color = color,
linewidth = linewidth)
if border[2] != 1 and loc_to_idx(x-1, y, nx = nx) not in surr:
axes.plot([x-0.5, x-0.5],[y-0.5, y+0.5],'-',color = color,
linewidth = linewidth)
if border[3] != 1 and loc_to_idx(x, y+1, nx = nx) not in surr:
axes.plot([x-0.5, x+0.5],[y+0.5, y+0.5],'-',color = color,
linewidth = linewidth)
那么这个时候就会特别慢.而且非常容易爆内存,这和你内存大小是没有关系的。
对于第一种的一种优化的方法是,在循环后面加上
plt.close()
plt.clf()
这两句是清楚画布的操作,相当于清除内存。但实际上因为matplotlib的某些更深层次的内存回收机制的原因,清除画布只能让内存的增长从几何增长到线性增长,并不能根本上解决爆内存的问题。我认为这可能和matplotlib.figure类型的内存回收可能不是很干净有关,因此即便使用上述两句清空画布,仍然有内存残余。在较大的绘图任务前,这些残余仍然有可能使你爆内存。
因此,根本的优化途径在于,少新建画布,少出现第二种所示的循环绘图。
import matplotlib.pyplot as plt
# 在外面新建画布,1次
fig = plt.figure(figsize=(6,6))
for k in ...:
plt.plot(..)
plt.imshow(..)
# 这些是所有图片中共用的一些组分,在循环外绘制
# 有时这些公共部分是类似第二种需要循环绘制的图,那么正好放在循环外以减少绘制次数。
# 而且这样可以使得这些图再内存充裕的情况下提前绘制,速率会较循环内更快。
for k in range(150):
# 这些是每张图特有的部分,循环内绘制
a = plt.plot(...)
im = plt.imshow(...)
......
plt.savefig(..,dpi = 600)
a.remove() # 绘制、保存之后回收这些特有部分
im.remove() # 恢复到公共部分
plt.close() #所有图画完再关闭画布
这样可以使得所有图片的绘制速率至少提升1~3的数量级。而唯一的限速步则回到保存这一步
plt.savefig(..,dpi = 600)
如果没有保存,可能绘图速率可以达到1000张/s~10000张/s,如果加上保存,则速率与dpi大小相关。dpi=600时也可以做到1张/s以内。(由于科研制图需要所以选了600的dpi,对于一般绘图完全不需要,dpi=300即可做到4~10张/s)