Keras:使用预训练模型迁移学习单通道灰度图像

目录

1. 问题引出    

2. 解决方案

2.1. 直接使用convert将L转为RGB

2.2. 数组拼接方法

3. 多进程加速运行 

4.使用预训练模型训练


1. 问题引出    

      最近在做一个图像分类的项目,由于性能比较差,因此需要尝试将彩色图转为灰度图进行训练,从而屏蔽掉颜色对分类结果的影响而着重关注纹理、结构等信息。由于样本数量较少,只有几百张的样子,如果自己搭网络的话从头训练的话,势必会因为样本数量的问题,无法达到一个满意的效果,因此考虑借鉴Imagenet的预训练权重。但是在Imagenet上预训练的模型(Xception, Resnet, VGG等)都是处理的彩色图,如果要使用预训练模型就必须要3通道的图像。

      搜索了一下,基本上目前的解决方法:

      暴力的将单通道的图复制为3份,然后合成为一张RGB图。显然,该图3个通道的数值完全相等,这样存在很多冗余计算,我们称之为“伪RGB图”。为了方便起见,自己实现了两种方法,完成如下转换:

RGB图  →  灰度图   →   伪RGB图

其中,转换为灰度图时,均使用的是如下标准公式:

2. 解决方案

首先,导入必要的包:

from multiprocessing import Pool
from PIL import Image
import numpy as np
import os

2.1. 直接使用convert将L转为RGB

def fakeRgb1(path, dst):
    '''
    方法1:直接使用convert将L转为RGB
    :param path:图片输出路径
    :param dst:图片输出路径
    :return:rgb3个通道值相等的rgb图像
    '''
    b = Image.open(path)
    # L代表转换为灰度图
    if b.mode != 'L':
        L = b.convert('L')
    L = L.convert('RGB')
    # 将图像转为数组
    rgb_array = np.asarray(L)
    # 将数组转换为图像
    rgb_image = Image.fromarray(rgb_array)
    rgb_image.save(dst + '\\' + path.split('\\')[-1])
    print(dst + '\\' + path.split('\\')[-1])

2.2. 数组拼接方法

def fakeRgb2(path, dst):
    '''
    方法二:最原始的拼接数组方法
    :param path:图片输入路径
    :param dst:图片输出路径
    :return:rgb3个通道值相等的rgb图像
    '''

    b = Image.open(path)
    # L代表转换为灰度图
    if b.mode != 'L':
        L = b.convert('L')
    # 将图像转为数组
    b_array = np.asarray(L)
    # 将3个二维数组重叠为一个三维数组
    rgb_array = np.zeros((b_array.shape[0], b_array.shape[1], 3), "uint8")
    rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2] = b_array, b_array, b_array
    rgb_image = Image.fromarray(rgb_array)
    rgb_image.save(dst + '\\' + path.split('\\')[-1])
    print(dst + '\\' + path.split('\\')[-1])

3. 多进程加速运行 

由于是批量处理,因此可能会遇到同时转换很多张图片,那么这个时候就必须使用多进程加速了,具体的加速方法看我的这篇博客:

Python:多进程运行含有任意个参数的函数

本文的加速代码如下:

def get_image_paths(folder):
    return [os.path.join(folder, f) for f in os.listdir(folder)]

if __name__ == '__main__': # 多线程,多参数,starmap版本
    images = get_image_paths(path)
    output = [src for i in images]

    zip_args = list(zip(images, output))
    pool = Pool()
    pool.starmap(fakeRgb2, zip_args)
    pool.close()
    pool.join()

4.使用预训练模型训练

     这部分就和训练普通RGB图像一样即可,在这里不赘述。 

你可能感兴趣的:(图像处理,深度学习,灰度图,与训练模型,单通道,Imagenet,Xception)