GaitSet源代码解读(三)

本篇文章,是系列的最后一部分

下面,我将会对测试部分开始解析

首先,来看test.py

from datetime import datetime
import numpy as np
import argparse

from model.initialization import initialization
from model.utils import evaluation
from config import conf


def boolean_string(s):
    if s.upper() not in {'FALSE', 'TRUE'}:
        raise ValueError('Not a valid boolean string')
    return s.upper() == 'TRUE'


parser = argparse.ArgumentParser(description='Test')
parser.add_argument('--iter', default='80000', type=int,
                    help='iter: iteration of the checkpoint to load. Default: 80000')
parser.add_argument('--batch_size', default='1', type=int,
                    help='batch_size: batch size for parallel test. Default: 1')#测试的batch_size=1
parser.add_argument('--cache', default=False, type=boolean_string,
                    help='cache: if set as TRUE all the test data will be loaded at once'
                         ' before the transforming start. Default: FALSE')
opt = parser.parse_args()


# Exclude identical-view cases
def de_diag(acc, each_angle=False):
    result = np.sum(acc - np.diag(np.diag(acc)), 1) / 10.0
    if not each_angle:
        result = np.mean(result)
    return result


m = initialization(conf, test=opt.cache)[0]

# load model checkpoint of iteration opt.iter
print('Loading the model of iteration %d...' % opt.iter)
m.load(opt.iter)#加载权重
print('Transforming...')
time = datetime.now()#计时开始
test = m.transform('test', opt.batch_size)
print('Evaluating...')
acc = evaluation(test, conf['data'])
print('Evaluation complete. Cost:', datetime.now() - time)#计时结束

# Print rank-1 accuracy of the best model
# e.g.
# ===Rank-1 (Include identical-view cases)===
# NM: 95.405,     BG: 88.284,     CL: 72.041
for i in range(1):
    print('===Rank-%d (Include identical-view cases)===' % (i + 1))
    print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % (
        np.mean(acc[0, :, :, i]),
        np.mean(acc[1, :, :, i]),
        np.mean(acc[2, :, :, i])))

# Print rank-1 accuracy of the best model,excluding identical-view cases
# e.g.
# ===Rank-1 (Exclude identical-view cases)===
# NM: 94.964,     BG: 87.239,     CL: 70.355
for i in range(1):
    print('===Rank-%d (Exclude identical-view cases)===' % (i + 1))
    print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % (
        de_diag(acc[0, :, :, i]),
        de_diag(acc[1, :, :, i]),
        de_diag(acc[2, :, :, i])))

# Print rank-1 accuracy of the best model (Each Angle)
# e.g.
# ===Rank-1 of each angle (Exclude identical-view cases)===
# NM: [90.80 97.90 99.40 96.90 93.60 91.70 95.00 97.80 98.90 96.80 85.80]
# BG: [83.80 91.20 91.80 88.79 83.30 81.00 84.10 90.00 92.20 94.45 79.00]
# CL: [61.40 75.40 80.70 77.30 72.10 70.10 71.50 73.50 73.50 68.40 50.00]
# np.set_printoptions(precision=2, floatmode='fixed')
for i in range(1):
    print('===Rank-%d of each angle (Exclude identical-view cases)===' % (i + 1))
    print('NM:', de_diag(acc[0, :, :, i], True))
    print('BG:', de_diag(acc[1, :, :, i], True))
    print('CL:', de_diag(acc[2, :, :, i], True))

参数初始化完毕后,进入initialization.py

作者:顾道长生
链接:https://zhuanlan.zhihu.com/p/446182146
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

# -*- coding: utf-8 -*-
# @Author  : admin
# @Time    : 2018/11/15
import os
from copy import deepcopy #深拷贝

import numpy as np

from .utils import load_data #加载数据集
from .model import Model # 模型


# 数据加载初始化  加载配置文件参数
# return ndarray
def initialize_data(config, train=False, test=False):
    # 这里的train和test代表的是否使用cache
    print("Initializing data source...")
    # 得到Dateset对象
    train_source, test_source = load_data(**config['data'], cache=(train or test)) # *标识符接收任何多余位置参数的元组,**标识符接收任何多余关键字参数的新字典
    if train:
        print("Loading training data...")
        train_source.load_all_data()
    if test:
        print("Loading test data...")
        test_source.load_all_data()
    print("Data initialization complete.")
    return train_source, test_source

# 模型参数初始化,加载配置文件参数
def initialize_model(config, train_source, test_source):
    print("Initializing model...")
    data_config = config['data']
    model_config = config['model']
    model_param = deepcopy(model_config)
    model_param['train_source'] = train_source
    model_param['test_source'] = test_source
    model_param['train_pid_num'] = data_config['pid_num']
    batch_size = int(np.prod(model_config['batch_size'])) # np.prod 计算所有元素的乘积
    model_param['save_name'] = '_'.join(map(str,[
        model_config['model_name'],
        data_config['dataset'],
        data_config['pid_num'],
        data_config['pid_shuffle'],
        model_config['hidden_dim'],
        model_config['margin'],
        batch_size,
        model_config['hard_or_full_trip'],
        model_config['frame_num'],
    ]))

    m = Model(**model_param)
    print("Model initialization complete.")
    return m, model_param['save_name']


def initialization(config, train=False, test=False):
    print("Initialzing...")
    WORK_PATH = config['WORK_PATH']
    os.chdir(WORK_PATH)# os.chdir() 方法用于改变当前工作目录到指定的路径。
    os.environ["CUDA_VISIBLE_DEVICES"] = config["CUDA_VISIBLE_DEVICES"]
    train_source, test_source = initialize_data(config, train, test)
    print('train:',len(train_source))
    print("test",len(test_source))
    return initialize_model(config, train_source, test_source)

这块,和训练部分一样,我前面文章已经解析过了,此处不再赘述

下面返回test.py,点击transform进入model.py

def transform(self, flag, batch_size=1):#测试
        self.encoder.eval()
        source = self.test_source if flag == 'test' else self.train_source
        self.sample_type = 'all'#全采样
        data_loader = tordata.DataLoader(
            dataset=source,
            batch_size=batch_size,#1
            sampler=tordata.sampler.SequentialSampler(source),#默认的采样器 按顺序采样
            collate_fn=self.collate_fn,#一个batch的数据
            num_workers=self.num_workers)

        feature_list = list()# {ndarray:(1,15872)}  5485个样本
        view_list = list() # 5485个 ['000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180', '000'...
        seq_type_list = list()# 5485个['bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-01', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'bg-02', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-01', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'cl-02', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-01', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-02', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-03', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-04', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-05', 'nm-06'...
        label_list = list()# 5485个['075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075', '075'...

        for i, x in enumerate(data_loader):
            seq, view, seq_type, label, batch_frame = x
            for j in range(len(seq)):
                seq[j] = self.np2ts(seq[j]).float()
            if batch_frame is not None:
                batch_frame = self.np2ts(batch_frame).int()
                print(batch_frame.size())#一个batch有多少帧
                print(batch_frame) #torch.Size([1, 1])
            # print(batch_frame, np.sum(batch_frame))

            feature, _ = self.encoder(*seq, batch_frame) ## feature:(1,62,256),因为测试batchsize只有一个人
            n, num_bin, _ = feature.size()
            feature_list.append(feature.view(n, -1).data.cpu().numpy())# feature_list里面放的是每个测试者的特征 #1*256*31*2对应于论文中图的网络最下面测试部分
            view_list += view
            seq_type_list += seq_type
            label_list += label

        # print(np.concatenate(feature_list, 0).shape)  (5485, 15872)
        return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list# 0维上拼接起来

这块就是测试函数,DataLoader部分和训练不同的在sampler和collate_fn

训练时self.sample_type ='all',所以采样时为全采样

作者:顾道长生
链接:https://zhuanlan.zhihu.com/p/446182146
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

def collate_fn(self, batch):
        # batch是一个list 大小是128,每一个list有5维 (frame*64*44,数字 0-frame,角度,bg-02,id),应该是调用for in trainloder的时候才会执行这个地方,生成规定的格式   测试len(batch)=1
        """
        其实这个函数就是自定义DataLoader如何取样本的
        改变的也是只有data,本来data是一个样本(这个样本包含许多轮廓图),然后经过select_frame有放回的取30帧,然后再做成batch
        :param batch:[30帧张量的data,view, seq_type, label, None]都是index索引对应的
        :return:
        """
        # print("batch",len(batch))
        batch_size = len(batch)
        """
                data = [self.__loader__(_path) for _path in self.seq_dir[index]]
                feature_num代表的是data数据所包含的集合的个数,这里一直为1,因为读取的是
                  _seq_dir = osp.join(seq_type_path, _view)
                        seqs = os.listdir(_seq_dir)  # 遍历出所有的轮廓剪影
        """
        feature_num = len(batch[0][0])
        # print(batch[0][0])
        # print(batch[0][1])
        # print(batch[0][2])
        # print(batch[0][3])
        # print(batch[0][4])
        seqs = [batch[i][0] for i in range(batch_size)]  # 对应于data
        # print(len(seqs))
        frame_sets = [batch[i][1] for i in range(batch_size)]  # 对应于 frame_set
        # print(frame_sets)
        # print("____________________")
        view = [batch[i][2] for i in range(batch_size)]  # 对应于self.view[index]
        seq_type = [batch[i][3] for i in range(batch_size)]  # 对应于self.seq_type[index]
        label = [batch[i][4] for i in range(batch_size)]  # 对应于self.label[index]    # 这几段代码就是从batch中分别取batch_size个对应的seqs、frame_sets、view、seq_type、label
        batch = [seqs, view, seq_type, label, None]# batch重新组织了一下,不同于刚开始调入时候的batch格式了
        '''
                 这里的一个样本由 data, frame_set, self.view[index], self.seq_type[index], self.label[index]组成
        '''

        def select_frame(index):
            sample = seqs[index]
            frame_set = frame_sets[index]
            if self.sample_type == 'random':
                # 这里的random.choices是有放回的抽取样本,k是选取次数,这里的frame_num=30
                frame_id_list = random.choices(frame_set, k=self.frame_num)  # 从所有frame数量的帧中 选取30帧,组成一个list
                _ = [feature.loc[frame_id_list].values for feature in sample]  # _:(30帧,64,44)  .loc是使用标签进行索引、.iloc是使用行号进行索引
            else:
                _ = [feature.values for feature in sample]
                # c = np.array(_)
                # print(c.shape)
            return _

        # print(len(seqs))
        seqs = list(map(select_frame, range(len(seqs))))#选取的30帧样本的ndarray与len(seqs)=128做一个键值对,然后转成一个list   # seqs:128长度的list,每个list:(30,64,44)。   测试时1长度的list,每个list:(一个样本的帧数,64,44) map函数意为将第二个参数(一般是数组)中的每一个项,处理为第一个参数的类型。


        if self.sample_type == 'random':
            seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] #选取的是一个样本中的30帧,所以一个样本是一个集合,feature_num=1    # asarry和.array的作用都是转为ndarray, feature_num=1
        else:# 全采样的话,数据就不都是30帧了,所以需要补充。batch_frames应该是只有在全采样和多个显卡的时候才会用到,否则基本用不到,先不用管
            gpu_num = min(torch.cuda.device_count(), batch_size)
            batch_per_gpu = math.ceil(batch_size / gpu_num) #“向上取整”, 即小数部分直接舍去,并向正数部分进1
            batch_frames = [[
                                len(frame_sets[i])
                                for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                if i < batch_size
                                ] for _ in range(gpu_num)]
            # print(batch_frames)
            if len(batch_frames[-1]) != batch_per_gpu:
                for _ in range(batch_per_gpu - len(batch_frames[-1])):
                    batch_frames[-1].append(0)#最后一个batch_frames[i]

接下来就是将测试样本送入网络

生成特征维度是(5485×15872),接下来返回到test.py中的m,再往下走,点击evaluation,进入evaluator.py

作者:顾道长生
链接:https://zhuanlan.zhihu.com/p/446182146
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

import torch
import torch.nn.functional as F
import numpy as np


def cuda_dist(x, y):
    # 计算x中的每个样本和y中每个样本的距离
    x = torch.from_numpy(x).cuda()
    y = torch.from_numpy(y).cuda()
    dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze(
        1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1))
    dist = torch.sqrt(F.relu(dist))
    return dist


def evaluation(data, config):
    # data : np.concatenate(feature_list, 0), view_list, seq_type_list, label_list
    dataset = config['dataset'].split('-')[0] #将CASIA-B切分   dataset:'CASIA'
    feature, view, seq_type, label = data
    label = np.array(label)
    view_list = list(set(view))
    view_list.sort() #view list ['000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180']
    view_num = len(view_list)#view num=11
    sample_num = len(feature)# sample num=5485

    probe_seq_dict = {'CASIA': [['nm-05', 'nm-06'], ['bg-01', 'bg-02'], ['cl-01', 'cl-02']],
                      'OUMVLP': [['00']]}
    gallery_seq_dict = {'CASIA': [['nm-01', 'nm-02', 'nm-03', 'nm-04']],
                        'OUMVLP': [['01']]}

    num_rank = 5#5
    # 下面的循环是求出probe在probe_view视角下,gallery在gallery_view视角的准确率,而且在是在probe_seq下和对应的gallery_seq下的,
    # probe_seq因为包含三种行走条件下的
    #                   集合个数                      视角个数  视角个数   top5
    acc = np.zeros([len(probe_seq_dict[dataset]), view_num, view_num, num_rank])#(3,11,11,5)
    for (p, probe_seq) in enumerate(probe_seq_dict[dataset]): # probe集合
        for gallery_seq in gallery_seq_dict[dataset]:# gallery集合
            for (v1, probe_view) in enumerate(view_list):# probe视角列表
                for (v2, gallery_view) in enumerate(view_list):# gallery视角列表
                    # seq(NM-01,NM-02...)类型元素在gallery_seq中,并且在当前的gallery_view 中,因为要求每个视角下的准确率
                    # gallery_seq和probe_seq都是列表
                    gseq_mask = np.isin(seq_type, gallery_seq) & np.isin(view, [gallery_view])
                    gallery_x = feature[gseq_mask, :]# 找出对应的gallery样本的特征
                    gallery_y = label[gseq_mask]# 找出对应的gallery样本的标签

                    # 下面的类似。找出相应的probe的样本特征,标签等
                    pseq_mask = np.isin(seq_type, probe_seq) & np.isin(view, [probe_view])
                    probe_x = feature[pseq_mask, :]
                    probe_y = label[pseq_mask]

                    dist = cuda_dist(probe_x, gallery_x)
                    idx = dist.sort(1)[1].cpu().numpy()# 对probe中的每个样本的预测的结果进行排序,这里返回的是在原始数组中的下标,
                    acc[p, v1, v2, :] = np.round(# 这里相当于在计算top(num_rank)的准确率
                        # acc[p, v1, v2, 0]保存的是top1准确率,而acc[p, v1, v2, num_rank-1]保存的是top5准确率(因为这里的num_rank=5)
                        # gallery_y[idx[:, 0:num_rank] 按下标取出前num_rank个样本标签
                        # 注意这里计算的是top(num_rank)的准确率,
                        # np.cumsum做一个累计计算,计算top_1,top_2,...,top_num_rank的准确率
                        np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
                               0) * 100 / dist.shape[0], 2)

    return acc

最后返回acc,回到test.py,打印rank-1准确度(包含相同视角)和rank-1准确度(不包含相同视角)

至此,测试部分讲解完毕。

如果文章对大家有帮助,麻烦大家一键三连,谢谢大家!!!

你可能感兴趣的:(步态识别,python,机器学习,深度学习)