nni模型剪枝

本文参考:

【模型部署】NNI:剪枝和量化_Jackilina_Stone的博客-CSDN博客_nni剪枝

 模型剪枝入门 — Neural Network Intelligence

NNI:Neural Network Intelligence,是一个轻量但强大的自动机器学习工具包,能帮助用户自动进行特征工程、神经网络架构搜索、超参调优以及模型压缩。

一、安装NNI

pip install nni

安装完成后检查nni版本:

nnictl --version

二、剪枝操作

之前已经写过两篇关于模型剪枝的方法,分别是:

模型剪枝初级方法_benben044的博客-CSDN博客

Slimming剪枝方法_benben044的博客-CSDN博客 

 

VGG模型代码:

import torch
import torch.nn as nn
import math
from torch.autograd import Variable

class vgg(nn.Module):
    def __init__(self, dataset='cifar10', init_weights=True, cfg=None):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
        self.feature = self.make_layers(cfg, True)

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        self.classifier = nn.Linear(cfg[-1], num_classes)
        if init_weights:
            self._initialize_weights()


    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

if __name__ == '__main__':
    net = vgg()
    x = Variable(torch.FloatTensor(16, 3, 40, 40))
    y = net(x)
    print(y.data.shape)

通过nni模型剪枝的代码:

import torch
import torch.optim as optim
from vgg import vgg
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.nn.functional as F
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
import shutil


config_list = [{
    'sparsity_per_layer': 0.5,
    'op_types': ['Conv2d']
}]
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
epochs = 6
log_interval = 100
batch_size = 100
sparsity_regularization = True
scale_sparse_rate = 0.0001
prune_checkpoint_path = 'nni_pruned_checkpoint.pth.tar'
prune_best_model_path = 'nni_pruned_model_best.pth.tar'




model = vgg()
model.cuda()
print("--------------raw model--------------")
print(model)

pruner = L1NormPruner(model, config_list)
print("--------------pruned model--------------")
print(model)

# compress the model and generate the masks
_, masks = pruner.compress()
# show the masks sparsity
for name, mask in masks.items():
    print(name, ' sparsity: ', '{:.2f}'.format(mask['weight'].sum() / mask['weight'].numel()))

# need to unwarp the model, if the model is wrawpped before speedup
pruner._unwrap_model()

# speedup the model
ModelSpeedup(model, torch.rand(16, 3, 40, 40).cuda(), masks).speedup_model()

print("--------------after sppedup--------------")
print(model)

# retrain
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('D:\\ai_data\\cifar10', train=True, download=True,
                     transform=transforms.Compose([
                         transforms.Pad(4),
                         transforms.RandomCrop(32),
                         transforms.RandomHorizontalFlip(),
                         transforms.ToTensor(),
                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                     ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=batch_size, shuffle=True)

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in tqdm(test_loader):
        data , target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

def save_checkpoint(state, is_best, filename=prune_checkpoint_path):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, prune_best_model_path)

best_prec = 0
for epoch in range(epochs):
    train(epoch)
    prec = test()
    is_best = prec > best_prec
    best_prec = max(prec, best_prec)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict()
    }, is_best)

在运行相同的迭代次数情况下,nni的准确率没有slimming高

你可能感兴趣的:(神经网络,剪枝,算法)