下载脚本代码:
''' Script for downloading all GLUE data.'''
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
def download_and_extract(task, data_dir):
print("Downloading and extracting %s..." % task)
data_file = "%s.zip" % task
urllib.request.urlretrieve(TASK2PATH[task], data_file)
with zipfile.ZipFile(data_file) as zip_ref:
zip_ref.extractall(data_dir)
os.remove(data_file)
print("\tCompleted!")
def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC")
if not os.path.isdir(mrpc_dir):
os.mkdir(mrpc_dir)
if path_to_data:
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
else:
print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))
with open(mrpc_train_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split('\t')
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
with open(mrpc_test_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t')
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
print("\tCompleted!")
def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
os.mkdir(os.path.join(data_dir, "diagnostic"))
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
print("\tCompleted!")
return
def get_tasks(task_names):
task_names = task_names.split(',')
if "all" in task_names:
tasks = TASKS
else:
tasks = []
for task_name in task_names:
assert task_name in TASKS, "Task %s not found!" % task_name
tasks.append(task_name)
return tasks
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir):
os.mkdir(args.data_dir)
tasks = get_tasks(args.tasks)
for task in tasks:
if task == 'MRPC':
format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == 'diagnostic':
download_diagnostic(args.data_dir)
else:
download_and_extract(task, args.data_dir)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
运行脚本下载所有数据集:
输出效果:
CoLA数据集文件样式:
文件样式说明:
在使用中常用到的文件是train.tsv, dev.tsv, test.tsv, 分别代表训练集, 验证集和测试集. 其中train.tsv与dev.tsv数据样式相同, 都是带有标签的数据, 其中test.tsv是不带有标签的数据.
train.tsv数据样式:
train.tsv数据样式说明:
train.tsv中的数据内容共分为4列, 第一列数据, 如gj04, bc01等代表每条文本数据的来源即出版物代号; 第二列数据, 0或1, 代表每条文本数据的语法是否正确, 0代表不正确, 1代表正确; 第三列数据, '', 是作者最初的正负样本标记, 与第二列意义相同, ''表示不正确; 第四列即是被标注的语法使用是否正确的文本句子.
test.tsv数据样式:
test.tsv数据样式说明:
test.tsv中的数据内容共分为2列, 第一列数据代表每条文本数据的索引; 第二列数据代表用于测试的句子.
CoLA数据集的任务类型:
二分类任务 评估指标为: MMC(马修斯相关系数, 在正负样本分布十分不均衡的情况下使用的二分类评估指标)
SST-2数据集文件样式:
文件样式说明:
在使用中常用到的文件是train.tsv, dev.tsv, test.tsv, 分别代表训练集, 验证集和测试集. 其中train.tsv与dev.tsv数据样式相同, 都是带有标签的数据, 其中test.tsv是不带有标签的数据.
train.tsv数据样式:
train.tsv数据样式说明:
train.tsv中的数据内容共分为2列, 第一列数据代表具有感情色彩的评论文本; 第二列数据, 0或1, 代表每条文本数据是积极或者消极的评论, 0代表消极, 1代表积极.
test.tsv数据样式:
test.tsv数据样式说明: * test.tsv中的数据内容共分为2列, 第一列数据代表每条文本数据的索引; 第二列数据代表用于测试的句子.
2.2.3 SST-2数据集的任务类型:
二分类任务 评估指标为: ACC
MRPC数据集文件样式:
文件样式说明:
在使用中常用到的文件是train.tsv, dev.tsv, test.tsv, 分别代表训练集, 验证集和测试集. 其中train.tsv与dev.tsv数据样式相同, 都是带有标签的数据, 其中test.tsv是不带有标签的数据.
train.tsv数据样式:
train.tsv数据样式说明:
train.tsv中的数据内容共分为5列, 第一列数据, 0或1, 代表每对句子是否具有相同的含义, 0代表含义不相同, 1代表含义相同. 第二列和第三列分别代表每对句子的id, 第四列和第五列分别具有相同/不同含义的句子对.
test.tsv数据样式:
test.tsv数据样式说明: * test.tsv中的数据内容共分为5列, 第一列数据代表每条文本数据的索引; 其余列的含义与train.tsv中相同.
2.2.4 MRPC数据集的任务类型:
句子对二分类任务 评估指标为: ACC和F1
STS-B数据集文件样式:
文件样式说明:
在使用中常用到的文件是train.tsv, dev.tsv, test.tsv, 分别代表训练集, 验证集和测试集. 其中train.tsv与dev.tsv数据样式相同, 都是带有标签的数据, 其中test.tsv是不带有标签的数据.
train.tsv数据样式:
train.tsv数据样式说明:
train.tsv中的数据内容共分为10列, 第一列数据是数据索引; 第二列代表每对句子的来源, 如main-captions表示来自字幕; 第三列代表来源的具体保存文件名, 第四列代表出现时间(年); 第五列代表原始数据的索引; 第六列和第七列分别代表句子对原始来源; 第八列和第九列代表相似程度不同的句子对; 第十列代表句子对的相似程度由低到高, 值域范围是[0, 5].
test.tsv数据样式:
test.tsv数据样式说明:
test.tsv中的数据内容共分为9列, 含义与train.tsv前9列相同.
STS-B数据集的任务类型:
句子对多分类任务/句子对回归任务 评估指标为: Pearson-Spearman Corr
QQP数据集文件样式:
文件样式说明:
在使用中常用到的文件是train.tsv, dev.tsv, test.tsv, 分别代表训练集, 验证集和测试集. 其中train.tsv与dev.tsv数据样式相同, 都是带有标签的数据, 其中test.tsv是不带有标签的数据.
train.tsv数据样式:
train.tsv数据样式说明:
train.tsv中的数据内容共分为6列, 第一列代表文本数据索引; 第二列和第三列数据分别代表问题1和问题2的id; 第四列和第五列代表需要进行'是否重复'判定的句子对; 第六列代表上述问题是/不是重复性问题的标签, 0代表不重复, 1代表重复.
test.tsv数据样式:
test.tsv数据样式说明:
test.tsv中的数据内容共分为3列, 第一列数据代表每条文本数据的索引; 第二列和第三列数据代表用于测试的问题句子对.
QQP数据集的任务类型:
句子对二分类任务 评估指标为: ACC/F1
MNLI数据集文件样式:
文件样式说明:
在使用中常用到的文件是train.tsv, dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, test_mismatched.tsv分别代表训练集, 与训练集一同采集的验证集, 与训练集不是一同采集验证集, 与训练集一同采集的测试集, 与训练集不是一同采集测试集. 其中train.tsv与dev_matched.tsv和dev_mismatched.tsv数据样式相同, 都是带有标签的数据, 其中test_matched.tsv与test_mismatched.tsv数据样式相同, 都是不带有标签的数据.
train.tsv数据样式:
train.tsv数据样式说明:
train.tsv中的数据内容共分为12列, 第一列代表文本数据索引; 第二列和第三列数据分别代表句子对的不同类型id; 第四列代表句子对的来源; 第五列和第六列代表具有句法结构分析的句子对表示; 第七列和第八列代表具有句法结构和词性标注的句子对表示, 第九列和第十列代表原始的句子对, 第十一和第十二列代表不同标准的标注方法产生的标签, 在这里,他们始终相同, 一共有三种类型的标签, neutral代表两个句子既不矛盾也不蕴含, entailment代表两个句子具有蕴含关系, contradiction代表两个句子观点矛盾.
test_matched.tsv数据样式:
test_matched.tsv数据样式说明:
test_matched.tsv中的数据内容共分为11列, 与train.tsv的前11列含义相同.
MNLI数据集的任务类型:
句子对多分类任务 评估指标为: ACC
(QNLI/RTE/WNLI)数据集文件样式:
QNLI, RTE, WNLI三个数据集的样式基本相同.
文件样式说明:
在使用中常用到的文件是train.tsv, dev.tsv, test.tsv, 分别代表训练集, 验证集和测试集. 其中train.tsv与dev.tsv数据样式相同, 都是带有标签的数据, 其中test.tsv是不带有标签的数据.
QNLI中的train.tsv数据样式:
RTE中的train.tsv数据样式:
WNLI中的train.tsv数据样式:
(QNLI/RTE/WNLI)中的train.tsv数据样式说明:
train.tsv中的数据内容共分为4列, 第一列代表文本数据索引; 第二列和第三列数据代表需要进行'是否蕴含'判定的句子对; 第四列数据代表两个句子是否具有蕴含关系, 0/not_entailment代表不是蕴含关系, 1/entailment代表蕴含关系.
QNLI中的test.tsv数据样式:
(QNLI/RTE/WNLI)中的test.tsv数据样式说明:
test.tsv中的数据内容共分为3列, 第一列数据代表每条文本数据的索引; 第二列和第三列数据代表需要进行'是否蕴含'判定的句子对.
(QNLI/RTE/WNLI)数据集的任务类型:
句子对二分类任务 评估指标为: ACC