WMT是机器翻译和机器翻译研究的主要活动。 该会议每年与自然语言处理方面的大型会议联合举行。2006年,第一届机器翻译研讨会在计算语言学协会北美分会年会上举行。2016年,随着神经机器翻译的兴起,WMT成为了一个自己的会议。 机器翻译会议仍然主要被称为WMT[1]。
有些机器翻译工作会使用历年WMT公开的数据集作为他们的数据集[2],如下图所示:
当笔者想要复现工作结果时,首先需要收集得到这样的数据集。而以WMT13[3]为例。如下图所示,笔者需要手动点击下载上面公开的每一个子数据集,然后汇总得到整个WMT13的训练、验证和测试集。而由于每一个子数据集的形式也不同,且数量较多…总的来说还是很麻烦的。
而笔者发现,huggingface[4]上面已经收集了部分年份的WMT数据,并提供了下载接口。以wmt14的所有hi-en数据为例,最终的下载结果如下图所示:
(笔者后知后觉意识到,只要想办法打开.arrow文件就可以得到对应数据了…艹)
本文旨在总结批量获取所有WMT数据的初步解决方案,通过修改huggingface datasets库的源码实现。
第一步,pip install datasets安装datasets库。
第二步,通过git clone https://github.com/huggingface/datasets克隆datasets库,datasets/datasets路径下面包含了该库提供的所有数据集的相关代码:
第三步,创建主程序文件(run.py),代码如下,其中,py_file_path
为上面说的datasets/datasets路径,save_dir
为保存到本地的路径:
from datasets import load_dataset
import os
wmt_dict = {
"wmt14": [(lang, "en") for lang in ["cs", "de", "fr", "hi", "ru"]],
"wmt15": [(lang, "en") for lang in ["cs", "de", "fi", "fr", "ru"]],
"wmt16": [(lang, "en") for lang in ["cs", "de", "fi", "ro", "ru", "tr"]],
"wmt17": [(lang, "en") for lang in ["cs", "de", "fi", "lv", "ru", "tr", "zh"]],
"wmt18": [(lang, "en") for lang in ["cs", "de", "et", "fi", "kk", "ru", "tr", "zh"]],
"wmt19": [(lang, "en") for lang in ["cs", "de", "fi", "gu", "kk", "lt", "ru", "zh"]] + [("fr", "de")],
}
py_file_path = r"C:\Users\13359\PycharmProjects\for_fun\other\wmt_datasets\datasets\datasets"
save_dir = r"D:\dataset\mt"
for wmt in wmt_dict:
for lang_tuple in wmt_dict[wmt]:
lang_pair = "-".join(lang_tuple)
print(f"wmt: {wmt} | lang_pair: {lang_pair}")
load_dataset(os.path.join(py_file_path, wmt), name = lang_pair, cache_dir = save_dir)
第四步,在上述datasets/datasets路径下面随便选择一个wmt文件夹,比如wmt14,将里面的wmt_utils.py复制到run.py的同级目录下。(暂时不知道为何,尝试下来这样没有错),也就是文件目录结构如下:
第五步,如果此时运行run.py,则会像前言中的那样,得到所有wmt的所有语言对的数据,但数据格式是arrow的。笔者脑抽,没有直接去想如何把.arrow文件转成更好理解的格式,而是通过修改pip install下来的datasets源码,来直接修改保存数据的过程。具体来说,通过ctrl+B追溯run.py中load_dataset的执行顺序,最终找到了保存数据的源码位置:load_dataset(run.py)->builder_instance.download_and_prepare(load.py,1738行)->self._download_and_prepare(builder.py, 638行)->self._prepare_split(builder.py, 723行),由于此时的self是一个wmtxx object,从具体的wmtxx.py(如wmt14.py,位于datasets/datasets/wmt14/wmt14.py)可知,wmtxx类的继承顺序是:wmtxx->Wmt(wmt_utils.py)->GeneratorBasedBuilder(builder.py)->DatasetBuilder(builder.py),所以self._prepare_split最终方法实现是GeneratorBasedBuilder类的_prepare_split方法。
该方法中完成了.arrow数据的创建,具体代码如下所示:
with ArrowWriter(
features=self.info.features,
path=fpath,
writer_batch_size=self._writer_batch_size,
hash_salt=split_info.name,
check_duplicates=check_duplicate_keys,
) as writer:
try:
for key, record in logging.tqdm(
generator,
unit=" examples",
total=split_info.num_examples,
leave=False,
disable=not logging.is_progress_bar_enabled(),
desc=f"Generating {split_info.name} split",
):
example = self.info.features.encode_example(record)
writer.write(example, key)
finally:
num_examples, num_bytes = writer.finalize()
split_generator.split_info.num_examples = num_examples
split_generator.split_info.num_bytes = num_bytes
其中,generator
就是包含了所有数据的生成器。于是,笔者修改了上面的代码,完成了对数据保存的修改:
# ...其它代码
generator = self._generate_examples(**split_generator.gen_kwargs)
# 修改代码
user_name = "xushaoyang"
str_lst = fpath.split("\\")
index = str_lst.index(self.name)
lang_pair = str_lst[index + 1]
source, target = lang_pair.split("-")
path_lst = str_lst[:index + 2]
path_lst[index] += f"_{user_name}"
dir_path = os.path.join(*path_lst)
os.makedirs(dir_path, exist_ok=True)
source_file_name = f"{self.name}.{lang_pair}-{split_generator.name}.{source}"
source_path = os.path.join(dir_path, source_file_name)
target_file_name = f"{self.name}.{lang_pair}-{split_generator.name}.{target}"
target_path = os.path.join(dir_path, target_file_name)
# flag = "population of the US, or of the combined current population of India and China"
with open(source_path, mode="w", encoding="utf-8") as source_f:
with open(target_path, mode="w", encoding="utf-8") as target_f:
with ArrowWriter(
features=self.info.features,
path=fpath,
writer_batch_size=self._writer_batch_size,
hash_salt=split_info.name,
check_duplicates=check_duplicate_keys,
) as writer:
try:
for key, record in logging.tqdm(
generator,
unit=" examples",
total=split_info.num_examples,
leave=False,
disable=not logging.is_progress_bar_enabled(),
desc=f"Generating {split_info.name} split",
):
assert list(record.keys()) == ["translation"]
lang_keys = list(record["translation"].keys())
if source not in lang_keys: # 问题1:zh的key在有些数据集里是ch
target_idx = lang_keys.index(target)
source_idx = (target_idx + 1) % 2
error_source_key = lang_keys[source_idx]
new_record = {"translation": {
source: record["translation"][error_source_key],
target: record["translation"][target]
}}
record = new_record
del new_record
# arrow write
example = self.info.features.encode_example(record)
writer.write(example, key)
# file write
source_sentence = record['translation'][source]
target_sentence = record['translation'][target]
source_sentence = source_sentence.replace("\r", "").replace("\n", "") # 问题2:多余的换行
target_sentence = target_sentence.replace("\r", "").replace("\n", "")
source_f.write(source_sentence + "\n")
target_f.write(target_sentence + "\n")
finally:
num_examples, num_bytes = writer.finalize()
split_generator.split_info.num_examples = num_examples
split_generator.split_info.num_bytes = num_bytes
完成修改后执行run.py,以wmt14的hi-en数据为例,得到的数据如下图所示,文件的命名仿照了OPUS100[5]:
如"不足之处"的第3点所述,有一些文件没有提供自动下载的url,笔者的解决方案是在下载过程中记录哪些数据集没有自动下载,所有下载完成之后再去手动补上。具体来说,笔者在wmt_utils.py中的_split_generators
函数中的 if dataset.get_manual_dl_files(source):
语句下,加入了如下语句:
with open(f"{self.name}_error_log", mode="a", encoding="utf-8") as file:
file.write(f"lang: {'-'.join(self.config.language_pair)} | data_name: {dataset.name} | url: {str(dataset.get_manual_dl_files(source))}" + "\n")
在运行run.py的过程中,出现数据集缺失的情况,这样的记录就会被保存在wmtxxx_error_log日志文件中,如下图所示:
打开datasets/datasets中的wmt19文件夹,修改里面的wmt19.py和wmt_utils.py。
# wmt19.py:line79
datasets.Split.TEST: ["newstest2019", "newstest2019_csen", "newstest2019_frde"],
# wmt_utils.py:line617
SubDataset(
name="newstest2019",
target="en",
sources={"de", "fi", "gu", "kk", "lt", "ru", "zh"},
url="http://data.statmt.org/wmt19/translation-task/test.tgz",
path=("sgm/newstest2019-{src}en-src.{src}.sgm", "sgm/newstest2019-{src}en-ref.en.sgm"),
),
SubDataset(
name="newstest2019_csen",
target="en",
sources={"cs"},
url="http://data.statmt.org/wmt19/translation-task/test.tgz",
path=("sgm/newstest2019-en{src}-src.en.sgm", "sgm/newstest2019-en{src}-ref.{src}.sgm"),
),
SubDataset(
name="newstest2019_frde",
target="de",
sources={"fr"},
url="http://data.statmt.org/wmt19/translation-task/test.tgz",
path=("sgm/newstest2019-frde-ref.de.sgm", "sgm/newstest2019-frde-src.fr.sgm"),
),
另外,直接这样运行run.py会报错,因为dataset_infos.json中的内容还没有修改。而程序会读取dataset_info.json中一些预设的信息,然后和下载下来的结果进行一些校验:verify_checksums、verify_splits,但是直接修改dataset_infos.json较为麻烦,所以笔者选择取消校验:
load_dataset(os.path.join(py_file_path, wmt), name = lang_pair, cache_dir = save_dir, save_infos = True)
# ignore_verifications = True,也可以
另外,还是要对dataset_infos.json做一个细微的修改,即在splits中增加”test“:
"test": {
"name": "test",
"num_bytes": 3000, # 这个随便设置的,不影响下载
"num_examples": 3000, # 同上
"dataset_name": "wmt19"
}
cs-en:
train:
source:953621
target:953621
validation:
source:3000
target:3000
test:
source:3003
target:3003
de-en:
train:
source:4508785
target:4508785
validation:
source:3000
target:3000
test:
source:3003
target:3003
fr-en:
train:
source:40836715
target:40836715
validation:
source:3000
target:3000
test:
source:3003
target:3003
hi-en:
train:
source:32863
target:32863
validation:
source:520
target:520
test:
source:2507
target:2507
ru-en:
train:
source:1486965
target:1486965
validation:
source:3000
target:3000
test:
source:3003
target:3003
cs-en:
train:
source:959768
target:959768
validation:
source:3003
target:3003
test:
source:2656
target:2656
de-en:
train:
source:4522998
target:4522998
validation:
source:3003
target:3003
test:
source:2169
target:2169
fi-en:
train:
source:2073394
target:2073394
validation:
source:1500
target:1500
test:
source:1370
target:1370
fr-en:
train:
source:40853137
target:40853137
validation:
source:4503
target:4503
test:
source:1500
target:1500
ru-en:
train:
source:1495081
target:1495081
validation:
source:3003
target:3003
test:
source:2818
target:2818
cs-en:
train:
source:997240
target:997240
validation:
source:2656
target:2656
test:
source:2999
target:2999
de-en:
train:
source:4548885
target:4548885
validation:
source:2169
target:2169
test:
source:2999
target:2999
fi-en:
train:
source:2073394
target:2073394
validation:
source:1370
target:1370
test:
source:6000
target:6000
ro-en:
train:
source:610320
target:610320
validation:
source:1999
target:1999
test:
source:1999
target:1999
ru-en:
train:
source:1516162
target:1516162
validation:
source:2818
target:2818
test:
source:2998
target:2998
tr-en:
train:
source:205756
target:205756
validation:
source:1001
target:1001
test:
source:3000
target:3000
cs-en:
train:
source:1018291
target:1018291
validation:
source:2999
target:2999
test:
source:3005
target:3005
de-en:
train:
source:5906184
target:5906184
validation:
source:2999
target:2999
test:
source:3004
target:3004
fi-en:
train:
source:2656542
target:2656542
validation:
source:6000
target:6000
test:
source:6004
target:6004
lv-en:
train:
source:3567528
target:3567528
validation:
source:2003
target:2003
test:
source:2001
target:2001
ru-en:
train:
source:24782720
target:24782720
validation:
source:2998
target:2998
test:
source:3001
target:3001
tr-en:
train:
source:205756
target:205756
validation:
source:3000
target:3000
test:
source:3007
target:3007
zh-en:
train:
source:25134743
target:25134743
validation:
source:2002
target:2002
test:
source:2001
target:2001
cs-en:
train:
source:11046024
target:11046024
validation:
source:3005
target:3005
test:
source:2983
target:2983
de-en:
train:
source:42271874
target:42271874
validation:
source:3004
target:3004
test:
source:2998
target:2998
et-en:
train:
source:2175873
target:2175873
validation:
source:2000
target:2000
test:
source:2000
target:2000
fi-en:
train:
source:3280600
target:3280600
validation:
source:6004
target:6004
test:
source:3000
target:3000
kk-en:
train:
source:0
target:0
validation:
source:0
target:0
test:
source:0
target:0
ru-en:
train:
source:36858512
target:36858512
validation:
source:3001
target:3001
test:
source:3000
target:3000
tr-en:
train:
source:205756
target:205756
validation:
source:3007
target:3007
test:
source:3000
target:3000
zh-en:
train:
source:25160346
target:25160346
validation:
source:2001
target:2001
test:
source:3981
target:3981
cs-en:
train:
source:7270695
target:7270695
validation:
source:2983
target:2983
test:
source:1997
target:1997
de-en:
train:
source:38690334
target:38690334
validation:
source:2998
target:2998
test:
source:2000
target:2000
fi-en:
train:
source:6587448
target:6587448
validation:
source:3000
target:3000
test:
source:1996
target:1996
gu-en:
train:
source:11670
target:11670
validation:
source:1998
target:1998
test:
source:1016
target:1016
kk-en:
train:
source:126583
target:126583
validation:
source:2066
target:2066
test:
source:1000
target:1000
lt-en:
train:
source:2344893
target:2344893
validation:
source:2000
target:2000
test:
source:1000
target:1000
ru-en:
train:
source:37492126
target:37492126
validation:
source:3000
target:3000
test:
source:2000
target:2000
zh-en:
train:
source:25984574
target:25984574
validation:
source:3981
target:3981
test:
source:2000
target:2000
fr-de:
train:
source:9824476
target:9824476
validation:
source:1512
target:1512
test:
source:1707
target:1707
可以看到存在大量的重复。以上语料的下载链接统计如下:
基本都需要提交申请
[1]https://machinetranslate.org/wmt
[2]https://arxiv.org/pdf/2105.09259v1.pdf
[3]https://www.statmt.org/wmt14/translation-task.html
[4]https://github.com/huggingface/datasets
[5]https://github.com/EdinburghNLP/opus-100-corpus