剪枝与重参第二课:修剪方法和稀疏训练

目录

  • 修剪方法和稀疏训练
    • 前言
    • 1.修剪方法
      • 1.1 经典框架:训练-剪枝-微调
      • 1.2 训练时剪枝(rewind)
      • 1.3 removing剪枝
    • 2.dropout and dropconnect
    • 3.稀疏训练(Sparse training)
    • 总结

修剪方法和稀疏训练

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解修剪方法和稀疏训练。

课程大纲可看下面的思维导图

剪枝与重参第二课:修剪方法和稀疏训练_第1张图片

1.修剪方法

修剪方法主要包含训练后剪枝和训练时剪枝两种方法。

下图展示了这两种常见的剪枝方法的流程:

剪枝与重参第二课:修剪方法和稀疏训练_第2张图片

1.1 经典框架:训练-剪枝-微调

训练后剪枝方法包含三个步骤:训练、剪枝、微调。在这种方法中,首先对模型训练以获得初始模型,然后对模型进行剪枝以去除冗余参数,最后对剪枝后的模型进行微调以保存模型性能。可参考下面这篇文献:

Song Han, Jeff Pool, John Tran, and William J Dally. Learning both weights and connections for efficient neural network. In NIPS, 2015.链接

训练后剪枝方法的示例代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

# 1. 训练基础的大网络
class BigModel(nn.Module):
    def __init__(self) -> None:
        super(BigModel, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 准备MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)        
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

def train(model, dataloader, criterion, optimizer, device='cpu', num_epochs=10):
    model.train()
    model.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            # 前向传播
            outputs = model(inputs.view(inputs.size(0), -1))
            loss = criterion(outputs, targets)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")
    
    return model

big_model = BigModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(big_model.parameters(), lr=1e-3)
big_model = train(big_model, train_loader, criterion, optimizer, device='cuda', num_epochs=10)

# 保存训练好的大网络
torch.save(big_model.state_dict(), "big_model.pth")

# 2. 修剪大网络为小网络  <==================================
def prune_network(model, pruning_rate=0.5, method='global'):
    for name, param in model.named_parameters():
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            if method == "global":
                threshold = np.percentile(abs(tensor), pruning_rate * 100)
            else: # local pruning
                threshold = np.percentile(abs(tensor), pruning_rate * 100, axis=1, keepdims=True)
            mask = abs(tensor) > threshold
            param.data = torch.FloatTensor(tensor * mask.astype(float)).to(param.device)
        
big_model.load_state_dict(torch.load("big_model.pth"))
prune_network(big_model, pruning_rate=0.5, method="global")

# 保存修剪后的模型
torch.save(big_model.state_dict(), "pruned_model.pth")

# 3. 以低的学习率做微调
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(big_model.parameters(), lr=1e-4)
finetuned_model = train(big_model, train_loader, criterion, optimizer, device="cuda", num_epochs=10)

# 保存微调后的模型
torch.save(finetuned_model.state_dict(), "finetuned_pruned_model.pth")

上述示例代码展示的是训练后剪枝的过程。首先,我们使用MNIST数据集训练了一个包含三个全连接层的神经网络,并保存训练好的模型。接着,我们使用训练好的模型进行剪枝,剪枝率为50%,使用全局剪枝方法。具体而言,对于每个权重参数,计算其绝对值的百分位数,将小于该百分位数的权重参数设置为0。最后,使用低的学习率对修剪后的模型进行微调,训练10个周期,并保存微调后的模型。

我们可以拿剪枝前后的模型进行相关测试,测试示例代码如下:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# 1. 定义网络模型
class BigModel(nn.Module):
    def __init__(self):
        super(BigModel, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 2. 加载模型和测试数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
model = BigModel()
model.load_state_dict(torch.load("finetuned_pruned_model.pth"))

# 3. 测试模型并计算准确率
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for i, (inputs, targets) in enumerate(test_loader):
        # if i == 10:
        #     break  # 只测试前10个batch的数据
        outputs = model(inputs.view(inputs.size(0), -1))
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        if i == 1:
            # 可视化第一个batch的数据
            fig, axs = plt.subplots(2, 5)
            axs = axs.flatten()
            for j in range(len(axs)):
                axs[j].imshow(inputs[j].squeeze(), cmap='gray')
                axs[j].set_title(f"Target: {targets[j]}, Predicted: {predicted[j]}")
                axs[j].axis('off')
            # plt.savefig("fine-tune.png", bbox_inches="tight")
            plt.show()


accuracy = 100 * correct / total
print(f"Accuracy: {accuracy:.2f}%")

# 97.65% 97.73% 98.57% 97.78%

下图展示了该模型的部分预测效果:

剪枝与重参第二课:修剪方法和稀疏训练_第3张图片

1.2 训练时剪枝(rewind)

训练时剪枝,也称为剪枝回溯(pruning with rewinding),在这种方法中,模型的训练和剪枝是交替进行的,模型在训练过程中会被周期性地剪枝,同时保留训练过程中的最佳模型作为最终模型。可参考下面这篇文献:

He Y, Kang G, Dong X, et al. Soft filter pruning for accelerating deep convolutional neural networks[J]. arXiv preprint arXiv:1808.06866, 2018.链接

训练时剪枝方法的示例代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

# 1. 训练基础的大网络
class BigModel(nn.Module):
    def __init__(self) -> None:
        super(BigModel, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 准备MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)        
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 2. 修剪大网络为小网络  <==================================
def prune_network(model, pruning_rate=0.5, method='global'):
    for name, param in model.named_parameters():
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            if method == "global":
                threshold = np.percentile(abs(tensor), pruning_rate * 100)
            else: # local pruning
                threshold = np.percentile(abs(tensor), pruning_rate * 100, axis=1, keepdims=True)
            mask = abs(tensor) > threshold
            param.data = torch.FloatTensor(tensor * mask.astype(float)).to(param.device)

# 3. 训练时修剪
def train_with_pruning(model, dataloader, criterion, optimizer, device='cpu', num_epochs=10, pruning_rate=0.5):
    model.train()
    model.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            # 前向传播
            outputs = model(inputs.view(inputs.size(0), -1))
            loss = criterion(outputs, targets)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")
    
        # 在每个 epoch 结束后进行剪枝
        prune_network(model, pruning_rate, method="global") # <================================== just prune the weights ot 0 but still allow them to grow back by optimizer.step()

    return model

big_model = BigModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(big_model.parameters(), lr=1e-3)
big_model = train_with_pruning(big_model, train_loader, criterion, optimizer, device='cuda', num_epochs=10, pruning_rate=0.1)

# 保存训练好的大网络
torch.save(big_model.state_dict(), "trained_with_pruning_model.pth")

上述示例代码展示的是训练时剪枝的过程。首先,我们定义了一个包含三个全连接层的神经网络,接着定义了一个剪枝网络的函数prune_network,该函数可以通过设置不同的修剪率和修剪方式来生成小模型。然后我们通过训练函数train_with_pruning,在训练的每个epoch结束后调用剪枝函数进行修剪,实现了训练时剪枝的功能。最后我们将训练好的模型保存到文件中。

1.3 removing剪枝

在之前的剪枝方式中我们使用的都是直接填0的剪枝,当然还有另一种剪枝方式就是直接将不满足条件的元素remove,二者的优缺点对比如下

直接填0的剪枝

优点:

  • 保留了原始网络结构,便于实现和微调
  • 部分减少模型的计算量

缺点:

  • 零权重仍然需要存储,因此不会减少内存使用
  • 一些硬件和软件无法利用稀疏计算,从而无法提高计算效率

直接remove的剪枝

优点:

  • 可以减少模型的计算量和内存使用
  • 可以通过减少网络容量来防止过拟合

缺点:

  • 可能会降低网络的表示能力,导致性能下降
  • 需要对网络结构进行改变,这可能会增加实现和微调的复杂性

现在我们可以来进行直接remove的剪枝,首先我们需要训练一个model,其示例代码如下:


from random import shuffle
from turtle import down, forward
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

# 1. Train a large base network
class BigModel(nn.Module):
    def __init__(self) -> None:
        super(BigModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.fc = nn.Linear(16 * 28 * 28, 10)

        # Initialize l1norm as a parameter and register as buffer
        self.conv1_l1norm = nn.Parameter(torch.Tensor(32), requires_grad=False)
        self.conv2_l1norm = nn.Parameter(torch.Tensor(16), requires_grad=False)
        self.register_buffer('conv1_l1norm_buffer', self.conv1_l1norm)
        self.register_buffer('conv2_l1norm_buffer', self.conv2_l1norm)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        self.conv1_l1norm.data = torch.sum(torch.abs(self.conv1.weight.data), dim=(1, 2, 3))

        x = torch.relu(self.conv2(x))
        self.conv2_l1norm.data = torch.sum(torch.abs(self.conv2.weight.data), dim=(1, 2, 3))

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
    
    # Training function
def train(model, dataloader, criterion, optimizer, device='cpu', num_epochs=10):
    model.train()
    model.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward propagation
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # print(f"Loss: {running_loss / len(dataloader)}")
            
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")

    return model

if __name__ == "__main__":
    # Prepare the MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

    big_model = BigModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(big_model.parameters(), lr=1e-3)
    big_model = train(big_model, train_loader, criterion, optimizer, device='cuda', num_epochs=10)

    # Save the trained big network
    torch.save(big_model.state_dict(), "big_model.pth")
    
    # Set the input shape of the model
    dummy_input = torch.randn(1, 1, 28, 28).to('cuda')
    
    # Export the model to ONNX format
    torch.onnx.export(big_model, dummy_input, "big_model.onnx")

在上面的示例代码中,我们训练了一个包含两个卷积层和一个全连接层的简单model,值得注意的是,在BigModel类中,我们初始化了两个参数conv1_l1normconv2_l1norm,分别对应两个卷积层的L1范数,用于后续remove剪枝时求取threshold。其中nn.Parameter()用于将参数设置为可训练的模型参数,而requires_grad=False则表示在模型训练过程中不需要计算该参数的梯度。

接下来,self.conv1_l1normself.conv2_l1norm被注册为模型缓冲区,通过self.register_buffer()方法将其添加到缓存区中。这样做的目的是使得这两个参数可以被保存到模型中,在模型加载时可以直接获取这两个参数的值,而无需重新计算。在模型训练过程中,这两个参数的值会随着模型训练的进行而发生变化,用于记录模型中对应的卷积层的L1范数。

上述模型训练完成后我们就可以进行remove剪枝了,其具体步骤可以看下图:

剪枝与重参第二课:修剪方法和稀疏训练_第4张图片

示例代码如下:

from remove import BigModel, train
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 1. load a model and inspect it
model = BigModel()
model.load_state_dict(torch.load("big_model.pth"))

# 2. get the global threshold according to the l1norm
all_l1norm_values = []
for name, m in model.named_modules():
    if isinstance(m, nn.Conv2d):
        l1norm_buffer_name = f"{name}_l1norm_buffer"
        l1norm = getattr(model, l1norm_buffer_name)
        all_l1norm_values.append(l1norm)

all_l1norm_values = torch.cat(all_l1norm_values)

threshold = torch.sort(all_l1norm_values)[0][int(len(all_l1norm_values) * 0.5)]

# 3. prune the conv based on the l1norm along axis = 0 for each weight tensor
conv1 = model.conv1 # torch.Size([32, 1, 3, 3])
conv2 = model.conv2 #            [16, 32,3, 3]
fc    = model.fc

conv1_l1norm_buffer = model.conv1_l1norm_buffer # 32
conv2_l1norm_buffer = model.conv2_l1norm_buffer # 16

# Top conv
keep_idxs = torch.where(conv1_l1norm_buffer >= threshold)[0]
k = len(keep_idxs)

conv1.weight.data = conv1.weight.data[keep_idxs]
conv1.bias.data   = conv1.bias.data[keep_idxs]
conv1_l1norm_buffer.data = conv1_l1norm_buffer.data[keep_idxs]
conv1.out_channels = k

# Bottom conv
_, keep_idxs = torch.topk(conv2_l1norm_buffer, k)
conv2.weight.data = conv2.weight.data[:,keep_idxs]
conv2.in_channels = k

# Save the pruned model state_dict
torch.save(model.state_dict(), "pruned_model.pth")

# Set the input shape of the model
dummy_input = torch.randn(1, 1, 28, 28)

# Export the model to ONNX format
torch.onnx.export(model, dummy_input, "pruned_model.onnx")

#################################### FINE TUNE ######################################
# Prepare the MNIST dataset
model.load_state_dict(torch.load("pruned_model.pth"))
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
big_model = train(model, train_loader, criterion, optimizer, device='cuda', num_epochs=3)

# Save the trained big network
torch.save(model.state_dict(), "pruned_model_after_finetune.pth")

上面的示例代码就是按照上述remove剪枝图的操作步骤进行的,具体如下:

  • 1.加载已经训练好的模型BigModel
  • 2.获取全局阈值,通过模型的l1norm_buffer计算获得
  • 3.剪枝卷积层,将每一层卷积核中L1norm小于threshold的部分移除
  • 4.导出剪枝后的模型
  • 5.对剪枝后的模型进行微调
  • 6.保存微调后的模型

2.dropout and dropconnect

拓展:实现dropout和dropconnect layer

dropoutdropconnect都是常见的神经网络正则化技术,它们的主要作用是减少神经网络的过拟合现象,提高模型的泛化能力。但是它们在实现上有所不同,下面分别介绍一下它们的区别。(from chatGPT)

dropout
dropout是Hinton团队在2012年提出的正则化方法。它的实现方式是在神经网络的训练过程中,以一定的概率随机地删除一部分神经元,即将神经元的输出设置为0,从而使神经元不会过度依赖其他神经元。dropout可以看做是一种模型平均方法,可以让不同的神经元组合成不同的子网络,增加了模型的泛化能力。

dropconnect
dropconnect是Wan等人在2013年提出的正则化方法。它的实现方式是在神经网络的训练过程中,以一定的概率随机地删除一部分连接权重,即将权重设置为0,从而使每个神经元都不能过度依赖其他神经元的输入。相比于dropout,dropconnect删除的是权重,而不是神经元的输出,从而可以更加灵活地控制神经元之间的相互关系。

综上所述,dropout和dropconnect的主要区别在于它们删除的是神经元输出还是连接权重。由于删除的对象不同,它们对于模型的正则化效果也会有所不同,需要根据具体的应用场景选择合适的正则化方法。

下面展示了三种网络分别是No-Drop、DropOut以及DropConnect:

剪枝与重参第二课:修剪方法和稀疏训练_第5张图片

示例代码如下:


import numpy as np

def dropout_layer(x, dropout_rate):
    dropout_mask = np.random.randn(*x.shape) > dropout_rate
    # dropout_mask = np.random.binomial(1, 1-dropout_rate, size=x.shape)
    return x * dropout_mask / (1 - dropout_rate)

def dropconnect_layer(weights, input_data, dropconnect_rate):
    dropconnect_mask = np.random.randn(*weights.shape) > dropconnect_rate
    masked_weights = weights * dropconnect_mask
    return input_data @ masked_weights

# Example usage for dropout
input_data = np.array([[0.1, 0.5, 0.2],
                       [0.8, 0.6, 0.7],
                       [0.9, 0.3, 0.4]])

dropout_rate = 0.5
output_data_dropout = dropout_layer(input_data, dropout_rate)
print(output_data_dropout)

# Example usage for dropconnect
dropconnect_rate = 0.5
weights = np.random.randn(3,4)
output_data_dropconnect = dropconnect_layer(weights, input_data, dropout_rate)
print(output_data_dropconnect)

上述示例代码中,dropout_layer返回的值是x * dropout_mask / (1 - dropout_rate)(1 - dropout_rate)是用来缩放输出值的。在训练过程中,因为一部分神经元被随机丢弃了(相当于其他神经元输出值被放大了),为了保持总体的期望值不变,需要将剩余神经元的输出值进行缩放,因此要除以(1 - dropout_rate)。这样做的目的是使得输出值的期望值保持不变,同时方差变小,进行增强模型的泛化能力。

我们对比原始的model和剪枝后的model发现二者的大小差不多,二者模型对比可见下图:

剪枝与重参第二课:修剪方法和稀疏训练_第6张图片

二者模型参数量的计算如下:

  • Original model
    • conv1:1x32x3x3 = 288 parameters
    • conv2:32x16x3x3 = 4608 parameters
    • fc:16x28x28x10 = 125440 parameters
    • Total:288 + 4608 + 125440 = 130336 parameters
  • Pruned model
    • conv1:1x8x3x3 = 72 parameters
    • conv2:16x8x3x3 = 1152 parameters
    • fc:16x28x28x10 = 125440 parameters
    • Total:72 + 1152 + 125440 = 126664 parameters

对比后可知原因在于主要参数量集中在fc层,而我们只对conv层进行剪枝,却对fc层没有进行任何操作,所以最后二者的参数量差别不大

3.稀疏训练(Sparse training)

稀疏训练(Sparse training)最开始起源于下面这篇文章:

Decebal Constantin Mocanu, Elena Mocanu, Peter Stone, Phuong H Nguyen, Madeleine Gibescu, and Antonio Liotta. Scalable training of artificial neural networks with adaptive sparse connectivity inspired by network science. Nature communications, 9(1):1–12, 2018.链接

我们关注其设计方式即可,引出关于稀疏训练的思考,下图说明了上面这篇文章的实现过程:

剪枝与重参第二课:修剪方法和稀疏训练_第7张图片

其步骤主要包含以下四步

  • 1.初始化一个带有随机mask的网络
  • 2.训练这个pruned network 一个epoch
  • 3.去掉一些权重比较小的一些weights(或者不满自定义条件的weights)
  • 4.重新生成(regrow)同样数量的random weights

示例代码如下:

# raw net
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# Define the network architecture
class SparseNet(nn.Module):
    def __init__(self, sparsity_rate, mutation_rate = 0.5):
        super(SparseNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
        self.sparsity_rate = sparsity_rate
        self.mutation_rate = mutation_rate
        self.initialize_masks() # <== 1.initialize a network with random mask

    def forward(self, x):
        x = x.view(-1, 784)
        x = x @ (self.fc1.weight * self.mask1.to(x.device)).T + self.fc1.bias
        x = torch.relu(x)
        x = x @ (self.fc2.weight * self.mask2.to(x.device)).T + self.fc2.bias
        return x

    def initialize_masks(self):
        self.mask1 = self.create_mask(self.fc1.weight, self.sparsity_rate)
        self.mask2 = self.create_mask(self.fc2.weight, self.sparsity_rate)

    def create_mask(self, weight, sparsity_rate):
        k = int(sparsity_rate * weight.numel())
        _, indices = torch.topk(weight.abs().view(-1), k, largest=False) # take the minimum k elements
        mask = torch.ones_like(weight, dtype=bool)
        mask.view(-1)[indices] = False
        return mask  # <== 1.initialize a network with random mask

    def update_masks(self):
        self.mask1 = self.mutate_mask(self.fc1.weight, self.mask1, self.mutation_rate)
        self.mask2 = self.mutate_mask(self.fc2.weight, self.mask2, self.mutation_rate)
    
    def mutate_mask(self, weight, mask, mutation_rate=0.5): # weight and mask: 2d shape
        # Find the number of elements in the mask that are true
        num_true = torch.count_nonzero(mask)

        # Compute the number of elements to mutate
        mutate_num = int(mutation_rate * num_true)

        # 3) pruning a certain amount of weights of lower magnitude
        true_indices_2d = torch.where(mask == True) # index the 2d mask where is true
        true_element_1d_idx_prune = torch.topk(weight[true_indices_2d], mutate_num, largest=False)[1]
        
        for i in true_element_1d_idx_prune:
            mask[true_indices_2d[0][i], true_indices_2d[1][i]] = False

        # 4) regrowing the same amount of random weights
        # Get the indices of the False elements in the mask
        false_indices = torch.nonzero(~mask)

        # Randomly select n indices from the false_indices tensor
        random_indices = torch.randperm(false_indices.shape[0])[:mutate_num]

        # the element to be regrow
        regrow_indices = false_indices[random_indices]
        for regrow_idx in regrow_indices:
            mask[tuple(regrow_idx)] = True
        
        return mask
        
# Set the device to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize the network, loss function, and optimizer
sparsity_rate = 0.5
model = SparseNet(sparsity_rate).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Training loop
n_epochs = 10
for epoch in range(n_epochs):
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # Move the data to the device
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        # print(f"Loss: {running_loss / (batch_idx+1)}")
    
    # Update masks
    model.update_masks() # generate a new mask based on the update weights

    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {running_loss / (batch_idx+1)}")

上面示例代码演示了稀疏训练的过程,其中包括四个步骤:

  • 1.初始化带有随机mask的网络:首先我们定义了一个包含两个线性层的神经网络,同时使用create_mask方法为每个线性层创建一个与权重相同形状的mask,通过top-k方法选择一部分元素变成0,实现了一定的稀疏性,其中sparsity_rate为稀疏率
  • 2.训练一个epoch的pruned network:使用随机mask训练网络,然后更新mask
  • 3.剪枝权重:将权重较小的一部分权重剪枝,对应的mask中的元素变成0
  • 4.重新regrow同样数量的random weights:在mask中元素为0的位置随机选择与剪枝的元素数量相同,将其对应的元素重新生成

总结

本次课程首先学习了训练后剪枝和训练时剪枝(rewind)两种方法,同时拓展学习了dropout和dropconnect,一个是使得神经元输出置0,一个是使得权重置0,最后回顾了一个经典的稀疏训练流程,引发对稀疏训练的思考。

你可能感兴趣的:(剪枝与重参,模型剪枝,模型重参数化,深度学习)