目录
代码一
代码二
第一部分代码对于skip-gram和CBOW模型是通用的,第二部分是实现skip-gram模型的代码。
import os
from six.moves.urllib.request import urlretrieve
import zipfile
import collections
# http://mattmahoney.net/dc/textdata.html
dataset_link = 'http://mattmahoney.net/dc/'
zip_file = 'text8.zip'
# 查看下载进度
def cbk(a,b,c):
'''回调函数
@a:已经下载的数据块
@b:数据块的大小
@c:远程文件的大小
'''
per = 100.0*a*b/c
if per > 100:
per = 100
print('%.2f%%' % per)
def data_download(zip_file):
'''下载数据集'''
if not os.path.exists(zip_file):
# urlretrieve()方法直接将远程数据下载到本地
zip_file, _ = urlretrieve(dataset_link + zip_file, zip_file, cbk)
print('File downloaded successfully!')
return None
def extracting(extracted_folder, zip_file):
'''解压缩'''
if not os.path.isdir(extracted_folder):
with zipfile.ZipFile(zip_file) as zf:
# 功能:解压zip文档中的所有文件到当前目录。
zf.extractall(extracted_folder)
def text_processing(ft8_text):
# 标点处理
ft8_text = ft8_text.lower()
ft8_text = ft8_text.replace('.', ' ')
ft8_text = ft8_text.replace(',', ' ')
ft8_text = ft8_text.replace('"', ' ')
ft8_text = ft8_text.replace(';', ' ')
ft8_text = ft8_text.replace('!', ' ')
ft8_text = ft8_text.replace('?', ' ')
ft8_text = ft8_text.replace('(', ' ')
ft8_text = ft8_text.replace(')', ' ')
ft8_text = ft8_text.replace('--', ' ')
ft8_text = ft8_text.replace(':', ' ')
ft8_text_tokens = ft8_text.split()
return ft8_text_tokens
def remove_lowerfreword(ft_tokens):
'''去除与单词相关的噪音:输入数据集中词频小于7的单词'''
word_cnt = collections.Counter(ft_tokens) # 统计列表元素出现次数,一个无序的容器类型,以字典的键值对形式存储,其中元素作为key,其计数作为value
shortlisted_words = [w for w in ft_tokens if word_cnt[w]>7]
print(shortlisted_words[:15]) # 列出数据集中词频最高的15个单词
print('Total number of shortlisted_words', len(shortlisted_words)) # 16616688
print('Unique number of shortlisted_words', len(set(shortlisted_words))) #53721
return shortlisted_words
def dict_creation(shortlisted_words):
'''创建词汇表:单词-词频'''
counts = collections.Counter(shortlisted_words)
vocabulary = sorted(counts, key=counts.get, reverse=True)
rev_dictionary = {ii:word for ii,word in enumerate(vocabulary)} # 整数:单词
dictionary = {word:ii for ii, word in rev_dictionary.items()} # 单词:整数
return dictionary, rev_dictionary
部分库解读: 1. six是用来兼容python2和3的库。 six.moves 是用来处理那些在2和3里面函数的位置有变化的,直接用six.moves就可以屏蔽掉这些变化 2. zipfile.ZipFile(zip_file) 打开压缩文件zip_file ZipFile.extractall([path[, members[, pwd]]]) 解压zip文档中的所有文件到当前目录。 参数: path 指定解析文件保存的文件夹 member 指定要解压的文件名称或对应的ZipInfo对象 pwd 解压密码
import collections
import time
import numpy as np
import random
import tensorflow as tf
from text_processing import *
from sklearn.manifold import TSNE
def subsampling(words_cnt):
# 采用子采样处理文本中的停止词
thresh = 0.00005
word_counts = collections.Counter(words_cnt)
total_count = len(words_cnt)
freqs = {word: count/total_count for word, count in word_counts.items()}
p_drop = {word: 1 - np.sqrt(thresh/freqs[word]) for word in word_counts}
train_words = [word for word in words_cnt if p_drop[word] < random.random()]
return train_words
def skipG_target_set_generation(batch_, batch_index, word_window):
# 以所需格式创建skip-gram模型的输入:即中心词周围的词
random_num = np.random.randint(1, word_window+1) # 在word_window范围内随机选取周围词的数量
words_start = batch_index - random_num if (batch_index-random_num) > 0 else 0
words_stop = batch_index + random_num
window_target = set(batch_[words_start:batch_index] + batch_[batch_index+1:words_stop+1])
return list(window_target)
def skipG_batch_creation(short_words,batch_length,word_window):
# 创建中心词及其周围单词的组合形式
batch_cnt = len(short_words)//batch_length
print('batch_cnt=',batch_cnt)
short_words = short_words[:batch_cnt*batch_length]
for word_index in range(0, len(short_words), batch_length):
input_words,label_words = [],[]
word_batch = short_words[word_index:word_index+batch_length]
for index_ in range(len(word_batch)): # 遍历每个batch中的每个中词
batch_input = word_batch[index_]
batch_label = skipG_target_set_generation(word_batch, index_, word_window) # 获取周围单词
label_words.extend(batch_label)
input_words.extend([batch_input]*len(batch_label)) # skip_gram的输入形式,周围单词都得对应上中心词
yield input_words, label_words
# extracted_folder = 'dataset'
# full_text = extracting(extracted_folder, zip_file)
with open('dataset/text8') as ft_:
full_text = ft_.read()
ft_tokens = text_processing(full_text) # 单词列表
shortlisted_words = remove_lowerfreword(ft_tokens)
dictionary, rev_dictionary = dict_creation(shortlisted_words)
words_cnt = [dictionary[word] for word in shortlisted_words] # 通过词典获取每个单词对应的整数
train_words = subsampling(words_cnt)
print('train_words=',len(train_words))
# 1.
tf_graph = tf.Graph()
with tf_graph.as_default():
input_ = tf.placeholder(tf.int32, [None], name='input_')
label_ = tf.placeholder(tf.int32, [None, None], name='label_')
# 2. 得到embedding
with tf_graph.as_default():
word_embed = tf.Variable(tf.random_uniform((len(rev_dictionary), 300),-1,1))
embedding = tf.nn.embedding_lookup(word_embed, input_) # 将单词转换为向量
# 3.定义优化算法
vocabulary_size = len(rev_dictionary)
with tf_graph.as_default():
sf_weights = tf.Variable(tf.truncated_normal((vocabulary_size,300),stddev=0.1))
sf_bias = tf.Variable(tf.zeros(vocabulary_size))
# 通过负采样计算loss
loss_fn = tf.nn.sampled_softmax_loss(weights=sf_weights,
biases=sf_bias,
labels=label_,
inputs=embedding,
num_sampled=100,
num_classes=vocabulary_size)
cost_fn = tf.reduce_mean(loss_fn)
optim = tf.train.AdamOptimizer().minimize(cost_fn)
# 4. 验证集:在语料库中选择常见和不常见词的组合,并基于词向量之间的余弦相似性返回最接近它们之间的单词
with tf_graph.as_default():
validation_cnt = 16
validation_dict = 100
validation_words = np.array(random.sample(range(validation_dict), validation_cnt//2)) # 从list(range(validation_dict))中随机获取8个元素,作为一个片断返回
validation_words = np.append(validation_words, random.sample(range(1000, 1000+validation_dict), validation_cnt//2))
validation_data = tf.constant(validation_words, dtype=tf.int32)
normalization_embed = word_embed / (tf.sqrt(tf.reduce_sum(tf.square(word_embed),1,keep_dims=True)))
validation_embed = tf.nn.embedding_lookup(normalization_embed, validation_data)
word_similarity = tf.matmul(validation_embed,tf.transpose(normalization_embed))
epochs = 2
batch_length = 1000
word_window = 10
# 定义模型存储检查点model_checkpoint
with tf_graph.as_default():
saver = tf.train.Saver()
with tf.Session(graph=tf_graph) as sess:
iteration = 1
loss = 0
sess.run(tf.global_variables_initializer())
print("Begin training-----------")
for e in range(1, epochs+1):
batches = skipG_batch_creation(train_words, batch_length, word_window)
start = time.time()
for x, y in batches:
train_loss, _ = sess.run([cost_fn, optim],
feed_dict={input_:x, label_:np.array(y)[:,None]})
loss += train_loss
if iteration % 100 ==0:
end = time.time()
print('Epoch {}/{}'.format(e,epochs),
', Iteration:{}'.format(iteration),
', Avg.Training loss:{:.4f}'.format(loss/100),
', Processing:{:.4f} sec/batch'.format((end-start)/100))
loss = 0
start = time.time()
if iteration % 2000 ==0:
similarity_ = word_similarity.eval() # 返回结果值
for i in range(validation_cnt):
validated_words = rev_dictionary[validation_words[i]]
top_k = 8
nearest = (-similarity_[i,:]).argsort()[1:top_k+1] # argsort将similarity_中的元素从小到大排列,提取其对应的index(索引)
log = 'Nearest to %s:' % validated_words
for k in range(top_k):
close_word = rev_dictionary[nearest[k]]
log = '%s %s,' % (log, close_word)
print(log)
iteration += 1 # 每遍历一个batch,iteration值加1
save_path = saver.save(sess, "model_checkpoint/skipGram_text8.ckpt")
embed_mat = sess.run(normalization_embed)
with tf_graph.as_default():
saver = tf.train.Saver()
with tf.Session(graph=tf_graph) as sess:
saver.restore(sess, tf.train.latest_checkpoint('model_checkpoint'))
embed_mat = sess.run(word_embed)
# 使用t分布随机邻嵌入(t-SNE)来实现可视化
word_graph = 250
tsne = TSNE()
word_embedding_tsne = tsne.fit_transform(embed_mat[:word_graph,:])
可视化结果: