阅读该博客 https://blog.csdn.net/weixin_41424027/article/details/87896768#convet_caffe_pretrain_318
对一些点进行补充记录,便于自己记忆
自定义数据读取
init:需要读入路径,把文件名都读入一个self.ids中
len:返回一个文件总数
getitem:给一个idx,能够返回img,bbox,label 最后需要用np.stack堆成np形式
import os
import xml.etree.ElementTree as ET
import numpy as np
from .util import read_image
VOC_CLASSES=('aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'pottedplant',
'sheep',
'sofa',
'train',
'tvmonitor')
class VOCBboxDataset:
def __init__(self,data_dir,split='trainval',use_difficult=False,return_difficult=False):
id_list_file=os.path.join(data_dir,'ImageSets/Main/{0}.txt'.format(split))
self.ids=[ id_.strip() for id_ in open(id_list_file)]
self.data_dir=data_dir
self.use_difficult=use_difficult
self.return_difficult=return_difficult
self.label_names=VOC_CLASSES
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
id_=self.ids[i]
anno= ET.parse(os.path.join(self.data_dir,'Annotations',id_+'.xml'))
bbox=list()
label=list()
difficult=list()
for obj in anno.findall('object'):
if not self.use_difficult and int(obj.find('difficult').text)==1:
continue
difficult.append(int(obj.find('difficult').text))
bndbox_anno=obj.find('bndbox')
bbox.append([int(bndbox_anno.find(tag).text-1) for tag in ('ymin','xmin','ymax','ymax')])
name=obj.find('name').text.lower().strip()
label.append(VOC_CLASSES.index(name))
bbox=np.stack(bbox).astype(np.float32)#trans the box from list to np.float32
label=np.stack(label).astype(np.int32)
difficult=np.array(difficult,dtype=np.bool).astype(np.uint8)
img_file=os.path.join(self.data_dir,'JPEGImages',id_+'.jpg')
img=read_image(img_file,color=True)
return img,bbox,label,difficult
把之前自定义的data类型包进dataset.py中
init :读入config的参数,初始化 自定义数据,初始化trans
getiitem:读入img,bbox,label后,用transform处理,然后返回预处理后的数据的copy
len:返回数目总条数
1.训练预处理
(1)图像先归一化到0-1
(2)比例转化 比如到600,1000
(3)标准化,减去均值除以标准差
(4)若图片需要水平翻转等数据增强操作,在这时候添加
其中比例转化和数据增强操作也需要对gt_bbox进行操作,保持图片和gt_box能够对应上
from __future__ import absolute_import
#绝对引入主要是针对python2.4及之前的版本的,这些版本在引入某一个.py文件时,
# 会首先从当前目录下查找是否有该文件。如果有,则优先引用当前包内的文件。而如果我们想引用python自带的.py文件时,则需要使用,
from __future__ import division
import torch as t
from data.voc_dataset import VOCBboxDataset
from skimage import transform as sktsf
from torchvision import transforms as tvtsf
from data import util
import numpy as np
from utils._config import opt
def inverse_normalize(img):# for the vis
if opt.caffe_pretrain:
img=img+(np.array([122.7717,115.9465,102.9801]).reshape(3,1,1))#add mean
#caffe has no std reshape is for the broadcast
return img[::-1,:,:]#caffe is [BGR,H,W] need to be [RGB,H,W]
return (img*0.225+0.45).clip(min=0,max=1)*255#pytorch pretrain img from 0-1 add mean and multiply std need to mul 255
def pytorch_normalize(img):
normalize=tvtsf.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])#pytorch method
img=normalize(t.from_numpy(img))# 因为normalize方法只接受tensor对象,将img转化为tensor传入
return img.numpy()#将标准化后的img再从tensor转化诶numpy
def caffe_normalize(img):
img=img[::-1,:,:]#RGB2BGR 因为如果使用caffe_pretrain ,那么整个模型参数都是基于caffe训练的,需要先用bgr图篇进行训练,万了之后再inverse_nomalize还原
img=img*255
mean=np.array([122.7717,115.9465,102.9801]).reshape(3,1,1)
img=(img-mean).astype(np.float32,copy=True)#返回一个float32类型的 img矩阵的副本
return img
def preprocess(img,min_size=600,max_size=1000):#输入原始的img矩阵,返回取值0-1的,经过resize的,标准化后的img numpy矩阵 min_size就是输出图片的短边最长为600
#长边最长为1000
C,H,W=img.shape
scale1=min_size/min(H,W)
scale2=max_size/max(H,W)
scale=min(scale1,scale2)#比较两个scale,看哪个才是主要影响所方因子,就是为防止 长边可能超过1000或者短边可能超过600
img=img/255#先转化为0-1
img=sktsf.resize(img,(C,H*scale,W*scale),mode='reflect')#resize来自skimage的transform
if opt.caffe_pretrain:
normalize=caffe_normalize
else:
normalize=pytorch_normalize
return normalize(img)
#一张图可能有R个box和label box shape为(R,4) label shape 为(R,)
# 接受魔术方法get_example传来的一张图片的原始 img box label
# 返回resize和normalize后的img 对应处理后的box 以及label(没有处理)
class Transform():
def __init__(self,min_size,max_size):
self.min_size=min_size
self.max_size=max_size
def __call__(self, in_data):#使得类的实例也能像函数一样
img,bbox,label=in_data
_,H,W=img.shape
img=preprocess(img,self.min_size,self.max_size)
_,o_H,o_W=img.shape
scale=o_H/H
bbox=util.resize_bbox(bbox,(H,W),(o_H,o_W))
#水平翻转
img,params=util.random_flip(img,x_random=True,return_param=True)
bbox=util.flip_bbox(bbox,(o_H,o_W),x_flip=params['x_flip'])#根据img水平翻转情况,对bbox也进行翻转
return img,bbox,label,scale
class Dataset:
# 取训练数据最大的类
# 如果你读过pytorch源码
# 你会发现其实并不用继承dataset类
# 因为那个类是空
# 的
# 只实现了两个pass空方法
# getitem和len两个魔术方法
# 所以我们只要实现这两个方法就不用继承就可以传入DataLoader
def __init__(self,opt):#opt是传进来的参数,来自utils.config 包含了voc_data 的路径
self.opt=opt
self.db=VOCBboxDataset(opt.voc_data_dir)
self.tsf=Transform(opt.min_size,opt.max_size)
def __getitem__(self, idx):
ori_img,bbox,label,difficult=self.db.__getitem__(idx)
img,bbox,label,scale=self.tsf((ori_img,bbox,label))
return img.copy(),bbox.copy(),label.copy(),scale
def __len__(self):
return len(self.db)
class TestDataset:
def __init__(self,opt,split='test',use_difficult=True):
self.opt=opt
self.db=VOCBboxDataset(opt.voc_data_dir,split=split,use_difficult=use_difficult)
def __getitem__(self, idx):
ori_img,bbox,label,difficult=self.db.__getitem__(idx)
img=preprocess(ori_img)
return img,ori_img.shape[1:],bbox,label,difficult#返回原图的HW 去掉了C
def __len__(self):
return len(self.db)
import numpy as np
from PIL import Image
import random
def read_image(path,dtype=np.float32,color=True):
f=Image.open(path)
try:
if color:
img=f.convert('RGB')
else:
img=f.convert('P')#gray
img=np.asarray(img,dtype=dtype)#trans to np.float32 array
finally:
if hasattr(f,'close'):
f.close()
if img.ndim==2:#gray
return img[np.newaxis]#add a new axis
else:
return img.transpose((2,0,1))# HWC 2 CHW
def resize_bbox(bbox,in_size,out_size):
bbox=bbox.copy()
y_scale=out_size[0]/in_size[0]
x_scale=out_size[1]/in_size[1]
bbox[:,0]=bbox[:,0]*y_scale
bbox[:,1]=bbox[:,1]*x_scale
bbox[:,2]=bbox[:,2]*y_scale
bbox[:,3]=bbox[:,3]*y_scale
return bbox
def random_flip(img,y_random=False,x_random=True,return_param=True,copy=False):
#
# img: 图片矩阵
# y_random: 是否使用垂直随机翻
# return_param:是否返回翻转状态
# 一个dict很好懂
# copy: 是否返回img的副本
y_flip,x_flip=False,False
if y_random:
y_flip=random.choice([True,False])#随即选取是否翻转
if x_random:
x_flip=random.choice([True,False])
if y_flip:
img=img[:,::-1,:]#图片翻转,CHW H翻转
if x_flip:
img=img[:,:,::-1]
if copy:
img=img.copy()
if return_param:
# 因为我们这里只翻转了图片
# 保留dict参数是为了翻转box时使用
# 如果img水平翻转了
# 那么x_flip = True
# 我们记录这个参数
# 以后也应当水平翻转这张图片的所有box
# R个box
return img,{'y_flip':y_flip,'x_flip':x_flip}
else:
return img
def flip_bbox(bbox,size,y_flip=False,x_flip=False):
H,W=size
bbox=bbox.copy()
if y_flip:
y_max=H-bbox[:,0]#H-ymin
y_min=H-bbox[:,2]
bbox[:,0]=y_min
bbox[:,2]=y_max
if x_flip:
x_max=W-bbox[:,1]
x_min=W-bbox[:,3]
bbox[:,1]=x_max
bbox[:,3]=x_min
return bbox
from pprint import pprint#打印出来更美观
class Config:
#data
voc_data_dir='/home/wrc/yuyijie/KITTI/VOCdevkit/VOC2007'
min_size=600
max_size=1000
num_works=8
test_num_works=8
rpn_sigma=3.
roi_sigma=1.
# for optimizer
wd=0.0005
lr_decay=0.1
lr=1e-3
#vis
env='faster-rcnn'
port=8097#visdom 端口
plot_every=40
#preset
data='voc'
pretrained_model='vgg16'
epoch=14
use_adam=False
use_chainer=False
use_drop=False
#debug
debug_file='/tmp/debugf'
test_num=10000
#model
load_path=None
caffe_pretrain=False
caffe_pretrain_path='checkpoints/vgg16_caffe.pth'
def _parse(self,kwargs):#解析并设置用户设定的参数
state_dict=self._state_dict()#读取Config类所有参数dict{para_name:para_value}
for k,v in kwargs.items():#遍历用户传来的dict
if k not in state_dict:
raise ValueError('Unknow option:"--%s"'%k)
setattr(self,k,v)#设置参数
print('=============user config=========')
pprint(self._state_dict())#打印参数
print('=============end=========')
def _state_dict(self):
return {k:getattr(self,k) for k ,_ in Config.__dict__.items() if not k.startswith('_')}
#字典解析,字典解析,Config.__dict__.items() 取出类中所有的函数、全局变量以及一些内置的属性
# 前面我们设定的都是全局变量(键值对:比如min_size = 600),没有函数,而系统内置属性都是_打头的,
# 所以我们要not k.startswith('_') 返回结果dict{para_name0:para_value0,para_name1:para_value1,....}
opt=Config()#创建config对象
trans the box from list to np.float32
想要转变np数据类型可以直接用astype
[BGR,H,W] need to be [RGB,H,W]
t.from_numpy(img)
img.numpy()
img[np.newaxis]#add a new axis
比如(100,2)新增之后会变成(1,100,2)
img.transpose((2,0,1))# HWC 2 CHW
random.choice([True,False])#随即选取是否翻转