论文名:TSM: Temporal Shift Module for Efficient Video Understanding
代码链接:https://github.com/mit-han-lab/temporal-shift-module
代码的主要结构如下:
python file | function |
---|---|
mian.py | 主要的训练函数 |
opts.py | 代码的参数配置 |
ops/dataset.py | 数据集的载入部分,核心是__getitem__函数。 |
ops/dataset_config.py | 用于配置不同的数据集 |
ops/models.py | 组装模型 |
ops/temporal_shift.py | 核心的temporal shift操作 |
parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
– modality 表示输入的类型,RGB/FLOW
–num_segments 表示每个视频采样frames数,一般是8或者16。
–shift 表示是否加入tsm模块。
–shift_div 表示shift的特征的比例,一般是8。表示2*1/8比例的特征会移动,其中1/8的特征做shift left, 另1/8的特征做shift right。
每个数据集实现一个return_xxx(modality)
返回数据集支持的子类名,train_list路径,val_list路径,数据集的根路径等信息。
dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。
dataset.py中实现了TSNDataSet类,处理原始数据,此类继承于torch.utils.data.dataset类。
其中首先定义了一个简单类:
class VideoRecord(object):
def __init__(self, row):
self._data = row
@property
def path(self):
return self._data[0]
@property
def num_frames(self):
return int(self._data[1])
@property
def label(self):
return int(self._data[2])
class TSNDataSet(data.Dataset):
def __init__(self, root_path, list_file,
num_segments=3, new_length=1, modality='RGB',
image_tmpl='img_{:05d}.jpg', transform=None,
random_shift=True, test_mode=False,
remove_missing=False, dense_sample=False, twice_sample=False):
self.root_path = root_path
self.list_file = list_file
self.num_segments = num_segments
self.new_length = new_length
self.modality = modality
self.image_tmpl = image_tmpl
self.transform = transform
self.random_shift = random_shift
self.test_mode = test_mode
self.remove_missing = remove_missing
self.dense_sample = dense_sample # using dense sample as I3D
self.twice_sample = twice_sample # twice sample for more validation
if self.dense_sample:
print('=> Using dense sample for the dataset...')
if self.twice_sample:
print('=> Using twice sample for the dataset...')
if self.modality == 'RGBDiff':
self.new_length += 1 # Diff needs one more image to calculate diff
self._parse_list()
设置一些参数和参数默认值之后,调用了_parse_list()函数,tmq是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。然后调用VideoRecord()函数,将内容写入到一个VideoRecord的list中去。
tmp = [x.strip().split(' ') for x in open(self.list_file)]
self.video_list = [VideoRecord(item) for item in tmp]
首先我们需要获取视频对应的image和label,image的格式为[ n*t, c, h, w]
def __getitem__(self, index):
record = self.video_list[index]
那么对于一个视频,如何获得num_segments个帧呢?
if not self.test_mode: # test_mode: False; random_shift: True;
segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
else:
segment_indices = self._get_test_indices(record)
return self.get(record, segment_indices)
有3种不同采样函数(针对于train、test、val):
其中,_sample_indice函数是针对于train的采样方式;默认情况下, _sample_indice函数会随机采取num_segments个index,有稠密采样(dense)和稀疏采样(normal)两种方式.
_get_test_indices会稀疏且固定地得到num_segments个index。
简单举个例子,num_frames=120, num_segments=3的时候,
_sample_indices中normal sample会随机返回:[ 4,44,84],[ 5,45,85], [ 11,51,91]。
_sample_indices中dense sample会随机返回:[15, 36, 57], [30, 51, 72], [44, 65, 86]。
_get_test_indices中dense_sample同理
_get_test_indices中twice_sample会随机返回:[11,31,51,1,21,41]
最终,每个视频都采样num_segments个帧,getitem返回维度为:[n * t, c, h, w] (frames), 1 (label)。
models.py的主要功能是对之后的训练模型做准备;首先使用一些经典的模型作为基础,如resnet50,针对不同的输入模态,对最后一层全连接层进行修改,得到我们TSN模型,而在其中又引入了是否加入TSM模块,从而得到我们所需要的TSM模型。
init函数设置一些参数和参数默认值,通过调用函数修改模型得到TSN模型,在调用TSM模块函数,从而得到TSM模型,其中init函数调用了
1 调用 _prepare_base_model(base_model)构建出基础的模型
2 调用_prepare_tsn(num_class),用于根据不同数据集的子类数,适配fc层大小
3 对于flow和rgbdiff的输入,调用_construct_flow_model和_construct_diff_model更改第一个卷积核的大小
而调用make_temporal_shit()函数加入tsm模块:
def _prepare_base_model(self, base_model):
print('=> base model: {}'.format(base_model))
if 'resnet' in base_model:
# torchvision.models.resnet50
self.base_model = getattr(torchvision.models, base_model)(True if self.pretrain == 'imagenet' else False)
if self.is_shift: # 默认false
print('Adding temporal shift...')
from ops.temporal_shift import make_temporal_shift
make_temporal_shift(self.base_model, self.num_segments,
n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)
if self.non_local: # 默认False
print('Adding non-local module...')
from ops.non_local import make_non_local
make_non_local(self.base_model, self.num_segments)
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
# 1代表输出特征图的大小;torch.Size([2, 32, 16, 16])----->torch.Size([2, 32, 1, 1])
self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)
_prepare_tsn函数的功能在于对已知的basemodel网络结构进行修改,微调最后一层(全连接层)的结构,成为适合该数据集输出的形式。
def _prepare_tsn(self, num_class):
# 获取模型最后一层输入层的维度
feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
# 如果dropout==0,直接添加新的全连接层,输出维度是num_class
if self.dropout == 0:
setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
self.new_fc = None
# 如果有dropout!=0,添加dropout层,然后再添加全连接层,输出维度是num_class
else:
setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
self.new_fc = nn.Linear(feature_dim, num_class)
std = 0.001
if self.new_fc is None:
normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
else:
if hasattr(self.new_fc, 'weight'):
normal_(self.new_fc.weight, 0, std)
constant_(self.new_fc.bias, 0)
return feature_dim
make_temporal_shift(self.base_model, self.num_segments,
n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)
那么TemporalShift(b, n_segment=this_segment, n_div=n_div又是怎么实现的呢?
对于[n*t, c, h, w]的输入,t是segment的值,首先reshape成[n, t, c, h, w],
如果当前feature map的通道256,fold_div=8,那么有256/8的特征进行shift left,256/8的特征进行shift right。其他一部分的特征不动。
@staticmethod
def shift(x, n_segment, fold_div=3, inplace=False):
nt, c, h, w = x.size()
n_batch = nt // n_segment
x = x.view(n_batch, n_segment, c, h, w)
fold = c // fold_div
if inplace:
# Due to some out of order error when performing parallel computing.
# May need to write a CUDA kernel.
raise NotImplementedError
# out = InplaceShift.apply(x, fold)
else:
out = torch.zeros_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
return out.view(nt, c, h, w)
举个简单的例子:
当c = 8, num_segment=4, 2维度的特征表示如下:
0_xx代表第一帧的特征。1_xx代码第二帧的特征,每个特征有8个通道。原始特征如下:
当fold_div = 8的时候,移动后如下:
可见第一帧中融入了第二帧的特征,第二帧中融入了第三帧和第二帧的特征。
当fold_div=4的时候,移动的部分会更多,即当前帧的特征中会包含更多前一帧和后一帧的信息。
最后我们来讲解训练主函数,将上述将的类和函数串联起来。
adjust_learning_rate,根据策略调整每个epoch的学习率。
for epoch in range(args.start_epoch, args.epochs):
# Sets the learning rate to the initial LR decayed by 10 every 30 epochs
adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
# evaluate on validation set
if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
print(output_best)
log_training.write(output_best + '\n')
log_training.flush() # 刷新缓冲区
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_prec1': best_prec1,
}, is_best)
train的核心函数如下:
def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
if args.no_partialbn: # 默认false
model.module.partialBN(False)
else:
model.module.partialBN(True)
# switch to train mode
model.train()
end = time.time() # 返回当前时间戳
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
raise RuntimeError
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# print('********************')
# print(loss)
# raise RuntimeError
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# compute gradient and do SGD step
loss.backward()
if args.clip_gradient is not None: # None
total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
optimizer.step()
optimizer.zero_grad()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0: # print_freq = 20
output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
print(output)
log.write(output + '\n')
log.flush()
tf_writer.add_scalar('loss/train', losses.avg, epoch)
tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)