剪枝与重参第六课:基于VGG的模型剪枝实战

目录

  • 基于VGG的模型剪枝实战
    • 前言
    • 1.Intro
    • 2.Prune实战
      • 2.1 说明
      • 2.2 test()
      • 2.3 加载稀疏训练模型
      • 2.4 前处理
      • 2.5 建立新模型并存储信息
      • 2.6 BatchNorm层的剪枝
      • 2.7 Conv2d的剪枝
      • 2.8 Linear的剪枝
    • 3.基于VGG的模型剪枝
    • 总结

基于VGG的模型剪枝实战

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解基于VGG的模型剪枝的实战。

课程大纲可看下面的思维导图

剪枝与重参第六课:基于VGG的模型剪枝实战_第1张图片

1.Intro

本次剪枝实战是基于下面这篇论文去复现的,主要是实现对BN层的 γ \gamma γ进行剪枝操作,

  • 相关论文:Learning Efficient Convolutional Networks through Network Slimming (ICCV 2017)

我们先来考虑一个问题,深度学习模型里面的卷积层出来之后的特征有非常多,这里面会不会存在一些没有价值的特征及其相关的连接?又如何判断一个特征及其连接是否有价值?

先给出答案:在Batch Normalize层的缩放因子上施加L1正则化(这是上面这篇论文的核心思想,更多细节请自行阅读论文)

优点:

  • 不需要对现有的CNN架构进任何更改
  • 使用L1正则化将BN缩放因子的值推向零
    • 使我们能够识别不重要的通道(或神经元),因为每个缩放因子对应于特定的卷积通道(或全连接层的神经元)
    • 这有利于在接下来的步骤中进行通道级剪枝
  • 附加的正则化项很少会损害性能。不仅如此,在某些情况下,它会导致更高的泛化精度
  • 剪枝不重要的通道有时可能会暂时降低性能,但这个效应可以通过接下来的修剪网络的微调来弥补
  • 剪枝后,由此得到的较窄的网络在模型大小、运行时内存和计算操作方面比初始的宽网络更加紧凑。上述过程可以重复几次,得到一个多通道网络瘦身方案,从而实现更加紧凑的网络。

下面是论文中提出的用于BN层 γ \gamma γ参数稀疏训练的损失函数
L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) L = \sum_{(x,y)} l\Big(f(x, W), y\Big) + \lambda\sum_{\gamma \in \Gamma} g(\gamma) L=(x,y)l(f(x,W),y)+λγΓg(γ)
总的来说,这个损失函数的作用是在分类问题的基础上,加上对BN层 γ \gamma γ参数的L1正则化。在训练过程中,通过调整超参数 λ \lambda λ的大小,可以实现对BN层 γ \gamma γ参数的稀疏训练,即将一部分 γ \gamma γ参数的值变为0,从而达到减少模型复杂度、提高模型泛化性能的效果。

具体实现流程可见下图

剪枝与重参第六课:基于VGG的模型剪枝实战_第2张图片

首先初始化模型获得一个benchmark=>稀疏训练=>剪枝=>微调=>最终模型

2.Prune实战

2.1 说明

我们对模型进行剪枝,主要针对有参数的层:Conv2d、BatchNorm2d、Linear,Pool2d的层只用来做下采样,没有可学习的参数,不用处理。下面是一些关于mask的一些说明

cfg和cfg_mask

  • 在之前的课程中我们对 BatchNorm 进行了稀疏训练
  • 训练完成后我们获取所有的 BatchNorm 的参数数量,将 BatchNorm 所有参数取出来排序
  • 根据剪枝比例 r r r 设置 threshold 阈值,通过 gt() (greater than) 方法得到 mask,小于 threshold 的置零
  • 根据 mask 计算剩余的数量,记录
    • cfg:用于创建新模型
    • cfg_mask:用于剪枝
  • 后面会用到这两个 mask,操作每一层的输入和输出

剪枝与重参第六课:基于VGG的模型剪枝实战_第3张图片

Conv2d

  • weights:(out_channels, in_channels, kernel_size, kernel_size)
  • 利用 mask 做索引,对应赋值
  • 使用 start_mask、end_mask

BatchNorm2d

  • self.weight:存储 γ \gamma γ,(input_size)
  • self.bias:存储 β \beta β,(input_size)
  • 使用 end_mask
  • 更新 start_mask、end_mask

Linear

  • self.weight:(out_features, int_features)
  • self.bias:(out_features)
  • 使用 start_mask

剪枝与重参第六课:基于VGG的模型剪枝实战_第4张图片

2.2 test()

我们先来实现一个test()函数,用于测试prune剪枝后模型的性能,示例代码如下:

import argparse
from utils import get_test_dataloader
import torch

def parse_opt():
    # Prune setting
    parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
    parser.add_argument('--dataset', type=str, default='cifar10', help='training dataset (default: cifar10)')
    parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for test (default: 256)')
    parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument('--depth', type=int, default=11, help='depth of the vgg')
    parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)')
    parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)')
    parser.add_argument('--save', default='logs/', type=str, metavar='PATH', help='path to save pruned model (default: none)')
    args = parser.parse_args()
    return args

def test(model):
    kwargs = {'num_workers' : 1, 'pin_memory' : True} if args.cuda else {}

    test_loader = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)

    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
        accuracy = 100. * correct / len(test_loader.dataset)
        print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
            correct, len(test_loader), accuracy
        ))
    return accuracy / 100

if __name__ == "__main__":
    args = parse_opt()

2.3 加载稀疏训练模型

我们需要使用上节课的train.py来获得稀疏训练模型的权重,可在终端执行如下指令:

python .\train.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 11 --epochs 10

接着我们对稀疏训练后的模型进行加载,因为需要对稀疏训练后BN层的一些参数进行统计,示例代码如下:

import os
import argparse
from models.vgg import VGG
from utils import get_test_dataloader
import torch

def parse_opt():
    ...

def test(model):
    ...

if __name__ == "__main__":
    args = parse_opt()

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    
    model = VGG(depth=args.depth)

    if args.cuda:
        model.cuda()
    
    if args.model:
        if os.path.isfile(args.model):
            print("=> loading checkpoing '{}'".format(args.model))
            checkpoint = torch.load(args.model)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(
                args.model, checkpoint['epoch'], best_prec1
            ))
        else:
            print("=> no checkpoing found at '{}'".format(args.model))
    
    print(model)

2.4 前处理

现在我们对BN层进行prune,将其配置信息cfg和对应的cfg_mask保存下来,供后面使用,示例代码如下:


if __name__ == "__main__":
    
    ...
    
    total = 0
    
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    
    bn = torch.zeros(total)

    index = 0
    # 获取所有BN层的 gamma 参数,存储在nn.BatchNorm2d.weight.data
    # beta 参数存储在nn.BatchNorm2d.bias.data
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index+size)] = m.weight.data.abs().clone()
            index += size
    
    # 获取threshold
    y, i = torch.sort(bn)
    thre_index = int(total * args.percent)
    thre = y[thre_index]

    pruned  = 0
    cfg = []
    cfg_mask = []
    
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre).float().cuda()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(
                k, mask.shape[0], int(torch.sum(mask))
            ))
        elif isinstance(m, nn.MaxPool2d):
            cfg.append('M')
    
    pruned_ratio = pruned / total
    print("Pre-process Sucessful Pruned Ratio: {:.2f}%".format(pruned_ratio * 100.))
    acc = test(model)
    print(cfg)

其打印输出如下:

layer index: 3   total channel: 64       remaining channel: 63
layer index: 7   total channel: 128      remaining channel: 126
layer index: 11          total channel: 256      remaining channel: 227
layer index: 14          total channel: 256      remaining channel: 162
layer index: 18          total channel: 512      remaining channel: 180
layer index: 21          total channel: 512      remaining channel: 194
layer index: 25          total channel: 512      remaining channel: 191
layer index: 28          total channel: 512      remaining channel: 232
Pre-process Sucessful Pruned Ratio: 50.04%
Files already downloaded and verified

Test set: Accuracy: 1757/40 (17.6%)

[63, 'M', 126, 'M', 227, 162, 'M', 180, 194, 'M', 191, 232]

可以看到最终的配置信息与原来的有所不同,具体对比如下:(前面通道数没啥变化,后面通道数剪枝较多,精度下降严重)

# old
[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
# new
[63, 'M', 126, 'M', 227, 162, 'M', 180, 194, 'M', 191, 232]

2.5 建立新模型并存储信息

我们拿到cfg和cfg_mask后就可以对带参数的三个层即Conv2d、BatchNorm2d、Linear进行剪枝的操作了,

我们先通过cfg建立一个新模型,并存储其相关信息,示例代码如下:

if __name__ == "__main__":
    newmodel = VGG(cfg=cfg)
    if args.cuda:
        newmodel.cuda()
    
    num_parameters = sum([param.nelement() for param in newmodel.parameters()])
    savepath = os.path.join(args.save, "prune.txt")
    with open(savepath, 'w') as fp:
        fp.write("Configuation: " + str(cfg) + "\n")
        fp.write("Number of parameters: " + str(num_parameters) + "\n")
        fp.write("Test accuracy: " + str(acc))
    
    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        pass

2.6 BatchNorm层的剪枝

说明:start_mask和end_mask => 对应于Conv+BN层的输入和输出

BN层剪枝的示例代码如下:

if __name__ == "__main__":
    ...
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        if isinstance(m0, nn.BatchNorm2d):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))

                m1.weight.data = m0.weight.data[idx1.tolist()].clone()
                m1.bias.data   = m0.bias.data[idx1.tolist()].clone()
                m1.running_mean = m0.running_mean[idx1.tolist()].clone()
                m1.running_var  = m0.running_var[idx1.tolist()].clone()
                layer_id_in_cfg += 1
                start_mask = end_mask.clone()
                if layer_id_in_cfg < len(cfg_mask):
                    end_mask = cfg_mask[layer_id_in_cfg]

2.7 Conv2d的剪枝

Conv层剪枝的示例代码如下:

if __name__ == "__main__":
    ...
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        ...

        elif isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))

            print("In channels: {:d}, Out channels: {:d}".format(idx0.size, idx1.size))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()

2.8 Linear的剪枝

Linear层剪枝的示例代码如下:

if __name__ == "__main__":
    ...
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        ...
        elif isinstance(m0, nn.Linear):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))

                m1.weight.data = m0.weight.data[:, idx0].clone()
                m1.bias.data   = m0.bias.data.clone() 

3.基于VGG的模型剪枝

完整的示例代码如下:

import os
import argparse
import numpy as np
import torch
import torch.nn as nn

from models.vgg import VGG
from utils import get_test_dataloader


def parse_opt():
    # Prune settings
    parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
    parser.add_argument('--dataset', type=str, default='cifar100', help='training dataset (default: cifar10)')
    parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for testing (default: 256)')
    parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
    parser.add_argument('--depth', type=int, default=19, help='depth of the vgg')
    parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)')
    parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)')
    parser.add_argument('--save', default='logs/', type=str, metavar='PATH', help='path to save pruned model (default: none)')
    args = parser.parse_args()
    return args

# simple test model after Pre-processing prune (simple set BN scales to zeros)
# Define a function named test that takes a PyTorch model as input
def test(model):
    # Set kwargs to num_workers=1 and pin_memory=True if args.cuda is True, 
    # otherwise kwargs is an empty dictionary
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    
    # Create a test data loader for the CIFAR10 dataset if args.dataset is 'cifar10'
    if args.dataset == 'cifar10':
        test_loader = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    # Set the model to evaluation mode
    model.eval()
    # Initialize the number of correct predictions to 0
    correct = 0
    # Turn off gradient calculation during inference
    with torch.no_grad():
        # Loop through the test data
        for data, target in test_loader:
            # Move the data and target tensors to the GPU if args.cuda is True
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            # Compute the output of the model on the input data
            output = model(data)
            # Compute the predictions from the output using the argmax operation
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            # Compute the number of correct predictions and add it to the running total
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    # Compute the test accuracy and print the result
    accuracy = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), accuracy))
    # Return the test accuracy as a float
    return accuracy / 100.



if __name__ == '__main__':
    # Parse command line arguments using the parse_opt() function
    args = parse_opt()
    # Check if CUDA is available and set args.cuda flag accordingly
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    # Create the save directory if it does not exist
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    # Create a new VGG model with the specified depth
    model = VGG(depth=args.depth)
    # Move the model to the GPU if args.cuda is True
    if args.cuda:
        model.cuda()
    # If args.model is not None, 
    # attempt to load a checkpoint from the specified file
    if args.model:
        if os.path.isfile(args.model):
            print("=> loading checkpoint '{}'".format(args.model))
            checkpoint = torch.load(args.model)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
                .format(args.model, checkpoint['epoch'], best_prec1))
        else:
            print("=> no checkpoint found at '{}'".format(args.model))
            
    # Print the model to the console
    print(model)
    # Initialize the total number of channels to 0
    total = 0
    # Loop through the model's modules and count the number of channels in each BatchNorm2d layer
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    # Create a new tensor to store the absolute values of the weights of each BatchNorm2d layer
    bn = torch.zeros(total)
    # Initialize an index variable to 0
    index = 0
    # Loop through the model's modules again and 
    # store the absolute values of the weights of each BatchNorm2d layer in the bn tensor
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index+size)] = m.weight.data.abs().clone()
            index += size
    # Sort the bn tensor and compute the threshold value for pruning
    y, i = torch.sort(bn)
    thre_index = int(total * args.percent)
    thre = y[thre_index]
    
    # Initialize the number of pruned channels to 0 and 
    # create lists to store the new configuration and mask for each layer
    pruned = 0
    cfg = []
    cfg_mask = []
    # Loop through the model's modules a third time and 
    # prune each BatchNorm2d layer that falls below the threshold value
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            # Compute a mask indicating which weights to keep and which to prune
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre).float().cuda()
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            # Apply the mask to the weight and bias tensors of the BatchNorm2d layer
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            # Record the new configuration and mask for this layer
            cfg.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            # Print information about the pruning for this layer
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                format(k, mask.shape[0], int(torch.sum(mask))))
        elif isinstance(m, nn.MaxPool2d):
            # If the module is a MaxPool2d layer, 
            # record it as an 'M' in the configuration list
            cfg.append('M')
    # Compute the ratio of pruned channels to total channels
    pruned_ratio = pruned/total
    # Print a message indicating that the pre-processing was successful
    print('Pre-processing Successful!')
    # Evaluate the pruned model on the test set and 
    # store the accuracy in the acc variable
    acc = test(model)
    
# ============================ Make real prune ============================

    # Print the new configuration to the console
    print(cfg)
    # Initialize a new VGG model with the pruned configuration
    newmodel = VGG(cfg=cfg)
    # Move the new model to the GPU if available
    if args.cuda:
        newmodel.cuda()
    # Compute the number of parameters in the new model 
    num_parameters = sum([param.nelement() for param in newmodel.parameters()])
    # Save the configuration above, number of parameters, and test accuracy to a file
    savepath = os.path.join(args.save, "prune.txt")
    with open(savepath, "w") as fp:
        fp.write("Configuration: \n"+str(cfg)+"\n")
        fp.write("Number of parameters: "+str(num_parameters)+"\n")
        fp.write("Test accuracy: "+str(acc))

    # Initialize variables for the masks corresponding to the start and end of each pruned layer
    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    
    # Loop through the modules of the original and new models
    # Copy the weights and biases of each layer from the original model to the new model
    # Applying the appropriate masks to the weights and biases of the pruned layers
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        
        # ============================ BatchNorm Layers ============================
        # If the module is a BatchNorm2d layer, 
        # compute the indices of the non-zero weights and biases in the new model and 
        # copy them from the original model
        if isinstance(m0, nn.BatchNorm2d):
            # Compute the list of indices of the remaining channels in the current BatchNorm2d layer
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            # Resize the index list if it has only one element
            if idx1.size == 1:
                idx1 = np.resize(idx1,(1,))
            # Compute the weight of the current layer 
            # by copying only the weights of the remaining channels using the index list
            m1.weight.data = m0.weight.data[idx1.tolist()].clone()
            # Compute the bias of the current layer 
            # by copying the bias values of the original layer and then cloned
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            # Compute the running mean of the current layer by 
            # copying the mean values of the original layer and then cloned
            m1.running_mean = m0.running_mean[idx1.tolist()].clone()
            # Compute the running variance of the current layer by 
            # copying the variance values of the original layer and then cloned
            m1.running_var = m0.running_var[idx1.tolist()].clone()
            # Update the masks for the next pruned layer
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
                
        # ============================ Conv2d Layers ============================
        # If the module is a Conv2d layer, 
        # compute the indices of the non-zero weights in the input and output channels and 
        # copy them from the original model
        elif isinstance(m0, nn.Conv2d):
            # Get the indices of input and output channels that are not pruned for this convolutional layer, 
            # by converting the start and end masks from the previous and current layers into numpy arrays, 
            # finding the non-zero elements, and removing the extra dimensions
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            # Print the number of input and output channels that are not pruned
            print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
            # If either idx0 or idx1 has a size of 1, 
            # resize it to (1,) to avoid a broadcasting error.
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            # Extract the weight tensor for this layer from the original model (m0) 
            # by selecting the input and output channels that are not pruned, 
            # and clone it to create a new tensor (w1)
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
            
        # ============================ Linear Layers ============================
        # If the module is a Linear layer, 
        # compute the indices of the non-zero weights in the input channels and 
        # copy them from the original model
        elif isinstance(m0, nn.Linear):
            # Compute the list of indices of the remaining neurons/channels 
            # of the previous layer that connect to this current linear layer
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            # Resize the index list if it has only one element
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            # Compute the weight of the current layer 
            # by copying only the weights of the remaining channels of the previous layer 
            # using the index list
            m1.weight.data = m0.weight.data[:, idx0].clone()
            m1.bias.data   = m0.bias.data.clone()

    torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth'))

    print(newmodel)
    model = newmodel
    test(model)

总结

本次课程完成了对VGG模型的剪枝训练,主要是复现论文中对BN层的 γ \gamma γ参数进行稀疏训练,得到对应的mask后对Conv2d、Batch Normalize以及Linear层进行剪枝,可以看到剪枝后的模型的参数量大大减少(71M=>9.6M),且预测准确率反而提高了(87.4%=>88.4%),而对YOLOv8采用该方法进行剪枝时,其精度会略微下降,但是其参数量会大大减少,具有可应用性,期待下基于YOLOv8的剪枝吧

你可能感兴趣的:(剪枝与重参,模型剪枝,模型重参数,深度学习)