大模型多轮对话数据集构建

如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据:

{“prompt”: “长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线”,
“response”: “用电脑能读数据流吗?水温多少”,
“history”: []}

{“prompt”: “95”,
“response”: “上下水管温差怎么样啊?空气是不是都排干净了呢?”,
“history”: [[“长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线”, “用电脑能读数据流吗?水温多少”]]}

{“prompt”: “是的。上下水管都好的”,
“response”:“那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!”,
“history”:[[“长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线”,“用电脑能读数据流吗?水温多少”], [“95”, “上下水管温差怎么样啊?空气是不是都排干净了呢?”]]}

训练时需要指定 --history_column 为数据中聊天历史的 key(在此例子中是 history),将自动把聊天历史拼接。要注意超过输入长度 max_source_length 的内容会被截断

--prompt_column prompt \
--response_column response \
--history_column history 

模型的输入处理代码如下

def preprocess_function_train(examples):
    max_seq_length = data_args.max_source_length + data_args.max_target_length

    model_inputs = {
        "input_ids": [],
        "labels": [],
    }
    for i in range(len(examples[prompt_column])):
        if examples[prompt_column][i] and examples[response_column][i]:
            query, answer = examples[prompt_column][i], examples[response_column][i]

            if history_column is None:
                prompt = query
            else:
                prompt = ""
                history = examples[history_column][i]
                for turn_idx, (old_query, response) in enumerate(history):
                    prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
                prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)

            prompt = prefix + prompt
            a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
            b_ids = tokenizer.encode(text=answer, add_special_tokens=False)

            if len(a_ids) > data_args.max_source_length - 1:
                a_ids = a_ids[: data_args.max_source_length - 1]

            if len(b_ids) > data_args.max_target_length - 2:
                b_ids = b_ids[: data_args.max_target_length - 2]

            input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)

            # context_length = input_ids.index(tokenizer.bos_token_id)
            context_length = len(a_ids)
            mask_position = context_length - 1
            labels = [-100] * context_length + input_ids[mask_position+1:]
            
            pad_len = max_seq_length - len(input_ids)
            input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
            labels = labels + [tokenizer.pad_token_id] * pad_len
            if data_args.ignore_pad_token_for_loss:
                labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]

            model_inputs["input_ids"].append(input_ids)
            model_inputs["labels"].append(labels)

    return model_inputs

模型的评估代码如下

    def preprocess_function_eval(examples):
        inputs, targets = [], []
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query = examples[prompt_column][i]
                if history_column is None or len(examples[history_column][i]) == 0:
                    prompt = query
                else:
                    prompt = ""
                    history = examples[history_column][i]
                    for turn_idx, (old_query, response) in enumerate(history):
                        prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
                    prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
                inputs.append(prompt)
                targets.append(examples[response_column][i])

        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
        labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)

        if data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

你可能感兴趣的:(ai)