首先使用
from pytorch_pretrained_bert import BertTokenizer, BertModel
model = BertModel.from_pretrained(--bert_model)的时候
1、在一个程序中,多次进行
all_hidden_states,encoded_main = model(input_ids=main_x, attention_mask=main_mask),多次加载GPU,原来的也不释放。
2.后改用from transformers import BertModel, get_linear_schedule_with_warmup,BertConfig
self.model_config = BertConfig.from_pretrained('bert-base-chinese',)
self.model_config.output_hidden_states = True
self.bert = BertModel.from_pretrained('bert-base-chinese', config=self.model_config)
然后使用bert后几层进行处理:
vector, pooler,enc_layers = self.bert(input_ids=main_x, attention_mask=main_mask)
#vector, pooler, all_hidden_states = model(input_ids_1_tensor)
#print("AAAA",len(last_hidden_states), last_hidden_states[0].size())
#encoded_main = torch.cat([t.unsqueeze(-1) for t in last_hidden_states[-4:]], 3).sum(-1)
max_seq_length = len(enc_layers[0][0])
batch_tokens = []
for batch_i in range(len(enc_layers[0])):
token_embeddings = []
for token_i in range(max_seq_length):
hidden_layers = []
for layer_i in range(len(enc_layers)):
vec = enc_layers[layer_i][batch_i][token_i]
hidden_layers.append(vec)
token_embeddings.append(hidden_layers)
batch_tokens.append(token_embeddings)
# first_layer = torch.mean(enc_layers[0], 1)
# second_to_last = torch.mean(enc_layers[11], 1)
# batch_token_last_four_sum = []
# for i, batch in enumerate(batch_tokens):
# for j, token in enumerate(batch_tokens[i]):
# token_last_four_sum = torch.sum(torch.stack(token)[-4:], 0)
# batch_token_last_four_sum.append(token_last_four_sum)
# last_four_sum = torch.stack(batch_token_last_four_sum)
#print("last_four_sum ",last_four_sum.shape)
batch_token_last_four_cat = []
for i, batch in enumerate(batch_tokens):
for j, token in enumerate(batch_tokens[i]):
token_last_four_cat = torch.cat((token[-1], token[-2], token[-3], token[-4]), 0)
batch_token_last_four_cat.append(token_last_four_cat)
last_four_cat = torch.stack(batch_token_last_four_cat)
#print("last_four_cat ",last_four_cat.shape)
# batch_token_sum_all = []
# for i, batch in enumerate(batch_tokens):
# for j, token in enumerate(batch_tokens[i]):
# token_sum_all = torch.sum(torch.stack(token)[0:], 0)
# batch_token_sum_all.append(token_sum_all)
# sum_all = torch.stack(batch_token_sum_all)
#print("sum_all ",sum_all.shape)
enc_layers是13层,第一层为word-embedding结果,每层结果都是[batchsize,seqlength,Hidden_size],其它层的大小是[batchsize,seqlength,embedding_size]