导入相关的库
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import time
import torch.nn.functional as F
导入数据
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
batch_size = 64
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=val_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
170499072/? [00:05<00:00, 32627522.63it/s]
Extracting ./data/cifar-10-python.tar.gz to ./data Files already downloaded and verified
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
Using cuda device
class Residual(nn.Module):
def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
super(Residual, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
return F.relu(Y + X)
def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
if first_block:
assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
else:
blk.append(Residual(out_channels, out_channels))
return nn.Sequential(*blk)
net = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))
fc = nn.Sequential(
nn.AvgPool2d(7,7),
nn.Flatten(),
nn.Dropout(0.5),
nn.Linear(512, 10),
)
net.add_module("fc", fc)
model = net.to(device)
model
Sequential( (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (resnet_block1): Sequential( (0): Residual( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): Residual( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (resnet_block2): Sequential( (0): Residual( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2)) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): Residual( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (resnet_block3): Sequential( (0): Residual( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2)) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): Residual( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (resnet_block4): Sequential( (0): Residual( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2)) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): Residual( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (fc): Sequential( (0): AvgPool2d(kernel_size=7, stride=7, padding=0) (1): Flatten(start_dim=1, end_dim=-1) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=512, out_features=10, bias=True) ) )
from torchsummary import summary
summary(model,(3,224,224))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 112, 112] 9,472 BatchNorm2d-2 [-1, 64, 112, 112] 128 ReLU-3 [-1, 64, 112, 112] 0 MaxPool2d-4 [-1, 64, 56, 56] 0 Conv2d-5 [-1, 64, 56, 56] 36,928 BatchNorm2d-6 [-1, 64, 56, 56] 128 Conv2d-7 [-1, 64, 56, 56] 36,928 BatchNorm2d-8 [-1, 64, 56, 56] 128 Residual-9 [-1, 64, 56, 56] 0 Conv2d-10 [-1, 64, 56, 56] 36,928 BatchNorm2d-11 [-1, 64, 56, 56] 128 Conv2d-12 [-1, 64, 56, 56] 36,928 BatchNorm2d-13 [-1, 64, 56, 56] 128 Residual-14 [-1, 64, 56, 56] 0 Conv2d-15 [-1, 128, 28, 28] 73,856 BatchNorm2d-16 [-1, 128, 28, 28] 256 Conv2d-17 [-1, 128, 28, 28] 147,584 BatchNorm2d-18 [-1, 128, 28, 28] 256 Conv2d-19 [-1, 128, 28, 28] 8,320 Residual-20 [-1, 128, 28, 28] 0 Conv2d-21 [-1, 128, 28, 28] 147,584 BatchNorm2d-22 [-1, 128, 28, 28] 256 Conv2d-23 [-1, 128, 28, 28] 147,584 BatchNorm2d-24 [-1, 128, 28, 28] 256 Residual-25 [-1, 128, 28, 28] 0 Conv2d-26 [-1, 256, 14, 14] 295,168 BatchNorm2d-27 [-1, 256, 14, 14] 512 Conv2d-28 [-1, 256, 14, 14] 590,080 BatchNorm2d-29 [-1, 256, 14, 14] 512 Conv2d-30 [-1, 256, 14, 14] 33,024 Residual-31 [-1, 256, 14, 14] 0 Conv2d-32 [-1, 256, 14, 14] 590,080 BatchNorm2d-33 [-1, 256, 14, 14] 512 Conv2d-34 [-1, 256, 14, 14] 590,080 BatchNorm2d-35 [-1, 256, 14, 14] 512 Residual-36 [-1, 256, 14, 14] 0 Conv2d-37 [-1, 512, 7, 7] 1,180,160 BatchNorm2d-38 [-1, 512, 7, 7] 1,024 Conv2d-39 [-1, 512, 7, 7] 2,359,808 BatchNorm2d-40 [-1, 512, 7, 7] 1,024 Conv2d-41 [-1, 512, 7, 7] 131,584 Residual-42 [-1, 512, 7, 7] 0 Conv2d-43 [-1, 512, 7, 7] 2,359,808 BatchNorm2d-44 [-1, 512, 7, 7] 1,024 Conv2d-45 [-1, 512, 7, 7] 2,359,808 BatchNorm2d-46 [-1, 512, 7, 7] 1,024 Residual-47 [-1, 512, 7, 7] 0 AvgPool2d-48 [-1, 512, 1, 1] 0 Flatten-49 [-1, 512] 0 Dropout-50 [-1, 512] 0 Linear-51 [-1, 10] 5,130 ================================================================ Total params: 11,184,650 Trainable params: 11,184,650 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 49.97 Params size (MB): 42.67 Estimated Total Size (MB): 93.21 ----------------------------------------------------------------
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.train()
train_loss, correct = 0, 0
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loss /= num_batches
correct /= size
print(f"train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {train_loss:>8f} \n")
def val(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
val_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
val_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
val_loss /= num_batches
correct /= size
print(f"val Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")
epochs = 30
since = time.time()
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(trainloader, model, loss_fn, optimizer)
val(testloader, model, loss_fn)
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print("Done!")
Epoch 1 ------------------------------- loss: 0.924146 [ 0/50000] loss: 0.881102 [ 6400/50000] loss: 1.156437 [12800/50000] loss: 0.981805 [19200/50000] loss: 0.971733 [25600/50000] loss: 0.857609 [32000/50000] loss: 0.725273 [38400/50000] loss: 0.870839 [44800/50000] train Error: Accuracy: 66.4%, Avg loss: 0.967532 val Error: Accuracy: 79.2%, Avg loss: 0.589209 Epoch 2 ------------------------------- loss: 0.949273 [ 0/50000] loss: 0.868454 [ 6400/50000] loss: 0.839819 [12800/50000] loss: 0.780564 [19200/50000] loss: 1.012495 [25600/50000] loss: 1.088979 [32000/50000] loss: 1.041193 [38400/50000] loss: 0.716399 [44800/50000] train Error: Accuracy: 68.1%, Avg loss: 0.912827 val Error: Accuracy: 78.0%, Avg loss: 0.635003
...
Epoch 24 ------------------------------- loss: 0.361401 [ 0/50000] loss: 0.525371 [ 6400/50000] loss: 0.538393 [12800/50000] loss: 0.341801 [19200/50000] loss: 0.595791 [25600/50000] loss: 0.480633 [32000/50000] loss: 0.619632 [38400/50000] loss: 0.317345 [44800/50000] train Error: Accuracy: 82.2%, Avg loss: 0.513217 val Error: Accuracy: 89.9%, Avg loss: 0.296979 Epoch 25 ------------------------------- loss: 0.423538 [ 0/50000] loss: 0.531991 [ 6400/50000] loss: 0.515880 [12800/50000] loss: 0.398145 [19200/50000] loss: 0.481868 [25600/50000] loss: 0.375752 [32000/50000] loss: 0.483883 [38400/50000] loss: 0.502430 [44800/50000] train Error: Accuracy: 82.8%, Avg loss: 0.500697 val Error: Accuracy: 89.4%, Avg loss: 0.313115 Epoch 26 ------------------------------- loss: 0.438543 [ 0/50000] loss: 0.463468 [ 6400/50000] loss: 0.420758 [12800/50000] loss: 0.690606 [19200/50000] loss: 0.537573 [25600/50000] loss: 0.369099 [32000/50000] loss: 0.609451 [38400/50000] loss: 0.674081 [44800/50000] train Error: Accuracy: 83.1%, Avg loss: 0.492140 val Error: Accuracy: 90.3%, Avg loss: 0.293007 Epoch 27 ------------------------------- loss: 0.530172 [ 0/50000] loss: 0.578006 [ 6400/50000] loss: 0.535782 [12800/50000] loss: 0.520152 [19200/50000] loss: 0.433240 [25600/50000] loss: 0.587386 [32000/50000] loss: 0.541297 [38400/50000] loss: 0.610835 [44800/50000] train Error: Accuracy: 83.2%, Avg loss: 0.488220 val Error: Accuracy: 90.4%, Avg loss: 0.293391 Epoch 28 ------------------------------- loss: 0.409375 [ 0/50000] loss: 0.685596 [ 6400/50000] loss: 0.317676 [12800/50000] loss: 0.753025 [19200/50000] loss: 0.443549 [25600/50000] loss: 0.538023 [32000/50000] loss: 0.356508 [38400/50000] loss: 0.377472 [44800/50000] train Error: Accuracy: 83.4%, Avg loss: 0.480486 val Error: Accuracy: 90.5%, Avg loss: 0.299219 Epoch 29 ------------------------------- loss: 0.435585 [ 0/50000] loss: 0.436688 [ 6400/50000] loss: 0.638597 [12800/50000] loss: 0.381801 [19200/50000] loss: 0.310054 [25600/50000] loss: 0.513684 [32000/50000] loss: 0.280855 [38400/50000] loss: 0.444685 [44800/50000] train Error: Accuracy: 83.6%, Avg loss: 0.472668 val Error: Accuracy: 90.5%, Avg loss: 0.284233 Epoch 30 ------------------------------- loss: 0.322984 [ 0/50000] loss: 0.380073 [ 6400/50000] loss: 0.310956 [12800/50000] loss: 0.712281 [19200/50000] loss: 0.627207 [25600/50000] loss: 0.603485 [32000/50000] loss: 0.499440 [38400/50000] loss: 0.525810 [44800/50000] train Error: Accuracy: 83.6%, Avg loss: 0.471256 val Error: Accuracy: 90.7%, Avg loss: 0.281865 Training complete in 66m 15s Done!
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
epochs = 10
since = time.time()
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(trainloader, model, loss_fn, optimizer)
val(testloader, model, loss_fn)
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print("Done!")
Epoch 1 ------------------------------- loss: 0.353835 [ 0/50000] loss: 0.268893 [ 6400/50000] loss: 0.551030 [12800/50000] loss: 0.395407 [19200/50000] loss: 0.488981 [25600/50000] loss: 0.508927 [32000/50000] loss: 0.498940 [38400/50000] loss: 0.284048 [44800/50000] train Error: Accuracy: 86.6%, Avg loss: 0.385834 val Error: Accuracy: 92.2%, Avg loss: 0.230816 Epoch 2 ------------------------------- loss: 0.414547 [ 0/50000] loss: 0.389818 [ 6400/50000] loss: 0.428614 [12800/50000] loss: 0.506149 [19200/50000] loss: 0.320031 [25600/50000] loss: 0.351844 [32000/50000] loss: 0.455790 [38400/50000] loss: 0.281952 [44800/50000] train Error: Accuracy: 86.9%, Avg loss: 0.382021 val Error: Accuracy: 92.1%, Avg loss: 0.230783 Epoch 3 ------------------------------- loss: 0.465969 [ 0/50000] loss: 0.345926 [ 6400/50000] loss: 0.708646 [12800/50000] loss: 0.490343 [19200/50000] loss: 0.337607 [25600/50000] loss: 0.368772 [32000/50000] loss: 0.294270 [38400/50000] loss: 0.433146 [44800/50000] train Error: Accuracy: 86.7%, Avg loss: 0.391087 val Error: Accuracy: 92.1%, Avg loss: 0.229934 Epoch 4 ------------------------------- loss: 0.165779 [ 0/50000] loss: 0.510993 [ 6400/50000] loss: 0.518649 [12800/50000] loss: 0.348599 [19200/50000] loss: 0.333966 [25600/50000] loss: 0.274560 [32000/50000] loss: 0.206739 [38400/50000] loss: 0.521378 [44800/50000] train Error: Accuracy: 86.9%, Avg loss: 0.383319 val Error: Accuracy: 92.1%, Avg loss: 0.229616 Epoch 5 ------------------------------- loss: 0.467936 [ 0/50000] loss: 0.307089 [ 6400/50000] loss: 0.735660 [12800/50000] loss: 0.397244 [19200/50000] loss: 0.476827 [25600/50000] loss: 0.289780 [32000/50000] loss: 0.344361 [38400/50000] loss: 0.416531 [44800/50000] train Error: Accuracy: 86.9%, Avg loss: 0.378607 val Error: Accuracy: 92.1%, Avg loss: 0.230443 Epoch 6 ------------------------------- loss: 0.422807 [ 0/50000] loss: 0.309076 [ 6400/50000] loss: 0.288851 [12800/50000] loss: 0.418327 [19200/50000] loss: 0.692769 [25600/50000] loss: 0.304201 [32000/50000] loss: 0.315583 [38400/50000] loss: 0.375601 [44800/50000] train Error: Accuracy: 86.7%, Avg loss: 0.378628 val Error: Accuracy: 92.2%, Avg loss: 0.228222 Epoch 7 ------------------------------- loss: 0.261598 [ 0/50000] loss: 0.334787 [ 6400/50000] loss: 0.151369 [12800/50000] loss: 0.292047 [19200/50000] loss: 0.392653 [25600/50000] loss: 0.235816 [32000/50000] loss: 0.221566 [38400/50000] loss: 0.404798 [44800/50000] train Error: Accuracy: 86.7%, Avg loss: 0.384073 val Error: Accuracy: 92.1%, Avg loss: 0.229147 Epoch 8 ------------------------------- loss: 0.233399 [ 0/50000] loss: 0.257654 [ 6400/50000] loss: 0.381010 [12800/50000] loss: 0.351332 [19200/50000] loss: 0.597045 [25600/50000] loss: 0.304622 [32000/50000] loss: 0.529189 [38400/50000] loss: 0.404364 [44800/50000] train Error: Accuracy: 86.9%, Avg loss: 0.381050 val Error: Accuracy: 92.3%, Avg loss: 0.228166 Epoch 9 ------------------------------- loss: 0.355533 [ 0/50000] loss: 0.362407 [ 6400/50000] loss: 0.363182 [12800/50000] loss: 0.199613 [19200/50000] loss: 0.180766 [25600/50000] loss: 0.486774 [32000/50000] loss: 0.386876 [38400/50000] loss: 0.435546 [44800/50000] train Error: Accuracy: 87.0%, Avg loss: 0.377219 val Error: Accuracy: 92.4%, Avg loss: 0.228705 Epoch 10 ------------------------------- loss: 0.429351 [ 0/50000] loss: 0.181094 [ 6400/50000] loss: 0.240211 [12800/50000] loss: 0.365143 [19200/50000] loss: 0.447223 [25600/50000] loss: 0.314632 [32000/50000] loss: 0.221123 [38400/50000] loss: 0.396281 [44800/50000] train Error: Accuracy: 87.2%, Avg loss: 0.369424 val Error: Accuracy: 92.2%, Avg loss: 0.227289 Training complete in 22m 3s Done!