机器翻译:西班牙文「---」英文

# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals   # 把下一个新版本的特性导入到当前版本
from sklearn.model_selection import train_test_split
import tensorflow.compat.v1 as tf
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import numpy as np
import unicodedata
import re   # 正则处理
### 代码.     匹配除换行符以外的任意字符
### 代码\w    匹配字母/数字/下划线/汉字
### 代码\s    匹配任意的空白符
### 代码\d    匹配数字
### 代码\b    匹配单词的开始或结束
### 代码^     匹配字符串的开始
### 代码$     匹配字符串的结束
### 代码*     重复零次或更多次,优先更多
### 代码+     重复一次或更多次,优先更多
### 代码?    重复零次或一次,优先一次
### 代码{n}   重复n次
### 代码{n,}  重复n次或更多次
### 代码{n,m} 重复n到m次
### 
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("^匹配",str)) #字符串开始位置与匹配规则符合就匹配且打印匹配内容,否则不匹配,返回值是list
### ['匹配']
###
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("[^a-z]",str)) #反取,匹配出除字母外的字符,返回值是list
### ['匹', '配', '规', '则', '这', '个', '字', '符', '串', '是', '否', '匹', '配', '规', '则', '则', '则', '则', '则']
### 
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("则$",str)) #字符串结束位置与则符合就匹配,否则不匹配,返回值是list
### ['则']
### 
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("则*",str)) #星号前面的一个字符可以是0次或多次,返回值是list
### print(re.findall("规则*",str)) #星号前面的一个字符可以是0次或多次,返回值是list
### ['', '', '', '', '则', '', '', '', '', '', '', '', '', '', '', '', '', '', '则则', '', '', '则则则', '']
### ['规则', '规则则']
### 
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("则+",str)) #加号前面的一个字符可以是1次或多次,返回值是list
### print(re.findall("规则+",str)) #加号配前面的一个字符可以是1次或多次,返回值是list
### ['则', '则则', '则则则']
### ['规则', '规则则']
###
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("则?",str)) #问号前面的一个字符可以是0次或1次,返回值是list
### print(re.findall("规则?",str)) #问号前面的一个字符可以是0次或1次,返回值是list
### ['', '', '', '', '则', '', '', '', '', '', '', '', '', '', '', '', '', '', '则', '则', '', '', '则', '则', '则', '']
### ['规则', '规则']
### 
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("则{2}",str)) #匹配前一个字符2次,返回值是list
### print(re.findall("规则{1,2}",str)) #匹配前一个字符1-2次,返回值是list
### ['则则', '则则']
### ['规则', '规则则']
### 
### str="匹配s规则这s个字符串是否s匹配f规则则re则则则"
### print(re.findall("匹配[s,f]规则",str)) #匹配字符后,只有符合[]中任意字符均可,返回值是list
### ['匹配s规则', '匹配f规则']
### 
### str="匹配s规则这s个字符串4是否s匹配3f规则则re则则2则"
### print(re.findall("\d",str)) #匹配字符串所有的数字,返回值是list
### ['4', '3', '2']
### 
### str="匹配s规则这s个字符串455是否s匹配3f规则则re则则2则"
### print(re.findall("\d+",str)) #匹配字符串中一位或多位数字,返回值是list
### ['455', '3', '2']
### 
### str="匹配s规则这s个字符串455是否s匹配3f规则则re则则2则"
### print(re.findall("\D",str)) #匹配字符串中非数字,返回值是list
### ['匹', '配', 's', '规', '则', '这', 's', '个', '字', '符', '串', '是', '否', 's', '匹', '配', 'f', '规', '则', '则', 'r', 'e', '则', '则', '则']
### 
### str="匹配s规则这s个字 符 串 \n \t \f \v455是否s匹配3f规则则re则则2则"
### print(re.findall("\s",str)) #匹配字符串空白字符(\t\n\r\f\v),返回值是list
### [' ', ' ', ' ', '\n', ' ', '\t', ' ', '\x0c', ' ', '\x0b']
### 
### str="匹配s规则这s个字 符 串 \n \t \f \v455是"
### print(re.findall("\S",str)) #匹配字符串非空白字符(\t\n\r\f\v),返回值是list
### ['匹', '配', 's', '规', '则', '这', 's', '个', '字', '符', '串', '4', '5', '5', '是']
### 
### str="匹配s规则这s个_字 S符 串-455是"
### print(re.findall("\w",str)) #匹配字符串下划线,汉字,字母,数字,返回值是list
### ['匹', '配', 's', '规', '则', '这', 's', '个', '_', '字', 'S', '符', '串', '4', '5', '5', '是']
### 
### str="匹配s规则这s个_字 S符 串-455是"
### print(re.findall("\W",str)) #匹配字符串非下划线,汉字,字母,数字,返回值是list
### [' ', ' ', '-']
### 
### str="a3a3ddd"
### print(re.search("(a3)+",str).group()) #匹配一个或多个a3
### a3a3
### 
### str="a3死a3d有dd"
### print(re.findall(r"死|有+",str)) #匹配|前后一个字符均可
### ['死', '有']
### 
### str="hello egon bcd egon lge egon acd 19"
### r=re.match("h\w+",str) #match,从起始位置开始匹配,匹配成功返回一个对象,未匹配成功返回None,非字母,汉字,数字及下划线分割
### print(r.group()) # 获取匹配到的所有结果,不管有没有分组将匹配到的全部拿出来
### print(r.groups()) # 获取模型中匹配到的分组结果,只拿出匹配到的字符串中分组部分的结果
### print(r.groupdict())  # 获取模型中匹配到的分组结果,只拿出匹配到的字符串中分组部分定义了key的组结果
### hello
### ()
### {}
###
### r2=re.match("h(\w+)",str) #match,从起始位置开始匹配,匹配成功返回一个对象,未匹配成功返回None
### print(r2.group())
### print(r2.groups())
### print(r2.groupdict())
### hello
### ('ello',)
### {}
### 
### r3=re.match("(?Ph)(?P\w+)",str)  #?P<>定义组里匹配内容的key(键),<>里面写key名称,值就是匹配到的内容
### print(r3.group())
### print(r3.groups())
### print(r3.groupdict())
### hello
### ('h', 'ello')
### {'n1': 'h', 'n2': 'ello'}
### 
### str="hello egon bcd egon lge egon acd 19"
### r=re.search("h\w+",str) #match,从起始位置开始匹配,匹配成功返回一个对象,未匹配成功返回None,非字母,汉字,数字及下划线分割
### print(r.group()) # 获取匹配到的所有结果,不管有没有分组将匹配到的全部拿出来
### print(r.groups()) # 获取模型中匹配到的分组结果,只拿出匹配到的字符串中分组部分的结果
### print(r.groupdict())  # 获取模型中匹配到的分组结果,只拿出匹配到的字符串中分组部分定义了key的组结果
### hello
### ()
### {}
### 
### r2=re.search("h(\w+)",str) #match,从起始位置开始匹配,匹配成功返回一个对象,未匹配成功返回None
### print(r2.group())
### print(r2.groups())
### print(r2.groupdict())
### hello
### ('ello',)
### {}
### 
### r3=re.search("(?Ph)(?P\w+)",str)  #?P<>定义组里匹配内容的key(键),<>里面写key名称,值就是匹配到的内容
### print(r3.group())
### print(r3.groups())
### print(r3.groupdict())
### hello
### ('h', 'ello')
### {'n1': 'h', 'n2': 'ello'}
### 
### r=re.findall("\d+\w\d+","a2b3c4d5") #浏览全部字符串,匹配所有合规则的字符串,匹配到的字符串方到一个列表中
### print(r)
### ['2b3', '4d5'] #匹配成功的字符串,不再参与下次匹配,所以3c4也符合规则但是没有匹配到
### 
### r=re.findall("","a2b3c4d5") #浏览全部字符串,匹配所有合规则的字符串,匹配到的字符串方到一个列表中
### print(r)
### ['', '', '', '', '', '', '', '', ''] #如果没有写匹配规则,也就是空规则,返回的是一个比原始字符串多一位的空字符串列表,如上是8个字符,返回是9个空字符
### 
### r=re.findall("(ca)*","ca2b3caa4d5") #浏览全部字符串,匹配所有合规则的字符串,匹配到的字符串方到一个列表中
### print(r)
### ['ca', '', '', '', 'ca', '', '', '', '', '']#用*号会匹配出空字符
### 
### r=re.findall("a\w+","ca2b3 caa4d5") #浏览全部字符串,匹配所有合规则的字符串,匹配到的字符串方到一个列表中
### print(r)
### ['a2b3', 'aa4d5']#匹配所有合规则的字符串,匹配到的字符串放入列表
### 
### r=re.findall("a(\w+)","ca2b3 caa4d5") #有分组:只将匹配到的字符串里,组的部分放到列表里返回
### print(r)
### ['2b3', 'a4d5']#返回匹配到组里的内容返回
### 
### r=re.findall("(a)(\w+)","ca2b3 caa4d5") #有多分组:只将匹配到的字符串里,组的部分放到一个元组中,最后将所有元组放到一个列表里返回
### print(r)
### [('a', '2b3'), ('a', 'a4d5')]#返回的是多维数组
### 
### r=re.findall("(a)(\w+(b))","ca2b3 caa4b5") #分组中有分组:只将匹配到的字符串里,组的部分放到一个元组中,先将包含有组的组,看作一个整体也就是一个组,把这个整体组放入一个元组里,然后在把组里的组放入一个元组,最后将所有组放入一个列表返回
### print(r)
### [('a', '2b', 'b'), ('a', 'a4b', 'b')]#返回的是多维数组
### 
### r=re.findall("a(?:\w+)","a2b3 a4b5 edd") #?:在有分组的情况下,不只拿分组里的字符串,拿所有匹配到的字符串,注意?:只用于不是返回正则对象的函数如findall()
### print(r)
### ['a2b3', 'a4b5']
### 
### r=re.split("a\w","sdfadfdfadsfsfafsff")
### print(r)
### r2=re.split("a\w","sdfadfdfadsfsfafsff",maxsplit=2)
### print(r2)
### ['sdf', 'fdf', 'sfsf', 'sff']
### ['sdf', 'fdf', 'sfsfafsff']
### 
### r=re.sub("a\w","替换","sdfadfdfadsfsfafsff")
### print(r)
### sdf替换fdf替换sfsf替换sff
### 
### a,b=re.subn("a\w","替换","sdfadfdfadsfsfafsff") #替换匹配成功的指定位置字符串,并且返回替换次数,可以用两个变量分别接受
### print(a) #返回替换后的字符串
### print(b) #返回替换次数
### sdf替换fdf替换sfsf替换sff

import os
import io
import time

tf.disable_v2_behavior()
tf.enable_eager_execution()

### 下载文件
path_to_zip = tf.keras.utils.get_file('spa-eng.zip', origin='https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',extract=True)
path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"

### 字符转换
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')
    ### normalize() 第一个参数指定字符串标准化的方式,NFC表示字符应该是整体组成,而NFD表示字符应该分解为多个组合字符表示。
    ### s1 = 'Spicy Jalape\u00f1o'
    ### s2 = 'Spicy Jalapen\u0303o'
    ### s1
    ### 'Spicy Jalape?o'
    ### s2
    ### 'Spicy Jalape?o'
    ### s1 == s2
    ### False
    ### len(s1)
    ### 14
    ### len(s2)
    ### 15
    ### t1 = unicodedata.normalize('NFC', s1)
    ### t2 = unicodedata.normalize('NFC', s2)
    ### t1 == t2
    ### True
    ### print(ascii(t1))
    ### 'Spicy Jalape\xf1o'
    ### t3 = unicodedata.normalize('NFD', s1)
    ### t4 = unicodedata.normalize('NFD', s2)
    ### t3 == t4
    ### True
    ### print(ascii(t3))
    ### 'Spicy Jalapen\u0303o'

def preprocess_sentence(w):
    w = unicode_to_ascii(w.lower().strip())
    w = re.sub(r"([?.!,?])", r" \1 ", w)   # 在标点和单词间增加空格
    w = re.sub(r'[" "]+', " ", w)
    w = re.sub(r"[^a-zA-Z?.!,?]+", " ", w)   # 除指定字符外,其他都用空格替换
    w = w.rstrip().strip()   # 删除末尾空格,开头空格
    w = ' ' + w + ' '
    return w

en_sentence = u"May I borrow this book?"
sp_sentence = u"?Puedo tomar prestado este libro?"
print("英文:    ",preprocess_sentence(en_sentence))
print("西班牙文:",preprocess_sentence(sp_sentence).encode('utf-8'))

### 返回单词组
def create_dataset(path, num_examples):
    lines = io.open(path, encoding='UTF-8').read().strip().split('\n')
    word_pairs = [[preprocess_sentence(w) for w in l.split('\t')]  for l in lines[:num_examples]]
    return zip(*word_pairs)

en, sp = create_dataset(path_to_file, None)
print("英文:    ",en[-1])
print("西班牙文:",sp[-1])

def max_length(tensor):
    return max(len(t) for t in tensor)

def tokenize(lang):
    lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
    lang_tokenizer.fit_on_texts(lang)
    tensor = lang_tokenizer.texts_to_sequences(lang)
    tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,padding='post')
    return tensor, lang_tokenizer

def load_dataset(path, num_examples=None):
    targ_lang, inp_lang = create_dataset(path, num_examples)
    input_tensor, inp_lang_tokenizer = tokenize(inp_lang)
    target_tensor, targ_lang_tokenizer = tokenize(targ_lang)
    return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer

num_examples = 30000
input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path_to_file, num_examples)
max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)
input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)

print("训练集长度:",len(input_tensor_train), "训练集值长度:",len(target_tensor_train), "验证集长度:",len(input_tensor_val), "验证值集长度:",len(target_tensor_val))

def convert(lang, tensor):
    for t in tensor:
        if t!=0:
            print ("%s ---------> %s" % (fixedlen(str(t)), lang.index_word[t]))

def fixedlen(inputstr):
    if len(inputstr)<10:
        inputstr=" "*(10-len(inputstr))+inputstr
    return inputstr

print ("输入语种,索引与单词映射")
convert(inp_lang,   input_tensor_train[0])
print ()
print ("输出语种,索引与单词映射")
convert(targ_lang, target_tensor_train[0])

BUFFER_SIZE = len(input_tensor_train)
BATCH_SIZE = 64
steps_per_epoch = len(input_tensor_train)//BATCH_SIZE
embedding_dim = 256
units = 1024
vocab_inp_size = len(inp_lang.word_index)+1
vocab_tar_size = len(targ_lang.word_index)+1

dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
example_input_batch, example_target_batch = next(iter(dataset))
example_input_batch.shape, example_target_batch.shape

class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
        super(Encoder, self).__init__()
        self.batch_sz = batch_sz
        self.enc_units = enc_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(self.enc_units,return_sequences=True,return_state=True,recurrent_initializer='glorot_uniform')

    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state = hidden)
        return output, state

    def initialize_hidden_state(self):
        return tf.zeros((self.batch_sz, self.enc_units))

class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V  = tf.keras.layers.Dense(1)

    def call(self, query, values):
        hidden_with_time_axis = tf.expand_dims(query, 1)
        score = self.V(tf.nn.tanh(self.W1(values) + self.W2(hidden_with_time_axis)))
        attention_weights = tf.nn.softmax(score, axis=1)
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.dec_units = dec_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(self.dec_units,return_sequences=True,return_state=True,recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size)
        self.attention = BahdanauAttention(self.dec_units)   # 使用注意力

    def call(self, x, hidden, enc_output):
        context_vector, attention_weights = self.attention(hidden, enc_output)
        x = self.embedding(x)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        output, state = self.gru(x)
        output = tf.reshape(output, (-1, output.shape[2]))
        x = self.fc(output)
        return x, state, attention_weights

def loss_function(real, pred):
    mask   = tf.math.logical_not(tf.math.equal(real, 0))
    loss_  = loss_object(real, pred)
    mask   = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_mean(loss_)

@tf.function   # 装饰器来将python代码转成图表示代码
def train_step(inp, targ, enc_hidden):
    loss = 0
    with tf.GradientTape() as tape:
        enc_output, enc_hidden = encoder(inp, enc_hidden)
        dec_hidden = enc_hidden
        dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)
        for t in range(1, targ.shape[1]):
            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
            loss += loss_function(targ[:, t], predictions)
            dec_input = tf.expand_dims(targ[:, t], 1)
    batch_loss = (loss / int(targ.shape[1]))
    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    return batch_loss

def evaluate(sentence):
    attention_plot = np.zeros((max_length_targ, max_length_inp))
    sentence = preprocess_sentence(sentence)
    inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]
    inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],maxlen=max_length_inp,padding='post')
    inputs = tf.convert_to_tensor(inputs)
    result = ''
    hidden = [tf.zeros((1, units))]
    enc_out, enc_hidden = encoder(inputs, hidden)
    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([targ_lang.word_index['']], 0)
    for t in range(max_length_targ):
        predictions, dec_hidden, attention_weights = decoder(dec_input,dec_hidden,enc_out)
        attention_weights = tf.reshape(attention_weights, (-1, ))
        attention_plot[t] = attention_weights.numpy()
        predicted_id = tf.argmax(predictions[0]).numpy()
        result += targ_lang.index_word[predicted_id] + ' '
        if targ_lang.index_word[predicted_id] == '':
            return result, sentence, attention_plot
        dec_input = tf.expand_dims([predicted_id], 0)
    return result, sentence, attention_plot

def plot_attention(attention, sentence, predicted_sentence):
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')
    fontdict = {'fontsize': 14}
    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    plt.show()

def translate(sentence):
    result, sentence, attention_plot = evaluate(sentence)
    print('原文:%s' % (sentence))
    print('译文:{}'.format(result))
    attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
    plot_attention(attention_plot, sentence.split(' '), result.split(' '))
    # restoring the latest checkpoint in checkpoint_dir

encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
sample_hidden = encoder.initialize_hidden_state()
sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)
print ('编码器输出结构: (批大小, 序列长度, 单元) {}'.format(sample_output.shape))
print ('编码器隐藏层状态结构: (批大小, 单元) {}'.format(sample_hidden.shape))

attention_layer = BahdanauAttention(10)
attention_result, attention_weights = attention_layer(sample_hidden, sample_output)
print("注意力结果结构:(批大小,单元) {}".format(attention_result.shape))
print("注意力权重结构:(批大小,序列长度, 1) {}".format(attention_weights.shape))

decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)
sample_decoder_output, _, _ = decoder(tf.random.uniform((64, 1)),sample_hidden, sample_output)
print ('解码器输出结构: (批大小, 词汇大小) {}'.format(sample_decoder_output.shape))

optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

checkpoint_dir   = './machinetranslation.checkpoints'
checkpoint_prefix= os.path.join(checkpoint_dir, "ckpt")
checkpoint       = tf.train.Checkpoint(optimizer=optimizer,encoder=encoder,decoder=decoder)
checkpointmanager= tf.train.CheckpointManager(checkpoint, directory=checkpoint_dir, checkpoint_name='ckpt', max_to_keep=1)

if os.listdir(checkpoint_dir):
    print("--------------------加载已训练模型--------------------")
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

TrainYES=False
time.sleep(60)
STRpredicted=""
INTinitloss=100
epoch=0
while True:
    if not TrainYES:
        os.system("clear")
        print("机器翻译:西班牙文 <<<--->>> 英文")
        print("        Train..............训练数据",STRpredicted)
        print("        Translate..........翻译语句")
        print("        Quit...............退出系统")
        STRinput=input("        >>>>>>>>请输入选择项:")
        STRinput=STRinput.upper()                                                   # 将输入项转换为大写
    if STRinput=="TRAIN":
        STRstarttime=">>>>>>>>>>本轮训练开始时间:"+time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 
        print(STRstarttime)
        start = time.time()
        enc_hidden = encoder.initialize_hidden_state()
        total_loss = 0
        for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
            batch_loss = train_step(inp, targ, enc_hidden)
            total_loss += batch_loss
            if batch % 10 == 0:
                print('训练轮次:{:>4d},批次:{:>10d},损失率:{:.10f}'.format(epoch + 1,batch,batch_loss.numpy()))
        if total_loss / steps_per_epoch < INTinitloss:
            INTinitloss=total_loss / steps_per_epoch
            checkpointmanager.save()
            ### checkpoint.save(file_prefix = checkpoint_prefix)
        STRendtime="<<<<<<<<<<本轮训练结束时间:"+time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        print(STRendtime)
        print('训练轮次:{:>4d},损失率:{:.10f},耗时:{:.10f}秒\n'.format(epoch + 1, total_loss / steps_per_epoch, time.time() - start ))
        STRpredicted=".....训练轮次:"+str(epoch + 1)+",损失率:"+str(total_loss / steps_per_epoch)+",耗时:"+str(time.time() - start)+"秒"
        epoch+=1
        if epoch < 5:
            TrainYES=True
        else:
            TrainYES=False
    elif STRinput=="TRANS" or STRinput=="TRANSLATE":
        if os.listdir(checkpoint_dir):
            print("--------------------加载已训练模型--------------------")
            checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
        while True:
            print("输入样例:hace mucho frio aqui. esta es mi vida. ?todavia estan en casa?")
            STRinput1=input("请输入待翻译语句(Exit返回):")
            if STRinput1.upper()=="EXIT" or STRinput1.upper()=="E":
                break
            elif len(STRinput1)==0:
                STRinput1=u'hace mucho frio aqui. esta es mi vida. ?todavia estan en casa?'
            else:
                STRinput1=u"'"+STRinput1+"'"
            translate(STRinput1)
    elif STRinput=="QUIT" or STRinput=="Q":
        break
    else:
        continue

你可能感兴趣的:(tensorflow,Python)