tsn

TSN

1.如何提帧

1.1数据集准备

下载网址:http://crcv.ucf.edu/data/UCF101/UCF101.rar

下载成功后的UCF文件夹如下所示:
该文件夹下是各种动作的视频文件,共有101种类别
tsn_第1张图片

下图是UCF101在进行训练和测试时,分割的依据文件

tsn_第2张图片

1.2源码准备

在实验过程中,我们需要使用tsn-pytorch和mmaction的一些代码文件,所以我们提前从Git上获得存储在本地。

下载mmaction:

git clone --recursive https://github.com/open-mmlab/mmaction.git

下载tsn-pytorch:

git clone --recursive https://github.com/yjxiong/tsn-pytorch

1.3提帧

在我们下载好的UCF101数据集中,视频大多是长时间的,很难对其进行动作识别,所以需要进行提帧操作。
首先在mmaction的data/ucf101中创建rawframes、videos、annotations文件夹。

rawframes:视频提帧后存放的文件目录

videos:拷贝ucf101数据集中的101个文件目录,放置其中

annotations:ucf101之后进行分割训练集、测试集的依据文件

tsn_第3张图片

然后在mmaction/data_tools/build_rawframes.py的同级目录下进行视频提帧的代码文件,输入命令如下所示:

python build_rawframes.py ../data/ucf101/videos ../data/ucf101/rawframes/ --level 2  --ext avi

tsn_第4张图片

生成的文件目录形式如下所示:
tsn_第5张图片

tsn_第6张图片

在这里插入图片描述

运行完成后,将每一个视频的每一帧提取出来,放在特定名称的文件夹中。

1.4生成file_list

在tsn-pytorch的readme文件中可以看到,训练过程中需要<ucf101_rgb_train_list>和<ucf101_rgb_val_list>,所以生成这两个list文件是必需的。使用mmaction/data_tools/buid_file_list.py即可对ucf101生成的帧进行训练集和测试集的划分。输入命令如下所示:

python data_tools/build_file_list.py ucf101 data/ucf101/rawframes/ --level 2 --format rawframes --shuffle

也可在mmaction/data_tools/ucf101/中输入

bash generate_filelist.sh

tsn_第7张图片

生成的filelist在data/ucf101目录下,形式如下:
tsn_第8张图片

file_list的内容如下所示:

tsn_第9张图片

file_list中有三列,第一列代表文件的地址,第二列代表视频的帧数,第三列代表视频的类别。这里仅仅使用ucf101的3个文件夹,所以类别只有0 1 2。

2.如何feed帧出特征

代码修改部分参考 https://blog.csdn.net/qq_39862223/article/details/108461526

2.1IPO

下图展示了,TSN如何将ucf101数据集提出的帧进行分类的过程,标明了每一个阶段的tensor大小

tsn_第10张图片tsn_第11张图片

3.如何save,以便load

定义的保存模型以及参数信息的方法,该方法会在进行模型训练的时候得到调用。

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):  
    filename = '_'.join((args.snapshot_pref, args.modality.lower(), filename)) # 用于保存模型以及参数信息的路径以及文件名
    torch.save(state, filename)  # 将模型以上述名称保存在该路径下
    if is_best: # 如果准确率得到提高就进行模型的被备份
        best_name = '_'.join((args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar')) # 备份路径以及文件名称
        shutil.copyfile(filename, best_name) # 进行文件复制

对该方法的调用,通过该方法保存模型,准确率,模型参数并判断是否进行模型复制

best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)

加载保存的参数

    if args.resume:  # args.resume是保存模型的路径
        if os.path.isfile(args.resume): # 判断该绝对路径下是否是文件,也就是保存模型方法中的绝对路径
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume) # 进行加载checkpoint 字典的形式,里面包括epoch,arch,state_dict,best_prec1 
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

你可能感兴趣的:(tsn)