加载时数据集未指定 inputs 和 labels 。
# Get the column names for input/target.
# 设置 input/target 的逻辑: 1.指定名称(data_args.text_column, data_args.summary_column) 2.指定数据集(自带名称map) 3.默认为(dataset_columns[0], dataset_columns[1])
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
if data_args.text_column is None:
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
text_column = data_args.text_column
if text_column not in column_names:
raise ValueError(
f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
)
if data_args.summary_column is None:
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
summary_column = data_args.summary_column
if summary_column not in column_names:
raise ValueError(
f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
)
设置 input/target 的逻辑:
1.指定名称(data_args.text_column, data_args.summary_column)
2.指定数据集(自带名称map)
3.默认为(dataset_columns[0], dataset_columns[1])
这里采用了在读取参数时指定,这里代码写了一种从文件读取参数的方式,需要一个参数文件config.json
。
在 run_summarization.py:311
行加载参数文件。
# main, run_summarization.py:311
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
参数文件内容
{
"model_name_or_path":"fnlp_bart-base-chinese",
"text_column":"context",
"summary_column":"response",
"max_source_length":128,
"max_target_length":256,
"dataset_name":"data_douhao",
"num_train_epochs":160,
"save_steps":1000,
"per_device_train_batch_size":32,
"per_device_eval_batch_size":32,
"do_train":true,
"do_eval":true,
"do_predict":false,
"include_inputs_for_metrics":true,
"predict_with_generate":true,
"output_dir":"checkpoints/model_douhao2",
"overwrite_output_dir":true
}
这里在参数文件中用 "text_column":"context"
, "summary_column":"response",
指定 inputs 和 labels 签。
设置 input/target 的逻辑:
1.指定名称 (data_args.text_column, data_args.summary_column)
在加载完数据集后,直接赋值就可以了
data_args.text_column = "context",
data_args.summary_column = "response",
2.指定数据集(自带名称map)
在 run_summarization.py:289
行有个数据集映射表,增加一行数据集:元组映射即可。
summarization_name_mapping = {
"amazon_reviews_multi": ("review_body", "review_title"),
"big_patent": ("description", "abstract"),
......
"wiki_summary": ("article", "highlights"),
"multi_news": ("document", "summary"),
# 加一行即可
# "数据集名称": (input,output)
"data": ("context", "response"),
}
3.默认为(dataset_columns[0], dataset_columns[1])
(被坑了)
训练出来的模型生成长度始终为20。
transformers 库中加载模型超参数时,有个默认值 max_length = 20
,控制生成文本长度,在载入模型config文件时,没设置值,自动加载的默认值。(默认值有点短)
# 下面按照顺序一层一层进入
# --------------------------------
main, run_summarization.py:417 # 这行加载了BART模型config
config = AutoConfig.from_pretrained(
# --------------------------------
from_pretrained, configuration_auto.py:941 # 加载完 config 文件数值去找对应的模型 config 类了
# 这里模型的 config 里写明了 "model_type": "bart"
# 所以载入时 config_dict["model_type"] = "bart"
return config_class.from_dict(config_dict, **unused_kwargs)
# --------------------------------
from_dict, configuration_utils.py:701 # 同样在找找对应的模型 config 类
config = cls(**config_dict)
# --------------------------------
__init__, configuration_bart.py:165 # 找到对应 bart 模型 config 类,进行初始化
super().__init__(
# --------------------------------
__init__, configuration_utils.py:285 # 用的是通用 config 加载
self.max_length = kwargs.pop("max_length", 20)
在最后这个文件configuration_utils.py
第285行打个断点就可以看到加载的默认值了。
改的话其实比较好改,加载完 config 后,将变量 config.max_length 改为 256,生成文本长度即可改变。
config.max_length = 256