nni的简单使用

git链接

1、训练cifar10模型
网络结构主要基于torchvision的resnet18修改,conv1的卷积核由7x7变成3x3,layer2的stride换成1

import torch
import numpy as np
from nni.compression.pytorch.utils.counter import count_flops_params
from cifar_resnet import ResNet18
import torch.nn as nn
from tqdm import tqdm
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_workers = 16
torch.set_num_threads(num_workers)

def test(model, valid_dataloader):
    model.eval()
    
    loss_func = nn.CrossEntropyLoss()
    acc_list, loss_list = [], []
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(tqdm(valid_dataloader)):
            inputs, labels = inputs.float().to(device), labels.to(device)
            preds= model(inputs)
            pred_idx = preds.max(1).indices
            acc = (pred_idx == labels).sum().item() / labels.size(0)
            acc_list.append(acc)
            loss = loss_func(preds, labels).item()
            loss_list.append(loss)

    valid_loss = np.array(loss_list).mean()
    valid_acc = np.array(acc_list).mean()
    
    return valid_loss, valid_acc

def train():
    # torchvision和cifa10_resnet的区别:
    # 1.torchvision的resnet18的conv1卷积核为7x7,而cifar10_resnet18的conv1为3x3;
    # 2.torchvision的layer2的stride为2,而cifar10_resnet18的layer2的stride为1;
    # 3.torchvision的layer1之前有maxpool操作。
    model = ResNet18(num_classes=10)
    model = model.to(device)

    print(model)

    # check model FLOPs and parameter counts with NNI utils
    dummy_input = torch.rand([1, 3, 32, 32]).to(device)

    flops, params, results = count_flops_params(model, dummy_input)
    print(f"FLOPs: {flops}, params: {params}")

    
    train_dataloader = torch.utils.data.DataLoader(
        CIFAR10('/workspace_wjr/develop/tutorials/data', train=True, transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]), download=True), batch_size=512, num_workers=num_workers)
    valid_dataloader = torch.utils.data.DataLoader(
        CIFAR10('/workspace_wjr/develop/tutorials/data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ]), download=True), batch_size=512, num_workers=num_workers)


    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    # lr_policy = torch.optim.lr_scheduler.StepLR(optimizer, 70, 0.1)

    # optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    lr_policy = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

    best_valid_acc = 0

    for epoch in range(100):
        print('Start training epoch {}'.format(epoch))
        loss_list = []
        # train
        model.train()
        for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
            optimizer.zero_grad()
            inputs, labels = inputs.float().to(device), labels.to(device)
            preds = model(inputs)
            loss = criterion(preds, labels)
            loss_list.append(loss.item())
            loss.backward()
            optimizer.step()
        lr_policy.step()
        print(lr_policy.get_lr()[0])

        
        # validation
        valid_loss, valid_acc = test(model, valid_dataloader)
        train_loss = np.array(loss_list).mean()
        print('Epoch {}: train loss {:.4f}, valid loss {:.4f}, valid acc {:.4f}'.format
                (epoch, train_loss, valid_loss, valid_acc))
        
        # save
        if valid_acc > best_valid_acc:
            best_valid_acc = valid_acc
            torch.save(model.state_dict(), 'checkpoint_best.pt')


if __name__ == '__main__':
    # train()
    model = ResNet18(num_classes=10).to(device)
    model.load_state_dict(torch.load("checkpoint_best.pt",map_location=device))
    model.eval()

    valid_dataloader = torch.utils.data.DataLoader(
        CIFAR10('/workspace_wjr/develop/tutorials/data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ]), download=True), batch_size=512, num_workers=num_workers)
    valid_loss, valid_acc = test(model, valid_dataloader)
    print(valid_acc, valid_loss)
    

2、模型灵敏度分析
逐层进行剪枝,根据验证函数分析模型指标

from nni.compression.pytorch.utils.sensitivity_analysis import SensitivityAnalysis
import torch
from torchvision import datasets, transforms
from cifar_resnet import ResNet18
import os

device = "cuda:0"


def test(model):
    criterion = torch.nn.CrossEntropyLoss()
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('/workspace_wjr/develop/tutorials/data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ]), download=True), batch_size=512, num_workers=16)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    acc = 100 * correct / len(test_loader.dataset)

    # print('Test Loss: {:.6f}  Accuracy: {}%\n'.format(
    #     test_loss, acc))
    return acc


model = ResNet18(num_classes=10).to(device)
model.load_state_dict(torch.load("checkpoint_best.pt", map_location=device))
model.eval()

# print(test(net))

# 注意,由于test函数的返回值范围是0~100m,所以early_stop_value范围也是0~100
# SensitivityAnalysis本质上是对所有的层(目前只支持conv2d和conv2的)进行prune操作,再分别根据val_func函数计算分析指标
s_analyzer = SensitivityAnalysis(model=model, val_func=test, early_stop_mode="minimize", early_stop_value=50,
                                 prune_type="fpgm")
# 可以通过specified_layer参数指定只分析哪些层的灵敏度,如["conv1", "layer1.0.conv1"]
sensitivity = s_analyzer.analysis(val_args=[model],specified_layers=["layer4.1.conv2", "layer4.1.conv1"])
os.makedirs("outdir", exist_ok=True)
s_analyzer.export(os.path.join("outdir", "fpgm.csv"))

3、通道依赖关系分析
分析通道依赖关系,具有依赖关系的层需要设置相同的稀疏度

import torch
from cifar_resnet import ResNet18
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency

model = ResNet18()
print(model)
dummy_input = torch.rand((1,3,32,32))
channel_depen = ChannelDependency(model, dummy_input)


channel_depen.export("outdir/dependency.csv")

4、简单的剪枝
设置需要剪枝的层和稀疏度,进行简单的剪枝

import torch
from torchvision.models import resnet18
import torch.nn as nn
from nni.algorithms.compression.pytorch.pruning import FPGMPruner
from cifar_resnet import ResNet18
from torchvision import datasets
from torchvision.transforms import transforms
import numpy as np
import logging

logger = logging.getLogger(__name__)

device = "cuda:0"
test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('/workspace_wjr/develop/tutorials/data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ]), download=True), batch_size=512, num_workers=16)

def test(model):
    criterion = torch.nn.CrossEntropyLoss()

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    acc = 100 * correct / len(test_loader.dataset)

    # print('Test Loss: {:.6f}  Accuracy: {}%\n'.format(
    #     test_loss, acc))
    return acc



model = ResNet18(10)
ori_state_dict = torch.load("checkpoint_best.pt", map_location="cuda")
model.load_state_dict(ori_state_dict)
model = model.cuda().eval()

# 首先分析通道依赖关系,具有依赖关系的层需要设置相同的稀疏度,然后根据灵敏度分析结果进行剪枝
# cfg = [{'sparsity': 0.3, 'op_names': ["conv1"], 'op_types': ['Conv2d']}]
for val1 in np.arange(0.1, 1.0, 0.1):
    for val2 in np.arange(0.1, 1.0, 0.1):
        cfg = [{"sparsity": val2, "op_names": ["layer4.1.conv2"], 'op_types': ['Conv2d']},
               {"sparsity": val1, "op_names": ["layer4.1.conv1"], 'op_types': ['Conv2d']}]
        # cfg = [{"sparsity": val2, "op_names": ["layer4.1.conv2"], 'op_types': ['Conv2d']},]
        pruner = FPGMPruner(model, cfg, dummy_input=torch.rand((4,3,32,32)).cuda())
        pruner.compress()
        print(val1, val2, test(model))
        pruner._unwrap_model()
        del pruner
    # 这句话一定要加,否则会在之前剪枝的基础上进行剪枝
    model.load_state_dict(ori_state_dict)
    ```

你可能感兴趣的:(pytorch,深度学习,python)