1.加载数据集,并对数据集进行增强,类型转换
官网cifar10数据集
附链接:https://www.cs.toronto.edu/~kriz/cifar.html
读取数据过程中,可以改变batch_size和num_workers来加快训练速度
transform=transforms.Compose([
#图像增强
transforms.Resize(120),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(96),
transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
#转变为tensor 正则化
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) #正则化
])
trainset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=True,
download=True,
transform=transform
)
trainloader=data.DataLoader(
trainset,
batch_size=8,
shuffle=True, #乱序
num_workers=4,
)
testset=tv.datasets.CIFAR10(
root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',
train=False,
download=True,
transform=transform
)
testloader=data.DataLoader(
testset,
batch_size=2,
shuffle=False,
num_workers=2
)
net网络:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5)
self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
self.max=nn.MaxPool2d(2,2)
self.q1=nn.Linear(16*441,120)
self.q2=nn.Linear(120,84)
self.q3=nn.Linear(84,10)
self.relu=nn.ReLU()
def forward(self,x):
x1=self.max(F.relu(self.conv1(x)))
x2=F.max_pool2d(self.relu(self.conv2(x1)),2)
x3=x2.view(x2.size()[0],-1)
x4=F.relu(self.q1(x3))
x5=F.relu(self.q2(x4))
x6=self.q3(x5)
return x6
训练模型
net=Net()
#损失函数
loss=nn.CrossEntropyLoss()
opt=optim.SGD(net.parameters(),lr=0.001)
for epoch in range(5):
running_loss=0.0
for i,data in enumerate(trainloader,0):
inputs,labels=data
inputs=inputs.cuda()
labels=labels.cuda()
inputs,labels=Variable(inputs),Variable(labels)
opt.zero_grad()
net.to(torch.device('cuda:0'))
h=net(inputs)
cost=loss(h,labels)
cost.backward()
opt.step()
running_loss+=cost.item()
if i%2000==1999:
print('[%d,%5d] loss:%.3f' %(epoch+1,i+1,running_loss/2000))
running_loss=0.0
torch.save(net.state_dict(),r'net.pth')
correct=0
total=0
for data in testloader:
images,labels=data
optputs=net(Variable(images.cuda()))
_,predicted=torch.max(optputs.cpu(),1)
total+=labels.size(0)
correct+=(predicted==labels).sum()
print("准确率: %d %%" %(100*correct/total))