解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题

解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题

问题描述

学习Pytorch官方的tutorial时,在教程的第四部分(Training a Classifier)中会看到作者展示的代码:
解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题_第1张图片
解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题_第2张图片
解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题_第3张图片
etc.
于是,我就把上面展示代码的如数复制到了pycharm上,想着直接运行。
直接copy得到这样的运行结果:
解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题_第4张图片

出错了!!

解决办法

好端端的怎么会出错呢??仔细看看异常对象和python提供的描述:

RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

原来问题可以通过添加一段代码解决:

if __name__ == '__main__':

只要把它放在运行的文件的开头就行了
像这样:

if __name__=='__main__':
    import torch
    import torchvision
    import torchvision.transforms as transforms

    # function to show an image

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                             shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    import matplotlib.pyplot as plt
    import numpy as np


    # functions to show an image

    def imshow(img):
        img = img / 2 + 0.5  # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()


    # get some random training images
    print(trainloader)
    dataiter = iter(trainloader)
    images, labels = dataiter.next()

    # show images
    imshow(torchvision.utils.make_grid(images))
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
运行:

解决复制Pytorch官方tutorial代码而出现的RuntimeError的问题_第5张图片

成功召唤图片!!

你可能感兴趣的:(Deeplearning)