Caffe实现多标签图像分类(2)——基于Python接口实现多标签图像分类(自己的数据集)

1.前言

上一篇博客介绍了一个Caffe自带的例子,使用Python接口实现多标签数据的输入,通过SigmoidCrossEntropy Loss函数实现Caffe多标签分类。实现多标签分类的关键是实现多标签图像数据的录入,这篇博客介绍如何为自己的数据集定制一个基于Python的多标签图像数据输入的data层。

2.数据集

本文选用多标签数据集Corel5k用于实验,平均每幅图像含有3.5个标注词。

3.为数据集定制用于输入的data层

由上一篇博客可以看出,实现多标签图像数据输入的两个文件是pascal_multilabel_datalayers.py(实现数据输入)、tools.py(实现图像数据的处理)。因此,我们需要为目标数据集定制两个类似的文件用于多标签图像数据的输入。

3.1Corel5k_multilabel_datalayers.py(用于多标签图像数据的输入)

多标签图像数据的输入需要两个txt文件(train.txt与val.txt),文件中存储了图像的文件名与标签,标签由1和0组成(含有这个标签则为1,不含有这个标签则为0)。使用Python读取的时候可以使用空格为分隔符,依次读取出图像的文件名与标签。

Caffe实现多标签图像分类(2)——基于Python接口实现多标签图像分类(自己的数据集)_第1张图片

import scipy.misc
import caffe

import numpy as np
import os.path as osp

from random import shuffle
from PIL import Image

from tools import SimpleTransformer


class Corel5kMultilabelDataLayerSync(caffe.Layer):


    def setup(self, bottom, top):
        
        #定义data层的输入名称(data和label)
        self.top_names = ['data', 'label']

        #读取Python data层的变量名
        params = eval(self.param_str)

        #判断params是否合法
        check_params(params)

        # 获取每次训练需要读取的batch_size
        self.batch_size = params['batch_size']

        # 用于读取每个批次的图像数据信息
        self.batch_loader = BatchLoader(params, None)

        #reshape输入的图像数据
        top[0].reshape(self.batch_size, 3, params['im_shape'][0], params['im_shape'][1])
        
        #reshape标签信息
        top[1].reshape(self.batch_size, 260)
        
        #输出params信息
        print_info("Corel5kMultilabelDataLayerSync", params)
    
    # 前向函数
    def forward(self, bottom, top):
        # 读取图像数据与标签
        for itt in range(self.batch_size):
            #读取并保存每张图片的图像内容信息和标签信息
            im, multilabel = self.batch_loader.load_next_image()

            # 上面定义了top[0]存取的是label、top[1]是标签
            top[0].data[itt, ...] = im
            top[1].data[itt, ...] = multilabel

    # data层不需要Reshape,因为输入是固定的
    def reshape(self, bottom, top):
        pass

    # 定义反向传播函数(data层不需要反向传播)
    def backward(self, top, propagate_down, bottom):
        pass


class BatchLoader(object):


    def __init__(self, params, result):
        self.result = result
        # 获取batch_size大小
        self.batch_size = params['batch_size']
        # 获取root文件夹,里面有train.tx与val.txt
        self.Corel5k_root = params['Corel5k_root']
        # 输入图像的大小
        self.im_shape = params['im_shape']
       
        # 判断是训练阶段还是测试阶段
        # self.isshuffle定义是否打乱图像的顺序
        # self.iscenter 定义是否在中心裁剪图像数据
        # self.isflip 定义图像是否翻转
        if params['split'] == 'train':
            self.isshuffle = True
            self.iscenter = False
            self.isflip = True
        else:
            self.isshuffle = False
            self.iscenter = True
            self.isflip = False
        
        #图片的存储地址
        self.floder = r'E:\MATLAB_space\DataSet_Image\Corel5k'
        #保存着数据信息的txt文件的位置
        list_file = params['split'] + '.txt'
        
        #获取每张图片的图片名
        self.indexlist = [line.rstrip('\n') for line in open(
            osp.join(self.Corel5k_root, list_file))]
        
        # 用于读取第一幅图像
        self._cur = 0  
        # 定义是否中心裁剪出用于输入的图像
        self.transformer = SimpleTransformer(center=self.iscenter)
        
        #开始的时候打乱训练数据
        if self.isshuffle:
            shuffle(self.indexlist)

        print "BatchLoader initialized with {} images".format(
            len(self.indexlist))
   
    # 用于读取图像数据与信息
    def load_next_image(self):
        # 判断是否跑完一个epoch(即整个训练集都参加了一遍训练)
        if self._cur == len(self.indexlist):
            self._cur = 0
            #每次读取一遍训练数据都打乱训练数据的列表
            if self.isshuffle:
                shuffle(self.indexlist)

        #获取图片的信息
        rowdata = self.indexlist[self._cur]
        rowdata = rowdata.split()
        image_file_name = rowdata[0]
        im = np.asarray(Image.open(osp.join( self.floder, image_file_name)))
        
        #resize图像的大小
        im = scipy.misc.imresize(im,[256,256]) 
        
        #用于进行图像翻转
        if self.isflip:
            flip = np.random.choice(2)*2-1
            im = im[:, ::flip, :]

        # 初始化用于存储图像标签数据的矩阵
        multilabel = np.zeros(260).astype(np.float64)
        
        #获取图片的标签信息
        for j in range(260):
            multilabel[j] = np.float64(rowdata[j+1])
        self._cur += 1
        #返回图像信息,和标签信息
        return self.transformer.preprocess(im), multilabel

#判断params是否合法
def check_params(params):
   
    assert 'split' in params.keys(
    ), 'Params must include split (train, val, or test).'

    required = ['batch_size', 'Corel5k_root', 'im_shape']
    for r in required:
        assert r in params.keys(), 'Params must include {}'.format(r)

#输出params信息
def print_info(name, params):
  
    print "{} initialized for split: {}, with bs: {}, im_shape: {}.".format(
        name,
        params['split'],
        params['batch_size'],
        params['im_shape'])

3.2tools.py(实现图像数据的预处理)

import numpy as np


class SimpleTransformer:


    def __init__(self, center=False):
        #导入mean文件
        #导入数据集的均值文件(256*256)
        mean = np.load(r'D:\caffe-master\data\Corel5k_multilabel\mean.npy')
        #导入中心均值文件(227*227)
        center_mean = np.load(r'D:\caffe-master\data\Corel5k_multilabel\center_mean.npy')
        #将均值文件由3*256*256转换为256*256*3
        self.mean = mean.transpose((1,2,0))
        self.center_mean = center_mean.transpose((1,2,0))
        #缩放系数
        self.scale = 1.0
        #是否中心裁剪
        self.center = center
  
    # 初始中心文件
    def set_mean(self, mean):
    
        self.mean = mean
    
    # 初始缩放系数
    def set_scale(self, scale):
      
        self.scale = scale
    
    #新加的一个crop函数
    #对图片进行crop
    #当中心裁剪时,在256*256的图像中心,裁剪出227*227的图像
    #当随机裁剪时,在256*256的图像中,随机裁剪出227*227的图像
    def crop(self, im, cropx=227, cropy=227):
        y,x,_ = im.shape
        
        if self.center:
            startx = x//2 - (cropx//2)
            starty = y//2 - (cropy//2)
        else:
            startx, starty = np.random.randint(0,29), np.random.randint(0,29)
        return im[starty:starty+cropy,startx:startx+cropx,:]

    # 图像数据预处理
    def preprocess(self, im):
      
        im = np.float64(im)
        #将图像数据转换为BGR通道格式
        im = im[:, :, ::-1] 
        #减图像的均值文件,进行归一化
        im -= self.mean
        
        #crop后的图像数据
        im = self.crop(im)
        
        im *= self.scale
        #对im进行转置
        im = im.transpose((2, 0, 1))

        return im
    
    #还原图像
    def deprocess(self, im):

        #将图像由3*256*256转换为256*256*3
        im = im.transpose(1, 2, 0)
        im /= self.scale
        
        #还原图像
        im +=self.center_mean
        #将图像转换为RGB格式
        im = im[:, :, ::-1]  

        return np.uint8(im)

这里需要用到均值文件,一个简单的方法使用LMDB文件或LevelDB文件所生成的均值文件来转换。转换的代码如下所示:

#将binaryproto文件转换为npy文件
import caffe
import numpy as np

def convert():
    # 均值文件路径
    binaryproto=r'D:\caffe-master\data\Corel5k\Corel5k_mean.binaryproto'
    blob = caffe.proto.caffe_pb2.BlobProto()
    data = open(binaryproto, 'rb' ).read()
    blob.ParseFromString(data)
    arr = np.array(caffe.io.blobproto_to_array(blob))
    out = arr[0]
    # npy均值文件存储路径
    savepath=r'D:/caffe-master/data/Corel5k/mean'
    np.save(savepath, out)
    
    #生成center_mean
    startx = 256//2 - (224//2) 
    starty = 256//2 - (224//2)
    cropIm = out[:,startx:startx+224,starty:starty+224]
    cropImPath = r'D:/caffe-master/data/Corel5k/center_mean'
    np.save(cropImPath, cropIm)

if __name__=='__main__':
    convert()

3.3运行文件

import os
 
import numpy as np
import os.path as osp
import matplotlib.pyplot as plt

from copy import copy

plt.rcParams['figure.figsize'] = (6, 6)

import caffe

import tools
transformer = tools.SimpleTransformer()

classes = np.asarray(['city', 'mountain', 'sky', 'sun', 'water', 'clouds', 'tree', 'lake', 'sea', 'beach', 'boats', 'people', 'branch', 'leaf', 'grass', 'palm', 'horizon', 'hills', 'waves', 'birds', 'land', 'bridge', 'ships', 'buildings', 'fence', 'island', 'peaks', 'jet', 'plane', 'runway', 'basket', 'flight', 'flag', 'prop', 'f-16', 'tails', 'smoke', 'formation', 'bear', 'polar', 'snow', 'tundra', 'ice', 'head', 'black', 'reflection', 'ground', 'forest', 'river', 'field', 'flowers', 'meadow', 'rocks', 'hillside', 'shrubs', 'close-up', 'grizzly', 'cubs', 'log', 'hut', 'sunset', 'display', 'plants', 'pool', 'coral', 'fan', 'anemone', 'fish', 'ocean', 'sunrise', 'face', 'sand', 'farms', 'reefs', 'vegetation', 'house', 'village', 'path', 'wood', 'dress', 'coast', 'cat', 'tiger', 'bengal', 'fox', 'kit', 'shadows', 'bush', 'den', 'coyote', 'light', 'arctic', 'shore', 'town', 'road', 'harbor', 'windmills', 'restaurant', 'wall', 'skyline', 'window', 'clothes', 'shops', 'street', 'cafe', 'tables', 'nets', 'crafts', 'roofs', 'ruins', 'stone', 'cars', 'castle', 'courtyard', 'statue', 'stairs', 'costume', 'sign', 'palace', 'sheep', 'valley', 'balcony', 'post', 'gate', 'plaza', 'festival', 'temple', 'sculpture', 'museum', 'hotel', 'art', 'fountain', 'market', 'door', 'garden', 'butterfly', 'lion', 'cave', 'crab', 'buddha', 'decoration', 'monastery', 'landscape', 'detail', 'sails', 'food', 'entrance', 'fruit', 'night', 'cow', 'church', 'park', 'barn', 'arch', 'hats', 'cathedral', 'ceremony', 'glass', 'pillar', 'monument', 'vines', 'cottage', 'lawn', 'tower', 'tulip', 'canal', 'dock', 'horses', 'petals', 'column', 'elephant', 'monks', 'interior', 'vendor', 'silhouette', 'architecture', 'athlete', 'sidewalk', 'store', 'relief', 'frost', 'frozen', 'crystals', 'needles', 'mist', 'doorway', 'vineyard', 'pots', 'terrace', 'bulls', 'albatross', 'booby', 'nest', 'iguana', 'lizard', 'marine', 'deer', 'white-tailed', 'horns', 'slope', 'mule', 'antlers', 'elk', 'caribou', 'herd', 'moose', 'mare', 'foals', 'orchid', 'stems', 'blooms', 'cactus', 'giraffe', 'zebra', 'tusks', 'train', 'desert', 'dunes', 'canyon', 'lighthouse', 'swimmers', 'pyramid', 'mosque', 'sphinx', 'truck', 'fly', 'trunk', 'baby', 'lynx', 'rodent', 'squirrel', 'goat', 'marsh', 'porcupine', 'whales', 'tracks', 'locomotive', 'railroad', 'vehicle', 'man', 'woman', 'girl', 'indian', 'dance', 'african', 'buddhist', 'outside', 'formula', 'turn', 'prototype', 'scotland', 'antelope', 'calf', 'reptile', 'snake', 'cougar', 'oahu', 'kauai', 'maui', 'hawaii'])

if not os.path.isfile(r'D:\caffe-master\models\bvlc_reference_caffenet\bvlc_reference_caffenet.caffemodel'):
    print("Please downloading pre-trained CaffeNet model...")

#设定GPU运算
caffe.set_mode_gpu()

workdir = r"D:\caffe-master\data\Corel5k_multilabel"

solver = caffe.SGDSolver(osp.join(workdir, 'solver.prototxt'))
solver.net.copy_from(r'D:\caffe-master\models\bvlc_reference_caffenet\bvlc_reference_caffenet.caffemodel')
solver.test_nets[0].share_with(solver.net)
solver.step(1)

#用于检测导入的数据是否正确
image_index = 0 
plt.figure()
plt.imshow(transformer.deprocess(copy(solver.net.blobs['data'].data[image_index, ...])))

gtlist = solver.net.blobs['label'].data[image_index, ...].astype(np.int)
plt.title('GT: {}'.format(classes[np.where(gtlist)]))
plt.axis('off');
plt.show()

#测试准确率
#计算预测数据与真实数据相同的个数,占总个数的比例
def hamming_distance(gt, est):
    return sum([1 for (g, e) in zip(gt, est) if g == e]) / float(len(gt))

#这里的batch_size就是设定验证网络的batch_size
def check_accuracy(net, num_batches, batch_size = 100):
    acc = 0.0
    for t in range(num_batches):
        net.forward()
        gts = net.blobs['label'].data
        ests = net.blobs['score'].data > 0
        for gt, est in zip(gts, ests): 
            acc += hamming_distance(gt, est)
    return acc / (num_batches * batch_size)
# 这里决定你的网络训练多少次,Solver里的最大迭代次数不好使
for itt in range(200):
    solver.step(200)
    # 这里计算5个batch的平均准确率
    print 'itt:{:3d}'.format((itt + 1) * 100), 'accuracy:{0:.4f}'.format(check_accuracy(solver.test_nets[0], 5))
    
#检测基本准确率 
#这里的batch_size就是设定验证网络的batch_size
def check_baseline_accuracy(net, num_batches, batch_size = 100):
    acc = 0.0
    for t in range(num_batches):
        net.forward()
        gts = net.blobs['label'].data
        ests = np.zeros((batch_size, len(gts)))
        for gt, est in zip(gts, ests):
            acc += hamming_distance(gt, est)
    return acc / (num_batches * batch_size)

#个人认为这里的500是指验证集内所包含的图像数
print 'Baseline accuracy:{0:.4f}'.format(check_baseline_accuracy(solver.test_nets[0], 500/100))


#测试训练的网络
#取5幅图像测试一下训练网络的结果
test_net = solver.test_nets[0]
for image_index in range(5):
    plt.figure()
    plt.imshow(transformer.deprocess(copy(test_net.blobs['data'].data[image_index, ...])))
    gtlist = test_net.blobs['label'].data[image_index, ...].astype(np.int)
    #将最后一层输出数值大于0的标签定义为图像的标签
    estlist = test_net.blobs['score'].data[image_index, ...] > 0
    plt.title('GT: {} \n EST: {}'.format(classes[np.where(gtlist)], classes[np.where(estlist)]))
    plt.axis('off')
plt.show()

将上面3个文件放在同一个文件夹下即可运行,切记用CMD窗口运行,这样你可以看到报错,90%大部分是路径错误。

4.网络文件与solver文件

上篇博客里面,是利用python代码实现网络文件与solver文件的创建,写起来需要逻辑,用Caffe的大部分是找个现成的网络文件改一下data层与最后一个全连接层和Loss层就运行了,这种方法简单粗暴。

修改后的网data层如下所示(这里的路径十分容易出错,多加注意):

layer {
  name: "data"
  type: "Python"
  top: "data"
  top: "label"
  python_param {
    module: "Corel5k_multilabel_datalayers"
    layer: "Corel5kMultilabelDataLayerSync"
    param_str: "{\'im_shape\': [227, 227], \'Corel5k_root\': \'D:\\\\caffe-master\\\\data\\\\Corel5k_multilabel\', \'split\': \'train\', \'batch_size\': 128}"
  }
}

修改后的网络Loos层如下所示:

layer {
  name: "loss"
  type: "SigmoidCrossEntropyLoss"
  bottom: "score"
  bottom: "label"
  top: "loss"
}

Solver文件如下所示(与传统的Solver文件一致,只是分别指定了训练文件和测试文件):

test_net: "D:/caffe-master/data/Corel5k_multilabel/valnet.prototxt"
train_net: "D:/caffe-master/data/Corel5k_multilabel/trainnet.prototxt"
test_iter: 5
test_interval: 200
base_lr: 0.0001
lr_policy: "step"
gamma: 0.1
stepsize: 10000
display: 200
max_iter: 40000
momentum: 0.9
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "D:/caffe-master/data/Corel5k_multilabel/Corel5k-CaffeNet"
solver_mode: GPU

5.总结

以上是Caffe多标签分类的一个学习总结,纯属个人理解,如有错误,请批评指正。

你可能感兴趣的:(Caffe实现多标签图像分类(2)——基于Python接口实现多标签图像分类(自己的数据集))