将 Convolution1D 用于文本分类Keras的python源码DIY

将 Convolution1D 用于文本分类。
2个轮次后达到 0.89 的测试精度。
在 Intel i5 2.4Ghz CPU 上每轮次 90秒。
在 Tesla K40 GPU 上每轮次 10秒。
参考文档:https://keras-zh.readthedocs.io/examples/imdb_cnn/

代码如下:

from __future__ import print_function
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalMaxPooling1D
from keras.datasets import imdb

# 设置参数:
max_features = 5000
maxlen = 400
batch_size = 32
embedding_dims = 50
filters = 250
kernel_size = 3
hidden_dims = 250
epochs = 2

print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')

print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)

print('Build model...')
model = Sequential()

# 我们从有效的嵌入层开始,该层将 vocab 索引映射到 embedding_dims 维度
model.add(Embedding(max_features,
                    embedding_dims,
                    input_length=maxlen))
model.add(Dropout(0.2))

# 我们添加了一个 Convolution1D,它将学习大小为 filter_length 的过滤器词组过滤器:
model.add(Conv1D(filters,
                 kernel_size,
                 padding='valid',
                 activation='relu',
                 strides=1))
# 我们使用最大池化:
model.add(GlobalMaxPooling1D())

# We add a vanilla hidden layer:
model.add(Dense(hidden_dims))
model.add(Dropout(0.2))
model.add(Activation('relu'))

# 我们投影到单个单位输出层上,并用 sigmoid 压扁它:
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(x_test, y_test))

效果如图
将 Convolution1D 用于文本分类Keras的python源码DIY_第1张图片
Downloading data from

https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
Keras中文网:https://keras-zh.readthedocs.io/examples/imdb_cnn/

你可能感兴趣的:(python2021,keras,深度学习,神经网络,python)