关键函数:
sentences = [[vocabulary_inv[w] for w in s] for s in sentence_matrix]
embedding_model = word2vec.Word2Vec(sentences, workers=num_workers, size=num_features, min_count = min_word_count, window = context, sample = downsampling)
embedding_weights = [np.array([embedding_model[w] if w in embedding_modelelse np.random.uniform(-0.25,0.25,embedding_model.vector_size)for w in vocabulary_inv])]
from gensim.models import word2vec
from os.path import join, exists, split
import os
import numpy as np
import data_helpers
def train_word2vec(sentence_matrix, vocabulary_inv,
num_features=300, min_word_count=1, context=10):
"""
Trains, saves, loads Word2Vec model
Returns initial weights for embedding layer.
inputs:
sentence_matrix # int matrix: num_sentences x max_sentence_len
vocabulary_inv # dict {str:int}
num_features # Word vector dimensionality
min_word_count # Minimum word count
context # Context window size
"""
model_dir = 'word2vec_models'
model_name = "{:d}features_{:d}minwords_{:d}context".format(num_features, min_word_count, context)
model_name = join(model_dir, model_name)
print(model_name)
if exists(model_name):
#important
embedding_model = word2vec.Word2Vec.load(model_name)
print(split(model_name))
print('Loading existing Word2Vec model \'%s\'' % split(model_name)[-1])
else:
# Set values for various parameters
num_workers = 2 # Number of threads to run in parallel
downsampling = 1e-3 # Downsample setting for frequent words
# Initialize and train the model
print("Training Word2Vec model...")
sentences = [[vocabulary_inv[w] for w in s] for s in sentence_matrix]
embedding_model = word2vec.Word2Vec(sentences, workers=num_workers, \
size=num_features, min_count = min_word_count, \
window = context, sample = downsampling)
# If we don't plan to train the model any further, calling
# init_sims will make the model much more memory-efficient.
embedding_model.init_sims(replace=True)
# Saving the model for later use. You can load it later using Word2Vec.load()
if not exists(model_dir):
os.mkdir(model_dir)
print('Saving Word2Vec model \'%s\'' % split(model_name)[-1])
embedding_model.save(model_name)
print("<<<<<<>>>>>>")
print(embedding_model)
print("<<<<>>>>>>")
print(embedding_model['the'])
# add unknown words
embedding_weights = [np.array([embedding_model[w] if w in embedding_model\
else np.random.uniform(-0.25,0.25,embedding_model.vector_size)\
for w in vocabulary_inv])]
print("<<<<<<>>>>>>")
print(embedding_weights )
#form:[[[300 dimension],[300 dimension]], dtype=float32]
print("<<<<<<>>>>>>")
print(len(embedding_weights[0]))
print('<<<<<<>>>>>>>>')
print(embedding_weights[0][0])
return embedding_weights
if __name__=='__main__':
print("Loading data...")
x, _, _, vocabulary_inv = data_helpers.load_data()
w = train_word2vec(x, vocabulary_inv,num_features=50)
结果:
Loading data...
word2vec_models\50features_1minwords_10context
('word2vec_models', '50features_1minwords_10context')
Loading existing Word2Vec model '50features_1minwords_10context'
<<<<<<>>>>>>
Word2Vec(vocab=18765, size=50, alpha=0.025)
<<<<>>>>>>
[-0.02123935 -0.05490068 -0.23091695 0.22885838 0.27167228 -0.05158678
-0.05539207 -0.22459298 -0.13045608 0.1070291 -0.02403063 -0.00130523
0.19467545 0.0296479 -0.08463751 0.06144359 -0.12156744 -0.10263078
-0.06417726 -0.01894584 -0.24981096 0.06570466 0.18615869 0.0254211
0.16925883 0.16957371 0.09030077 -0.21692845 0.06049829 -0.16586471
-0.14992879 0.10567563 0.06595352 -0.2763775 0.19153105 0.12031604
-0.04390313 0.17974497 0.0111691 0.02915976 -0.2177939 -0.10611911
0.03313744 -0.1350788 0.14111236 -0.2494767 0.1069486 0.12323371
0.12049541 0.06635255]
<<<<<<>>>>>>
[array([[ 0.22287713, 0.16943391, -0.18881463, ..., 0.14732356,
0.13720259, 0.01359756],
[-0.02123935, -0.05490068, -0.23091695, ..., 0.12323371,
0.12049541, 0.06635255],
[ 0.00441776, 0.18812552, -0.04573391, ..., 0.21478625,
0.13668334, 0.07716116],
...,
[-0.12307893, 0.00517228, -0.15578713, ..., 0.14851454,
0.03096815, 0.02698081],
[ 0.10233459, 0.05561316, -0.14217722, ..., 0.13685261,
0.06430191, 0.07483072],
[ 0.38770503, 0.25239426, 0.05189984, ..., 0.05148352,
0.05523289, -0.00138285]], dtype=float32)]
<<<<<<>>>>>>
18765
<<<<<<>>>>>>>>
[ 0.22287713 0.16943391 -0.18881463 0.17884882 0.18564233 -0.0442104
-0.0878861 -0.1689464 -0.18261056 0.06646598 -0.01281965 -0.13765122
-0.13511361 -0.09671792 0.00905229 0.15683898 -0.05347729 -0.19133224
-0.19077209 0.06136192 -0.23724565 0.02753574 0.12075555 0.04671283
0.24587892 0.04068457 0.13252144 -0.17025913 0.02412261 -0.09669705
-0.1384121 0.23630536 0.02841855 -0.09812463 0.09301346 0.14504847
0.00938961 0.20613976 0.11047759 -0.09049261 -0.303936 0.05824288
0.03945156 -0.07627468 0.07813629 -0.24140984 0.11729364 0.14732356
0.13720259 0.01359756]
[Finished in 2.4s]