15.python设计模式【函数工厂模式】

1.知识讲解

  • 内容:定义一个字典,在python中一切皆对象,将所有的函数进行封装,然后定一个分发函数进行分发,将原来if…else全部干掉。
  • 角色:
    • 函数(function)
    • 函数工厂(function factory)
    • 客户端 (client)
  • 举个例子:
    需求:封装一个函数,能够同时进行加减乘除运算。
    加减乘除函数:
# 定义一个计算器的相关功能
def plus(a, b):
    return a + b


def substact(a, b):
    return a - b


def multiply(a, b):
    return a * b


def divide(a, b):
    return a / b

定义封装函数:

# 定义一个计算函数
def cal(a, b, how):
    if how == 1:
        return plus(a, b)
    elif how == 2:
        return substact(a, b)
    elif how == 3:
        return multiply(a, b)
    else:
        return None

从上面这个封装函数来看,太多了if…else…很冗余
于是定义一个函数工厂,将所有函数进行封装,然后根据函数名进行调用

# 定义函数工厂
# 在python里面一切皆是对象
# 定义了一个字典,key是函数名称,value是函数对象
func_map = {
    "plus": plus,
    "substract": substact,
    "multiply": multiply,
    "divide": divide
}
# 函数工厂模式就是一种对函数进行动态分发的模式
def cal(a,b,how):
    if how in func_map.keys():
        return func_map[how](a,b)
    else:
        return None
  • 优点:
    • 对函数进行动态分发,减少了函数的冗余代码。

2.实战

2.1 demo1

需求:这个是我在写深度学习项目的时候遇到的一个设计模式,当初不明白,现在明白了这个设计模式。自然语言处理中,有一次有一个实验,需要同时验证Bert,roberta,gpt,Xnet等预训练模型的相关功能的性能,他们大致分以下几个模块

  • config
  • tokenizer
  • 掩码模型:Bert,roberta,gpt使用的是mlm掩码模型,而Xnet使用的是plm掩码模型
  • 自带的分类模型:sequence_classifier ,但是GPT没有

因为他们每个的这四个部分的功能实现都不相同,但是在实验过程中都需要用到,因此就用到了函数工厂模式。

from torch import nn
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification, BertForMaskedLM, RobertaConfig, \
    RobertaTokenizer, RobertaForSequenceClassification, RobertaForMaskedLM, XLMRobertaConfig, XLMRobertaTokenizer, \
    XLMRobertaForSequenceClassification, XLMRobertaForMaskedLM, XLNetConfig, XLNetTokenizer, \
    XLNetForSequenceClassification, XLNetLMHeadModel, AlbertConfig, AlbertTokenizer, AlbertForSequenceClassification, \
    AlbertForMaskedLM, GPT2Config, GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer

# 定义一个函数工厂,将所有的函数全部用一个字典封装好,到时候用到那个预训练模型,则就根据预训练模型的名称调用对应的函数。
MODEL_CLASSES = {
    'bert': {
        'config': BertConfig,
        'tokenizer': BertTokenizer,
        "sequence_classifier": BertForSequenceClassification,
        "mlm":BertForMaskedLM
    },
    'roberta': {
        'config': RobertaConfig,
        'tokenizer': RobertaTokenizer,
        "sequence_classifier": RobertaForSequenceClassification,
        "mlm": RobertaForMaskedLM
    },
    'xlm-roberta': {
        'config': XLMRobertaConfig,
        'tokenizer': XLMRobertaTokenizer,
        "sequence_classifier": XLMRobertaForSequenceClassification,
        "mlm": XLMRobertaForMaskedLM
    },
    'xlnet': {
        'config': XLNetConfig,
        'tokenizer': XLNetTokenizer,
        "sequence_classifier": XLNetForSequenceClassification,
        "plm": XLNetLMHeadModel
    },
    'albert': {
        'config': AlbertConfig,
        'tokenizer': AlbertTokenizer,
        "sequence_classifier": AlbertForSequenceClassification,
        "mlm": AlbertForMaskedLM
    },
    'gpt2': {
        'config': GPT2Config,
        'tokenizer': GPT2Tokenizer,
        "mlm": GPT2LMHeadModel
    },
}

class TransformerModelWrapper(nn.Module):
    # 基于Transformer的语言模型的包装器。
    '''WrapperConfig封装了:
    model_type为Bert,roberta,gpt,Xnet,
    wrapper_type为mlm和plm两种类型'''

    def __init__(self, config: WrapperConfig):
        super(TransformerModelWrapper, self).__init__()
        self.config = config
        config_class = MODEL_CLASSES[self.config.model_type]['config']
        tokenizer_class = MODEL_CLASSES[self.config.model_type]['tokenizer']
        model_class = MODEL_CLASSES[self.config.model_type][self.config.wrapper_type]

你可能感兴趣的:(#,设计模式,设计模式,python,开发语言)