[外链图片转存失败(img-pzrpVZON-1565335005499)(./1550805027539.png)]
property
__init__
组件的初始化
create
在训练之前初始化组件
train
训练组件,如果不需要训练
persist
保存组件模型到本地以备将来使用,如果没有需要保存的东西,可以不实现
load
定义如何加载persist(持久化)后的模型,或者说从本地加载保存的东西,若没有保存东西到本地,也不需要实现定义
process
使用组件进行处理,从message中取想要的数据,计算完成后更新到message中
required_packages
指定需要安装哪些python包才能使用此组件。
Note:
训练阶段一般是 init、train 、 persist 方法
预测阶段一般是 load 、 process
import os
from nltk.classify import NaiveBayesClassifier
from rasa_nlu import utils
from rasa_nlu.components import Component
SENTIMENT_MODEL_FILE_NAME = "sentiment_classifier.pkl"
class SentimentAnalyzer(Component):
"""A custom sentiment analysis component"""
name = "sentiment"
provides = ["entities"]
requires = ["tokens"]
defaults = {
"label_path": 'labels.txt' # default don't load custom dictionary
}
language_list = ["en"]
def __init__(self, component_config=None):
super(SentimentAnalyzer, self).__init__(component_config)
# path to lables file or None
self.label_path = self.component_config.get('label_path')
def train(self, training_data, cfg, **kwargs):
"""Load the sentiment polarity labels from the text
file, retrieve training tokens and after formatting
data train the classifier."""
with open(self.label_path, 'r', encoding='utf-8') as f:
labels = f.read().splitlines()
training_data = training_data.training_examples # list of Message objects
tokens = [list(map(lambda x: x.text, t.get('tokens'))) for t in training_data]
processed_tokens = [self.preprocessing(t) for t in tokens]
labeled_data = [(t, x) for t, x in zip(processed_tokens, labels)]
self.clf = NaiveBayesClassifier.train(labeled_data)
def convert_to_rasa(self, value, confidence):
"""Convert model output into the Rasa NLU compatible output format."""
entity = {"value": value,
"confidence": confidence,
"entity": "sentiment",
"extractor": "sentiment_extractor"}
return entity
def preprocessing(self, tokens):
"""Create bag-of-words representation of the training examples."""
return ({word: True for word in tokens})
def process(self, message, **kwargs):
"""Retrieve the tokens of the new message, pass it to the classifier
and append prediction results to the message class."""
if not self.clf:
# component is either not trained or didn't
# receive enough training data
entity = None
else:
tokens = [t.text for t in message.get("tokens")]
tb = self.preprocessing(tokens)
pred = self.clf.prob_classify(tb)
sentiment = pred.max()
confidence = pred.prob(sentiment)
entity = self.convert_to_rasa(sentiment, confidence)
message.set("entities", [entity], add_to_output=True)
def persist(self, model_dir):
"""Persist this model into the passed directory."""
classifier_file = os.path.join(model_dir, SENTIMENT_MODEL_FILE_NAME)
utils.pycloud_pickle(classifier_file, self)
return {"classifier_file": SENTIMENT_MODEL_FILE_NAME}
@classmethod
def load(cls,
model_dir=None,
model_metadata=None,
cached_component=None,
**kwargs):
meta = model_metadata.for_component(cls.name)
file_name = meta.get("classifier_file", SENTIMENT_MODEL_FILE_NAME)
classifier_file = os.path.join(model_dir, file_name)
if os.path.exists(classifier_file):
return utils.pycloud_unpickle(classifier_file)
else:
return cls(meta)
添加自定义组件:情感分析,运行命令
-c ../sample_configs/sentiment.yml --data ../data/examples/rasa/demo-sentiment.md --path models -t 12 --verbose
实验用到的语料:
[外链图片转存失败(img-rzpiE7Hy-1565335005500)(./demo-sentiment-labels.txt)]
[外链图片转存失败(img-ZBPwbJ2K-1565335005501)(./demo-sentiment.md)]
[外链图片转存失败(img-pxdmYHRw-1565335005501)(./sentiment.yml)]
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from rasa_nlu.components import Component
class SentimentAnalyzer(Component):
"""A pre-trained sentiment component"""
name = "sentiment_pre_trained"
provides = ["entities"]
requires = []
defaults = {}
language_list = ["en"]
def __init__(self, component_config=None):
super(SentimentAnalyzer, self).__init__(component_config)
def train(self, training_data, cfg, **kwargs):
"""Not needed, because the the model is pretrained"""
pass
def convert_to_rasa(self, value, confidence):
"""Convert model output into the Rasa NLU compatible output format."""
entity = {"value": value,
"confidence": confidence,
"entity": "sentiment",
"extractor": "sentiment_extractor"}
return entity
def process(self, message, **kwargs):
"""Retrieve the text message, pass it to the classifier
and append the prediction results to the message class."""
sid = SentimentIntensityAnalyzer()
res = sid.polarity_scores(message.text)
key, value = max(res.items(), key=lambda x: x[1])
entity = self.convert_to_rasa(key, value)
message.set("entities", [entity], add_to_output=True)
def persist(self, model_dir):
"""Pass because a pre-trained model is already persisted"""
pass
添加自定义组件:预训练情感分析,运行命令
-c ../sample_configs/sentiment_pre_trained.yml --data ../data/examples/rasa/demo-sentiment.md --path models -t 12 --verbose
涉及到的文件
[外链图片转存失败(img-rdI5Jf9y-1565335005502)(./sentiment_pre_trained.yml)]