1.代码
import time
import torch
from torch import nn,optim
import torch.nn.functional as F
import torchvision
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def conv_block(in_channels,out_channels):
blk = nn.Sequential(nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1))
return blk
class DenseBlock(nn.Module):
def __init__(self,num_convs,in_channels,out_channels):
super(DenseBlock,self).__init__()
net = []
for i in range(num_convs):
in_c = in_channels + i*out_channels
net.append(conv_block(in_c,out_channels))
self.net = nn.ModuleList(net)
self.out_channels = in_channels + num_convs * out_channels
def forward(self,X):
for blk in self.net:
Y = blk(X)
X = torch.cat((X,Y),dim=1)
return X
""" blk = DenseBlock(2,3,10)
X = torch.rand(4,3,8,8)
Y = blk(X) """
def transition_block(in_channels,out_channels):
blk = nn.Sequential(nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels,out_channels,kernel_size=1),
nn.AvgPool2d(kernel_size=2,stride=2))
return blk
""" blk = transition_block(23,10)
print(blk(Y).shape) """
net = nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
num_channels,growth_rate = 64,32
num_convs_in_dense_blocks = [4,4,4,4]
for i,num_convs in enumerate(num_convs_in_dense_blocks):
DB = DenseBlock(num_convs,num_channels,growth_rate)
net.add_module("DenseBlock_%d" %i,DB)
num_channels = DB.out_channels
if i != len(num_convs_in_dense_blocks) - 1:
net.add_module("transition_block_%d" %i,
transition_block(num_channels,num_channels//2))
num_channels = num_channels // 2
class GlobalAvgPool2d(nn.Module):
def __init__(self):
super(GlobalAvgPool2d,self).__init__()
def forward(self,x):
return F.avg_pool2d(x,kernel_size=x.size()[2:])
class FlattenLayer(nn.Module):
def __init__(self):
super(FlattenLayer,self).__init__()
def forward(self,x):
return x.view(x.shape[0],-1)
net.add_module("BN",nn.BatchNorm2d(num_channels))
net.add_module("relu",nn.ReLU())
net.add_module("global_avg_pool",GlobalAvgPool2d())
net.add_module("fc",nn.Sequential(FlattenLayer(),
nn.Linear(num_channels,10)))
X = torch.rand((1,1,96,96))
for name,layer in net.named_children():
X = layer(X)
print(name,'output shape:\t',X.shape)
def evaluate_accuracy(data_iter,net,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
acc_sum,n = 0.0,0
with torch.no_grad():
for X,y in data_iter:
if isinstance(net,torch.nn.Module):
net.eval()
acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
net.train()
else:
if('is_training' in net.__code__.co_varnames):
acc_sum += (net(X,is_training=False).argmax(dim=1) == y).float().sum().item()
else:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum/n
def load_data_fashion_mnist(batch_size,resize=None,root='~/Datasets/FashionMNIST'):
trans = []
if resize:
trans.append(torchvision.transforms.Resize(size=resize))
trans.append(torchvision.transforms.ToTensor())
transform = torchvision.transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root=root,train=True,download=True,transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root=root,train=False,download=True,transform=transform)
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=4)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=4)
return train_iter,test_iter
def train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs):
net = net.to(device)
print("training on ",device)
loss = torch.nn.CrossEntropyLoss()
batch_count = 0
for epoch in range(num_epochs):
train_l_sum,train_acc_sum,n,start = 0.0,0.0,0,time.time()
for X,y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
l = loss(y_hat,y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = evaluate_accuracy(test_iter,net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' %(epoch+1,train_l_sum/batch_count,train_acc_sum/n,test_acc,time.time()-start))
batch_size = 256
train_iter,test_iter = load_data_fashion_mnist(batch_size,resize=96)
lr,num_epochs = 0.001,5
optimizer = torch.optim.Adam(net.parameters(),lr=lr)
train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)