diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 53166bbe7..000000000 --- a/.coveragerc +++ /dev/null @@ -1,8 +0,0 @@ -[run] -branch = True -omit = streaming/text/convert/enwiki/mds/*,streaming/text/convert/enwiki/tfrecord/* - -[report] -show_missing = True -precision = 2 -exclude_lines = raise NotImplementedError.* diff --git a/setup.py b/setup.py index 816bcd763..f019734e2 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'psutil>=5.8.0,<6', ] extra_deps = {} diff --git a/streaming/base/util.py b/streaming/base/util.py index 3be5b729a..d6f29fc11 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -13,11 +13,14 @@ import tempfile import urllib.parse from collections import OrderedDict +from multiprocessing import Pool from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from pathlib import Path from time import sleep, time from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload +import numpy as np +import psutil import torch.distributed as dist from streaming.base.constant import SHM_TO_CLEAN @@ -217,7 +220,7 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: def merge_index(*args: Any, **kwargs: Any): - r"""Merge index.json from partitions to form a global index.json. + r"""Merge index.json from streams to form a global index.json. This can be called as @@ -225,18 +228,18 @@ def merge_index(*args: Any, **kwargs: Any): merge_index(out, keep_local, download_timeout) - The first signature takes in a list of index files URLs of MDS partitions. - The second takes the root of a MDS dataset and parse the partition folders from there. + The first signature takes in a list of index files URLs of MDS streams. + The second takes the root of a MDS dataset and parse the streams folders from there. Args: - index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. + index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the streams. Each element can take the form of a single path string or a tuple string. 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file + out (Union[str, Tuple[str,str]]): folder that contain MDS streams and to put the merged index file 1. A local directory, merge index happens locally. 2. A remote directory, download all the sub-directories index.json, merge locally and upload. @@ -253,14 +256,59 @@ def merge_index(*args: Any, **kwargs: Any): raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') +def _download_url(url_info: Tuple[str, str, int]): + """Download a file given URL information.""" + from streaming.base.storage.download import download_file + src, dst, download_timeout = url_info + try: + download_file(src, dst, download_timeout) + except Exception as ex: + return f'Failed to download index.json: {src} to {dst}: {str(ex)}', ex + return dst, None + + +def _merge_stream_indices(stream_indices: List[str]): + """Function to be executed by each process to merge a subset of stream indices.""" + shards = [] + for stream_index in stream_indices: + p = Path(stream_index) + with open(stream_index, 'r') as f: + obj = json.load(f) + for shard in obj['shards']: + for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): + if shard.get(key): + basename = shard[key]['basename'] + shard[key]['basename'] = os.path.join(os.path.basename(p.parent), basename) + shards.extend(obj['shards']) + return shards + + +def _parallel_merge_streams(streams: List[str], n_processes: int = 1): + """Divide the list of streams among multiple processes and merge their shards in parallel.""" + with Pool(processes=n_processes) as pool: + # Split the list of streams into N chunks where N is the number of processes + chunk_size = int(np.ceil(len(streams) / n_processes)) + stream_chunks = [streams[i:i + chunk_size] for i in range(0, len(streams), chunk_size)] + + # Process each chunk in parallel + results = pool.map(_merge_stream_indices, stream_chunks) + pool.close() + pool.join() + + # Combine the results from all processes + final_shards = [shard for result in results for shard in result] + return final_shards + + def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]], out: Union[str, Tuple[str, str]], keep_local: bool = True, - download_timeout: int = 60) -> None: + download_timeout: int = 60, + n_processes: int = 1) -> None: """Merge index.json from a list of index files of MDS directories to create joined index. Args: - index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions + index_file_urls (Union[str, Tuple[str,str]]): index.json from all the streams each element can take the form of a single path string or a tuple string. The pattern of index_file_urls and corresponding reaction is one of: @@ -272,8 +320,8 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] out (Union[str, Tuple[str, str]]): path to put the merged index file keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + n_processes (int): The number of cores to run the function in parallel """ - from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader if not index_file_urls or not out: @@ -295,12 +343,17 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] else: urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) + cpu_count = max(psutil.cpu_count() - 2, 1) + n_processes = n_processes if (1 <= n_processes <= cpu_count) else 1 + + logger.warning(f'Using n_processes = {n_processes} to download and merge index in parallel') + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: - logging.warning(f'A temporary folder {temp_root} is created to store index files') + logging.info(f'Created temporary folder {temp_root} to store index files') # Copy files to a temporary directory. Download if necessary - partitions = [] + download_tasks = [] for url in urls: if isinstance(url, tuple): src = url[0] if os.path.exists(url[0]) else url[1] @@ -313,31 +366,21 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] raise FileNotFoundError( f'Check data availability! local index {url[0]} is not accessible.' + f'remote index {url[1]} does not have a valid url format') - dest = os.path.join(temp_root, path.lstrip('/')) + dst = os.path.join(temp_root, path.lstrip('/')) + download_tasks.append((src, dst, download_timeout)) - try: - download_file(src, dest, download_timeout) - except Exception as ex: - raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex - - if not os.path.exists(dest): - raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') - - partitions.append(dest) - - # merge shards from all index files - shards = [] - for partition_index in partitions: - p = Path(partition_index) - obj = json.load(open(partition_index)) - for i in range(len(obj['shards'])): - shard = obj['shards'][i] - for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): - if shard.get(key): - basename = shard[key]['basename'] - obj['shards'][i][key]['basename'] = os.path.join( - os.path.basename(p.parent), basename) - shards += obj['shards'] + with Pool(processes=n_processes) as pool: + results = pool.map(_download_url, download_tasks) + pool.close() + pool.join() + + streams = [] + for stream_index, error in results: + if error: + raise RuntimeError(stream_index) + streams.append(stream_index) + + shards = _parallel_merge_streams(streams, n_processes) # Save merged index locally obj = { @@ -403,7 +446,7 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], """Merge index.json given the root of MDS dataset. Write merged index to the root folder. Args: - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. + out (Union[str, Tuple[str,str]]): folder that contain MDS shards. :A local directory, merge index happens locally :A remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location @@ -424,28 +467,38 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], local_index_files = [] cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) + + logger.warning( + f'We will be listing objects from {out}, which may take a long time if the number of stream folders is large. Consider provide the list of path/to/index.json directly.' + ) + for file in cl.list_objects(): if file.endswith('.json') and _not_merged_index(file, cu.local): local_index_files.append(file) + cpu_count = max(psutil.cpu_count() - 2, 1) + if cu.remote: remote_index_files = _format_remote_index_files(cu.remote, cu.list_objects()) if len(local_index_files) == len(remote_index_files): _merge_index_from_list(list(zip(local_index_files, remote_index_files)), out, keep_local=keep_local, - download_timeout=download_timeout) + download_timeout=download_timeout, + n_processes=cpu_count) else: _merge_index_from_list(remote_index_files, out, keep_local=keep_local, - download_timeout=download_timeout) + download_timeout=download_timeout, + n_processes=cpu_count) return _merge_index_from_list(local_index_files, out, keep_local=keep_local, - download_timeout=download_timeout) + download_timeout=download_timeout, + n_processes=cpu_count) @overload diff --git a/tests/test_util.py b/tests/test_util.py index 5aa8cabd7..dee2614be 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -7,8 +7,10 @@ import time import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union +from unittest.mock import MagicMock +import numpy as np import pytest from streaming.base.constant import RESUME @@ -194,11 +196,66 @@ def test_format_remote_index_files(scheme: str): assert obj.scheme == scheme +@pytest.mark.parametrize('cpu_count', [0, 1, 4]) +def test_merge_index_from_list_local_cpucount(local_remote_dir: Tuple[str, str], cpu_count: int): + """Validate the multiprocessing setting""" + from pyspark.sql import SparkSession + + from streaming.base.converters import dataframeToMDS + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out.""" + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + keep_local = True + + local, _ = local_remote_dir + + mds_out = out = local + + os.cpu_count = MagicMock() + os.cpu_count.return_value = cpu_count + + spark = SparkSession.builder.getOrCreate() # pyright: ignore + + def random_string(length: int = 1000): + """Generate a random string of fixed length.""" + return ''.join(map(chr, np.random.choice(0x10FFFF - 1, length))) + + # Generate a DataFrame with 10000 rows of random text + num_rows = 100 + data = [(i, random_string(), random_string()) for i in range(num_rows)] + df = spark.createDataFrame(data, ['id', 'name', 'amount']) + + mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True} + dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + + local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) + local_index_files = [ + o for o in local_cu.list_objects() if o.endswith('.json') and not_merged_index(o, local) + ] + + merge_index(local_index_files, out, keep_local=keep_local) + + d1 = json.load(open(os.path.join(out, 'index.json'))) + + _merge_index_from_list_serial(local_index_files, out, keep_local=keep_local) + d2 = json.load(open(os.path.join(out, 'index.json'))) + + print('d1 = ', d1) + print('d2 = ', d2) + + assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different' + assert d1['shards'] == d2['shards'], 'parallel and serial results different' + + +@pytest.mark.parametrize('cpu_count', [1, 4]) @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) @pytest.mark.parametrize('keep_local', [True, False]) @pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://']) def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, - index_file_urls_pattern: int, scheme: str): + index_file_urls_pattern: int, scheme: str, cpu_count: int): """Validate the final merge index json for following patterns of index_file_urls: 1. All URLs are str (local). All URLs are accessible locally -> no download 2. All URLs are str (local). At least one url is unaccessible locally -> Error @@ -206,10 +263,10 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc 4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all 5. All URLs are str (remote) -> download all """ - from decimal import Decimal + import random + import string from pyspark.sql import SparkSession - from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType from streaming.base.converters import dataframeToMDS @@ -223,15 +280,18 @@ def not_merged_index(index_file_path: str, out: str): mds_out = out = local spark = SparkSession.builder.getOrCreate() # pyright: ignore - schema = StructType([ - StructField('id', IntegerType(), nullable=False), - StructField('name', StringType(), nullable=False), - StructField('amount', DecimalType(10, 2), nullable=False) - ]) - data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), - (3, 'Charlie', Decimal('987.65'))] - df = spark.createDataFrame(data=data, schema=schema).repartition(3) - mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True} + + def random_string(length: int = 1000): + """Generate a random string of fixed length.""" + letters = string.ascii_letters + string.digits + string.punctuation + ' ' + return ''.join(random.choice(letters) for _ in range(length)) + + # Generate a DataFrame with 10000 rows of random text + num_rows = 100 + data = [(i, random_string(), random_string()) for i in range(num_rows)] + df = spark.createDataFrame(data, ['id', 'name', 'amount']) + + mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True} dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) @@ -242,6 +302,18 @@ def not_merged_index(index_file_path: str, out: str): if index_file_urls_pattern == 1: merge_index(local_index_files, out, keep_local=keep_local) + if keep_local: + d1 = json.load(open(os.path.join(out, 'index.json'))) + + _merge_index_from_list_serial(local_index_files, out, keep_local=keep_local) + d2 = json.load(open(os.path.join(out, 'index.json'))) + + print('d1 = ', d1) + print('d2 = ', d2) + + assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different' + assert d1['shards'] == d2['shards'], 'parallel and serial results different' + if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: index_file_urls = [ @@ -323,3 +395,98 @@ def flaky_function(): return "Third time's a charm" assert flaky_function() == "Third time's a charm" + + +def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str, str]]], + out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: + import logging + import shutil + import urllib.parse + from collections import OrderedDict + from pathlib import Path + + from streaming.base.format.index import get_index_basename + from streaming.base.storage.download import download_file + from streaming.base.storage.upload import CloudUploader + + if not index_file_urls or not out: + return + + # This is the index json file name, e.g., it is index.json as of 0.6.0 + index_basename = get_index_basename() + + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + + # Remove duplicates, and strip '/' from right if any + index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) + urls = [] + for url in index_file_urls: + if isinstance(url, str): + urls.append(url.rstrip('/').strip()) + else: + urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) + + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. + with tempfile.TemporaryDirectory() as temp_root: + logging.warning(f'A temporary folder {temp_root} is created to store index files') + + # Copy files to a temporary directory. Download if necessary + partitions = [] + for url in urls: + if isinstance(url, tuple): + src = url[0] if os.path.exists(url[0]) else url[1] + else: + src = url + + obj = urllib.parse.urlparse(src) + scheme, bucket, path = obj.scheme, obj.netloc, obj.path + if scheme == '' and bucket == '' and path == '': + raise FileNotFoundError( + f'Check data availability! local index {url[0]} is not accessible.' + + f'remote index {url[1]} does not have a valid url format') + dest = os.path.join(temp_root, path.lstrip('/')) + + try: + download_file(src, dest, download_timeout) + except Exception as ex: + raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex + + if not os.path.exists(dest): + raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') + + partitions.append(dest) + + # merge shards from all index files + shards = [] + for partition_index in partitions: + p = Path(partition_index) + obj = json.load(open(partition_index)) + for i in range(len(obj['shards'])): + shard = obj['shards'][i] + for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): + if shard.get(key): + basename = shard[key]['basename'] + obj['shards'][i][key]['basename'] = os.path.join( + os.path.basename(p.parent), basename) + shards += obj['shards'] + + # Save merged index locally + obj = { + 'version': 2, + 'shards': shards, + } + merged_index_path = os.path.join(temp_root, index_basename) + with open(merged_index_path, 'w') as outfile: + json.dump(obj, outfile) + + # Move merged index from temp path to local part in out + # Upload merged index to remote if out has remote part + shutil.move(merged_index_path, os.path.join(cu.local, index_basename)) + if cu.remote is not None: + cu.upload_file(index_basename) + + # Clean up + if not keep_local: + shutil.rmtree(cu.local, ignore_errors=True)