pytorch教程batch-normalization解决报错RuntimeError: Expected object of type Variable[torch.FloatTensor]

在学习廖星云pytorch教程batch-normalization篇时出现以下错误:

File "C:/Users/demons/Desktop/trainingtorch/batch_normalization.py", line 25, in batch_norm_1d
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean

RuntimeError: Expected object of type Variable[torch.FloatTensor] but found type Variable[torch.cuda.FloatTensor] for argument #1 'other'
解决方法有两种:
(1)注释掉utils.py文件中的cuda()相关部分,并进行相应格式调整
#注释掉train函数中的相关部分,并进行相应格式调整
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
if torch.cuda.is_available():
        net = net.cuda()
for im, label in train_data:
            if torch.cuda.is_available():
                im = Variable(im.cuda())  # (bs, 3, h, w)
                label = Variable(label.cuda())  # (bs, h, w)
            else:
if torch.cuda.is_available():
                    im = Variable(im.cuda(), volatile=True)
                    label = Variable(label.cuda(), volatile=True)
                else:

(2)对Class multi_network函数做如下修改:
x = batch_norm_1d(x.cpu(), self.gamma.cpu(), self.beta.cpu(), is_train, self.moving_mean.cpu(), self.moving_var.cpu()).cuda()

运行成功:
runfile('C:/Users/demons/Desktop/trainingtorch/batch_normalization.py', wdir='C:/Users/demons/Desktop/trainingtorch')
Reloaded modules: utils
Epoch 0. Train Loss: 0.302185, Train Acc: 0.912797, Valid Loss: 0.186934, Valid Acc: 0.947191, Time 00:00:05
Epoch 1. Train Loss: 0.169958, Train Acc: 0.951359, Valid Loss: 0.133628, Valid Acc: 0.962520, Time 00:00:05
Epoch 2. Train Loss: 0.129881, Train Acc: 0.962803, Valid Loss: 0.117917, Valid Acc: 0.965487, Time 00:00:05
Epoch 3. Train Loss: 0.106306, Train Acc: 0.969150, Valid Loss: 0.106132, Valid Acc: 0.968354, Time 00:00:05
Epoch 4. Train Loss: 0.090785, Train Acc: 0.973764, Valid Loss: 0.101401, Valid Acc: 0.971025, Time 00:00:05
Epoch 5. Train Loss: 0.081850, Train Acc: 0.975746, Valid Loss: 0.093533, Valid Acc: 0.971618, Time 00:00:05
Epoch 6. Train Loss: 0.072291, Train Acc: 0.978995, Valid Loss: 0.092226, Valid Acc: 0.972112, Time 00:00:05
Epoch 7. Train Loss: 0.065007, Train Acc: 0.980844, Valid Loss: 0.090979, Valid Acc: 0.972310, Time 00:00:05
Epoch 8. Train Loss: 0.059790, Train Acc: 0.981726, Valid Loss: 0.090877, Valid Acc: 0.973299, Time 00:00:05
Epoch 9. Train Loss: 0.054136, Train Acc: 0.984325, Valid Loss: 0.089308, Valid Acc: 0.974288, Time 00:00:05
总用时:
57.95331573486328
Variable containing:
-1.8784
 4.0507
 0.2430
 0.1976
-0.3430
-2.2162
 0.8868
-1.9118
-1.3165
 0.9459
[torch.FloatTensor of size 10]

你可能感兴趣的:(Python,Pytorch)