【MXNet Gluon】模型训练使用多块显卡加速(multi-gpu)

承接图像分类、检测、分割、生成相关项目,私信。

使用单块显卡时的代码:

			devices = mx.gpu(0)
			data = mx.nd.array(batch_data).as_in_context(devices)
			label = mx.nd.array(batch_label).as_in_context(devices)
			# 更新生成器G
			with autograd.record():
     			output = G(data)
     			errG_idt = idt_loss(output, label)
     			errG_idt.backward()
	 			metric_G.update([label, ], [output, ])
			G_trainer.step(batch_size)

使用多块显卡时的代码:

			devices = [mx.gpu(0), mx.gpu(1), mx.gpu(2)]
			gpu_Xs = gutils.split_and_load(batch_data, devices)
            gpu_ys = gutils.split_and_load(batch_label, devices)
            ls = []
            with autograd.record():
                for gpu_X, gpu_y in zip(gpu_Xs, gpu_ys):
                    output = G(gpu_X)
                    errG_idt = idt_loss(output, gpu_y)
                    ls.append(errG_idt)
                    # update metrics
                    metric_G.update([gpu_y, ], [output, ])
                for l in ls:
                    l.backward()
            G_trainer.step(batch_size)

你可能感兴趣的:(MXNet从上手到入门)