"""
Copyright 2022 The TBAALi Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
@Author: TBAALi
@Email: [email protected]
@FileName: CIFAR-10-Classfication.py
@DateTime: 5/19/22 9:48 PM
@SoftWare: PyCharm
@Copyright: Copyright (C) 2022 TBAALi
"""
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))
]
)
trainset = torchvision.datasets.CIFAR10(
root = './data',
train = True,
download = False,
transform = transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size = 4,
shuffle = True,
num_workers = 2
)
testset = torchvision.datasets.CIFAR10(
root = './data',
train = False,
download = False,
transform = transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size = 4,
shuffle = False,
num_workers = 2
)
classes = ('plane', 'car', 'brid', 'cat',
'deer', 'dog', 'frog', 'horse',
'ship', 'truck')
def imshow(img) :
"""
show image
:param img:
:return:
"""
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class CNNNet(nn.Module):
"""
image calssfication net
"""
def __init__(self):
super(CNNNet, self).__init__()
self.conv1 = nn.Conv2d(
in_channels = 3,
out_channels = 16,
kernel_size = 5,
stride = 1
)
self.pool1 = nn.MaxPool2d(
kernel_size = 2,
stride = 2
)
self.conv2 = nn.Conv2d(
in_channels = 16,
out_channels = 36,
kernel_size = 3,
stride = 1
)
self.pool2 = nn.MaxPool2d(
kernel_size = 2,
stride = 2
)
self.fc1 = nn.Linear(1296, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 36 * 6 * 6)
x = F.relu(self.fc2(F.relu(self.fc1(x))))
return x
net = CNNNet()
net = net.to(device)
print(net)
nn.Sequential(*list(net.children())[:4])
for m in net.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight)
nn.init.xavier_normal_(m.weight)
nn.init.kaiming_normal_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight)
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print("Finished Traning")
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 100000 test images: %d %%' %(
100 * correct / total
))
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]
))
def print_hi(name):
print(f'Hi, {name}')
if __name__ == '__main__':
print_hi('Python')