关于if --name-- == ‘--main--‘判断的运用

遇到的问题

在上次的网络学习中,觉得在推理过程中需要重新定义网络,觉得过于繁琐,便想到能不能将train中的网络定义import到test函数中,但在运行测试函数的过程中发现,导入训练函数会导致网络重新训练,于是上网查阅资料,发现了问题所在并找到了解决方法,也就是 if --name-- == ‘–main–’。

解决方法

if --name-- == ‘–main–’:这个判断可以用于判断文件是作为一个脚本文件运行,还是作为一个包导入到新的文件中。如果该函数为真,则作为脚本文件运行;为假,则作为包导入。

改进后的代码

训练函数

import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader

#构建网络模型
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(3,64,3,padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(64,128,3,padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(128, 256,3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048,4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096,4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096,10)
        )

    def forward(self,x):
        x = self.feature(x)
        output = self.classifier(x)
        return output



#训练
def train():
    model.train()
    acc = 0.0
    sum = 0.0
    loss_sum = 0
    for batch,(data,target) in enumerate(train_dataloader):
        data,target = data.to(device),target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output,target)
        loss.backward()
        optimizer.step()

        acc += torch.sum(torch.argmax(output,dim=1) == target).item()
        sum += len(target)
        loss_sum += loss.item()

        if batch % 200 ==0:
            print('\tbatch: %d, loss:%.4f' %(batch,loss.item()))
    print('train acc : %.2f%%, loss : %4.f' %(100*acc/sum,loss_sum/(batch+1)))


#测试
def test():
    model.eval()
    acc = 0.0
    sum = 0.0
    loss_sum = 0
    acc_max = 0.0
    for batch,(data,target) in enumerate(test_dataloader):
        data,target = data.to(device),target.to(device)
        output = model(data)
        loss = criterion(output,target)
        acc += torch.sum(torch.argmax(output,dim=1) == target).item()
        sum += len(target)
        loss_sum += loss.item()
    print('test acc: %2.f%%, loss: %.4f' % (100 * acc / sum, loss_sum/(batch + 1)))
    if acc > acc_max:
        acc_max = acc
        torch.save(model,'model_weights.pth')

if __name__ == '__main__':
    # 数据预处理部分
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # 数据下载部分
    train_data = torchvision.datasets.CIFAR10(root='../data', train=True, transform=transform_train,
                                              download=True)
    test_data = torchvision.datasets.CIFAR10(root='../data', train=False, transform=transform_test,
                                             download=True)

    print("训练集的长度:{}".format(len(train_data)))
    print("测试集的长度:{}".format(len(test_data)))

    train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=256, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Model().to(device)
    print('training on ', device)
    # 设置优化器及损失函数
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    acc_max = 0.0

    for epoch in range(30):
        print('epoch: %d' % epoch)
        train()
        test()

测试函数

import torchvision
import torch
from PIL import Image
from train import Model

image = Image.open('C:/Users/PC/Desktop/p1/cifar10_CNN/cat.jpeg')
print(image)

image = image.convert('RGB')

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                            torchvision.transforms.ToTensor()])

image = transform(image)
model = Model()
model = torch.load('model_weights.pth')


image = torch.reshape(image,(1,3,32,32))
print(image.shape)

model.eval()
with torch.no_grad():
    image = image.cuda()
    output = model(image)

print(output.argmax(1))

参考资料

Python中if name == "main"的深层含义

你可能感兴趣的:(学习,python)