李宏毅作业七其三 Network Compression (Network Pruning)

Network Compression ——Network Pruning

  • 前言
  • 一、Network Pruning是什么?
    • Weight & Neuron Pruning
    • 对于修剪网络并不简单,这里只是给出了相对简单的修剪方法。
    • What to Prune?
    • 如何操作?
    • 细节
  • 二、代码示例
    • 1.加载数据和定义环境
    • 2.network purning
    • 3.数据处理
    • 4.预处理
    • 5.Start Training

前言

网络结构会存在冗余,所以我们要想办法删除里面的神经元neuron或者优化里面的权重weight。删除里面的神经元neuron和dropout让神经元随机失活有些相似,但是随机失活只是起到过拟合的作用,并没有在真正意义上让修剪网络。


一、Network Pruning是什么?

李宏毅作业七其三 Network Compression (Network Pruning)_第1张图片

Weight & Neuron Pruning

  • weight和neuron pruning,对neruron做purning效果来得更明显,且相对简单些。去掉一个neruron相当于把一整个matrix的整个column都删掉了,速度更快。

对于修剪网络并不简单,这里只是给出了相对简单的修剪方法。

What to Prune?

  • 既然要Neuron Pruning,那就必須要先衡量Neuron的重要性。衡量完所有的Neuron後,就可以把比較不重要的Neuron刪減掉。

  • 在這裡我們介紹一個很簡單可以衡量Neuron重要性的方法 - 就是看batchnorm layer的 γ \gamma γ因子來決定neuron的重要性。 (by paper - Network Slimming)
    李宏毅作业七其三 Network Compression (Network Pruning)_第2张图片

  • 相信大家看這個pytorch提供的batchnorm公式應該就可以意識到為甚麼 γ \gamma γ可以當作重要性來衡量了:)

  • Network Slimming其實步驟沒有這麼簡單,有興趣的同學可以check以下連結。Netowrk Slimming

如何操作?

  • 為了避免複雜的操作,我們會將StudentNet(width_mult= α \alpha α)的neuron經過篩選後移植到StudentNet(width_mult= β \beta β)。( α > β \alpha > \beta α>β)
  • 篩選的方法也很簡單,只需要抓出每一個block的batchnorm的 γ \gamma γ即可。

细节

  • 假設model中間兩層是這樣的:
Layer Output # of Channels
Input in_chs
Depthwise(in_chs) in_chs
BatchNorm(in_chs) in_chs
Pointwise(in_chs, mid_chs) mid_chs
Depthwise(mid_chs) mid_chs
BatchNorm(mid_chs) mid_chs
Pointwise(mid_chs, out_chs) out_chs

則你會發現利用第二個BatchNorm來做篩選的時候,跟他的Neuron有直接關係的是該層的Depthwise&Pointwise以及上層的Pointwise。
因此再做neuron篩選時記得要將這四個(包括自己, bn)也要同時prune掉。

  • 在Design Architecure內,model的一個block,名稱所對應的Weight;
# name meaning code weight shape
0 cnn.{i}.0 Depthwise Convolution Layer nn.Conv2d(x, x, 3, 1, 1, group=x) (x, 1, 3, 3)
1 cnn.{i}.1 Batch Normalization nn.BatchNorm2d(x) (x)
2 ReLU6 nn.ReLU6
3 cnn.{i}.3 Pointwise Convolution Layer nn.Conv2d(x, y, 1), (y, x, 1, 1)
4 MaxPooling nn.MaxPool2d(2, 2, 0)

二、代码示例

1.加载数据和定义环境

# Download dataset
!gdown --id '19CzXudqN58R3D-1G8KeFWk8UDQwlb8is' --output food-11.zip
# Unzip the files
!unzip food-11.zip
import torch
import os  #os模块提供了多数操作系统的功能接口函数。
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
# Load進我們的Model架構(在hw7_Architecture_Design.ipynb內)
!gdown --id '1lJS0ApIyi7eZ2b3GMyGxjPShI8jXM2UC' --output "hw7_Architecture_Design.ipynb"
%run "hw7_Architecture_Design.ipynb"

2.network purning

def network_slimming(old_model, new_model):#网络修剪,旧模型参数,新模型参数
    params = old_model.state_dict()
    new_params = new_model.state_dict()
    
    # selected_idx: 每一層所選擇的neuron index主要还是修剪神经元
    selected_idx = []
    # 我們總共有7層CNN,因此逐一抓取選擇的neuron index們。
    for i in range(8):
        # 根據上表,我們要抓的gamma係數在cnn.{i}.1.weight內。
        importance = params[f'cnn.{i}.1.weight']
        # 抓取總共要篩選幾個neuron。
        old_dim = len(importance)
        new_dim = len(new_params[f'cnn.{i}.1.weight'])

        # 以Ranking做Index排序,較大的會在前面(descending=True)。
        #argsort函数返回的是数组值从小到大(默认)的索引值,descending降序从大到小。
        ranking = torch.argsort(importance, descending=True)
        # 把篩選結果放入selected_idx中。
        selected_idx.append(ranking[:new_dim])

    now_processed = 1
    for (name, p1), (name2, p2) in zip(params.items(), new_params.items()):#注意item返回的是具体的数值
    
        # 如果是cnn層,則移植參數。
        # 如果是FC層,或是該參數只有一個數字(例如batchnorm的tracenum等等資訊),那麼就直接複製。
        if name.startswith('cnn') and p1.size() != torch.Size([]) and now_processed != len(selected_idx):
            # 當處理到Pointwise的weight時,讓now_processed+1,表示該層的移植已經完成。
            if name.startswith(f'cnn.{now_processed}.3'):
                now_processed += 1

            # 如果是pointwise,weight會被上一層的pruning和下一層的pruning所影響,因此需要特判。
            if name.endswith('3.weight'):
                # 如果是最後一層cnn,則輸出的neuron不需要prune掉。
                if len(selected_idx) == now_processed:
                    new_params[name] = p1[:,selected_idx[now_processed-1]]
                # 反之,就依照上層和下層所選擇的index進行移植。
                # 這裡需要注意的是Conv2d(x,y,1)的weight shape是(y,x,1,1),順序是反的。
                else:
                    new_params[name] = p1[selected_idx[now_processed]][:,selected_idx[now_processed-1]]
            else:
                new_params[name] = p1[selected_idx[now_processed]]
        else:
            new_params[name] = p1

    # 讓新model load進被我們篩選過的parameters,並回傳new_model。        
    new_model.load_state_dict(new_params)
    return new_model


3.数据处理

import re
import torch
from glob import glob#glob模块用来查找文件目录和文件
from PIL import Image
import torchvision.transforms as transforms

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, folderName, transform=None):#私有变量
        self.transform = transform
        self.data = []
        self.label = []

        for img_path in sorted(glob(folderName + '/*.jpg')):#寻找照片路径
            try:
                # Get classIdx by parsing image path
                class_idx = int(re.findall(re.compile(r'\d+'), img_path)[1])
            except:
                # if inference mode (there's no answer), class_idx default 0
                class_idx = 0
 
            image = Image.open(img_path)
            # Get File Descriptor
            image_fp = image.fp
            image.load()
            # Close File Descriptor (or it'll reach OPEN_MAX)
            image_fp.close()

            self.data.append(image)#数据扩增图片和索引(标签)
            self.label.append(class_idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = self.data[idx]
        if self.transform:
            image = self.transform(image)
        return image, self.label[idx]

#图片处理,数据扩增
trainTransform = transforms.Compose([
    transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])
testTransform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

def get_dataloader(mode='training', batch_size=32):

    assert mode in ['training', 'testing', 'validation']

    dataset = MyDataset(
        f'./food-11/{mode}',
        transform=trainTransform if mode == 'training' else testTransform)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'training'))

    return dataloader

4.预处理

# get dataloader
train_dataloader = get_dataloader('training', batch_size=32)
valid_dataloader = get_dataloader('validation', batch_size=32)
!gdown --id '12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL' --output student_custom_small.bin

net = StudentNet().cuda()
net.load_state_dict(torch.load('student_custom_small.bin'))

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=1e-3)

5.Start Training

每次Prune rate是0.95,Prune之后重新fine-tune 3 epochs。

def run_epoch(dataloader, update=True, alpha=0.5):
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空 optimizer
        optimizer.zero_grad()
        # 處理 input
        inputs, labels = batch_data
        inputs = inputs.cuda()
        labels = labels.cuda()
  
        logits = net(inputs)
        loss = criterion(logits, labels)
        if update:
            loss.backward()
            optimizer.step()

        total_hit += torch.sum(torch.argmax(logits, dim=1) == labels).item()
        total_num += len(inputs)
        total_loss += loss.item() * len(inputs)

    return total_loss / total_num, total_hit / total_num

now_width_mult = 1
for i in range(5):
    now_width_mult *= 0.95
    new_net = StudentNet(width_mult=now_width_mult).cuda()
    params = net.state_dict()
    net = network_slimming(net, new_net)#进行网络修剪
    now_best_acc = 0
    for epoch in range(5):
        net.train()
        train_loss, train_acc = run_epoch(train_dataloader, update=True)
        net.eval()
        valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)
        # 在每個width_mult的情況下,存下最好的model。
        if valid_acc > now_best_acc:
            now_best_acc = valid_acc
            torch.save(net.state_dict(), f'custom_small_rate_{now_width_mult}.bin')
        print('rate {:6.4f} epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(now_width_mult, 
            epoch, train_loss, train_acc, valid_loss, valid_acc))

你可能感兴趣的:(模型优化,李宏毅,Colab,深度学习,python,神经网络,机器学习)