实验步骤:
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
resnet18 = models.resnet18(pretrained = True)
# 得到对象 PIL.Image.Image image mode=RGB size=32x32 at 0x7FEECE0EEC50
# 如果需要对图像进行reshape或者归一化等操作,可以使用transforms.lambda(lambda x:---)进行定义
train_dataset = torchvision.datasets.CIFAR10('./data' ,train = True ,download=True, transform = transforms.Compose([transforms.ToTensor(),]))
test_dataset = torchvision.datasets.CIFAR10('./data' ,train = True ,download=False, transform = transforms.Compose([transforms.ToTensor(),]))
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=10,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=10,shuffle=True)
from tqdm import tqdm
epoch = 1
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
category_list = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
resnet18 = resnet18.to(device)
transfer_layer = torch.nn.Linear(1000,10).to(device)
# 联合参数进行优化需使用如下方式。key只能是params
optimizer = torch.optim.SGD([{'params':transfer_layer.parameters()},{'params':resnet18.conv1.parameters()}],lr = learning_rate)
def train():
for i in range(epoch):
for j,(data,target) in tqdm(enumerate(train_loader)):
logit = transfer_layer(resnet18(data.to(device)))
# print (logit.shape)
# print (target.shape)
loss = torch.nn.functional.cross_entropy(logit.to(device),target.to(device))
loss.backward()
for param in transfer_layer.parameters():
if param.grad is not None:
param.grad.zero_()
optimizer.step()
# 上下两种优化网络参数方式都行。
# for param in transfer_layer.parameters():
# param = (param - learning_rate*param.grad).detach().requires_grad_()
if j % 500 == 0:
print ('第{}次迭代,loss值为{}'.format(j*10,loss))
def test():
correct_num = torch.tensor(0).to(device)
for j,(data,target) in tqdm(enumerate(test_loader)):
data = data.to(device)
target = target.to(device)
logit = transfer_layer(resnet18(data))
pred = logit.max(1)[1]
num = torch.sum(pred==target)
correct_num = correct_num + num
print (correct_num)
print ('\n correct rate is {}'.format(correct_num/10000))
train()
test()
Output:
import matplotlib.pyplot as plt
import numpy as np
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
resnet18.conv1.register_forward_hook(get_activation('conv1'))
# 在0维上增加维度
# data.unsqueeze_(0)
for i,(data,target) in enumerate(test_loader):
if i>=1:
break
print (data.shape)
output = resnet18(data.to(device))
act = activation['conv1']
plt.imshow(np.transpose(data[0],(1,2,0)).detach().cpu().numpy())
plt.show()
plt.figure(figsize=(8*2,8*2))
cnt = 0
for j in range(act.size()[1]):
cnt = cnt + 1
plt.subplot(np.floor(np.sqrt(act.size()[1])),np.floor(np.sqrt(act.size()[1])),cnt)
plt.imshow(act[0][cnt-1].detach().cpu().numpy(),cmap='gray')
plt.show()
import matplotlib.pyplot as plt
import numpy as np
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
resnet18.layer1[0].conv1.register_forward_hook(get_activation('layer1_conv1'))
# 在0维上增加维度
# data.unsqueeze_(0)
for i,(data,target) in enumerate(test_loader):
if i>=1:
break
print (data.shape)
output = resnet18(data.to(device))
act = activation['layer1_conv1']
plt.imshow(np.transpose(data[1],(1,2,0)).detach().cpu().numpy())
plt.show()
plt.figure(figsize=(8*2,8*2))
cnt = 0
for j in range(act.size()[1]):
cnt = cnt + 1
plt.subplot(np.floor(np.sqrt(act.size()[1])),np.floor(np.sqrt(act.size()[1])),cnt)
plt.imshow(act[1][cnt-1].detach().cpu().numpy())
plt.show()
参考资料:
MY Coding: