论文看了好几遍,也看了一些讨论,讨论在:https://www.reddit.com/r/MachineLearning/comments/ayh2hf/r_repr_improved_training_of_convolutional_filters/eozi40e/
参考了这个复现:https://github.com/siahuat0727/RePr/blob/master/main.py
最后的结果就是:没达到论文效果,但是有点提升。
上次复现方式:上次复现使用的是keras,这次使用siahuat0727的代码,并稍作了修改。keras代码的冗余度较高,且没有对剪裁filters停止梯度更新而是在每个batch重新置0.siahuat0727的代码在训练过程中是停止pruned filters梯度更新的。
关于QR分解求解正交向量的问题:一个矩阵进行QR分解后,Q是正交方正,R是上三角矩阵。对于列满秩的矩阵A,A=QR后R存在零行。又因为Q.T=Q=Q逆,所以Q.TA=R,Q.T中的最后n行与A乘得到R中的最后n行,R中最后n行为零。所以取出最后n行的向量就是重新初始化的向量。
注:之前keras的代码就不删除了,我将修改后的siahuat0727的代码放在最前面,只放置修改过的部分,其他代码请到siahuat0727的github上查看。https://github.com/siahuat0727/RePr/blob/master/main.py
这里画图的部分我用的是visdom。
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function
import math
import visdom
import argparse
import time
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from models import Vanilla
from average_meter import AverageMeter
from utils import qr_null, test_filter_sparsity, accuracy
# from tensorboardX import SummaryWriter
# import tensorflow as tf
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
parser.add_argument('--repr', action='store_true', help="whether to use RePr training scheme")
parser.add_argument('--S1', type=int, default=20, help="S1 epochs for RePr")
parser.add_argument('--S2', type=int, default=10, help="S2 epochs for RePr")
parser.add_argument('--epochs', type=int, default=100, help="total epochs for training")
parser.add_argument('--workers', type=int, default=0, help="number of worker to load data")
parser.add_argument('--print_freq', type=int, default=50, help="print frequency")
parser.add_argument('--gpu', type=int, default=0, help="gpu id")
parser.add_argument('--save_model', type=str, default='best.pt', help="path to save model")
parser.add_argument('--prune_ratio', type=float, default=0.3, help="prune ratio")
parser.add_argument('--comment', type=str, default='', help="tag for tensorboardX event name")
parser.add_argument('--zero_init', action='store_true', help="whether to initialize with zero")
def train(train_loader, criterion, optimizer, epoch, model, viz, train_loss_win, train_acc_win, mask, args, conv_weights):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
# switch to train mode
model.train()
end = time.time() # 返回当前时间戳
for i, (data, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
if args.gpu is not None: # TODO None?
data = data.cuda(args.gpu, non_blocking=True) # 将数据放在gpu上,非阻塞
target = target.cuda(args.gpu, non_blocking=True)
output = model(data)
loss = criterion(output, target)
acc1, _ = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), data.size(0))
top1.update(acc1[0], data.size(0))
optimizer.zero_grad()
loss.backward()
S1, S2 = args.S1, args.S2
if args.repr and any(s1 <= epoch < s1+S2 for s1 in range(S1, args.epochs, S1+S2)): # 运行到指定epoch
if i == 0:
print('freeze for this epoch')
with torch.no_grad():
for name, W in conv_weights:
W.grad[mask[name]] = 0 # 裁剪filter停止梯度更新
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'LR {lr:.3f}\t'
.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1,
lr=optimizer.param_groups[0]['lr']))
end = time.time()
viz.line(Y=[losses.avg], X=[epoch], update='append', win=train_loss_win)
viz.line(Y=[top1.avg.item()], X=[epoch], update='append', win=train_acc_win)
# writer.add_scalar('Train/Acc', top1.avg, epoch) # tensorboard
# writer.add_scalar('Train/Loss', losses.avg, epoch)
def validate(val_loader, criterion, model, viz, test_loss_win, test_acc_win, args, epoch, best_acc):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (data, target) in enumerate(val_loader):
if args.gpu is not None: # TODO None?
data = data.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
# compute output
output = model(data)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, _ = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), data.size(0))
top1.update(acc1[0], data.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1))
end = time.time()
print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1))
viz.line(Y=[losses.avg], X=[epoch], update='append', win=test_loss_win)
viz.line(Y=[top1.avg.item()], X=[epoch], update='append', win=test_acc_win)
# writer.add_scalar('Test/Acc', top1.avg, epoch)
# writer.add_scalar('Test/Loss', losses.avg, epoch)
if top1.avg.item() > best_acc:
print('new best_acc is {top1.avg:.3f}'.format(top1=top1))
print('saving model {}'.format(args.save_model))
torch.save(model.state_dict(), args.save_model)
return top1.avg.item()
def pruning(conv_weights, prune_ratio):
print('Pruning...')
# calculate inter-filter orthogonality
inter_filter_ortho = {}
for name, W in conv_weights:
size = W.size()
W2d = W.view(size[0], -1) # 变成二维数据
W2d = F.normalize(W2d, p=2, dim=1) # 对输入的数据(tensor)进行指定维度的L2_norm运算。
W_WT = torch.mm(W2d, W2d.transpose(0, 1)) # 得到相关性矩阵
I = torch.eye(W_WT.size()[0], dtype=torch.float32).cuda()# 单位矩阵
P = torch.abs(W_WT - I)
P = P.sum(dim=1) / size[0] # 求行平均值,变成一维
inter_filter_ortho[name] = P.cpu().detach().numpy()
# the ranking is computed overall the filters in the network
ranks = np.concatenate([v.flatten() for v in inter_filter_ortho.values()])
threshold = np.percentile(ranks, 100*(1-prune_ratio)) # 将百分位数设置为阈值
prune = {}
mask = {}
drop_filters = {}
for name, W in conv_weights:
prune[name] = inter_filter_ortho[name] > threshold # e.g. [True, False, True, True, False] 找出相关性大的filters
# get indice of bad filters
mask[name] = np.where(prune[name])[0] # e.g. [0, 2, 3] # 找到要裁剪的filter的索引
drop_filters[name] = None
if mask[name].size > 0:
with torch.no_grad():
drop_filters[name] = W.data[mask[name]].view(mask[name].size, -1).cpu().numpy()
W.data[mask[name]] = 0 # 将对应的filter置为0
test_filter_sparsity(conv_weights)
return prune, mask, drop_filters
def reinitialize(mask, drop_filters, conv_weights, fc_weights, zero_init):
print('Reinitializing...')
with torch.no_grad():
prev_layer_name = None
prev_num_filters = None
for name, W in conv_weights + fc_weights:
if W.dim() == 4 and drop_filters[name] is not None: # conv weights
# find null space
size = W.size()
stdv = 1. / math.sqrt(size[1]*size[2]*size[3]) # https://github.com/pytorch/pytorch/blob/08891b0a4e08e2c642deac2042a02238a4d34c67/torch/nn/modules/conv.py#L40-L47
W2d = W.view(size[0], -1).cpu().numpy()
null_space = qr_null(np.vstack((drop_filters[name], W2d)))
null_space = torch.from_numpy(null_space).cuda()
if null_space.size == 0:
W.data[mask[name]].uniform_(-stdv, stdv)
else:
null_space = null_space.transpose(0, 1).view(-1, size[1], size[2], size[3])
null_count = 0
for mask_idx in mask[name]:
if null_count < null_space.size(0):
W.data[mask_idx] = null_space.data[null_count].clamp_(-stdv, stdv)
null_count += 1
else:
W.data[mask_idx].uniform_(-stdv, stdv)
# # mask channels of prev-layer-pruned-filters' outputs
# if prev_layer_name is not None:
# if W.dim() == 4: # conv
# if zero_init:
# W.data[:, mask[prev_layer_name]] = 0
# else:
# W.data[:, mask[prev_layer_name]].uniform_(-stdv, stdv)
# elif W.dim() == 2: # fc
# if zero_init:
# W.view(W.size(0), prev_num_filters, -1).data[:, mask[prev_layer_name]] = 0
# else:
# stdv = 1. / math.sqrt(W.size(1))
# W.view(W.size(0), prev_num_filters, -1).data[:, mask[prev_layer_name]].uniform_(-stdv, stdv)
# prev_layer_name, prev_num_filters = name, W.size(0)
test_filter_sparsity(conv_weights)
def main():
viz = visdom.Visdom(env='repr') # 定义好环境
if not torch.cuda.is_available():
raise Exception("Only support GPU training")
cudnn.benchmark = True # 加速卷积运算
args = parser.parse_args()
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([ # 数据增广
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=args.workers)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=args.workers)
# Model
print('==> Building model..')
model = Vanilla()
print(model)
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda()
else:
model.cuda()
model = torch.nn.DataParallel(model)
conv_weights = [] # 卷积层参数
fc_weights = [] # 全连接层参数
for name, W in model.named_parameters():
if W.dim() == 4: # 卷积层参数
conv_weights.append((name, W))
elif W.dim() == 2: # 全连接层参数
fc_weights.append((name, W))
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
train_loss_win = viz.line([0.0], [0.], win='train_loss', opts=dict(title='train loss',legend=['trian loss']))# 先定义好窗口
train_acc_win = viz.line([0.0], [0.], win='train_acc', opts=dict(title='train acc',legend=['trian acc']))# 先定义好窗口
test_loss_win = viz.line([0.0], [0.], win='test_loss', opts=dict(title='test loss',legend=['test loss']))# 先定义好窗口
test_acc_win = viz.line([0.0], [0.], win='test_acc', opts=dict(title='test acc',legend=['test acc']))# 先定义好窗口
# comment = "-{}-{}-{}".format("repr" if args.repr else "norepr", args.epochs, args.comment)
# writer = SummaryWriter(comment=comment)
mask = None
drop_filters = None
best_acc = 0 # best test accuracy
prune_map = []
for epoch in range(args.epochs):
if args.repr:
# check if the end of S1 stage
if any(epoch == s for s in range(args.S1, args.epochs, args.S1+args.S2)):
prune, mask, drop_filters = pruning(conv_weights, args.prune_ratio)
prune_map.append(np.concatenate(list(prune.values())))
# check if the end of S2 stage
if any(epoch == s for s in range(args.S1+args.S2, args.epochs, args.S1+args.S2)):
reinitialize(mask, drop_filters, conv_weights, fc_weights, args.zero_init)
# scheduler.step()
train(trainloader, criterion, optimizer, epoch, model, viz, train_loss_win, train_acc_win, mask, args, conv_weights)
acc = validate(testloader, criterion, model, viz, test_loss_win, test_acc_win, args, epoch, best_acc)
scheduler.step()
best_acc = max(best_acc, acc)
test_filter_sparsity(conv_weights)
# writer.close()
print('overall best_acc is {}'.format(best_acc))
# # Shows which filters turn off as training progresses
# if args.repr:
# prune_map = np.array(prune_map).transpose()
# print(prune_map)
# plt.matshow(prune_map.astype(np.int), cmap=ListedColormap(['k', 'w']))
# plt.xticks(np.arange(prune_map.shape[1]))
# plt.yticks(np.arange(prune_map.shape[0]))
# plt.title('Filters on/off map\nwhite: off (pruned)\nblack: on')
# plt.xlabel('Pruning stage')
# plt.ylabel('Filter index from shallower layer to deeper layer')
# plt.savefig('{}-{}.png'.format(
# datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H:%M:%S'),
# comment))
if __name__ == '__main__':
main()
效果图:
之前的内容:
思考:
1.首先一点是,在ranking的时候是进行全局的ranking,就是将所有的filters放在一起prune。但是O(公式2)是通过层内的计算而来的。生成W(公式1)是先将flatten之后的filter进行了归一化。详细内容可以看论文的第五部分。要注意的是:在讨论中,作者提到,在进行rank时不考虑第一个卷积层。
2.重新初始化 论文中的方法是用QR分解。我这里产生过一个问题,假如filters(全局)的个数远大于flat后的权重,或者每一层的权重尺寸不一样,后面的QR分解怎么操作。因为文章说了,在重新初始化时新的权重是与原来被prune的权重和当前新的权重同时正交的。
3.论文中的figure1 该训练图像很稳定,并且在reinitialize的时候没有出现下降的情况,在论文中的figure7中是出现了下降的。我在实验中也是出现下降的。
我的训练图像(出现reinit后的下降):
其实测试过程也不是很稳定(图中标错了,蓝色是训练acc,橙色是testacc):
4.说一下结果:我重复训练了几次,结果不是很稳定,可能是我还没完全理解作者的思想或者代码写的存在问题,但这个工作也算是告一段落了。
总结: 1.可能使用的网络和作者不太一样,参数设置也存在出入,但是实验还是有一些效果的。我总共训练了大概20次,最好的一次就是测试集的正确率从67%(standard)上升到了70%(RePr)。 2.网络模型可能用的不太一样,但是如果正确的完成了总是有点效果的。
贴一些主要的代码:
def standard(shape=(32, 32, 3), num_classes=10):
modelinput = Input(shape)
conv1 = Conv2D(32, (3, 3))(modelinput)
bn1 = BatchNormalization()(conv1)
act1 = ReLU()(bn1)
pool1 = MaxPooling2D((2, 2))(act1)
conv2 = Conv2D(32, (3, 3))(pool1)
bn2 = BatchNormalization()(conv2)
act2 = ReLU()(bn2)
pool2 = MaxPooling2D((2, 2))(act2)
conv3 = Conv2D(32, (3, 3))(pool2)
bn3 = BatchNormalization()(conv3)
act3 = ReLU()(bn3)
pool3 = MaxPooling2D((2, 2))(act3)
flat = Flatten()(pool3)
dense1 = Dense(512)(flat)
act4 = ReLU()(flat)
drop = Dropout(0.5)(act4)
dense2 = Dense(num_classes)(drop)
act5 = Softmax()(dense2)
model = Model(modelinput, act5)
return model
def get_convlayername(model):
'''
获取卷积层的名称
# 参数
model: 神经网络模型
'''
layername = []
for i in range(len(model.layers)):
# 将模型中所有层的名称存入列表
layername.append(model.layers[i].name)
# 将卷积层分离出来
convlayername = [layername[name] for name in range(len(layername)) if 'conv2d' in layername[name]]
return convlayername[1:] # 不包括第一层
def prunefilters(model, convlayername, count=0):
'''
裁剪filters
# 参数
model: 神经网络模型
convlayername: 保存所有卷积层(2D)的名称
count: 用于存储每层filters的起始index
'''
convnum = len(convlayername) # 卷积层的个数
params = [i for i in range(convnum)]
weight = [i for i in range(convnum)]
MASK = [i for i in range(convnum)]
rank = dict() # 初始化存储rank的字典
drop = []
index1 = 0
index2 = 0
for j in range(convnum):
# 保存卷积层的权重到一个列表,列表的每个元素是一个数组
params[j] = model.get_layer(convlayername[j]).get_weights() # 将权重转置后才是正常的数组排列(32,32,3,3)
weight[j] = params[j][0].T
filternum = weight[j].shape[0] # 获取每一层filter的个数
# 初始化一个用于判断正交性的矩阵
W = np.zeros((weight[j].shape[0], weight[j].shape[2]*weight[j].shape[3]*weight[j].shape[1]), dtype='float32')
for x in range(filternum):
# filters是一个列表,它的每一个元素是包含一个卷积层所有filter(1D)的列表
filter = weight[j][x,:,:,:].flatten()
filter_length = np.linalg.norm(filter)
eps = np.finfo(filter_length.dtype).eps
filter_length = max([filter_length, eps])
filter_norm = filter / filter_length # 归一化
# 将每一层的filters放到矩阵的每一行
W[x,:] = filter_norm
# 计算层内正交性
I = np.identity(filternum)
P = abs(np.dot(W, W.T) - I)
O = P.sum(axis=1) / 32 # 计算每行元素之和
for index, o in enumerate(O):
rank.update({index+count: o})
count = filternum + count
# 对字典进行排序,在所有filters上进行ranking
ranking = sorted(rank.items(), key=lambda x: x[1]) # ranking为一个列表,其元素是存放键值的元组
for t in range(int(len(ranking)*0.8), len(ranking)):
drop.append(ranking[t][0])
for j in range(convnum):
MASK[j] = np.ones((weight[j].shape), dtype='float32')
index2 = weight[j].shape[0] + index1
for a in drop:
if a >= index1 and a < index2:
MASK[j][a-index1,:,:,:] = 0
index1 = index2
# weight[j] = (weight[j] * MASK[j]).T
# for j in range(convnum):
# params[j][0] = weight[j]
# model.get_layer(convlayername[j]).set_weights(params[j])
return MASK, weight, drop, convnum, convlayername
def Mask(model, mask):
convlayername = get_convlayername(model)
for i in range(len(convlayername)):
Params = [i for i in range(len(convlayername))]
Weight = [i for i in range(len(convlayername))]
Params[i] = model.get_layer(convlayername[i]).get_weights()
Weight[i] = (Params[i][0].T*mask[i]).T
Params[i][0] = Weight[i]
model.get_layer(convlayername[i]).set_weights(Params[i])
prune_callback = LambdaCallback(
on_batch_end=lambda batch,logs: Mask(model, mask))
def reinit(model, weight, drop, convnum, convlayername):
index1 = 0
index2 = 0
new_params = [i for i in range(convnum)]
new_weight = [i for i in range(convnum)]
for j in range(convnum):
new_params[j] = model.get_layer(convlayername[j]).get_weights()
new_weight[j] = new_params[j][0].T
stack_new_filters = new_weight[0]
stack_filters = weight[0]
filter_index1 = 0
filter_index2 = 0
for i in range(len(new_weight)-1):
next_new_filter = new_weight[i+1]
next_filter = weight[i+1]
stack_new_filters = np.vstack((stack_new_filters, next_new_filter))
stack_filters = np.vstack((stack_filters, next_filter))
stack_new_filters_flat = np.zeros((stack_new_filters.shape[0],
stack_new_filters.shape[1]*stack_new_filters.shape[2]*stack_new_filters.shape[3]), dtype='float32')
stack_filters_flat = np.zeros((stack_filters.shape[0],
stack_filters.shape[1]*stack_filters.shape[2]*stack_filters.shape[3]), dtype='float32')
for p in range(stack_new_filters.shape[0]):
stack_new_filters_flat[p] = stack_new_filters[p].flatten()
stack_filters_flat[p] = stack_filters[p].flatten()
q = np.zeros((stack_new_filters_flat.shape[0]), dtype='float32')
tol = None
reinit = None
solve = None
for b in drop:
Q, R= qr(stack_new_filters_flat.T)
for k in range(R.shape[0]):
if np.abs(np.diag(R)[k])==0:
# print(k)
reinit = Q.T[k]
break
null_space = reinit
stack_new_filters_flat[b] = null_space
for filter_in_stack in range(stack_new_filters_flat.shape[0]):
stack_new_filters[filter_in_stack] = stack_new_filters_flat[filter_in_stack].reshape(
(stack_new_filters.shape[1], stack_new_filters.shape[2], stack_new_filters.shape[3]))
for f in range(len(new_weight)):
filter_index2 = new_weight[f].shape[0] + filter_index1
new_weight[f] = stack_new_filters[filter_index1:filter_index2,:,:,:]
filter_index1 = new_weight[f].shape[0]
new_params[f][0] = new_weight[f].T
model.get_layer(convlayername[f]).set_weights(new_params[f])