这个项目中 采用了tensorflow 1.12.0版本(任意TF版本都能使用)
安装方式
pip install tensorflow
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import random
train_data_dir = r'd:\img\train' # 根据实际情况替换
test_data_dir = r'/Users/hupeng/Downloads/img/test'
def gen_train_data(batch_size=32):
'''
生成训练数据
:param batch_size: 每次训练载入的图片得数目,默认为32
:return: x_data:图片数据,shape=(batch_size, 24, 60),y_data:标签信息, shape=(batch_size, 4)
'''
train_file_name_list = os.listdir(train_data_dir)
selected_train_file_name_list = random.sample(train_file_name_list, batch_size)
x_data = []
y_data = []
for selected_train_file_name in selected_train_file_name_list:
if selected_train_file_name.endswith('.gif'):
captcha_image = Image.open(os.path.join(train_data_dir, selected_train_file_name))
captcha_image_np = np.array(captcha_image)
assert captcha_image_np.shape == (24, 60)
captcha_image_np = np.expand_dims(captcha_image_np, 2)
x_data.append(captcha_image_np)
y_data.append(np.array(list(selected_train_file_name.split('.')[0])).astype(np.int32))
x_data = np.array(x_data).astype(np.float)
y_data = np.array(y_data)
return x_data, y_data
X = tf.placeholder(tf.float32, name="input")
Y = tf.placeholder(tf.int32)
keep_prob = tf.placeholder(tf.float32)
y_one_hot = tf.one_hot(Y, 10, 1, 0)
y_one_hot = tf.cast(y_one_hot, tf.float32)
# keep_prob = 1.0
def net(w_alpha=0.01, b_alpha=0.1):
'''
网络部分,三层卷积层,一个全连接层
:param w_alpha:
:param b_alpha:
:return: 网络输出,Tensor格式
'''
x_reshape = tf.reshape(X, (-1, 24, 60, 1))
w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 16]))
b_c1 = tf.Variable(b_alpha * tf.random_normal([16]))
conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x_reshape, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
conv1 = tf.nn.max_pool(conv1, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME')
conv1 = tf.nn.dropout(conv1, keep_prob)
w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
b_c2 = tf.Variable(b_alpha * tf.random_normal([16]))
conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv2 = tf.nn.dropout(conv2, keep_prob)
w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
b_c3 = tf.Variable(b_alpha * tf.random_normal([16]))
conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv3 = tf.nn.dropout(conv3, keep_prob)
# Fully connected layer
# 随机生成权重
w_d = tf.Variable(w_alpha * tf.random_normal([3 * 8 * 16, 128]))
# 随机生成偏置
b_d = tf.Variable(b_alpha * tf.random_normal([128]))
dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
w_out = tf.Variable(w_alpha * tf.random_normal([128, 4 * 10]))
b_out = tf.Variable(b_alpha * tf.random_normal([4 * 10]))
out = tf.add(tf.matmul(dense, w_out), b_out)
out = tf.reshape(out, (-1, 4, 10))
return out
cnn = net()
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=cnn, labels=y_one_hot))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
def train():
if not os.path.exists(train_data_dir):
raise RuntimeError('训练数据目录不存在,请检查"%s"参数' % 'train_data_dir')
print('开始执行训练')
saver = tf.train.Saver()
with tf.Session() as sess:
step = 0
tf.global_variables_initializer().run()
while True:
x_data, y_data = gen_train_data(64)
x_data = np.reshape(x_data, (-1))
loss_, cnn_, y_one_hot_, optimizer_ = sess.run([loss, cnn, y_one_hot, optimizer],
feed_dict={Y: y_data, X: x_data, keep_prob: 0.75})
print('step: %4d, loss: %.4f' % (step, loss_))
if loss_ < 0.01:
saver.save(sess, "./crack_capcha.model", global_step=step)
print("训练完成,模型保存成功!")
break
step += 1
def gen_test_data():
x_data = []
y_data = []
for parent, dirnames, filenames in os.walk(test_data_dir, followlinks=True):
for filename in filenames:
gif_file_path = os.path.join(parent, filename)
if gif_file_path.endswith('.gif'):
captcha_image = Image.open(gif_file_path)
captcha_image_np = np.array(captcha_image)
assert captcha_image_np.shape == (24, 60)
captcha_image_np = np.expand_dims(captcha_image_np, 2).astype(np.float32)
x_data.append(captcha_image_np)
y_data.append(filename.split('.')[0])
return x_data, y_data
def test():
if not os.path.exists(test_data_dir):
raise RuntimeError('测试数据目录不存在,请检查"%s"参数' % 'test_data_dir')
if tf.train.latest_checkpoint('.') is None:
raise RuntimeError('未找到模型文件,请先执行训练!')
print('%s' % '开始执行测试')
x, y = gen_test_data()
print('测试目录文件数量:%d' % len(x))
saver = tf.train.Saver()
sum = 0
correct = 0
error = 0
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('.'))
for i, image in enumerate(x):
answer = y[i]
image = image.reshape((1, 24, 60, 1))
cnn_out = sess.run(cnn, feed_dict={X: image, keep_prob: 1})
# print(cnn_out)
cnn_out = cnn_out[0]
predict_vector = np.argmax(cnn_out, 1)
predict = ''
for c in predict_vector:
predict += str(c)
print('预测:%s,答案:%s,判定:%s' % (predict, answer, "√" if predict == answer else "×"))
sum += 1
if predict == answer:
correct += 1
else:
error += 1
print("总数:%d,正确:%d,错误:%d" % (sum, correct, error))
if __name__=='__main__':
# 训练
# train()
# 测试
test()
训练集以及测试集 百度云下载
训练视频下载:链接: https://pan.baidu.com/s/1f0-_to6ynmfC-hsb0xnqMA 提取码: 4r7m (下载后观看更清晰)