安装 torchcrf:pip install pytorch-crf -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip list 显示的时候是 TorchCRF 然而导入的时候是用 import torchcrf 或者 from torchcrf import CRF
import torch
# 安装 torchcrf pip install pytorch-crf -i https://pypi.tuna.tsinghua.edu.cn/simple/
# pip list 显示的时候是 TorchCRF 然而导入的时候是用 import torchcrf 或者 from torchcrf import CRF
from torchcrf import CRF
num_tags = 5 # 实体命名识别 每个汉字可以预测多少中类型
# model = CRF(num_tags,batch_first=True)
model = CRF(num_tags)
seq_length = 3 # 句子长度(一个句子有三个单词)
batch_size = 1 # batch大小 一共输入几个句子 在这里是一个 句子
hidden= torch.randn(batch_size,seq_length,num_tags) # 输入的是 batch:几个句子 ,seq_length:每个句子的长度
print(hidden.shape)# torch.Size([1, 3, 5])
# 表示:一个句子 句子长度是3 每个单词的维度是 5 ,为什么是5呢?因为是为每个单词打标签,一共有五个标签 所以
print(hidden)
mask = torch.tensor([[1,1,0]], dtype=torch.uint8) # mask的意思是 有的汉字的向量 不进行标签的预测
# mask的形状是:[batch,seq_length]
# 这句话由于torchcrf版本不同 进而 函数设置不同 batch_first=True 假设没有这句话 那么输入模型的第一个句子序列的 mask都是true,假设有这句话 就没事 ,mask是正常的
# mask的作用是:因为是中文的句子 那么每句话都要padding 一定的长度 所以 告诉模型那些是padding的
tags = torch.tensor([[0,2,3]], dtype=torch.long) #(batch_size, seq_length)
# tags 是真实的每个单词的标签 在crf模型中用不到啊
loss=model(hidden,tags,mask) # 计算对数似然(用于向前) loss
print(loss)
a=model.viterbi_decode(hidden,mask) # 或者用 model.decode(hidden,mask)
print(a)