数据来源路透社的文章一些节选,需要的可以github项目中有下载地址
from pathlib import Path
import sys
sys.path.append('..')
import argparse
import shutil
import os
import logging
from textblob import TextBlob
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import *
from finbert.finbert import *
import finbert.utils as tools
from pprint import pprint
from sklearn.metrics import classification_report
project_dir = Path.cwd().parent
pd.set_option('max_colwidth', -1)
# %%
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.ERROR)
# %% md
## Prepare the model
# %% md
# %%
# lm_path = project_dir / 'models' / 'TRC2' / 'pytorch_model.bin'
lm_path = project_dir / 'models' / 'classifier_model' / 'TRC2'
cl_path = project_dir / 'models' / 'classifier_model' / 'finbert-sentiment'
cl_data_path = project_dir / 'data' / 'sentiment_data'
try:
shutil.rmtree(cl_path)
except:
pass
bertmodel = BertForSequenceClassification.from_pretrained(lm_path, cache_dir=None, num_labels=3)
config = Config(data_dir=cl_data_path,
bert_model=bertmodel,
num_train_epochs=4,
model_dir=cl_path,
max_seq_length=48,
train_batch_size=32,
learning_rate=2e-5,
output_mode='classification',
warm_up_proportion=0.2,
local_rank=-1,
discriminate=True,
gradual_unfreeze=True)
# config = Config( data_dir=cl_data_path,
# bert_model=bertmodel,
# num_train_epochs=4.0,
# model_dir=cl_path,
# max_seq_length = 64,
# train_batch_size = 32,
# learning_rate = 2e-5,
# output_mode='classification',
# warm_up_proportion=0.2,
# local_rank=-1,
# discriminate=True,
# gradual_unfreeze=True )
finbert = FinBert(config)
# %%
finbert.prepare_model(label_list=['positive', 'negative', 'neutral'])
# %% md
## Fine-tune the model
# %%
# Get the training examples
# train_data = finbert.get_data('TRC2-financial')
# train_data = finbert.get_data('test')
train_data = finbert.get_data('train')
# train_data =
# %%
model = finbert.create_the_model()
# %%
# This is for fine-tuning a subset of the model.
freeze = 11
for param in model.bert.embeddings.parameters():
param.requires_grad = False
for i in range(freeze):
for param in model.bert.encoder.layer[i].parameters():
param.requires_grad = False
# %% md
### Training
# %%
trained_model = finbert.train(train_examples=train_data, model=model)
# %% md
test_data = finbert.get_data('test')
# %%
results = finbert.evaluate(examples=test_data, model=trained_model)
# %% md
### Prepare the classification report
# %%
def report(df, cols=['label', 'prediction', 'logits']):
# print('Validation loss:{0:.2f}'.format(metrics['best_validation_loss']))
cs = CrossEntropyLoss(weight=finbert.class_weights)
loss = cs(torch.tensor(list(df[cols[2]])), torch.tensor(list(df[cols[0]])))
print("Loss:{0:.2f}".format(loss))
print("Accuracy:{0:.2f}".format((df[cols[0]] == df[cols[1]]).sum() / df.shape[0]))
print("\nClassification Report:")
print(classification_report(df[cols[0]], df[cols[1]]))
# %%
results['prediction'] = results.predictions.apply(lambda x: np.argmax(x, axis=0))
# %%
report(results, cols=['labels', 'prediction', 'predictions'])
# %% md
### Get predictions
# %%
text = "Later that day Apple said it was revising down its earnings expectations in \
the fourth quarter of 2018, largely because of lower sales and signs of economic weakness in China. \
The news rapidly infected financial markets. Apple’s share price fell by around 7% in after-hours \
trading and the decline was extended to more than 10% when the market opened. The dollar fell \
by 3.7% against the yen in a matter of minutes after the announcement, before rapidly recovering \
some ground. Asian stockmarkets closed down on January 3rd and European ones opened lower. \
Yields on government bonds fell as investors fled to the traditional haven in a market storm."
# %%
cl_path = project_dir / 'models' / 'classifier_model' / 'finbert-sentiment'
model = BertForSequenceClassification.from_pretrained(cl_path, cache_dir=None, num_labels=3)
# %%
result = predict(text, model)
# %%
blob = TextBlob(text)
result['textblob_prediction'] = [sentence.sentiment.polarity for sentence in blob.sentences]
print(f'Average sentiment is %.2f.' % (result.sentiment_score.mean()))
text2 = "Shares in the spin-off of South African e-commerce group Naspers surged more than 25% \
in the first minutes of their market debut in Amsterdam on Wednesday. Bob van Dijk, CEO of \
Naspers and Prosus Group poses at Amsterdam's stock exchange, as Prosus begins trading on the \
Euronext stock exchange in Amsterdam, Netherlands, September 11, 2019. REUTERS/Piroschka van de Wouw \
Prosus comprises Naspers’ global empire of consumer internet assets, with the jewel in the crown a \
31% stake in Chinese tech titan Tencent. There is 'way more demand than is even available, so that’s \
good,' said the CEO of Euronext Amsterdam, Maurice van Tilburg. 'It’s going to be an interesting \
hour of trade after opening this morning.' Euronext had given an indicative price of 58.70 euros \
per share for Prosus, implying a market value of 95.3 billion euros ($105 billion). The shares \
jumped to 76 euros on opening and were trading at 75 euros at 0719 GMT."
result2 = predict(text2, model)
blob = TextBlob(text2)
result2['textblob_prediction'] = [sentence.sentiment.polarity for sentence in blob.sentences]
print(f'Average sentiment is %.2f.' % (result2.sentiment_score.mean()))
python predict.py --text_path test.txt --output_dir output/ --model_path models/classifier_model/finbert-sentiment
启动main.py
使用postman进行测试
使用训练前的TRC2模型:
在原TRC2的模型基础上进行训练后得到的成功率明显提升