【mindspore】【训练】训练过程内存占用大

问题描述:

我目前在做pytorch reconet模型在mindspore上复现的工作,现在遇到了显存溢出的问题,而且显存占用是torch中的三倍以上,pytorch只需要7.6G显存,而mindspore 24G都溢出了

在pytorch中,在训练初始时加载一次vgg模型,在每个batch中vgg当做一个特征提取工具,也不需要参与模型梯度回传,训练步骤大体如下

model = ReCoNet().cuda()
vgg = Vgg16().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(n_epochs):
    for sample in traindata:
        optimizer.zero_grad()
        # Compute ReCoNet features and output
        reconet_input = preprocess_for_reconet(sample["frame"])
        feature_maps = model.encoder(reconet_input)
        output_frame = model.decoder(feature_maps)

        previous_reconet_input = preprocess_for_reconet(sample["previous_frame"])
        previous_feature_maps = model.encoder(previous_reconet_input)
        previous_output_frame = model.decoder(previous_feature_maps)

        # Compute VGG features
        vgg_input_frame = preprocess_for_vgg(sample["frame"])
        vgg_output_frame = preprocess_for_vgg(postprocess_reconet(output_frame))
        input_vgg_features = vgg(vgg_input_frame)
        output_vgg_features = vgg(vgg_output_frame)

        vgg_previous_input_frame = preprocess_for_vgg(sample["previous_frame"])
        vgg_previous_output_frame = preprocess_for_vgg(postprocess_reconet(previous_output_frame))
        previous_input_vgg_features = vgg(vgg_previous_input_frame)
        previous_output_vgg_features = vgg(vgg_previous_output_frame)

        loss = loss_func(...)
        loss.backward()
        optimizer.step()

而在mindspore中,由于模型的loss函数比较复杂,无法通过传入一个loss_fn的方式,因此参考了教程中的自定义loss的方式,定义了一个包含loss的模型reconet_with_loss,并在construct中返回loss,loss计算过程与上面的pytorch过程一致,另外为了能在模型中使用vgg模型我把vgg作为一个初始化参数送入模型中,通过TrainOneStepCell来完成训练,代码如下

model = RecoNet_with_Loss(reconet, vgg)
optim = nn.Adam(reconet.trainable_params(), learning_rate=0.1, weight_decay=0.0)
train_net = nn.TrainOneStepCell(model, optim)

通过parameters_dict查看发现train_net参数量很多,是pytorch的几倍,而且包含了vgg的权重以及还包含大量moment的权重,不清楚这些是否占用了过多内存
pytorch:


mindspore:

解答:

首先,需要确认是哪一部分的内存占用格外的高,一般网络图不怎么占内存,占内存的操作主要集中在数据处理。

1. 不带网络,单跑下数据处理看内存使用情况

for data in dataset:
    print("="*20)
    for item in data:
        print(item.shape)

2. 如果是数据处理的问题,建议减小并行度或者看是否有操作导致内存泄露

3. 如果不是数据问题,打桩一些怀疑的模块,看内存占用是否变小。

你可能感兴趣的:(大数据)