From 041ecee74cc80c343949d3dbc62bbdb11410d6b5 Mon Sep 17 00:00:00 2001 From: cronopioelectronico Date: Wed, 19 May 2021 18:41:50 +0100 Subject: [PATCH] Change GLUE data download script for a working one --- download_glue_data.py | 123 ++++++++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 45 deletions(-) diff --git a/download_glue_data.py b/download_glue_data.py index 9712870..fe1bfc8 100644 --- a/download_glue_data.py +++ b/download_glue_data.py @@ -1,10 +1,18 @@ -''' Script for downloading all GLUE data. +#!/usr/bin/env python3 + +""" Script for downloading all GLUE data. + +Example usage: + python download_glue_data.py --data_dir data --tasks all Note: for legal reasons, we are unable to host MRPC. -You can either use the version hosted by the SentEval team, which is already tokenized, -or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. -For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). -You should then rename and place specific files in a folder (see below for an example). +You can either use the version hosted by the SentEval team, which is already tokenized, +or you can download the original data from: +https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi # noqa +and extract the data from it manually. +For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library +such as 'cabextract' (see below for an example). You should then rename and place specific files in +a folder (see below for an example). mkdir MRPC cabextract MSRParaphraseCorpus.msi -d MRPC @@ -13,9 +21,7 @@ rm MRPC/_* rm MSRParaphraseCorpus.msi -1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. -2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! -''' +""" import os import sys @@ -26,20 +32,26 @@ import zipfile TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] -TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', - "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', - "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', - "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', - "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', - "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', - "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', - "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', - "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', - "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', - "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} - -MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' -MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' +TASK2PATH = { + "CoLA": "https://dl.fbaipublicfiles.com/glue/data/CoLA.zip", + "SST": "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip", + "MRPC": "https://dl.fbaipublicfiles.com/glue/data/mrpc_dev_ids.tsv", + "QQP": "https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip", + "STS": "https://dl.fbaipublicfiles.com/glue/data/STS-B.zip", + "MNLI": "https://dl.fbaipublicfiles.com/glue/data/MNLI.zip", + "SNLI": "https://dl.fbaipublicfiles.com/glue/data/SNLIv2.zip", + "QNLI": "https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip", + "RTE": "https://dl.fbaipublicfiles.com/glue/data/RTE.zip", + "WNLI": "https://dl.fbaipublicfiles.com/glue/data/WNLI.zip", + "diagnostic": [ + "https://dl.fbaipublicfiles.com/glue/data/AX.tsv", + "https://dl.fbaipublicfiles.com/glue/data/diagnostic-full.tsv", + ], +} + +MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt" +MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt" + def download_and_extract(task, data_dir): print("Downloading and extracting %s..." % task) @@ -50,6 +62,7 @@ def download_and_extract(task, data_dir): os.remove(data_file) print("\tCompleted!") + def format_mrpc(data_dir, path_to_data): print("Processing MRPC...") mrpc_dir = os.path.join(data_dir, "MRPC") @@ -71,41 +84,46 @@ def format_mrpc(data_dir, path_to_data): dev_ids = [] with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: for row in ids_fh: - dev_ids.append(row.strip().split('\t')) + dev_ids.append(row.strip().split("\t")) - with open(mrpc_train_file, encoding="utf8") as data_fh, \ - open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ - open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: + with open(mrpc_train_file, encoding="utf8") as data_fh, open( + os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8" + ) as train_fh, open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh: header = data_fh.readline() train_fh.write(header) dev_fh.write(header) for row in data_fh: - label, id1, id2, s1, s2 = row.strip().split('\t') + label, id1, id2, s1, s2 = row.strip().split("\t") if [id1, id2] in dev_ids: dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) else: train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) - with open(mrpc_test_file, encoding="utf8") as data_fh, \ - open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: + with open(mrpc_test_file, encoding="utf8") as data_fh, open( + os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8" + ) as test_fh: header = data_fh.readline() test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") for idx, row in enumerate(data_fh): - label, id1, id2, s1, s2 = row.strip().split('\t') + label, id1, id2, s1, s2 = row.strip().split("\t") test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) print("\tCompleted!") + def download_diagnostic(data_dir): - print("Downloading and extracting diagnostic...") - if not os.path.isdir(os.path.join(data_dir, "diagnostic")): - os.mkdir(os.path.join(data_dir, "diagnostic")) - data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") - urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) + print("Downloading and extracting diagnostic data...") + if not os.path.isdir(os.path.join(data_dir, "MNLI")): + os.mkdir(os.path.join(data_dir, "MNLI")) + data_file = os.path.join(data_dir, "MNLI", "diagnostic.tsv") + urllib.request.urlretrieve(TASK2PATH["diagnostic"][0], data_file) + data_file = os.path.join(data_dir, "MNLI", "diagnostic-full.tsv") + urllib.request.urlretrieve(TASK2PATH["diagnostic"][1], data_file) print("\tCompleted!") return + def get_tasks(task_names): - task_names = task_names.split(',') + task_names = task_names.split(",") if "all" in task_names: tasks = TASKS else: @@ -113,15 +131,30 @@ def get_tasks(task_names): for task_name in task_names: assert task_name in TASKS, "Task %s not found!" % task_name tasks.append(task_name) + if "MNLI" in tasks and "diagnostic" not in tasks: + tasks.append("diagnostic") + return tasks + def main(arguments): parser = argparse.ArgumentParser() - parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') - parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', - type=str, default='all') - parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', - type=str, default='') + parser.add_argument( + "--data_dir", help="directory to save data to", type=str, default="glue_data" + ) + parser.add_argument( + "--tasks", + help="tasks to download data for as a comma separated string", + type=str, + default="all", + ) + parser.add_argument( + "--path_to_mrpc", + help="path to directory containing extracted MRPC data, msr_paraphrase_train.txt and " + "msr_paraphrase_text.txt", + type=str, + default="", + ) args = parser.parse_args(arguments) if not os.path.isdir(args.data_dir): @@ -129,13 +162,13 @@ def main(arguments): tasks = get_tasks(args.tasks) for task in tasks: - if task == 'MRPC': + if task == "MRPC": format_mrpc(args.data_dir, args.path_to_mrpc) - elif task == 'diagnostic': + elif task == "diagnostic": download_diagnostic(args.data_dir) else: download_and_extract(task, args.data_dir) -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) \ No newline at end of file +if __name__ == "__main__": + sys.exit(main(sys.argv[1:]))