NLP 实战 (4) | 我发现的飞桨(paddlepaddle)大坑

文章目录

    • 分离阶段:以交付为目标
    • 尽早集成:暴露内存和性能问题
    • 重构代码:做好模块化
    • 耗时分析:找到性能瓶颈
    • 深入分析:二分排查
    • 解决问题:通常就几行代码
    • 小结

在 上一篇 我们介绍了数据集和模型的上传/下载管理。解决数据集和模型的管理问题,在我们的新成员加入时就体现了优势,新成员克隆仓库代码、根据文档执行命令下载相关数据集、下载相关模型、启动服务、执行测试,以最快的时间跑通全流程,进而获取新任务,达成 first commit 的目标。本节我们分析一个实战问题诊断的过程。


团队博客: CSDN AI小组


分离阶段:以交付为目标

我们反复强调,从数据集入手、训练模型、服务化,最终我们是要达成集成或交付的目标。所以事实上核心要解决的有两个方面的事情:

  • 清洗数据输入、选择算法和预训练模型,训练出有足够精度和召回率的模型,这是基础,这个步骤算法工程师们已经需要付出极大的努力。项目开发中要为算法工程师提供一定的时间集中精力解决这个部分。
  • 提供服务,这里面有大量的工程问题,例如前一个步骤不需要考虑代码如何组织的问题,对内存占用、性能也可以暂时不做要求。但是到了提供服务阶段,则必须考虑这些问题。

尽早集成:暴露内存和性能问题

实际上工程师们在第一步已经付出极大的努力,这第二步当然也可以通过学习获得快速成长。本节我们以一个实际的例子说明工程上的模块化和问题诊断的基本手法,我们展示实际的代码和问题只是为了示例。

当我们做了集成,提供了API给上层的时候,事实上是信心不太足的:

  • “该接口会耗时,它是异步的”
  • “那么它的平均耗时多少”
  • “嗯,确实需要测一下”

当我这么回复的时候,实际上说明我们的模块缺乏对响应的基本 profile,这是一个信号。但是为了尽早集成这个目标,是可以先把 API 推出来。上层这边做集成里还有一个压测环节,测试时立刻暴露了问题。

我们的工程师经过努力做了一个打标签的接口,这个接口内部有一组策略,其中有一个环节用到了paddlehub。在批量跑测试中立刻暴露了第1个问题:

  • 在4G内存的机器上接口跑着跑着就会报内存不足崩溃

我们的工程师咨询 paddlepaddle 的技术人员,认为应该跑在8G内存上才可以。我其实是存疑的,但是官方这么说,我们就试一下,在8G内存的服务器上提供服务确实可以把批量测试跑完而不出问题。

接着的一个问题是接口的平均耗时问题。实际上使用预训练模型做微调训练后的模型还是比较大,这种模式我心里对内存占用总有点没底。另外一个没底的是 NLP 预测是否一定会是耗时的?不过在多次集成模型之后,包括第2节我们提到的模型加载单例模式,我们还是有一些经验:

  • 进程内应该用单例模式管理模型的加载,避免同进程内反复加载同一个模型,这对内存占用和耗时都是不必要的浪费。
  • 如果有词向量计算,应该尽可能把能预先计算的做好预计算,提供服务的接口内应该做最小计算。
  • 如果是需要动态加载的,应该最小化每次加载的数据量。解决单次内存占用最小化和加载耗时之间的平衡。

回头来说,批量测试暴露的性能问题:

  • 该接口平均耗时5秒

这个数据并不好,我们立刻调整了优先级,必须解决性能问题。

重构代码:做好模块化

这又回到上一节 里提到的“Hackable Project” 的主题。我们希望问题出现的时候,代码是可以明显看出问题可能在哪。这里我花了一些时间重构已有的代码,原来的代码如下:

class TagService:
    def __init__(self, config, options):
        self.config = config
        self.options = options

    def load(self):
        self.inner_classifier = SGDText2PL()
        self.tag_label = TagLabel()
        self.tag_label.load_label()
        self.inner_classifier.load()
        self.ocr_client = OCRClient(self.config, self.options)
        self.code_extract = CodeExtractService(self.config, self.options)
        self.tag_score = TagScoreService(
            self.tag_label.kg_list, self.tag_label.catalog_xy_list, self.tag_label.kg_position_list)
        self.hub = PaddleHubPL()
        self.hub.load()
        self.ocr_client.load()

    def catalog_predict(self, title, content):
        # import paddlehub as hub
        # sentence = []
        # sentence.append(title + content)
        # sentences = []
        # catalog_id = ''
        # sentences.append(sentence)
        # # model = hub.Module(
        # #     name='ernie_tiny',
        # #     version='2.0.1',
        # #     task='seq-cls',
        # #     load_checkpoint= get_tag_model_path()+'model.pdparams',
        # #     label_map=LABEL_MAP)
        # results = self.hub.model.predict(sentences, max_seq_len=128, batch_size=1, use_gpu=False)
        # for idx, text in enumerate(sentences):
        #     catalog_id = results[idx]
        catalog_id = self.hub.predict(title, content)

        return catalog_id

    def predict(self, title, content):

        # print(question)
        # title = question.get('title')  # 获取标题
        # content = question.get('body')  # 获取内容

        code_title = get_en_character(title)  # 找标题代码
        code_content = self.code_extract.extract_code_for_title(content)[
            'code']  # 找内容代码
        checked = 0
        pre_result = []
        catalog_id = ''
        code_id = ''
        status = 0
        img_list = []
        ocr_text = ''

        try:
            img_list = get_img_url2(content)

            if len(img_list) > 0:

                # print(img_list[0])
                checked = 4
                ocr = self.ocr_client.extract(img_list[0])
                ocr_text = '\n'.join(ocr['code_text'])
                ocr_code_content = get_code_character(ocr_text)
                pre_result = self.inner_classifier.classify(ocr_code_content)
                code_id = pre_result.get('language')
                if code_id == 'text':
                    code_id = ''
            else:
                if code_title != '' and code_content != '':  # 当标题和内容都有代码的情况
                    checked = 0
                if code_title == '' and code_content != '':  # 当标题没有代码,但内容有代码的情况
                    checked = 1
                if code_title != '' and code_content == '':  # 当标题和代码都没有代码的情况
                    checked = 2
                if code_title == '' and code_content == '':  # 当标题和代码都没有代码的情况
                    checked = 3

            # if checked == 4:
            #     ocr = self.ocr_client.extract(img_list[0])
            #     ocr_text = '\n'.join(ocr['code_text'])
            #     ocr_code_content = get_code_character(ocr_text)
            #     pre_result = self.inner_classifier.classify(ocr_code_content)
            #     code_id = pre_result.get('language')
            #     print('#$$$$$$#'+code_id)

            if checked == 0 and code_id == '':
                pre_result = self.inner_classifier.classify(code_content)
                code_id = pre_result.get('language')

            if checked == 1 and code_id == '':
                pre_result = self.inner_classifier.classify(code_content)
                code_id = pre_result.get('language')

            if checked == 2 and code_id == '':
                datatitle = []
                datatitle.append(code_title)
                pre_result = self.inner_classifier.classify(datatitle)
                code_id = pre_result.get('language')

            if checked == 3 and code_id == '':
                cn_title = get_cn_character(title)
                cn_content = get_cn_character(content)
                code_id = self.catalog_predict(cn_title, cn_content)

                # print('**************'+code_id)

        except Exception as e:
            temp = str(e)
            code_id = '其他'
            status = -1

        if code_id == 'jar':
            catalog_id == 'jar'

        if checked != 3 and status != -1:
            cn_title = get_cn_character(title)
            cn_content = get_cn_character(content)
            character = {
     }
            if len(cn_title.strip()) == 0 and len(cn_content.strip()) == 0:
                character = {
     }
                status = 1
            else:
                catalog_id = self.catalog_predict(cn_title, cn_content)
                # print(catalog_id + '#################')
                status = 2
        if status == 0:
            catalog_id = code_id

        if catalog_id == '':
            catalog_id = '其他'

        if catalog_id == 'text':
            catalog_id = '其他'
        if code_id == 'text':
            catalog_id = '其他'

        return {
     
            'title': title,
            'content': content,
            'code_id': code_id,
            'catalog_id': catalog_id,
            'status': status
        }

这段代码存在两个典型的问题:

  • 方法 catalog_predict 存在大段注释代码。不要的代码不应该提交到 git,不要用注释的方式保留大段的“备用”代码,应该毫不留情地删除它,如果想看该文件历史上的代码片段,直接看 git 的历史即可。很多工程师不能理解这点,实际上有了 git ,你可以查看该代码文件历史上的任何提交过的代码,没有必要用注释的方式保留“备用代码”,例如上面这个代码片段,就是从 git 的历史 commit 里拷贝的。
  • 方法 predict 里存在一种典型的用flag变量,做代码分支逻辑判断依据的实现方式,而且存在两个交叉的 flag 变量:checkedstatus

第一个问题好解决,删除代码提交即可。第二个问题则让代码不好诊断问题。例如:

  • 几个连续的 if code_title != '' and code_content != '' 需要很费劲才能知道checked flag 的含义,以及它确实在每种情况下只会出现一个唯一的值,这样的代码一不小心就会挂。
  • 几个连续的 if checked == 0 and code_id == '' 需要很费劲才能知道这个分支的含义,以及它确实和其他 if 分支只会被执行一次。至少应该用if elif elif也比全部及格不做闭环的if好理解。
  • 需要很费劲才能理解 if checked == 3 and code_id == ''catalog_idcode_id 不存在时用来赋予了 code_id 的值

总之,需要很费劲才能分析这段代码的分支处理逻辑,以及多个处理情况之间是否有交叉,谁的优先级更高。

经过协调,我决定自己上手改这段代码。我觉的只在类内部用多个函数也能写好,不过我决定拆分下,让每个小类只做一件事。

首先分析 predict 要解决的问题,核心思路应该是:

  • 识别代码标签:
    • 如果内容里有图片,走OCR识别代码类型
    • 否则,如果内容里有代码,识别内容里的代码类型
    • 否则,如果标题里有代码,识别标题里的代码类型
  • 识别大类标签:
    • 提取代码和内容里的中文,用 paddlepaddle 模型来对标签分类

因此,建立一个子文件夹,把上面四个叶子结点的识别分别独立一个类,每个类只做一件事:

ocr_predict.py

class OCRPredict:
    def __init__(self, config, options, code_classifier):
        self.config = config
        self.options = options
        self.code_classifier = code_classifier
        self.ocr_client = None

    def load(self):
        if self.ocr_client is not None:
            return
        self.ocr_client = OCRClient(self.config, self.options)
        self.ocr_client.load()

    def predict(self, content):
        # 查找并处理图片,TODO:查找图片遍历内容和代码提取遍历重复了!
        img_list = get_img_url2(content)
        if len(img_list) == 0:
            return {
     
                'err': ErrorCode.NOT_FOUND
            }

        # OCR 识别
        ocr = self.ocr_client.extract(img_list[0])
        ocr_text = '\n'.join(ocr['code_text'])

        # TODO: get_code_character 这个步骤未必要,直接丢给 code_classifier 也是可以的
        ocr_code_content = get_code_character(ocr_text)
        pre_result = self.code_classifier.classify(ocr_code_content)
        code_name = pre_result.get('language')
        if code_name == 'text':
            return {
     
                'err': ErrorCode.NOT_FOUND
            }
        else:
            return {
     
                'err': ErrorCode.SUCCESS,
                'code_name': code_name,
            }

code_predict.py

class CodePredict:
    def __init__(self, config, options, code_classifier):
        self.config = config
        self.options = options
        self.code_classifier = code_classifier

    def load(self):
        pass

    def predict(self, code_content):
        pre_result = self.code_classifier.classify(code_content)
        code_name = pre_result.get('language')
        return {
     
            'err': ErrorCode.SUCCESS,
            'code_name': code_name
        }

category_predict.py

class CategoryPredict:
    def __init__(self, config, options):
        self.config = config
        self.options = options
        self.hub = None

    def load(self):
        if self.hub is not None:
            return
        self.hub = PaddleHubPL()
        self.hub.load()

    def predict(self, title, content):
        category_name = self.hub.predict(title, content)
        return {
     
            'err': ErrorCode.SUCCESS,
            'category_name': category_name,
        }

这里,OCRPredictCodePredict 都使用依赖注入的方式让外层传入code_classifier。每个类做的事情很简单:load and predict

有了上述三个叶子结点,我们提供一个 策略类 来组织管道的复合逻辑:

class ComposePredict:
    def __init__(self, config, options):
        self.config = config
        self.options = options

        self.code_classifier = None
        self.code_predict = None
        self.ocr_predict = None
        self.category_predict = None
        self.has_load = False

    def load(self):
        if self.has_load:
            return
        self.code_classifier = SGDText2PL()
        self.code_classifier.load()

        self.ocr_predict = OCRPredict(
            self.config, self.options, self.code_classifier)
        self.ocr_predict.load()

        self.code_predict = CodePredict(
            self.config, self.options, self.code_classifier)
        self.code_predict.load()

        self.category_predict = CategoryPredict(self.config, self.options)
        self.category_predict.load()
        self.has_load = True

    def predict(self, title, content, code_title, code_content, cn_title, cn_content):
        # 识别 code_name
        code_name = None
        code_ret = self.predict_code_name(
            title, content, code_title, code_content)

        if code_ret['err'] == ErrorCode.SUCCESS:
            code_name = code_ret['code_name']
            if code_name == 'text' or code_name == 'scheme' or code_name == '':
                code_name = '其他'
            if code_name == 'c':
                code_name = 'c语言'
            if code_name == 'go':
                code_name = 'golang'

        # 识别 category_name
        category_name = None
        category_ret = self.category_predict.predict(cn_title, cn_content)

        if category_ret['err'] != ErrorCode.SUCCESS:
            return category_ret
        else:
            category_name = category_ret['category_name']
            if category_name == 'text' or category_name == '':
                category_name = '其他'

        return {
     
            'err': ErrorCode.SUCCESS,
            'code_name': code_name,  # 可空
            'category_name': category_name,
        }

    def predict_code_name(self, title, content, code_title, code_content):
        # 内容有代码,尝试识别内容里的代码(内容比标题优先级高)
        if code_content != '':
            ret = self.code_predict.predict(code_content)
            if ret['err'] == ErrorCode.SUCCESS:
                return ret

        # 标题有代码,尝试识别标题里的代码
        if code_title != '':
            ret = self.code_predict.predict([code_title])
            if ret['err'] == ErrorCode.SUCCESS:
                return ret

        # 标题和内容都没有代码,尝试识别图片里的代码(成本最高,放在最后)
        ret = self.ocr_predict.predict(content)
        if ret['err'] == ErrorCode.SUCCESS:
            return ret

        # 识别失败
        return {
     
            'err': ErrorCode.NOT_FOUND
        }

可以看到,这个类 聚合 了前面的三个功能简单的类, ComposePredict 的使用方式同样是 load and predict。但是我们重点看下区别:

  • predict_code_name 里面使用 快速短路 的方式,从上往下组织管道处理:
    • 如果内容有代码,尝试识别内容里的代码(内容比标题优先级高),成功就直接返回
    • 如果标题有代码,尝试识别标题里的代码,成功就直接返回
    • 如果标题和内容都没有代码,尝试识别图片里的代码(成本最高,放在最后),成功就直接返回
    • 否则,返回失败

当你有一个 管道处理 流程时,用这种方式可以良好的组织管道过程和优先级编排,代码也不会很乱。事实上它是经典设计模式 职责链 模式。不过我日常并不记得它的名字叫什么,如果一个代码组织适合这样写,我们就这样写了。这里给它们起名字只是我在写博客的时候方便说明才起的而已。

再上面的一层 predict 内部,则是拆解了原始代码的意图:

  • 无论怎样 category_name 都是要识别的
  • code_name 可能不存在
  • 当然,这里顺手的变动是,code_namecategory_name 比原来的 code_idcategory_id 更符合含义,它们是名字,不是id。

好了,到了这里,核心的代码重构就完成了,其他还有一些细节的地方只是同理。

耗时分析:找到性能瓶颈

我们的目标是诊断性能瓶颈,最原始的方法就是对代码的每个环节做耗时统计,看哪部分耗时最多。先提供两个AK-47小函数:

def time_start(name):
    '''开始计时,返回计时器上下文'''
    return {
     
        'name': name,
        'start': round(time.time() * 1000)
    }

def time_end(ctx):
    '''结束计时,返回耗时统计'''
    end = round(time.time() * 1000)
    ctx['end'] = end
    ctx['elapse_mill_secs'] = end - ctx['start']
    ctx['elapse_secs'] = ctx['elapse_mill_secs']/1000
    print("{}耗时:{}毫秒".format(ctx['name'], ctx['elapse_mill_secs']))
    return ctx

于是,我们只需在代码的不同环节加上耗时统计:

timer = time_start()
...
time_end(timer)

通过这种方式,我们很快找到最耗时的地方是 CategoryPredict 类的 predict 方法。而这个类的实现其实是委托给 PaddleHubPL 类,我们看下这个类:

class PaddleHubPL:
    def __init__(self) -> None:
        # 使用 g_model_manager 做单例
        self.model_key = 'paddlehub_tag_svm'
        g_model_manager.register(self.model_key, lambda: PaddleHubPLImpl())
        self.model = None

    def load(self):
        try:
            self.model = g_model_manager.load(self.model_key)
            return {
     
                'err': ErrorCode.SUCCESS
            }

        except Exception as e:
            logger.error('load SGDText2PL model failed:', str(e))
            logger.error(traceback.format_exc())
            return {
     
                'err': ErrorCode.NOT_FOUND
            }

    def predict(self, title, content):
        ret = self.model.predict(title, content)
        return ret

由于 CategoryPredict 内部没有别的逻辑,它可以直接被 PaddleHubPL 替代,不过这个我们可以先不管。 PaddleHubPL 内部使用 g_model_manager 来单例化 PaddleHubPLImpl那既然已经单例化了,至少 PaddleHubPLImplload 应该最多只会被执行一次,为什么 predict 会耗时接近 5 秒呢?

我们先看下 PaddleHubPLImpl 的实现:

class PaddleHubPLImpl:
    def __init__(self):
        self.model = None
        self.tokenizer = None

    def load(self):
        self.model = self.load_or_fit()

    def load_or_fit(self):
        model = hub.Module(
            name='ernie_tiny',
            version='2.0.1',
            task='seq-cls',
            load_checkpoint=get_tag_model_path()+'model.pdparams',
            label_map=LABEL_MAP)
        return model

    def predict(self, title, content):
        sentence = []
        sentence.append(title + content)
        sentences = []
        catalog_id = ''
        sentences.append(sentence)
        results = self.model.predict(
            sentences, max_seq_len=128, batch_size=1, use_gpu=False)
        for idx, text in enumerate(sentences):
            catalog_id = results[idx]

        return catalog_id

这里面做的很简单,也就是在 load 里加载 paddlehub 的模型,在 predict 里预测标签而已。问题会出在哪里呢?

深入分析:二分排查

到这里,我们的工程师再次去咨询 paddlepaddle 的技术支持,他们回复要么用servering 的模式使用 paddle,要么可以用另外一个paddlelite

可是直觉告诉我,这是不对的,支撑的理由是:

  • 即使模型很大,第一次加载可以比较耗时,后续的预测应该是比较快才对
  • 除非是每次预测都要动态加载不同的模型或者预计算词向量

但是我看不出这个标签预测有什么地方应该动态加载不同的模型或者预计算词向量?进一步的思考是:

  • 如果 servering 能解决,那没理由不用 servering 耗时就要高。
  • 因为 servering 只是解决服务化的问题,没有理由只是拆分了进程,耗时就能被减少,这并不站得住脚。
  • paddlelite 则听上去像是 “ windows 系统不行,你用 mac/linux 试试” 的味道。

我决定继续诊断,那就要进入 paddlepaddle 内部的代码去做。想法很简单:

  • 如果 paddlehub 的代码是 Native 的,例如 C++的,那么跟进去的成本就比较高,远水救不了近火。
  • 如果 paddlehub 的代码是纯 Python 的,那我就有信心找到问题。

直觉告诉我,一定是内部有重复加载的地方。因为每次执行都会出现三条同样的加载日志:

[2021-06-22 18:04:57,093] [INFO] - Found /Users/{user_name}/.paddlenlp/models/ernie-tiny/vocab.txt
[2021-06-22 18:04:57,099] [INFO] - Found /Users/{user_name}/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model
[2021-06-22 18:04:57,101] [INFO] - Found /Users/{user_name}/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle

至少这三个文件不应该重复加载吧?看上去路径是一样的,同样的文件每次加载肯定不对。做好了决定,我就直接找到paddlehub的安装路径:

/Users/{user_name}/.pyenv/versions/3.8.9/lib/python3.8/site-packages/paddlehub

用VSCode打开,方便直接加日志诊断。我们先找到 predict 方法,在里面通过上面的定时器加了一些日志,但是很奇怪函数中间环节的耗时合计不等于整个函数的耗时??

class TransformerModule(RunModule, TextServing):
    ...
    def predict(self,
                data: List[List[str]],
                max_seq_len: int = 128,
                split_char: str = '\002',
                batch_size: int = 1,
                use_gpu: bool = False):
        ...
        tt = time_start()

        t1 = time_start()
        if self.task not in self._tasks_supported \
                and self.task is not None:      # None for getting embedding
            raise RuntimeError(f'Unknown task {self.task}, current tasks supported:\n'
                               '1. seq-cls: sequence classification;\n'
                               '2. token-cls: sequence labeling;\n'
                               '3. text-matching: text matching;\n'
                               '4. None: embedding')

        paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')

        print('NLP Predict')

        batches = self._batchify(
            data, max_seq_len, batch_size, split_char)
        results = []
        self.eval()

        d1 = time_end(t1)
        print("NLP 预处理耗时:{}".format(d1['elapse_mill_secs']))

        for batch in batches:
            print('batch')
            if self.task == 'text-matching':
                ...
            else:
                t2 = time_start()
                input_ids, segment_ids = batch
                input_ids = paddle.to_tensor(input_ids)
                segment_ids = paddle.to_tensor(segment_ids)
                d2 = time_end(t2)
                print("NLP 加载耗时:{}".format(d2['elapse_mill_secs']))

                if self.task == 'seq-cls':
                    t3 = time_start()
                    probs = self(input_ids, segment_ids)
                    idx = paddle.argmax(probs, axis=1).numpy()
                    idx = idx.tolist()
                    labels = [self.label_map[i] for i in idx]
                    results.extend(labels)

                    d3 = time_end(t3)
                    print("NLP 预测耗时:{}".format(d3['elapse_mill_secs']))
                elif self.task == 'token-cls':
                    ...

        dd = time_end(tt)
        print("NLP 总耗时:{}".format(dd['elapse_mill_secs']))
        return results

我纳闷了一会,我发现 for 循环里的 print('batch') 只打印了一次。我本想是不是 batches 比较多,导致单次执行的耗时虽然不多,但是 batches 很大,多次循环后累计就很大,如果是这样的话也很麻烦。但是 print('batch') 只打印了一次只打印了一次,那到底耗时在哪呢?

好在我立刻想到会不会是for batch in batches 这个语句耗时,我们知道 Python 的遍历背后是 迭代器,一种让人写的很爽,也有延迟计算能力的状态机语法糖。有了这个怀疑,我就加了一行代码,类似哈利波特的显形咒: “急急现形(Apareciym)”

batches = list(self._batchify(
            data, max_seq_len, batch_size, split_char))

加上 list 后,统计的耗时就正常了,下面这句显示耗时最多:

print("NLP 预处理耗时:{}".format(d1['elapse_mill_secs']))

那么毫无疑问,耗时在 self._batchify 这个成员函数里面,拆开盒子:

def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str):
        def _parse_batch(batch):
            if self.task != 'text-matching':
                print('text-matching parse batch')
                input_ids = [entry[0] for entry in batch]
                segment_ids = [entry[1] for entry in batch]
                return input_ids, segment_ids
            else:
                print('no text-matching parse batch')
                query_input_ids = [entry[0] for entry in batch]
                query_segment_ids = [entry[1] for entry in batch]
                title_input_ids = [entry[2] for entry in batch]
                title_segment_ids = [entry[3] for entry in batch]
                return query_input_ids, query_segment_ids, title_input_ids, title_segment_ids

        t1 = time_start()
        tokenizer = self.get_tokenizer()
        examples = []
        d1 = time_end(t1)
        print("get_tokenizer: {}".format(d1['elapse_mill_secs']))

        t2 = time_start()
        for texts in data:
            encoded_inputs = self._convert_text_to_input(
                tokenizer, texts, max_seq_len, split_char)
            example = []
            for inp in encoded_inputs:
                input_ids = inp['input_ids']
                if Version(paddlenlp.__version__) >= Version('2.0.0rc5'):
                    token_type_ids = inp['token_type_ids']
                else:
                    token_type_ids = inp['segment_ids']
                example.extend((input_ids, token_type_ids))
            examples.append(example)
        d2 = time_end(t2)
        print("for texts in data: {}".format(d2['elapse_mill_secs']))

        # Seperates data into some batches.
        t3 = time_start()
        one_batch = []
        for example in examples:
            one_batch.append(example)
            if len(one_batch) == batch_size:
                yield _parse_batch(one_batch)
                one_batch = []

        d3 = time_end(t3)
        print("for texts in data: {}".format(d3['elapse_mill_secs']))

        if one_batch:
            # The last batch whose size is less than the config batch_size setting.
            yield _parse_batch(one_batch)

果然,我们看到了几个熟悉的关键字 yield。不过耗时统计却显示最耗时的是这句:

t1 = time_start()
tokenizer = self.get_tokenizer()
examples = []
d1 = time_end(t1)
print("get_tokenizer: {}".format(d1['elapse_mill_secs']))

于是我想进一步看下 self.get_tokenizer 这个函数,很不幸这个函数并不是一个直接的成员函数,隐藏在一堆函数动态注入里面:

class RunModule(object):
    '''The base class of PaddleHub Module, users can inherit this class to implement to realize custom class.'''

    def __init__(self, *args, **kwargs):
        super(RunModule, self).__init__()

    def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
        mod = current_cls.__module__ + '.' + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                base_run_func = self._get_func_name(
                    base_class, module_func_dict)
                if base_run_func:
                    return base_run_func
        else:
            return None
  ...

不过到了这里,我看不出 self.get_tokenizer() 有每次重新加载的必要:

  • 每次 predict 动态的参数都是从参数传入的
  • 预处理的开头并没有把动态参数用上,就开始调用 self.get_tokenizer()

从逻辑上来说, self.get_tokenizer() 就不应该每次重新调用,它应该被:

  • 缓存

掉个书袋,计算机里的两个核心问题就是:

  • 命名
  • 缓存

解决问题:通常就几行代码

终于,我可以用最快的方式验证下这个猜想,只需从外部改造下PaddleHubPLImpl

class PaddleHubPLImpl:
    def __init__(self):
        self.model = None
        self.tokenizer = None

    def load(self):
        self.model = self.load_or_fit()
        # 修正辣鸡重复加载 tokenizer 代码
        get_tokenizer = self.model.get_tokenizer
        self.model.get_tokenizer = lambda: self.cached_get_tokenizer(
            get_tokenizer)

    def cached_get_tokenizer(self, get_tokenizer):
        if self.tokenizer is None:
            self.tokenizer = get_tokenizer()
        return self.tokenizer

    def load_or_fit(self):
        model = hub.Module(
            name='ernie_tiny',
            version='2.0.1',
            task='seq-cls',
            load_checkpoint=get_tag_model_path()+'model.pdparams',
            label_map=LABEL_MAP)
        return model

    def predict(self, title, content):
        sentence = []
        sentence.append(title + content)
        sentences = []
        catalog_id = ''
        sentences.append(sentence)
        results = self.model.predict(
            sentences, max_seq_len=128, batch_size=1, use_gpu=False)
        for idx, text in enumerate(sentences):
            catalog_id = results[idx]

        return catalog_id

其中核心的代码是:

def load(self):
        self.model = self.load_or_fit()
        get_tokenizer = self.model.get_tokenizer
        self.model.get_tokenizer = lambda: self.cached_get_tokenizer(
            get_tokenizer)

    def cached_get_tokenizer(self, get_tokenizer):
        if self.tokenizer is None:
            self.tokenizer = get_tokenizer()
        return self.tokenizer

我们用最快速的方式替换掉 self.model.get_tokenizer,让他有缓存的能力,Python 的动态语言特性在这里就显示出一定的便利性。

跑下测试:

@test_classifier_question_tag..
code-c++:100,编程语言:100

耗时: 0.297 秒

@test_classifier_question_tag..
code-c++:100,开发语言:0

耗时: 0.331 秒

@test_classifier_question_tag..
code-c++:100,编程语言:100

耗时: 0.244 秒

@test_classifier_question_tag..
code-c语言:35,编程语言:65

耗时: 0.373 秒

@test_classifier_question_tag..
code-c++:86,后端开发:14

耗时: 0.243 秒

@test_classifier_question_tag..
code-c++:100,编程语言:100

耗时: 0.242 秒

@test_classifier_question_tag..
code-c++:55,人工智能:45

耗时: 0.245 秒

屏幕一闪而过,看到 0.245秒 的时候,还是挺激动的,我解决了发现的飞桨(paddlepaddle) 大坑!而且进一步的猜测是内存占用的问题应该也跟这里有关系,我们可以进一步验证。

小结

本节我们通过一个实战的小例子,展示代码组织和问题诊断。通过尽早集成发现问题,同时通过一系列代码模块重新组织加上有序的profile找到性能的瓶颈,接着通过一系列合理的猜测和验证层层定位到问题,最后给出一个简洁的解决方式。

你可能感兴趣的:(NLP,In,Action,python,自然语言处理,机器学习)