import glob
import tensorflow as tf
from keras.preprocessing.text import Tokenizer
from keras.utils import pad_sequences, to_categorical
from keras import Sequential
import pandas as pd
from keras.layers import Embedding, Bidirectional, LSTM, Dense
import re
files = glob.glob('./nlu_data/SMSSpamCollection.csv')
data_pd = pd.concat([pd.read_csv(f, header=None, names=['label', 'text'], sep='\t') for f in files], ignore_index=True)
print(data_pd.info())
text_tok = Tokenizer(lower=False, split=' ', oov_token='')
label_tok = Tokenizer(lower=False, split=' ', oov_token='')
text_tok.fit_on_texts(data_pd['text'])
label_tok.fit_on_texts(data_pd['label'])
text_config = text_tok.get_config()
label_config = label_tok.get_config()
print(text_config.get('document_count'))
print(label_config)
text_vocab = eval(text_config['index_word'])
label_vocab = eval(label_config['index_word'])
x_tok = text_tok.texts_to_sequences(data_pd['text'])
y_tok = label_tok.texts_to_sequences(data_pd['label'])
print('text', data_pd['text'][0], x_tok[0])
print('label', data_pd['label'][0], y_tok[0])
max_len = 172
x_pad = pad_sequences(x_tok, padding='post', maxlen=max_len)
y_pad = y_tok
num_classes = len(label_vocab) + 1
Y = to_categorical(y_pad, num_classes)
vocab_size = len(text_vocab) + 1
embedding_dim = 64
rnn_units = 100
BATCH_SIZE = 90
dropout = 0.2
唯一需要注意的地方是,需要接两层BiLstm 来对数据进行降一个维度。
如果不降维,会导致输出的矩阵形状与预设值不一致
model = Sequential([
Embedding(vocab_size, embedding_dim, mask_zero=True, batch_input_shape=[BATCH_SIZE, None]),
Bidirectional(LSTM(units=rnn_units, return_sequences=True, dropout=dropout, kernel_initializer=tf.keras.initializers.he_normal())),
Bidirectional(LSTM(round(num_classes / 2))),
Dense(num_classes, activation='softmax')
])
print(model.summary())
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
X = x_pad
# 5572 61 = 49 + 12
X_train = X[0: 4410]
Y_train = Y[0: 4410]
print(Y_train.shape)
X_test = X[4410: 5490]
Y_test = Y[4410: 5490]
model.fit(X_train, Y_train, batch_size=BATCH_SIZE, epochs=15)
model.evaluate(X_test, Y_test, batch_size=BATCH_SIZE)
y_pred = model.predict(X_test, batch_size=BATCH_SIZE)
# 3s 43ms/step - loss: 0.2169 - accuracy: 0.9333
# convert prediction one-hot encoding back to number
y_pred = tf.argmax(y_pred, -1)
y_pnp = y_pred.numpy()
# convert ground true one-hot encode back to number
y_ground_true = tf.argmax(Y_test, -1)
y_ground_true_pnp = y_ground_true.numpy()
for i in range(20):
x = 'sentence=> ' + text_tok.sequences_to_texts([X_test[i]])[0]
x = re.sub(r'*.', '', x)
ground_true = 'ground_true=> ' + label_tok.sequences_to_texts([[y_ground_true_pnp[i]]])[0]
prediction = 'prediction=> ' + label_tok.sequences_to_texts([[y_pnp[i]]])[0]
print(x)
print(ground_true)
print(prediction)
print('\n')
测试集的准确率为 97.87
12/12 [==============================] - 3s 53ms/step - loss: 0.1081 - accuracy: 0.9787
输出结果
sentence=> For your chance to WIN a FREE Bluetooth Headset then simply reply back with ADP
ground_true=> spam
prediction=> spam
sentence=> You also didnt get na hi hi hi hi hi
ground_true=> ham
prediction=> ham
sentence=> Ya but it cant display internal subs so i gotta extract them
ground_true=> ham
prediction=> ham
sentence=> If i said anything wrong sorry de
ground_true=> ham
prediction=> ham
sentence=> Sad story of a Man Last week was my b'day My Wife did'nt wish me My Parents forgot n so did my Kids I went to work Even my Colleagues did not wish
ground_true=> ham
prediction=> ham
sentence=> How stupid to say that i challenge god You dont think at all on what i write instead you respond immed
ground_true=> ham
prediction=> ham
sentence=> Yeah I should be able to I'll text you when I'm ready to meet up
ground_true=> ham
prediction=> ham
sentence=> V skint too but fancied few bevies waz gona go meet othrs in spoon but jst bin watchng planet earth sofa is v comfey If i dont make it hav gd night
ground_true=> ham
prediction=> ham
sentence=> says that he's quitting at least5times a day so i wudn't take much notice of that Nah she didn't mind Are you gonna see him again Do you want to come to taunton tonight U can tell me all about
ground_true=> ham
prediction=> ham
sentence=> When you get free call me
ground_true=> ham
prediction=> ham
sentence=> How have your little darlings been so far this week Need a coffee run tomo Can't believe it's that time of week already …
ground_true=> ham
prediction=> ham
sentence=> Ok i msg u b4 i leave my house
ground_true=> ham
prediction=> ham
sentence=> Still at west coast Haiz Ü'll take forever to come back
ground_true=> ham
prediction=> ham
sentence=> MMM Fuck Merry Christmas to me
ground_true=> ham
prediction=> ham
sentence=> alright Thanks for the advice Enjoy your night out I'ma try to get some sleep
ground_true=> ham
prediction=> ham
sentence=> Update your face book status frequently
ground_true=> ham
prediction=> ham
sentence=> Just now saw your message it k da
ground_true=> ham
prediction=> ham
sentence=> Was it something u ate
ground_true=> ham
prediction=> ham
sentence=> So what did the bank say about the money
ground_true=> ham
prediction=> ham
sentence=> Aiyar dun disturb u liao Thk u have lots 2 do aft ur cupboard come
ground_true=> ham
prediction=> ham
代码传送门