通道注意力机制keras_【干货】基于Keras的注意力机制实战

【导读】近几年,注意力机制(Attention)大量地出现在自动翻译、信息检索等模型中。可以把Attention看成模型中的一个特征选择组件,特征选择一方面可以增强模型的效果,另一方面,我们可以通过计算出的特征的权重来计算结果与特征之间的某种关联。例如在自动翻译模型中,Attention可以计算出不同语种词之间的关系。本文一个简单的例子,来展示Attention是怎么在模型中起到特征选择作用的。

代码

导入相关库

#coding=utf-8

import numpy as np

from keras.models import *

from keras.layers import Input, Dense, merge

import matplotlib.pyplot as plt

import pandas as pd

数据生成函数

# 输入维度

input_dim = 32

# 生成数据,数据的的第attention_column个特征由label决定,

# 即label只与数据的第attention_column个特征相关

def get_data(n, input_dim, attention_column=1):

x = np.random.standard_normal(size=(n, input_dim))

y = np.random.randint(low=0, high=2, size=(n, 1))

x[:, attention_column] = y[:, 0]

return x, y

模型定义函数

将输入进行一次变换后,计算出Attention权重,将输入乘上Attention权重,获得新的特征。

# Attention模型

def build_model():

inputs = Input(shape=(input_dim,))

# 计算Attention权重

attention_probs = Dense(input_dim, activation="softmax",

name="attention_vec")(inputs)

# 根据Attention权重更新特征

attention_mul = merge([inputs, attention_probs],

output_shape=32,

name="attention_mul", mode="mul")

# 预测标签

attention_mul = Dense(64)(attention_mul)

output = Dense(1, activation="sigmoid")(attention_mul)

model = Model(input=[inputs], output=output)

attention_vec_model = Model(input=[inputs],

output=attention_probs)

return model, attention_vec_model

主函数

if __name__ == "__main__":

# 生成训练数据

N = 10000

inputs_1, outputs = get_data(N, input_dim)

# 获取模型,以及用于计算Attention权重的子模型

m, attention_vec_model = build_model()

m.compile(optimizer="adam", loss="binary_crossentropy",

metrics=["accuracy"])

print(m.summary())

# 训练

m.fit([inputs_1], outputs, epochs=20, batch_size=64,

validation_split=0.5)

# 生成测试数据

testing_inputs_1, testing_outputs = get_data(1, input_dim)

# 根据测试数据计算Attention权重

attention_vector = attention_vec_model.

predict([testing_inputs_1])[0].flatten()

print("attention =", attention_vector)

# 绘图

pd.DataFrame(attention_vector, columns=["attention (%)"])

.plot(kind="bar", title="Attention Mechanism as a function of

input dimensions.")

plt.show()

运行结果

代码中,attention_column为1,也就是说,label只与数据的第1个特征相关。从运行结果中可以看出,Attention权重成功地获取了这个信息。

参考链接

https://github.com/philipperemy/keras-attention-mechanism

更多教程资料请访问:人工智能知识资料全集

你可能感兴趣的:(通道注意力机制keras)