网络结构:https://www.jianshu.com/p/993c03c22d52
1.基于network-slimming论文的方法:pytorch版代码:https://github.com/Eric-mingjie/network-slimming
思路:去掉downsample里面的BN层,为了方便采用Resnetv2的结构:BN-Conv-ReLU,在每一个bottleneck的第一个BN后自定义一个通道选择层(全1层),训练过程中不影响,剪枝时先生成BN的通道mask,根据mask对通道选择层进行赋值,选出该BN层保留通道作为Conv的输入,再根据下一个BN的mask选出通道作为Conv的输出通道,这样循环遍历每一层得到剪枝后的网络,再进行finetune或者从头训练。
2.基于大佬最近开源的torch剪枝工具:https://github.com/VainF/Torch-Pruning
思路:在剪枝前,构建整体网络每一层的依赖关系,根据torch中的hooks机制,获取前向传播的每个module的grad_fn,构建module对应节点node,每个节点包含module、grad_fn、inputs
、outputs、dependencies、node_name、type等,这样可以获取每个module的inputs和outputs所依赖的层(运算),再执行剪枝时,根据依赖关系自动对齐通道。
hooks机制:https://cloud.tencent.com/developer/article/1122582、https://zhuanlan.zhihu.com/p/75054200
1.基于network-slimming论文的方法:
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from Resnet import *
import os
import torchvision
from tqdm import tqdm
from channel_selection import *
# Prune settings
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cat_dog',
help='training dataset (default: cat_dog)')
parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
help='input batch size for testing (default: 8)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--depth', type=int, default=164,
help='depth of the resnet')
parser.add_argument('--percent', type=float, default=0.3,
help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='logs/model_pruning_final.pth', type=str, metavar='PATH',
help='path to the model (default: none)')
parser.add_argument('--save', default='logs', type=str, metavar='PATH',
help='path to save pruned model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if not os.path.exists(args.save):
os.makedirs(args.save)
DEVICE = torch.device('cuda:1')
LR = 0.0001
EPOCH = 50
BTACH_SIZE = 100
train_root = './train'
vaild_root = './test'
#数据加载及处理
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
test_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
# transforms.RandomHorizontalFlip(),
# torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
# torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
vaild_data = torchvision.datasets.ImageFolder(
root=vaild_root,
transform=test_transform
)
test_set = torch.utils.data.DataLoader(
vaild_data,
batch_size=BTACH_SIZE,
shuffle=False
)
criteration = nn.CrossEntropyLoss()
model = resnet(depth=args.depth, dataset=args.dataset).to(DEVICE)
model.load_state_dict(torch.load(args.model))
def vaild(model,device,dataset):
model.eval()
correct = 0
with torch.no_grad():
for i,(x,y) in tqdm(enumerate(dataset)):
x,y = x.to(device) ,y.to(device)
output = model(x)
loss = criteration(output,y)
pred = output.max(1,keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
return 100*correct/(len(dataset)*BTACH_SIZE)
print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*BTACH_SIZE,100*correct/(len(dataset)*BTACH_SIZE)))
acc = vaild(model,DEVICE,test_set)
print("preprune acc:",acc)
# model = resnet(164, dataset="cat_dog").to(DEVICE)
# model.load_state_dict(torch.load('model_pruning_test.pth'))
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre.to(DEVICE)).float().to(DEVICE)
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
pruned_ratio = pruned/total
print('Pre-processing Successful!',"pruned_ratio:",pruned_ratio)
# simple test model after Pre-processing prune (simple set BN scales to zeros)
print("Cfg:")
print(cfg,len(cfg))
newmodel = resnet(depth=args.depth, dataset=args.dataset, cfg=cfg)
if args.cuda:
newmodel.to(DEVICE)
num_parameters = sum([param.nelement() for param in newmodel.parameters()])
savepath = os.path.join(args.save, "prune_0.3.txt")
with open(savepath, "w") as fp:
fp.write("Configuration: \n"+str(cfg)+"\n")
fp.write("Number of parameters: \n"+str(num_parameters)+"\n")
#fp.write("Test accuracy: \n"+str(acc))
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
for layer_id in range(len(old_modules)):
m0 = old_modules[layer_id]
m1 = new_modules[layer_id]
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
if isinstance(old_modules[layer_id + 1], channel_selection):
# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()
# We need to set the channel selection layer.
m2 = new_modules[layer_id + 1]
m2.indexes.data.zero_()
m2.indexes.data[idx1.tolist()] = 1.0
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
else:
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
if conv_count == 0:
m1.weight.data = m0.weight.data.clone()
conv_count += 1
continue
if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
# This convers the convolutions in the residual block.
# The convolutions are either after the channel selection layer or after the batch normalization layer.
conv_count += 1
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
# If the current convolution is not the last convolution in the residual block, then we can change the
# number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
if conv_count % 3 != 1:
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
continue
# We need to consider the case where there are downsampling convolutions.
# For these convolutions, we just copy the weights.
m1.weight.data = m0.weight.data.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned_0.3.pth.tar'))
print(newmodel)
model = newmodel
acc=vaild(model,DEVICE,test_set)
print("pruned acc:",acc)
2.基于Torch-Pruning剪枝工具
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import torch_pruning as tp
parser = argparse.ArgumentParser()
parser.add_argument('--train_root', type=str, default='/data/xywang/dataset/catdog_classification/train',
help='training dataset (default: train)')
parser.add_argument('--vaild_root', type=str, default='/data/xywang/dataset/catdog_classification/test',
help='training dataset (default: test)')
parser.add_argument('--sr', default=True, type=bool,
help='train with channel sparsity regularization')
parser.add_argument('--s', default=0.0001, type=float,
help='scale sparse rate (default: 0.0001)')
parser.add_argument('--batch_size', type=int, default=100, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
help='number of epochs to train (default: 160)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--save', default='./models', type=str, metavar='PATH',
help='path to save prune model (default: current directory)')
parser.add_argument('--percent',default=0.9, type=float,
help='the PATH to the pruned model')
args = parser.parse_args()
device = torch.device('cuda:1')
if not os.path.exists(args.save):
os.makedirs(args.save)
#数据加载及处理
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
test_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
# transforms.RandomHorizontalFlip(),
# torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
# torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
train_data = torchvision.datasets.ImageFolder(
root=args.train_root,
transform=train_transform
)
vaild_data = torchvision.datasets.ImageFolder(
root=args.vaild_root,
transform=train_transform
)
train_set = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True
)
test_set = torch.utils.data.DataLoader(
vaild_data,
batch_size=args.batch_size,
shuffle=False
)
def updateBN(model, s ,pruning_modules):
for module in pruning_modules:
module.weight.grad.data.add_(s * torch.sign(module.weight.data))
#训练和验证
criteration = nn.CrossEntropyLoss()
def train(model,device,dataset,optimizer,epoch,pruning_modules):
model.train().to(device)
correct = 0
for i,(x,y) in tqdm(enumerate(dataset)):
x , y = x.to(device), y.to(device)
optimizer.zero_grad()
output = model(x)
pred = output.max(1,keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
loss = criteration(output,y)
loss.backward()
optimizer.step()
if args.sr:
updateBN(model,args.s,pruning_modules)
print("Epoch {} Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(epoch,loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
def vaild(model,device,dataset):
model.eval().to(device)
correct = 0
with torch.no_grad():
for i,(x,y) in tqdm(enumerate(dataset)):
x,y = x.to(device) ,y.to(device)
output = model(x)
loss = criteration(output,y)
pred = output.max(1,keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
return 100*correct/(len(dataset)*args.batch_size)
def get_pruning_modules(model):
module_list = []
for module in model.modules():
if isinstance(module,torchvision.models.resnet.Bottleneck):
module_list.append(module.bn1)
module_list.append(module.bn2)
return module_list
def gather_bn_weights(model,pruning_modules):
size_list = [module.weight.data.shape[0] for module in model.modules() if module in pruning_modules]
bn_weights = torch.zeros(sum(size_list))
index = 0
for module, size in zip(pruning_modules, size_list):
bn_weights[index:(index + size)] = module.weight.data.abs().clone()
index += size
return bn_weights
def computer_eachlayer_pruned_number(bn_weights,thresh):
num_list = []
#print(bn_modules)
for module in bn_modules:
num = 0
#print(module.weight.data.abs(),thresh)
for data in module.weight.data.abs():
if thresh > data.float():
num +=1
num_list.append(num)
print(thresh)
return num_list
def prune_model(model,num_list):
model.to(device)
DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224) )
def prune_bn(bn, num):
L1_norm = bn.weight.detach().cpu().numpy()
prune_index = np.argsort(L1_norm)[:num].tolist() # remove filters with small L1-Norm
plan = DG.get_pruning_plan(bn, tp.prune_batchnorm, prune_index)
plan.exec()
blk_id = 0
for m in model.modules():
if isinstance( m, torchvision.models.resnet.Bottleneck ):
prune_bn( m.bn1, num_list[blk_id] )
prune_bn( m.bn2, num_list[blk_id+1] )
blk_id+=2
return model
model = torchvision.models.resnet50(pretrained=True)
model.fc = nn.Sequential(
nn.Linear(2048,2)
)
model.to(device)
model.load_state_dict(torch.load("models/model_pruning.pth"))
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
bn_modules = get_pruning_modules(model)
bn_weights = gather_bn_weights(model,bn_modules)
sorted_bn = torch.sort(bn_weights)[0]
sorted_bn, sorted_index = torch.sort(bn_weights)
thresh_index = int(len(bn_weights) * args.percent)
thresh = sorted_bn[thresh_index].to(device)
num_list = computer_eachlayer_pruned_number(bn_weights,thresh)
prune_model(model,num_list)
print(model)
prec = vaild(model,device,test_set)
for epoch in range(1,args.epochs + 1):
train(model,device,train_set,optimizer,epoch,bn_modules)
vaild(model,device,test_set)
# torch.save(model.state_dict(), 'model_pruned.pth')
torch.save(model, 'models/model_pruned_0.8.pth' )