pytorch 转移参数到gpu

问题描述:有时候需要将绑定到类属性的参数转移到gpu,该参数不随训练而变化,此时使用model.cuda()是不会将这部分参数转移到gpu的,例如下例中Test类中的参数p,将Test类转移到gpu上,但p仍在cpu上

解决方法:使用register_buffer函数,注册这部分参数,例如下例中的w

import torch
class Test(torch.nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.p = torch.zeros((5, 10))
        self.register_buffer('w', torch.zeros((5, 10)))
t = Test()
t.to('cuda')
print('device of p: ',t.p.device)
print('device of w: ', t.w.device)

输出结果
在这里插入图片描述

你可能感兴趣的:(pytorch,深度学习,人工智能)