用resnet50进行中文字符分类(汉字字符识别)

在项目根目录下打开终端

在终端输入以下命令进行训练:

    python resnet_classifier.py --is_train=True --data_path=data/train--dict_path=key_index.txt --model_dir=model --log_dir=logs --batch_size= 8 --class_num=7 --epoch=100 --lr= 0.01 --gpu_id=0

在终端输入以下命令进行测试:

    python resnet_classifier.py --is_train=False --data_path=data/test --dict_path=key_index.txt --model_dir=model --log_dir=logs --batch_size= 8 --class_num=7 --epoch=1 --gpu_id=0

 我的文件目录结构如下图所示:

train和test下文件结构一样,每个文件夹代表一类数据,文件夹名就是该文件夹下所有图片类的标签。

用resnet50进行中文字符分类(汉字字符识别)_第1张图片

key_index.txt是字典文件,存放所有的字符标签,如下图所示:

用resnet50进行中文字符分类(汉字字符识别)_第2张图片

代码实现如下:

# -*- coding: UTF-8 -*-
import os
from tensorflow.contrib.slim.nets import resnet_v1
import tensorflow as tf
import glob
import numpy as np
import cv2
import logging

slim = tf.contrib.slim
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_bool('is_train', True, 'train or test')
tf.flags.DEFINE_string('data_path', 'data/train', 'The dataset path ')
tf.flags.DEFINE_string('dict_path', 'key_index.txt',
                       'The directpry of chardict')
tf.flags.DEFINE_string('model_dir', 'models', 'The model saved directory ')
tf.flags.DEFINE_string('log_dir', 'logs', 'The log saved directory ')
tf.flags.DEFINE_integer('batch_size', 8, 'The batch size')
tf.flags.DEFINE_integer('class_num', 7, 'Total class number')
tf.flags.DEFINE_integer('epoch', 100, 'The repeat times')
tf.flags.DEFINE_float('lr', 0.01, 'The learning rate')
tf.flags.DEFINE_string('gpu_id','0','the gpu id')


def get_index(index_file):
    #把标签值转为序号,把序号转为标签值
    value2key = {}
    key2value = {}
    with open(index_file, 'r') as f:
        for i, line in enumerate(f.readlines()):
            value2key[line.strip()] = i
            key2value[i] = line.strip()
    return value2key, key2value


def getdata(datapath, batch_size,v2k):
    #直接读取图片数据
    def parse(path, label):
        img = tf.read_file(path)
        img = tf.image.decode_jpeg(img, 3)
        img = tf.image.resize_images(img, (28, 28))
        img = tf.image.rgb_to_grayscale(img) # 转为灰度图
        img = tf.image.per_image_standardization(img) # 标准化
        # img = img*2/255.0-1
        return img, label
    path = []
    labels = []
    for label in os.listdir(datapath):
        imgs_path = glob.glob(os.path.join(datapath, label, '*.jpg'))
        path += imgs_path
        labels += [v2k[label]]*len(imgs_path)
    
    data = tf.data.Dataset.from_tensor_slices((path, labels))
    # tensorflow1.4
    data = data.map(parse).shuffle(batch_size*10)
    data = data.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
    # tensorflow1.10+
    # data = data.shuffle(1000).batch(batch_size, drop_remainder=True) 
    iterator = data.make_initializable_iterator()
    imgs, labels = iterator.get_next()
    return iterator, imgs, labels


def gettfdata(datapath, batch_size):
    从tfrecord中读取数据

    keys2features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
                     'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
                     'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64))
                     }

    def parse_example(example):
        features = tf.parse_single_example(example, keys2features)
        image = tf.image.decode_jpeg(features['image/encoded'], 1)
        image = tf.image.resize(image, (28, 28))
        image = tf.image.per_image_standardization(image)
        label = tf.cast(features["image/class/label"], tf.int32)
        return image, label
    data = tf.data.TFRecordDataset(datapath).map(parse_example)
    # tensorflow1.4
    data = data.map(parse).shuffle(batch_size*10)
    data = data.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
    # tensorflow1.10+
    # data = data.shuffle(1000).batch(batch_size, drop_remainder=True) 
    iterator = data.make_initializable_iterator()
    imgs, labels = iterator.get_next()
    return iterator, imgs, labels


def train(datapath, dictpath, logpath, modelpath, learning_rate, batch_size, epoch, class_num):
    logging.basicConfig(filename=os.path.join(logpath,'train.log'), filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S", level=logging.DEBUG)
    v2k, k2v = get_index(dictpath)
    
    # iterator, imgs, labels = gettfdata('cifar10/cifar10_train.tfrecord', batch_size)
    iterator, imgs, labels = getdata(datapath, batch_size,v2k)
    # 获取数据

    # labels_ph = tf.placeholder(dtype=tf.int32, shape=[batch_size])
    # images_ph = tf.placeholder(dtype=tf.float32, shape=[batch_size, 28, 28, 1])

    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
        net, end_points = resnet_v1.resnet_v1_50(
            imgs, class_num, is_training=True)
        end_points = tf.reshape(end_points['predictions'], (batch_size, -1))
        end_net = tf.reshape(net, (batch_size, -1))
        logits = tf.argmax(end_points, dimension=-1,
                           output_type=tf.int32)
        acc = tf.reduce_mean(tf.cast(tf.equal(labels, logits), tf.float32))
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=end_net, labels=labels))
        tf.summary.scalar('accuracy', acc)
        tf.summary.scalar('loss', loss)
        global_step = tf.Variable(0, name='global_step', trainable=False)
        decay_lr = tf.train.exponential_decay(
            learning_rate, global_step, 10000, 0.9)
        opt = tf.train.GradientDescentOptimizer(
            decay_lr).minimize(loss, global_step=global_step)

    saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=5)
    with tf.Session() as sess:
        kpt = tf.train.latest_checkpoint(modelpath)
        if kpt:
            step = int(kpt.split("-")[-1])
            saver.restore(sess, kpt)
        else:
            step = 0
            sess.run(tf.initialize_all_variables())
        merged = tf.summary.merge_all()
        writer = tf.summary.FileWriter(logpath, sess.graph)
        for epc in range(epoch):
            sess.run(iterator.initializer)
            logging.info("====the epoch%s is begin====" % epc)
            # print("====the epoch%s is begin====" % epc)
            while True:
                try:
                    # imgs_, labels_ = sess.run([imgs, labels])

                    if step > 0:
                        step -= 1
                        # print('跳过')
                        continue

                    # labels_ = [v2k[label] for label in labels_]

                    merged_, global_step_, opt_, acc_, loss_, decay_lr_ = sess.run(
                        [merged, global_step, opt, acc, loss, decay_lr])  # , feed_dict={labels_ph: labels_, images_ph: imgs_}

                    writer.add_summary(merged_, global_step=global_step_)

                    if global_step_ % 1000 == 0:
                        logging.info("global_step:{},decay_lr:{},loss:{},acc:{}".format(
                            global_step_, decay_lr_, loss_, acc_))
                        saver.save(sess, modelpath +
                                   "/model.ckpt", global_step=global_step)
                except tf.errors.OutOfRangeError:
                    logging.info("====the epoch%s is finished====" % epc)
                    break


def test(datapath, dictpath, logpath, modelpath, batch_size, class_num):
    logging.basicConfig(filename=os.path.join(logpath,'test.log'), filemode="w", format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S", level=logging.DEBUG)
    # iterator, imgs, labels = gettfdata('cifar10/cifar10_train.tfrecord', batch_size)
    v2k, k2v = get_index(dictpath)
    iterator, imgs, labels = getdata(datapath, batch_size,v2k)
    # 获取数据

    labels_ph = tf.placeholder(dtype=tf.int32, shape=[batch_size])
    images_ph = tf.placeholder(dtype=tf.float32, shape=[batch_size, 28, 28, 1])
    # v2k, k2v = get_index(dictpath)

    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
        net, end_points = resnet_v1.resnet_v1_50(
            images_ph, class_num, is_training=True)
        end_points = tf.reshape(end_points['predictions'], (batch_size, -1))
        end_net = tf.reshape(net, (batch_size, -1))
        logits = tf.argmax(end_points, dimension=-1,
                           output_type=tf.int32)
        acc = tf.reduce_mean(tf.cast(tf.equal(labels_ph, logits), tf.float32))
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=end_net, labels=labels_ph))
        tf.summary.scalar('accuracy', acc)
        tf.summary.scalar('loss', loss)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        kpt = tf.train.latest_checkpoint(modelpath)
        if kpt:
            saver.restore(sess, kpt)
        else:
            logging.info("model not found")
            os._exit(0)

        sess.run(iterator.initializer)
        logging.info("====the test is begin====")
        total_acc = 0
        total_loss = 0
        step = 0
        while True:
            try:
                imgs_, labels_ = sess.run(
                    [imgs, labels])

                # labels_ = [v2k[label] for label in labels_]

                acc_, loss_, logits_ = sess.run(
                    [acc, loss, logits], feed_dict={labels_ph: labels_, images_ph: imgs_})
                total_acc += acc_
                total_loss += loss_
                step += 1
                logstr = 'gt=>pre:'
                for i, label_ in enumerate(labels_):
                    logstr += k2v[label_]
                    logstr += '=>'
                    logstr += k2v[logits_.tolist()[i]]
                    logstr += ' '

                logging.info(logstr)
                logging.info(
                    "this batch acc is {},loss is {}".format(acc_, loss_))

            except tf.errors.OutOfRangeError:
                logging.info("====the test is finished====")
                break

        logging.info("avg_acc:{},avg_loss:{}".format(
            total_acc/step, total_loss/step))


if __name__ == "__main__":

    tf.reset_default_graph()
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id
    print('gpu_id:'+FLAGS.gpu_id)
    print('data_path:'+FLAGS.data_path+',dict_path:'+FLAGS.dict_path+',log_dir:'+FLAGS.log_dir+',model_dir:'+FLAGS.model_dir)
    if FLAGS.is_train:
        train(FLAGS.data_path, FLAGS.dict_path, FLAGS.log_dir, FLAGS.model_dir,
              FLAGS.lr, FLAGS.batch_size, FLAGS.epoch, FLAGS.class_num)
    else:
        test(FLAGS.data_path, FLAGS.dict_path, FLAGS.log_dir,
             FLAGS.model_dir, FLAGS.batch_size, FLAGS.class_num)

 

你可能感兴趣的:(resnet,字符识别,字符分类,分类,tensorflow)