【bert】: 在eval时pooler、last_hiddent_state、cls的区分

问题

在使用bert的时候,有几种输出:poolerlast_hiddent_statecls的区分:

解决

self.model = BertModel.from_pretrained(model_name_or_path)    
outputs = self.bert(**inputs, output_hidden_states=True)
# 说明:
# 一、如果 self.model(**inputs, output_hidden_states=True) 时, outputs 有三个内容
# 
# 其中 outputs[0] 表示:last_hidden_state  也可以使用:outputs.last_hidden_state调用
# 其中 outputs[1] 表示:pooler             也可以使用:outputs.pooler_output调用
# 其中 outputs[2] 表示:hidden_states      也可以使用:outputs.hidden_states调用
#       是一个存储所有state的元组:outputs.hidden_states 大小为13,其中:第0个表示embeddings层

# 二、如果 不设置 output_hidden_states, 那么只有outputs[0] 和 outputs[1]

# 1. last_hiddent_state 三种调用方式:  [batch, seqlen, 768]
print(outputs[0]) 
print(outputs.last_hidden_state) 
print(outputs.hidden_states[-1])

# 2. cls 两种调用方式   [batch, 768]
print(outputs.last_hidden_state[:, 0]) # 注意:这里要使用[:, 0]而不能使用[0],因为前一个表示整个batch的第一个cls
print(outputs[0][:,0])

# 3. pooler 的两种调用方式: [batch, 768] [这个是cls 接一个全连接层,然后再接一个tanh 激活层的输出] 
print(outputs[1]) 
print(outputs.pooler_output)

# 4. first hiddent state  [batch, seqlen, 768]
print(outputs.hidden_states[1])
print(outputs[2][1])
                        

你可能感兴趣的:(pytorch,bert,pytorch,bert)