在上次的网络学习中,觉得在推理过程中需要重新定义网络,觉得过于繁琐,便想到能不能将train中的网络定义import到test函数中,但在运行测试函数的过程中发现,导入训练函数会导致网络重新训练,于是上网查阅资料,发现了问题所在并找到了解决方法,也就是 if --name-- == ‘–main–’。
if --name-- == ‘–main–’:这个判断可以用于判断文件是作为一个脚本文件运行,还是作为一个包导入到新的文件中。如果该函数为真,则作为脚本文件运行;为假,则作为包导入。
训练函数
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
#构建网络模型
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.feature = nn.Sequential(
nn.Conv2d(3,64,3,padding=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2,2),
nn.Conv2d(64,128,3,padding=2),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2,2),
nn.Conv2d(128, 256,3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2,2),
nn.Conv2d(256, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(2048,4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096,4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096,10)
)
def forward(self,x):
x = self.feature(x)
output = self.classifier(x)
return output
#训练
def train():
model.train()
acc = 0.0
sum = 0.0
loss_sum = 0
for batch,(data,target) in enumerate(train_dataloader):
data,target = data.to(device),target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output,target)
loss.backward()
optimizer.step()
acc += torch.sum(torch.argmax(output,dim=1) == target).item()
sum += len(target)
loss_sum += loss.item()
if batch % 200 ==0:
print('\tbatch: %d, loss:%.4f' %(batch,loss.item()))
print('train acc : %.2f%%, loss : %4.f' %(100*acc/sum,loss_sum/(batch+1)))
#测试
def test():
model.eval()
acc = 0.0
sum = 0.0
loss_sum = 0
acc_max = 0.0
for batch,(data,target) in enumerate(test_dataloader):
data,target = data.to(device),target.to(device)
output = model(data)
loss = criterion(output,target)
acc += torch.sum(torch.argmax(output,dim=1) == target).item()
sum += len(target)
loss_sum += loss.item()
print('test acc: %2.f%%, loss: %.4f' % (100 * acc / sum, loss_sum/(batch + 1)))
if acc > acc_max:
acc_max = acc
torch.save(model,'model_weights.pth')
if __name__ == '__main__':
# 数据预处理部分
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 数据下载部分
train_data = torchvision.datasets.CIFAR10(root='../data', train=True, transform=transform_train,
download=True)
test_data = torchvision.datasets.CIFAR10(root='../data', train=False, transform=transform_test,
download=True)
print("训练集的长度:{}".format(len(train_data)))
print("测试集的长度:{}".format(len(test_data)))
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=256, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)
print('training on ', device)
# 设置优化器及损失函数
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
acc_max = 0.0
for epoch in range(30):
print('epoch: %d' % epoch)
train()
test()
测试函数
import torchvision
import torch
from PIL import Image
from train import Model
image = Image.open('C:/Users/PC/Desktop/p1/cifar10_CNN/cat.jpeg')
print(image)
image = image.convert('RGB')
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()])
image = transform(image)
model = Model()
model = torch.load('model_weights.pth')
image = torch.reshape(image,(1,3,32,32))
print(image.shape)
model.eval()
with torch.no_grad():
image = image.cuda()
output = model(image)
print(output.argmax(1))
Python中if name == "main"的深层含义