Tensorflow入门(二)文本自动生成

参考:简单粗暴Tensorflow https://tf.wiki/zh/models.html#id8

class DataLoader():
    def __init__(self):
        path=tf.keras.utils.get_file('nietzsche.txt',
            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path,encoding='utf-8') as f:#文件的安全打开方式,避免读写异常处理和异常发生后文件句柄无法关闭
            self.raw_text=f.read().lower()    #lower()是字符串的内置方法,将所有大写字母转换为小写
        self.chars=sorted(list(set(self.raw_text)))#set是集合,用于文本去重
        self.char_indices=dict((c,i)for i,c in enumerate(self.chars))#建立字符编号对应关系表
        self.indices_char=dict((i,c)for i,c in enumerate(self.chars))
        self.text=[self.char_indices[c]for c in self.raw_text]#根据对应关系表将文本字符编号
        
    def get_batch(self,seq_length,batcha_size):
        seq=[]
        next_char=[]
        for i in range(batch_size):
            index=np.random.randint(0,len(self.text)-seq_length)##理解:随机选取的样本范围是index~index+seq_length,那么
            seq.append(self.text[index:index+seq_length])       ##index的最大值即为len(self.text)-seq_length
            next_char.append(self.text[index+seq_length])
        return np.array(seq),np.array(next_char)
        
class RNN(tf.keras.Model):
    def __init__(self,num_chars):
        super().__init__()
        self.num_chars=num_chars
        self.cell=tf.nn.rnn.cell.BasicLSTMCell(num_units=256)
        self.dense=tf.keras.layers.Dense(units=self.num_chars)
    
    def call(self,inputs):
        batch_size,seq_length=tf.shape(inputs)
        #与cs231n的softamx中one_hot的作用不同,softmax中one_hot只是作用于标签,这里one_hot的作用是将样本编码
        inputs=tf.one_hot(inputs,depth=self.num_chars)
        state=self.cell.zero_state(batch_size=batch_size,dtype=tf.float32)
        for t in range(seq_length.numpy()):
            output,state=self.cell(inputs[:,t,:],state)
        output=self.dense(output)
        return output
    
    def predict(self,inputs,temperature=1.):
        batch_size,_=tf.shape(inputs)
        logits=self(inputs)#logits.shape=[batch_size,num_chars]
        #temperature控制分布的形状,参数越大,分布越平缓(最大值和最小值的差值越小),生成文本的丰富度越高
        prob=tf.nn.softmax(logits/temperature).numpy()
        #np.random.choice(a,size=1,replace=true,p=prob)  从[0,a)中以概率p随机有放回(可能重复)返回size个数
        return np.array([np.random.choice(self.num_chars,p=prob[i,:])for i in range(batch_size.numpy())])
    
data_loader=DataLoader()
model=RNN(len(data_loader.chars))

learning_rate=0.01
num_batches=100
batch_size=50
seq_length=50
optimizer=tf.AdamOptimizer(learning_rate=learning_rate)


for batch_index in range(num_batches):#num_batches哪里来,得自己定义的超参数
    X,y=data_loader.get_batch(seq_length,batch_size)#seq_length,batch_size哪里来?得自己定义超参数
    with tf.GradientTape as tape:
        y_logit_pred=model(X)
        loss=tf.losses.sparse_softmax_cross_entropy(labels=y,logits=y_logit_pred)
    grads=tape.gradient(loss,model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))


X_,_=data_loader.get_batch(seq_length,1)
for diversity in [0.2,0.5,1.0,1.2]:
    X=X_
    print("diversity %f:"%diversity)
    for t in range(400):
        y_pred=model.predict(X,diversity)
        #print默认是换行的,若要求不换行,则参数end='';flush表示是否立刻将输出语句输入到file所指向的文件对象(默认是sys.stdout)中
        print(data_loader.indices_char[y_pred[0]],end='',flush=True)#y_pred.shape=[1,]
        #np.concatenate数字拼接函数,axis=0,纵向拼接;axis=1,横向拼接;axis=-1,I don't know.
        #np.expand_dims维度扩展,不太理解
        X=np.concatenate([X[:,1],np.expand_dims(y_pred,axis=1)],axis=-1)

你可能感兴趣的:(Tensorflow入门(二)文本自动生成)