行人重识别的代码复现

参考:https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial

1、环境的安装

系统的基础环境:

  • ubantu16.04
  • CUDA9.0+Cudnn7.4.2
  • Python3.7.4
  • Anaconda 3

创建虚拟环境

conda create -n re_id python=3.7.4
source activate re_id

安装Pytorch

根据CUDA的版本来安装:
https://pytorch.org/get-started/previous-versions/

conda install pytorch==1.0.1 torchvision==0.2.2 cudatoolkit=9.0 -c pytorch

安装 yacs

git clone https://github.com/rbgirshick/yacs
cd yacs
python setup.py install

安装其他依赖库

pip install pretrainedmodels
conda install matplotlib
conda install future
pip install torchvision
pip install tensorboardX
pip install tensorflow -i https://pypi.mirrors.ustc.edu.cn/simple
conda install scipy
conda install Cython

2、开始

数据集和代码的准备

数据集:Market-1501
代码:Practical-Baseline

2.1训练

2.1.1:数据的准备(python prepare.py)

数据集分布如下:

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training 
│   ├── gt_bbox/                    /* Files for multiple query testing 
│   ├── gt_query/                   /* We do not use it 
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt

 1. "bounding_box_test" – 19732张图片,测试集,也是所谓的gallery参考图像集;
 2. "bounding_box_train" – 12936张图片,训练集;
 3. "query" – 3368张query图片,即要查询的图片,在 "bounding_box_test"中执行搜索;
 4. "gt_bbox" – 25259张图片(人工标注),对应test和train数据集中1501个个体,用于区分"good"、“junk"和"distractors”;
 5. "gt_query" – 对于3368张query图片的每一个,都有"good"和"junk"相关的图像(包含相同个体),这个文件夹包含了"good"和"junk"图像的索引,用在性能评估中。

打开代码prepare.py。 将第五行的地址改为你本地的地址,比如 \home\zzd\Download\Market,然后在终端中运行代码。
记得所有操作都在刚刚创建的虚拟环境下

python prepare.py

运行后文件的改变如下:

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training 
│   ├── gt_bbox/                    /* Files for multiple query testing 
│   ├── gt_query/                   /* We do not use it
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt
│   ├── pytorch/
│       ├── train/                   /* train 
│           ├── 0002
|           ├── 0007
|           ...
│       ├── val/                     /* val
│       ├── train_all/               /* train+val      
│       ├── query/                   /* query files  
│       ├── gallery/                 /* gallery files

2.1.2:搭建神经网络模型(model.py)

我们可以使用预先训练好的网络结构,例如“ AlexNet”,“ VGG16”,“ ResNet”和“ DenseNet”。 通常,经过预训练的网络结构有助于保留更好的性能,因为它保留了ImageNet的优点[1].

在 pytorch中, 两行代码就可以导入模型:

from torchvision import models
model = models.resnet50(pretrained=True)

但是我们需要稍微调整一下网络结构。 Market-1501中有751个类别(不同的人),与ImageNet中的1,000个类别所不同。 因此,在这里我们修正模型以使用分类器。

import torch
import torch.nn as nn
from torchvision import models

# Define the ResNet50-based Model
class ft_net(nn.Module):
    def __init__(self, class_num = 751):
        super(ft_net, self).__init__()
        #load the model
        model_ft = models.resnet50(pretrained=True) 
        # change avg pooling to global pooling
        model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.model = model_ft
        self.classifier = ClassBlock(2048, class_num) #define our classifier.

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = torch.squeeze(x)
        x = self.classifier(x) #use our classifier.
        return x

为什么我们使用AdaptiveAvgPool2d? AvgPool2d和AdaptiveAvgPool2d有什么区别? 该模型现在有参数吗? 如何在新的网络层中初始化参数?

仔细看看model.py吧。
这里我们不需要修改model.py,已经修改好了

2.1.3:开始训练(train.py)

  • 训练方法 【ResNet-50】
python train.py --gpu_ids 0 --name ft_ResNet50 --train_all --batchsize 32  --data_dir /home/huan/deep_learning/ReID/Person_reID_baseline_pytorch/Market/pytorch/
  • 训练方法 【ResNet-50(alltricks)
python train.py --warm_epoch 5 --stride 1 --erasing_p 0.5 --batchsize 8 --lr 0.02 --name warm5_s1_b8_lr2_p0.5 --gpu_ids 0  --data_dir /home/huan/deep_learning/ReID/Person_reID_baseline_pytorch/Market/pytorch/
--gpu_ids 运行的gpu型号

--name 模型名字

--data_dir 训练数据路径

--train_all 所有用来训练的图像.

--batchsize batch大小

--erasing_p 随机删除参数.

2.1.4:开始测试(test.py)

  • 测试Market-1501数据集
python test.py --gpu_ids 0 --name ft_ResNet50 --test_dir /home/huan/deep_learning/ReID/Person_reID_baseline_pytorch/Market/pytorch/  --batchsize 32 --which_epoch 19
  • 测试自己的数据集,并生成json格式的文件。

你可能感兴趣的:(Pytorch,行人重识别)