Tensorflow应用之简单验证码识别

1.Tensorflow的安装方式:

这个项目中 采用了tensorflow 1.12.0版本(任意TF版本都能使用)
安装方式
pip install tensorflow

2.训练集训练集以及测试及如下(纯手工打造,所以数量不多):

Tensorflow应用之简单验证码识别_第1张图片
Tensorflow应用之简单验证码识别_第2张图片

3.实现代码


无论cpu版本 还是gpu都能用, 两千步训练(2分钟左右)就能训练完成


main.py(主要的神经网络代码)

import tensorflow as tf
import numpy as np
from PIL import Image
import os
import random

train_data_dir = r'd:\img\train' # 根据实际情况替换
test_data_dir = r'/Users/hupeng/Downloads/img/test'


def gen_train_data(batch_size=32):
    '''
    生成训练数据
    :param batch_size: 每次训练载入的图片得数目,默认为32
    :return: x_data:图片数据,shape=(batch_size, 24, 60),y_data:标签信息, shape=(batch_size, 4)
    '''
    train_file_name_list = os.listdir(train_data_dir)
    selected_train_file_name_list = random.sample(train_file_name_list, batch_size)
    x_data = []
    y_data = []
    for selected_train_file_name in selected_train_file_name_list:
        if selected_train_file_name.endswith('.gif'):
            captcha_image = Image.open(os.path.join(train_data_dir, selected_train_file_name))
            captcha_image_np = np.array(captcha_image)
            assert captcha_image_np.shape == (24, 60)
            captcha_image_np = np.expand_dims(captcha_image_np, 2)
            x_data.append(captcha_image_np)
            y_data.append(np.array(list(selected_train_file_name.split('.')[0])).astype(np.int32))
    x_data = np.array(x_data).astype(np.float)
    y_data = np.array(y_data)
    return x_data, y_data

X = tf.placeholder(tf.float32, name="input")
Y = tf.placeholder(tf.int32)
keep_prob = tf.placeholder(tf.float32)
y_one_hot = tf.one_hot(Y, 10, 1, 0)
y_one_hot = tf.cast(y_one_hot, tf.float32)

# keep_prob = 1.0
def net(w_alpha=0.01, b_alpha=0.1):
    '''
    网络部分,三层卷积层,一个全连接层
    :param w_alpha:
    :param b_alpha:
    :return: 网络输出,Tensor格式
    '''
    x_reshape = tf.reshape(X, (-1, 24, 60, 1))
    w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 16]))
    b_c1 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x_reshape, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv1 = tf.nn.dropout(conv1, keep_prob)

    w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
    b_c2 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv2 = tf.nn.dropout(conv2, keep_prob)

    w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
    b_c3 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv3 = tf.nn.dropout(conv3, keep_prob)

    # Fully connected layer
    # 随机生成权重
    w_d = tf.Variable(w_alpha * tf.random_normal([3 * 8 * 16, 128]))
    # 随机生成偏置
    b_d = tf.Variable(b_alpha * tf.random_normal([128]))
    dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))

    w_out = tf.Variable(w_alpha * tf.random_normal([128, 4 * 10]))
    b_out = tf.Variable(b_alpha * tf.random_normal([4 * 10]))
    out = tf.add(tf.matmul(dense, w_out), b_out)
    out = tf.reshape(out, (-1, 4, 10))
    return out
cnn = net()
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=cnn, labels=y_one_hot))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

def train():
    if not os.path.exists(train_data_dir):
        raise RuntimeError('训练数据目录不存在,请检查"%s"参数' % 'train_data_dir')
    print('开始执行训练')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        step = 0
        tf.global_variables_initializer().run()
        while True:
            x_data, y_data = gen_train_data(64)
            x_data = np.reshape(x_data, (-1))
            loss_, cnn_, y_one_hot_, optimizer_ = sess.run([loss, cnn, y_one_hot, optimizer],
                                                           feed_dict={Y: y_data, X: x_data, keep_prob: 0.75})
            print('step: %4d, loss: %.4f' % (step, loss_))
            if loss_ < 0.01:
                saver.save(sess, "./crack_capcha.model", global_step=step)
                print("训练完成,模型保存成功!")
                break
            step += 1
def gen_test_data():
    x_data = []
    y_data = []
    for parent, dirnames, filenames in os.walk(test_data_dir, followlinks=True):
        for filename in filenames:
            gif_file_path = os.path.join(parent, filename)
            if gif_file_path.endswith('.gif'):
                captcha_image = Image.open(gif_file_path)
                captcha_image_np = np.array(captcha_image)
                assert captcha_image_np.shape == (24, 60)
                captcha_image_np = np.expand_dims(captcha_image_np, 2).astype(np.float32)
                x_data.append(captcha_image_np)
                y_data.append(filename.split('.')[0])
    return x_data, y_data
def test():
    if not os.path.exists(test_data_dir):
        raise RuntimeError('测试数据目录不存在,请检查"%s"参数' % 'test_data_dir')
    if tf.train.latest_checkpoint('.') is None:
        raise RuntimeError('未找到模型文件,请先执行训练!')
    print('%s' % '开始执行测试')
    x, y = gen_test_data()
    print('测试目录文件数量:%d' % len(x))
    saver = tf.train.Saver()
    sum = 0
    correct = 0
    error = 0
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint('.'))
        for i, image in enumerate(x):
            answer = y[i]
            image = image.reshape((1, 24, 60, 1))
            cnn_out = sess.run(cnn, feed_dict={X: image, keep_prob: 1})
            # print(cnn_out)
            cnn_out = cnn_out[0]
            predict_vector = np.argmax(cnn_out, 1)
            predict = ''
            for c in predict_vector:
                predict += str(c)
            print('预测:%s,答案:%s,判定:%s' % (predict, answer, "√" if predict == answer else "×"))
            sum += 1
            if predict == answer:
                correct += 1
            else:
                error += 1
    print("总数:%d,正确:%d,错误:%d" % (sum, correct, error))
if __name__=='__main__':
    # 训练
    # train()
    # 测试
    test()

4.效果

在测试集上的识别率
Tensorflow应用之简单验证码识别_第3张图片

5.相关文件下载

训练集以及测试集 百度云下载

训练视频下载:链接: https://pan.baidu.com/s/1f0-_to6ynmfC-hsb0xnqMA 提取码: 4r7m (下载后观看更清晰)

你可能感兴趣的:(tensorflow)