ESIM (Enhanced LSTM for Natural Language Inference)

ESIM历史意义:

1、 深层次的挖掘文本间的语义关系特征

2、加入文本语法结构信息

 

本文主要结构如下:

一、Abstract

1、推理是人工智能的关键体现

2、SNLI为推理模型提供数据支撑

3、文中提出的模型不需要复杂的模型结构且效果较好

4、融合语法树模型,效果更佳

二、Introduction

     主要介绍数据集和推理的相关背景知识

ESIM (Enhanced LSTM for Natural Language Inference)_第1张图片

三、Related Work

           简单的介绍了相关工作以及数据集的相关应用情况

ESIM (Enhanced LSTM for Natural Language Inference)_第2张图片

四、Hybrid Neural Inference Models

模型结构输入层-输入编码层,即输入的两个匹配文本数值矩阵通过双向LSTM获得输出,与此同时还给出Tree-LSTM网络结构

ESIM (Enhanced LSTM for Natural Language Inference)_第3张图片

ESIM (Enhanced LSTM for Natural Language Inference)_第4张图片

局部推理建模层: 相当于注意力层,计算注意力矩阵并通过element-wise添加特征维度

推理组合层: 将两句话特征矩阵通过池化操作得到最终特征向量,实现特征聚合

输出预测层: 利用全连接 + softmax进行分类

ESIM (Enhanced LSTM for Natural Language Inference)_第5张图片

 

五、Experiment Setup

 实验参数的一些设置,包括学习率、batchsize大小、wordEmbedding维度等

ESIM (Enhanced LSTM for Natural Language Inference)_第6张图片

六、Results

     实验进行了多组变量分析,最后分析发现 ESIM + tree-LSTM组合效果最好

ESIM (Enhanced LSTM for Natural Language Inference)_第7张图片

ESIM (Enhanced LSTM for Natural Language Inference)_第8张图片

七、Conclusions and Future Work

          主要是对文章进行总结并对未来进行展望

ESIM (Enhanced LSTM for Natural Language Inference)_第9张图片

八、Code

# -*- coding: utf-8 -*-

# @Time : 2021/2/16 下午6:54
# @Author : TaoWang
# @Description : ESIM网络结构

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torchtext import data
from torchtext.vocab import Vectors
from torchtext.data import Field, Example, Iterator


class ESIM(nn.Module):
    def __init__(self, text):
        super(ESIM, self).__init__()
        self.embedding = nn.Embedding(*text.vocab.vectors.size())
        self.embedding.weight.data.copy_(text.vocab.vectors)
        
        # 输入编码层
        self.a_bilstm_input = nn.LSTM(text.vocab.vectors.size()[1], hidden_size, batch_first=True, bidirectional=True)
        self.b_bilstm_input = nn.LSTM(text.vocab.vectors.size()[1], hidden_size, batch_first=True, bidirectional=True)
        
        # 推理组合层
        self.a_bilstm_infer = nn.LSTM(8 * hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.b_bilstm_infer = nn.LSTM(8 * hidden_size, hidden_size, batch_first=True, bidirectional=True)
        
        # 预测层全连接网络
        self.linear = nn.Sequential(
            nn.Linear(8 * hidden_size, 2 * hidden_size),
            nn.ReLU(True),
            nn.Linear(2 * hidden_size, linear_size),
            nn.ReLU(True),
            nn.Linear(linear_size, classes)
        )
        
    def forward(self, a, b):
        """
        :param a: 
        :param b: 
        :return: 
        """
        # 词嵌入
        emb_a, emb_b = self.embedding(a), self.embedding(b)
        
        # 输入编码层
        a_ba, _ = self.a_bilstm_input(emb_a)
        b_ba, _ = self.b_bilstm_input(emb_b)
        
        # 局部推理建模层
        e = torch.matmul(a_ba, b_ba.permute(0, 2, 1))
        a_ti = torch.matmul(torch.softmax(e, dim=2), b_ba)
        b_ti = torch.matmul(torch.softmax(e, dim=1).permute(0, 2, 1), a_ba)
        
        # 矩阵拼接
        ma = torch.cat([a_ba, a_ti, a_ba - a_ti, a_ba * a_ti], dim=2)
        mb = torch.cat([b_ba, b_ti, b_ba - b_ti, b_ba * b_ti], dim=2)
        
        # 推理组合层
        va, _ = self.a_bilstm_infer(ma)
        vb, _ = self.b_bilstm_infer(mb)
        va_avg, va_max = torch.mean(va, dim=1), torch.max(va, dim=1)[0]
        vb_avg, vb_max = torch.mean(vb, dim=1), torch.max(vb, dim=1)[0]
        v = torch.cat([va_avg, va_max, vb_avg, vb_max], dim=1)
        
        # 输出预测层
        out = self.linear(v)

        return out
    
        

 

你可能感兴趣的:(文本匹配)