pytorch Resnet50分类模型剪枝

Resnet50

网络结构:https://www.jianshu.com/p/993c03c22d52

剪枝方式

1.基于network-slimming论文的方法:pytorch版代码:https://github.com/Eric-mingjie/network-slimming
思路:去掉downsample里面的BN层,为了方便采用Resnetv2的结构:BN-Conv-ReLU,在每一个bottleneck的第一个BN后自定义一个通道选择层(全1层),训练过程中不影响,剪枝时先生成BN的通道mask,根据mask对通道选择层进行赋值,选出该BN层保留通道作为Conv的输入,再根据下一个BN的mask选出通道作为Conv的输出通道,这样循环遍历每一层得到剪枝后的网络,再进行finetune或者从头训练。
2.基于大佬最近开源的torch剪枝工具:https://github.com/VainF/Torch-Pruning
思路:在剪枝前,构建整体网络每一层的依赖关系,根据torch中的hooks机制,获取前向传播的每个module的grad_fn,构建module对应节点node,每个节点包含module、grad_fn、inputs
、outputs、dependencies、node_name、type等,这样可以获取每个module的inputs和outputs所依赖的层(运算),再执行剪枝时,根据依赖关系自动对齐通道。
hooks机制:https://cloud.tencent.com/developer/article/1122582、https://zhuanlan.zhihu.com/p/75054200

剪枝核心代码

1.基于network-slimming论文的方法:

import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from Resnet import *
import os
import torchvision
from tqdm import tqdm
from channel_selection import *


# Prune settings
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cat_dog',
                    help='training dataset (default: cat_dog)')
parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
                    help='input batch size for testing (default: 8)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--depth', type=int, default=164,
                    help='depth of the resnet')
parser.add_argument('--percent', type=float, default=0.3,
                    help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='logs/model_pruning_final.pth', type=str, metavar='PATH',
                    help='path to the model (default: none)')
parser.add_argument('--save', default='logs', type=str, metavar='PATH',
                    help='path to save pruned model (default: none)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if not os.path.exists(args.save):
    os.makedirs(args.save)

DEVICE = torch.device('cuda:1')
LR = 0.0001
EPOCH = 50
BTACH_SIZE = 100
train_root = './train'
vaild_root = './test'

#数据加载及处理
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
    # transforms.RandomHorizontalFlip(),
    # torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    # torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

vaild_data = torchvision.datasets.ImageFolder(
        root=vaild_root,
        transform=test_transform
    )

test_set = torch.utils.data.DataLoader(
    vaild_data,
    batch_size=BTACH_SIZE,
    shuffle=False
)
criteration = nn.CrossEntropyLoss()

model = resnet(depth=args.depth, dataset=args.dataset).to(DEVICE)
model.load_state_dict(torch.load(args.model))

def vaild(model,device,dataset):
    model.eval()
    correct = 0
    with torch.no_grad():
        for i,(x,y) in tqdm(enumerate(dataset)):
            x,y = x.to(device) ,y.to(device)
            output = model(x)
            loss = criteration(output,y)
            pred = output.max(1,keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
    return 100*correct/(len(dataset)*BTACH_SIZE)
    print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*BTACH_SIZE,100*correct/(len(dataset)*BTACH_SIZE)))

acc = vaild(model,DEVICE,test_set)
print("preprune acc:",acc)
# model = resnet(164, dataset="cat_dog").to(DEVICE)
# model.load_state_dict(torch.load('model_pruning_test.pth'))
total = 0

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        mask = weight_copy.gt(thre.to(DEVICE)).float().to(DEVICE)
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))

pruned_ratio = pruned/total

print('Pre-processing Successful!',"pruned_ratio:",pruned_ratio)

# simple test model after Pre-processing prune (simple set BN scales to zeros)


print("Cfg:")
print(cfg,len(cfg))

newmodel = resnet(depth=args.depth, dataset=args.dataset, cfg=cfg)
if args.cuda:
    newmodel.to(DEVICE)

num_parameters = sum([param.nelement() for param in newmodel.parameters()])
savepath = os.path.join(args.save, "prune_0.3.txt")
with open(savepath, "w") as fp:
    fp.write("Configuration: \n"+str(cfg)+"\n")
    fp.write("Number of parameters: \n"+str(num_parameters)+"\n")
    #fp.write("Test accuracy: \n"+str(acc))

old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0

for layer_id in range(len(old_modules)):
    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]
    if isinstance(m0, nn.BatchNorm2d):
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        if idx1.size == 1:
            idx1 = np.resize(idx1,(1,))

        if isinstance(old_modules[layer_id + 1], channel_selection):
            # If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
            m1.weight.data = m0.weight.data.clone()
            m1.bias.data = m0.bias.data.clone()
            m1.running_mean = m0.running_mean.clone()
            m1.running_var = m0.running_var.clone()

            # We need to set the channel selection layer.
            m2 = new_modules[layer_id + 1]
            m2.indexes.data.zero_()
            m2.indexes.data[idx1.tolist()] = 1.0

            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):
                end_mask = cfg_mask[layer_id_in_cfg]
        else:
            m1.weight.data = m0.weight.data[idx1.tolist()].clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            m1.running_mean = m0.running_mean[idx1.tolist()].clone()
            m1.running_var = m0.running_var[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d):
        if conv_count == 0:
            m1.weight.data = m0.weight.data.clone()
            conv_count += 1
            continue
        if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
            # This convers the convolutions in the residual block.
            # The convolutions are either after the channel selection layer or after the batch normalization layer.
            conv_count += 1
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()

            # If the current convolution is not the last convolution in the residual block, then we can change the 
            # number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
            if conv_count % 3 != 1:
                w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            continue

        # We need to consider the case where there are downsampling convolutions. 
        # For these convolutions, we just copy the weights.
        m1.weight.data = m0.weight.data.clone()
    elif isinstance(m0, nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))

        m1.weight.data = m0.weight.data[:, idx0].clone()
        m1.bias.data = m0.bias.data.clone()

torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned_0.3.pth.tar'))

print(newmodel)
model = newmodel

acc=vaild(model,DEVICE,test_set)
print("pruned acc:",acc)

2.基于Torch-Pruning剪枝工具

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import torch_pruning as tp


parser = argparse.ArgumentParser()
parser.add_argument('--train_root', type=str, default='/data/xywang/dataset/catdog_classification/train',
                    help='training dataset (default: train)')
parser.add_argument('--vaild_root', type=str, default='/data/xywang/dataset/catdog_classification/test',
                    help='training dataset (default: test)')
parser.add_argument('--sr', default=True, type=bool,
                    help='train with channel sparsity regularization')
parser.add_argument('--s', default=0.0001, type=float, 
                    help='scale sparse rate (default: 0.0001)')
parser.add_argument('--batch_size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
                    help='number of epochs to train (default: 160)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                    help='learning rate (default: 0.001)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--save', default='./models', type=str, metavar='PATH',
                    help='path to save prune model (default: current directory)')
parser.add_argument('--percent',default=0.9, type=float,
                    help='the PATH to the pruned model')

args = parser.parse_args()
device = torch.device('cuda:1')

if not os.path.exists(args.save):
    os.makedirs(args.save)

#数据加载及处理
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(0.6,1.0),ratio=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomResizedCrop(224,scale=(1.0,1.0),ratio=(1.0,1.0)),
    # transforms.RandomHorizontalFlip(),
    # torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    # torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

train_data =  torchvision.datasets.ImageFolder(
        root=args.train_root,
        transform=train_transform
    )

vaild_data = torchvision.datasets.ImageFolder(
        root=args.vaild_root,
        transform=train_transform
    )

train_set = torch.utils.data.DataLoader(
    train_data,
    batch_size=args.batch_size,
    shuffle=True
)

test_set = torch.utils.data.DataLoader(
    vaild_data,
    batch_size=args.batch_size,
    shuffle=False
)

def updateBN(model, s ,pruning_modules):
    for module in pruning_modules:
        module.weight.grad.data.add_(s * torch.sign(module.weight.data))
#训练和验证
criteration = nn.CrossEntropyLoss()
def train(model,device,dataset,optimizer,epoch,pruning_modules):
    model.train().to(device)
    correct = 0
    for i,(x,y) in tqdm(enumerate(dataset)):
        x , y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        pred = output.max(1,keepdim=True)[1]
        correct += pred.eq(y.view_as(pred)).sum().item()
        loss =  criteration(output,y)     
        loss.backward()
        optimizer.step()

        if args.sr:
            updateBN(model,args.s,pruning_modules)
        
    print("Epoch {} Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(epoch,loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
    

def vaild(model,device,dataset):
    model.eval().to(device)
    correct = 0
    with torch.no_grad():
        for i,(x,y) in tqdm(enumerate(dataset)):
            x,y = x.to(device) ,y.to(device)
            output = model(x)
            loss = criteration(output,y)
            pred = output.max(1,keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
    print("Test Loss {:.4f} Accuracy {}/{} ({:.3f}%)".format(loss,correct,len(dataset)*args.batch_size,100*correct/(len(dataset)*args.batch_size)))
    return 100*correct/(len(dataset)*args.batch_size)

def get_pruning_modules(model):
    module_list = []
    for module in model.modules():
        if isinstance(module,torchvision.models.resnet.Bottleneck):
            module_list.append(module.bn1)
            module_list.append(module.bn2)
    return module_list

def gather_bn_weights(model,pruning_modules):
    size_list = [module.weight.data.shape[0] for module in model.modules() if module in pruning_modules]
    bn_weights = torch.zeros(sum(size_list))
    index = 0
    for module, size in zip(pruning_modules, size_list):
        bn_weights[index:(index + size)] = module.weight.data.abs().clone()
        index += size

    return bn_weights

def computer_eachlayer_pruned_number(bn_weights,thresh):
    num_list = []
    #print(bn_modules)
    for module in bn_modules:
        num = 0
        #print(module.weight.data.abs(),thresh)
        for data in module.weight.data.abs():
            if thresh > data.float():
                num +=1
        num_list.append(num)
    print(thresh)
    return num_list

def prune_model(model,num_list):
    model.to(device)
    DG = tp.DependencyGraph().build_dependency(model, torch.randn(1, 3, 224, 224) )
    def prune_bn(bn, num):
        L1_norm = bn.weight.detach().cpu().numpy()
        prune_index = np.argsort(L1_norm)[:num].tolist() # remove filters with small L1-Norm
        plan = DG.get_pruning_plan(bn, tp.prune_batchnorm, prune_index)
        plan.exec()
    
    blk_id = 0
    for m in model.modules():
        if isinstance( m, torchvision.models.resnet.Bottleneck ):
            prune_bn( m.bn1, num_list[blk_id] )
            prune_bn( m.bn2, num_list[blk_id+1] )
            blk_id+=2
    return model  


model = torchvision.models.resnet50(pretrained=True)
model.fc = nn.Sequential(
        nn.Linear(2048,2)
    )
model.to(device)
model.load_state_dict(torch.load("models/model_pruning.pth"))
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

bn_modules = get_pruning_modules(model)

bn_weights = gather_bn_weights(model,bn_modules)
sorted_bn = torch.sort(bn_weights)[0]
sorted_bn, sorted_index = torch.sort(bn_weights)
thresh_index = int(len(bn_weights) * args.percent)
thresh = sorted_bn[thresh_index].to(device)

num_list = computer_eachlayer_pruned_number(bn_weights,thresh)

prune_model(model,num_list)
print(model)

prec = vaild(model,device,test_set)
for epoch in range(1,args.epochs + 1):
    train(model,device,train_set,optimizer,epoch,bn_modules)
    vaild(model,device,test_set)
    # torch.save(model.state_dict(), 'model_pruned.pth')
    torch.save(model, 'models/model_pruned_0.8.pth' )

你可能感兴趣的:(模型剪枝,网络结构,pytorch)