attention visualization

import matplotlib.pyplotas plt

import numpyas np

def samplemat(dims):

"""Make a matrix with all zeros and increasing elements on the diagonal"""

    aa = np.zeros(dims)

for iin range(min(dims)):

aa[i, i] = i

return aa

def draw(sentence="城镇小区配套幼儿园不得办成营利性幼儿园", attention_matrix=None):

text_labels =list(sentence)

# Display matrix

    figure = plt.figure()

ax = figure.add_axes([0.1, 0.1, 0.8, 0.8])

font = {"family":"SimHei", "weight":"bold", "size":"8"}# setup font properties for xtick or ytick

    ax.set_xticks([ifor iin range(len(text_labels) +2)])# setup xtick position and step

    ax.set_yticks([ifor iin range(len(text_labels) +2)])# setup ytick position and step

    ax.set_xticklabels(["CLS"] + text_labels + ["SEP"], **font)# setup text label in x axis

    ax.set_yticklabels(["CLS"] + text_labels + ["SEP"], **font)# setup text label in y axis

    ax.imshow(X=samplemat((len(text_labels)+2, len(text_labels)+2)))# draw matrix


你可能感兴趣的:(attention visualization)