【NLP】TextCNN

模型

model.jpg

四种模式

  1. CNN-rand: 单词向量是随机初始化,向量随着模型学习而改变
  2. CNN-static: 使用预训练的静态词向量,向量不会随着模型学习而改变
  3. CNN-non-static: 使用预训练的静态词向量,预训练的向量可以微调(fine-tuned)
  4. CNN-multichannel: 静态+微调 两个channel都使用预训练的静态词向量,卷积核用在两个channel上,反向传播只改变一个channel

代码

if args.static:     #使用预训练的静态词向量
    args.embedding_dim = text_field.vocab.vectors.size()[-1]
    args.vectors = text_field.vocab.vectors
if args.multichannel:
    args.static = True
    args.non_static = True
# args.class_num = len(label_field.vocab)
args.class_num = len(label_field.vocab) - 1
import torch
import torch.nn as nn
import torch.nn.functional as F


class TextCNN(nn.Module):
    def __init__(self, args):
        super(TextCNN, self).__init__()
        self.args = args

        class_num = args.class_num
        channel_num = 1
        filter_num = args.filter_num
        filter_sizes = args.filter_sizes

        vocabulary_size = args.vocabulary_size
        embedding_dimension = args.embedding_dim
        self.embedding = nn.Embedding(vocabulary_size, embedding_dimension)
        if args.static:
            self.embedding = self.embedding.from_pretrained(args.vectors, freeze=not args.non_static)
        if args.multichannel:
            # multichannel:non_static=True and static=True
            # channel1 fine-tuned
            # channel2 static
            self.embedding2 = nn.Embedding(vocabulary_size, embedding_dimension).from_pretrained(args.vectors)
            channel_num += 1
        else:
            self.embedding2 = None
        self.convs = nn.ModuleList(
            # ModuleList是一个特殊的module,可以包含几个子module,
            # 可以像用list一样使用它,但不能直接把输入传给 ModuleList。
            # (N, C_in, H, W) => (N, C_out, H, W)
            [nn.Conv2d(channel_num, filter_num, (size, embedding_dimension)) for size in filter_sizes])
        self.dropout = nn.Dropout(args.dropout)
        self.fc = nn.Linear(len(filter_sizes) * filter_num, class_num)

    def forward(self, x):
        if self.embedding2:
            x = torch.stack([self.embedding(x), self.embedding2(x)], dim=1)
        else:
            x = self.embedding(x)
            # torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度
            # 升维  (N, size, embedding_dimension) => 
            # (N, channel_num, size, embedding_dimension)
            x = x.unsqueeze(1)      
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]     #卷积后降维
        x = [F.max_pool1d(item, item.size(2)).squeeze(2) for item in x]     #最大值池化后降维
        #torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度
        x = torch.cat(x, 1)  # 拼接 3个卷集核,一个卷集核100(filter_num)个值
        x = self.dropout(x)
        logits = self.fc(x)

        return logits

问题

  1. target = target.data.sub(1)
  2. len(label_field.vocab) == 3 ?

你可能感兴趣的:(【NLP】TextCNN)