大模型训练时的内存泄漏与显存不足

内存泄漏位置

位置1 FaceDetection

不可重复创建FaceDetection,该位置是内存泄漏的大头

mediapipe.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)

位置2 tensorboardX

不可重复创建 tensorboardX.SummaryWriter

SummaryWriter(file_name)

位置3 torch.utils.data.DataLoader

torch.utils.data.DataLoader是torch中已知的内存泄漏点,直接压测DataLoader能直接观察到内测无法释放的现象。

解决办法是改为单线程。这在大规模数据训练时,是不可取的。数据量少,倒是无所谓。

data_loader = torch.utils.data.DataLoader({}, num_workers=1,
                                              collate_fn=funa,
                                              )

位置4 GeneralCustomKStablePipeline

对于diffuser_extension.pipeline.GeneralCustomKDiffusionPipeline,无法完全释放内存。

pipe = GeneralCustomKDiffusionPipeline

del pipe # 显存得到完全释放,而内存却有泄漏

多模型显存不足的问题解决

由于pipeline不能重复创建与释放,但显存无法容纳两个模型。
解决办法是把模型在内存与显存之间反复迁移,以解决显存不足的问题。

pipeline.to('cuda')

pipeline.to('cpu')

显存回收时的效果:
大模型训练时的内存泄漏与显存不足_第1张图片

内存泄漏的分析过程

主要用到的工具就是memory_profiler与objgraph。当然还有别的内存分析工具。

第一个工具memory_profiler

非常简单,只需一个注解,能显示每行代码执行后的内存增量变化。

from memory_profiler import profile

@profile(precision=4, stream=open("memo_profiler.log", "w"))
def your_function():
  pass

第一个工具memory_profiler,显示每行代码执行后的内存增量,在大的内存泄漏时能大致定位到代码位置,只在FaceDetection上有明确指导。
但是微量内存泄漏,看不出来哪行代码的导致的最终增量。

第二个工具objgraph

微量内存泄漏,主要用到objgraph。

import objgraph
import gc
mem_incress_file = open("", "w")
while True:
  # your code
  gc.collect()
  objgraph.show_growth(limit=10, file=mem_incress_file)
  mem_incress_file.write("===>>>finished once time.\n")
  mem_incress_file.flush()

工具objgraph,能在执行完一段代码后,展示增量。
正常的代码逻辑只会在前几次有增量,之后就完全没有增量,而内存泄漏的代码则每次循环中都有稳定的增量。
例如排除结束后,确认没有增量的日志。

dict                         111214   +111214
function                     100747   +100747
tuple                         72811    +72811
list                          26095    +26095
cell                          25263    +25263
ReferenceType                 23604    +23604
OrderedDict                   19758    +19758
builtin_function_or_method    18043    +18043
getset_descriptor             15438    +15438
type                          11861    +11861
===>>>finished once time.
TopLevelThreadTracerNoBackFrame       13       +12
ThreadTracer                          19       +12
tuple                              72822       +11
dict                              111219        +5
ReferenceType                      23606        +2
list                               26096        +1
builtin_function_or_method         18044        +1
NetCommand                             1        +1
SafeCallWrapper                        5        +1
===>>>finished once time.
tuple                              72860       +38
TopLevelThreadTracerNoBackFrame       50       +37
ThreadTracer                          56       +37
dict                              111227        +8
lock                                 111        +5
ReferenceType                      23611        +5
cell                               25268        +5
builtin_function_or_method         18048        +4
method                              2134        +2
Event                                 18        +2
===>>>finished once time.
ReferenceType    23613        +2
===>>>finished once time.
===>>>finished once time.
tuple    72861        +1
===>>>finished once time.
===>>>finished once time.
===>>>finished once time.
===>>>finished once time.
===>>>finished once time.
===>>>finished once time.
===>>>finished once time.

你可能感兴趣的:(机器学习,stable,diffusion)