# -*- coding: UTF-8 -*-

import os
import glob
import time
import numpy as np
import tensorflow as tf
import cv2
from skimage import io, transform
from tensorflow.python.framework import graph_util

strIndexDict = {'1hongtao2': 0, '1hongtao3': 1, '1hongtao4': 2, '1hongtao5': 3, '1hongtao6': 4, '1hongtao7': 5, '1hongtao8': 6, '1hongtao9': 7, '1hongtao10': 8, '1hongtaoJ': 9, '1hongtaoQ': 10, '1hongtaoK': 11, '1hongtaoA': 12,
                '2heitao2': 13, '2heitao3': 14, '2heitao4': 15, '2heitao5': 16, '2heitao6': 17, '2heitao7': 18, '2heitao8': 19, '2heitao9': 20, '2heitao10': 21, '2heitaoJ': 22, '2heitaoQ': 23, '2heitaoK': 24, '2heitaoA': 25,
                '3fangkuai2':26, '3fangkuai3':27, '3fangkuai4':28, '3fangkuai5':29, '3fangkuai6':30, '3fangkuai7':31, '3fangkuai8':32, '3fangkuai9':33, '3fangkuai10':34, '3fangkuaiJ':35, '3fangkuaiQ':36, '3fangkuaiK':37, '3fangkuaiA':38,
                '4yinghua2':39, '4yinghua3':40, '4yinghua4':41, '4yinghua5':42, '4yinghua6':43, '4yinghua7':44, '4yinghua8':45, '4yinghua9':46, '4yinghua10':47, '4yinghuaJ':48, '4yinghuaQ':49, '4yinghuaK':50, '4yinghuaA':51,
                '5dawang': 52, '5xiaowang': 53}


def convert_rgbtogray(img):
    width, height, channel = img.shape()
    grayImg = np.zeros((width, height))

    for row in range(height):
        for col in range(width):
            R = img[col, row, 2]
            G = img[col, row, 1]
            B = img[col, row, 0]
            grayImg[col, row] = R*0.299 + G*0.587 + B*0.114

    return grayImg

# 读取图片
def read_img(path, w, h):
    cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
    # print(cate)
    imgs = []
    labels = []
    print('Start read the image ...')
    for index, folder in enumerate(cate):
        # print(index, folder)
        indexStr = folder[folder.rfind('/')+1:]
        idx = strIndexDict[indexStr]
        for im in glob.glob(folder + '/*.jpg'):
            # print('Reading The Image: %s' % im)
            img = cv2.imread(im, cv2.IMREAD_GRAYSCALE)
            img = img * (1.0 / 255);
    print('Finished ...')
    return np.asarray(imgs, np.float32), np.asarray(labels, np.float32)

def read_img_data(filePathName):
    imgs = []
    labels = []

    file = open(filePathName)
    while 1:
        linedata = file.readline()
        if linedata == '':
        dList = linedata.split(",")
        img = []
        for i in range(len(dList)):
            if i == 0 or i> 1120:
            d = dList[i]
            #print(i, d)


    return np.asarray(imgs, np.float32), np.asarray(labels, np.int32)

def get_img_paths(path):
    cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
    # print(cate)
    imgpaths = []
    labels = []
    print('Start read the image ...')
    for index, folder in enumerate(cate):
        # print(index, folder)
        indexStr = folder[folder.rfind('/') + 1:]
        idx = strIndexDict[indexStr]
        for im in glob.glob(folder + '/*.jpg'):
    print('Finished ...')
    return imgpaths, labels

# 打乱顺序
def shuffleOrder(data, label):
    num_example = data.shape[0]
    arr = np.arange(num_example)
    data = data[arr]
    label = label[arr]

    return data, label

# 将所有数据分为训练集和验证集
def segmentation(data, label, ratio=0.8):
    num_example = data.shape[0]
    s = np.int(num_example * ratio)
    x_train = data[:s]
    y_train = label[:s]
    x_val = data[s:]
    y_val = label[s:]

    return x_train, y_train, x_val, y_val

# 定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batch_size]
            excerpt = slice(start_idx, start_idx + batch_size)

        # 生成label矩阵
        targets2 = targets[excerpt]

        # reshape inputs
        inputs2 = inputs[excerpt].reshape(batch_size, 20*56*1)

        yield inputs2, targets2

def reshapebatchs(inputs, labels):
    # 重新生成形状
    inputs_ = inputs.reshape((-1, 20*56*1))

    # 生成label矩阵

    labels_ = []
    for labels_i in range(len(labels)):
        labels_line = []
        for idx in range(54):
            if labels[labels_i] == idx:

    return inputs_, np.asarray(labels_, np.int32)

# weight initialization
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1, seed=1)
    return tf.Variable(initial)
def bias_variable(shape):
    initial = tf.constant(0.1, shape = shape)
    return tf.Variable(initial)

# convolution
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# pooling
def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

def loss(logits,label_batches):
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=label_batches)
    cost = tf.reduce_mean(cross_entropy)
    return cost

def startTrain():
    # start tensorflow interactiveSession
    sess = tf.InteractiveSession()

    # Create the model
    # placeholder
    x = tf.placeholder("float", [None, 1120], name='input')
    y_ = tf.placeholder("float", [None, 54])

    # first convolutinal layer
    w_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])

    x_image = tf.reshape(x, [-1, 56, 20, 1])

    h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    # second convolutional layer
    w_conv2 = weight_variable([5, 5, 32, 64])
    b_conv2 = bias_variable([64])

    h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    # densely connected layer
    w_fc1 = weight_variable([14 * 5 * 64, 1024])
    b_fc1 = bias_variable([1024])

    h_pool2_flat = tf.reshape(h_pool2, [-1, 14 * 5 * 64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

    # dropout
    keep_prob = tf.placeholder("float", name="keep_prob")
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    # readout layer
    w_fc2 = weight_variable([1024, 54])
    b_fc2 = bias_variable([54])

    y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2, name="softmax")

    # train and evaluate the model
    # loss = tf.losses.softmax_cross_entropy(y_, y_conv)
    cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
    train_step = tf.train.GradientDescentOptimizer(1e-4).minimize(cross_entropy)
    argmaxy = tf.argmax(y_conv, 1, name="output")
    argmaxy_ = tf.argmax(y_, 1)
    correct_prediction = tf.equal(argmaxy, argmaxy_)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

    # defind saver
    saver = tf.train.Saver(max_to_keep=100)

    # get data
    imgpath = './charSamples/'
    w = 20
    h = 56
    c = 1
    ratio = 0.8  # 选取训练集的比例
    data, label = read_img(path=imgpath, w=w, h=h)
    #data, label = read_img_data(imgpath)
    data, label = shuffleOrder(data=data, label=label)
    x_train, y_train, x_val, y_val = segmentation(data=data, label=label, ratio=ratio)
    n_epoch = 200
    batch_size = 128


    # 加载上次的数据
    latestStep = 0
    latestPath = tf.train.latest_checkpoint('model/')
    if latestPath != None:
        latestStep = int(latestPath[latestPath.rfind("-") + 1:])
        latestMeta = latestPath + ".meta"
        print(" latestPath %s, latestStep %d, latestMeta %s" % (latestPath, latestStep, latestMeta))
        saver_r = tf.train.import_meta_graph(latestMeta)
        saver_r.restore(sess, latestPath)

    for i in range(n_epoch):

        i = i + 1 + latestStep
        t_start = time.clock()

        for batch0_, batch1_ in minibatches(x_train, y_train, batch_size, shuffle=True):
            batch0, batch1 = reshapebatchs(batch0_, batch1_)
            train_step.run(feed_dict={x: batch0, y_: batch1, keep_prob: 0.5})

        train_accuracy = accuracy.eval(feed_dict={x: batch0, y_: batch1, keep_prob: 1.0})
        losses = cross_entropy.eval(feed_dict={x: batch0, y_: batch1, keep_prob: 1.0})

        t_spent = time.clock() - t_start

        print("step %d, train accuracy %g, loss: %g, spent %f" % (i, train_accuracy, losses, t_spent))

        if i % 10 == 0:
            saver.save(sess, "model/poker", global_step=i)
            print("save model")

    x_val_, y_val_ = reshapebatchs(x_val, y_val)
    print("test accuracy %g" % accuracy.eval(feed_dict={x: x_val_, y_: y_val_, keep_prob: 1.0}))

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["softmax"])
    with tf.gfile.FastGFile("model/poker.pb", mode='wb') as f:

def predict2(img):
    img.resize((1, 1120))
    img = img * (1.0 / 255)

    sess = tf.Session()

    latestPath = tf.train.latest_checkpoint('model/')
    if latestPath != None:
        latestMeta = latestPath + ".meta"
        print(" latestPath %s, latestMeta %s" % (latestPath, latestMeta))
        saver_r = tf.train.import_meta_graph(latestMeta)
        saver_r.restore(sess, latestPath)
    # 获取placeholder变量
    x = sess.graph.get_tensor_by_name("input:0")
    keep_prob = sess.graph.get_tensor_by_name("keep_prob:0")
    out_softmax = sess.graph.get_tensor_by_name("softmax:0")
    out_label = sess.graph.get_tensor_by_name("output:0")

    img_out_softmax = sess.run(out_softmax, feed_dict={x: img, keep_prob:1.0})
    print("img_out_softmax:", img_out_softmax)

    prediction_labels = np.argmax(img_out_softmax, axis=1)
    print("label:", prediction_labels)
