From 02afcf1483cbb7bd3f66b799a4cebdeac1052779 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 22 May 2024 23:17:43 -0700 Subject: [PATCH 001/145] update --- setup.py | 1 + streaming/base/dataset.py | 14 ++- streaming/base/stream.py | 179 +++++++++++++++++++++++++++++++++ tests/test_streaming_remote.py | 90 ++++++----------- 4 files changed, 222 insertions(+), 62 deletions(-) diff --git a/setup.py b/setup.py index 2d7419f09..6e7646851 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'databricks-connect>=14.3.0', ] extra_deps = {} diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index e43e212e8..cca0245ac 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -34,7 +34,7 @@ from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, get_shm_prefix) from streaming.base.spanner import Spanner -from streaming.base.stream import Stream +from streaming.base.stream import Stream, DeltaStream from streaming.base.util import bytes_to_int, number_abbrev_to_int from streaming.base.world import World @@ -443,6 +443,15 @@ def __init__(self, } for stream in streams: stream.apply_default(default) + elif remote is not None and remote.startswith('SELECT'): + default = DeltaStream(remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + streams = [default] else: default = Stream(remote=remote, local=local, @@ -507,6 +516,8 @@ def __init__(self, # Build the shard index (for partitioning and mapping samples to shards). self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) + print('I am here 5.1, samples_per_shard = ') + print(self.samples_per_shard) self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard self.spanner = Spanner(self.samples_per_shard) @@ -1225,6 +1236,7 @@ def get_item(self, sample_id: int, retry: int = 7) -> Any: raise RuntimeError('Background thread failed. Check other traceback.') # Locate the shard and sample offset within that shard where the sample lives. shard_id, shard_sample_id = self.spanner[sample_id] + #print('I am here 5.2', shard_id, shard_sample_id) shard = self.shards[shard_id] sample = None diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 7287a5c71..56d792403 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -22,6 +22,10 @@ from streaming.base.util import retry, wait_for_file_to_exist from streaming.base.world import World +import pyarrow as pa +import requests +from tempfile import TemporaryDirectory + class Stream: """A dataset, or sub-dataset if mixing, from which we stream/cache samples. @@ -505,3 +509,178 @@ def get_index_size(self) -> int: """ filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size + + +class DeltaStream(Stream): + + def __init__(self, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None) -> None: + super().__init__(remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + + self.url_to_basename= {} + self.basename_to_url={} + + def generate_unique_basename(self, url: str, index: int) -> str: + """Generate a unique basename for the file path from the URL.""" + hash_object = hashlib.md5(url.encode()) + hex_dig = hash_object.hexdigest() + # basename = f"{hex_dig[:3]}/shard.{int(hex_dig, 16) % 100000:05d}.mds" + basename = '.'.join(['shard', f'{index:05}', 'mds']) + self.url_to_basename[url] = basename + self.basename_to_url[basename] = url + return basename + + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + """Load this Stream's index, retrieving its shard readers. + + Args: + world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. + + Returns: + `List[Reader]: Shard readers. + """ + # Prepare cloudfetch + from databricks.connect import DatabricksSession + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + cluster_id = "0201-234512-tcp9nfat" + + print('I am here 1') + sparkSession = DatabricksSession.builder.remote( + host=w.config.host, + token=w.config.token, + cluster_id=cluster_id).getOrCreate() + + print('I am here 2') + df = sparkSession.sql(self.remote) + print('I am here 2.0') + query = df._plan.to_proto(df._session.client) # pyright: ignore + print('I am here 2.1') + schema, cloudfetch_results = df._session.client.experimental_to_cloudfetch(query, "arrow", compression=False) # pyright: ignore + + # Local leader prepares the index file based on cloudfetch results + print('I am here 3') + basename = get_index_basename() + filename = os.path.join(self.local, self.split, basename) + + print('schema = ', schema) + self.columns = {'text': 'str'} + + print('I am here 4', len(cloudfetch_results)) + if world.is_local_leader: + + metadata = { + "version": 2, + "shards": [] + } + + for index, result in enumerate(cloudfetch_results): + shard = { + "column_encodings": ["str"], + "column_names": ["tokenized_example"], + "column_sizes": [None], + "compression": None, + "format": "mds", + "hashes": ["sha1"], + "raw_data": { + "basename": self.generate_unique_basename(result.url, index), + "bytes": result.uncompressed_size, + "hashes": {} + }, + "samples": result.row_count, + "size_limit": 67108864, + "version": 2, + "zip_data": None + } + metadata["shards"].append(shard) + + print('metadata = ') + print(metadata) + + with open(filename, 'w') as f: + json.dump(metadata, f, indent=4) + + else: + wait_for_file_to_exist( + filename, TICK, self.download_timeout, + f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + + f'-> {filename} took too long to download. Either increase the ' + + f'`download_timeout` value or check the other traceback.') + + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + + # Initialize shard readers according to the loaded info. + shards = [] + for info in obj['shards']: + shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) + shards.append(shard) + + print('I am here 4.1, shard.samples = ', shard.samples) + + return shards + + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + """Safely download a file from remote to local cache. + + Args: + from_basename (str): Source basename. + to_basename (str, optional): Destination basename, if different. + + Returns: + str: Local cache filename. + """ + def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + from streaming import MDSWriter + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + + print('from_basename = ', from_basename) + cloud_fetch_url = self.basename_to_url[from_basename] + local = os.path.join(self.local, self.split, from_basename) + + # Attempt to download, possibly repeating on failure. + retry(num_attempts=self.download_retry)( + lambda: fetch_and_convert(cloud_fetch_url, local))() + + return local + + diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 1c1f7e10c..a75fe226e 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -20,69 +20,39 @@ def get_dataset(name: str, other_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[int, StreamingDataset]: other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { - 'ade20k': { - 'remote': 's3://mosaicml-internal-dataset-ade20k/mds/2/', + 'cpt': { + 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/CPT/mds_data_11Jan24_3/', 'num_samples': { 'train': 20206, - 'val': 2000, + 'val': 0, }, - 'class': StreamingADE20K, - 'kwargs': {}, - }, - 'imagenet1k': { - 'remote': 's3://mosaicml-internal-dataset-imagenet1k/mds/2/', - 'num_samples': { - 'train': 1281167, - 'val': 50000, - }, - 'class': StreamingImageNet, - 'kwargs': {}, - }, - 'coco': { - 'remote': 's3://mosaicml-internal-dataset-coco/mds/2/', - 'num_samples': { - 'train': 117266, - 'val': 4952, - }, - 'class': StreamingCOCO, + 'class': StreamingDataset, 'kwargs': {}, }, - 'c4': { - 'remote': 's3://mosaicml-internal-dataset-c4/mds/2/', + 'dummy_table': { + 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', 'num_samples': { - 'train': 364868892, - 'val': 364608, - }, - 'class': StreamingC4, - 'kwargs': { - 'tokenizer_name': 'bert-base-uncased', - 'max_seq_len': 512, - 'group_method': 'truncate' - }, - }, - 'cifar10': { - 'remote': 's3://mosaicml-internal-dataset-cifar10/mds/2/', - 'num_samples': { - 'train': 50000, - 'val': 10000, + 'train': 20206, + 'val': 0, }, - 'class': StreamingCIFAR10, + 'class': StreamingDataset, 'kwargs': {}, }, - 'test_streaming_upload': { - 'remote': 's3://streaming-upload-test-bucket/', + 'random_cpt_table': { + 'remote': 'SELECT text FROM main.streaming.random_cpt_table', 'num_samples': { - 'all': 0, + 'train': 20206, + 'val': 0, }, 'class': StreamingDataset, 'kwargs': {}, - } + }, } - if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: - raise ValueError('Could not load dataset with name={name} and split={split}') + #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: + # raise ValueError('Could not load dataset with name={name} and split={split}') d = dataset_map[name] - expected_samples = d['num_samples'][split] + expected_samples = 1 # d['num_samples'][split] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} dataset = d['class'](local=local, @@ -94,23 +64,14 @@ def get_dataset(name: str, return (expected_samples, dataset) -@pytest.mark.remote -@pytest.mark.parametrize('name', [ - 'ade20k', - 'imagenet1k', - 'coco', - 'cifar10', - 'c4', -]) -@pytest.mark.parametrize('split', ['val']) -def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) -> None: +def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() expected_samples, dataset = get_dataset(name=name, - local=str(tmp_path), + local=f'/tmp/test_delta_05May1029', split=split, shuffle=False, - batch_size=None) + batch_size=16) build_end = time.time() build_dur = build_end - build_start print('Built dataset') @@ -121,7 +82,7 @@ def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) for _ in dataset: rcvd_samples += 1 - if (rcvd_samples % 1000 == 0): + if (rcvd_samples % 100 == 0): print(f'samples read: {rcvd_samples}') iter_end = time.time() @@ -129,8 +90,15 @@ def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) samples_per_sec = rcvd_samples / iter_dur # Print debug info + print(f'received {rcvd_samples} samples') print(f'build_dur={build_dur:.2f}s, iter_dur={iter_dur:.2f}, ' + f'samples_per_sec={samples_per_sec:.2f}') # Test all samples arrived - assert rcvd_samples == expected_samples + assert rcvd_samples >= expected_samples + + +if __name__ == "__main__": +# test_streaming_remote_dataset(name = 'cpt', split=None) + # test_streaming_remote_dataset(name = 'dummy_table', split=None) + test_streaming_remote_dataset(name = 'random_cpt_table', split=None) From b0436f9ef0971b77ad95dc0e554d7389a8401975 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 21:26:52 +0000 Subject: [PATCH 002/145] update --- tests/test_streaming_remote.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index a75fe226e..2f2f02afd 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -20,8 +20,9 @@ def get_dataset(name: str, other_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[int, StreamingDataset]: other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { - 'cpt': { - 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/CPT/mds_data_11Jan24_3/', + 'refinedweb': { + 'local': f'/tmp/test_refinedweb_05May1029', + 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', 'num_samples': { 'train': 20206, 'val': 0, @@ -30,6 +31,7 @@ def get_dataset(name: str, 'kwargs': {}, }, 'dummy_table': { + 'local': f'/tmp/test_dummy_table_05May1029', 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', 'num_samples': { 'train': 20206, @@ -39,6 +41,7 @@ def get_dataset(name: str, 'kwargs': {}, }, 'random_cpt_table': { + 'local': f'/tmp/test_random_cpt_table_05May1029', 'remote': 'SELECT text FROM main.streaming.random_cpt_table', 'num_samples': { 'train': 20206, @@ -53,6 +56,7 @@ def get_dataset(name: str, d = dataset_map[name] expected_samples = 1 # d['num_samples'][split] + local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} dataset = d['class'](local=local, @@ -99,6 +103,6 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: if __name__ == "__main__": -# test_streaming_remote_dataset(name = 'cpt', split=None) + test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) - test_streaming_remote_dataset(name = 'random_cpt_table', split=None) +# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) From 11df9673ef6eb018ba0141312ebc6fba31c08724 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 22:13:24 +0000 Subject: [PATCH 003/145] update --- streaming/base/stream.py | 17 ++++++++++++++--- tests/test_streaming_remote.py | 32 ++++++++++++++++---------------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 56d792403..0b7ff48b3 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -563,6 +563,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: # Prepare cloudfetch from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient + from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() cluster_id = "0201-234512-tcp9nfat" @@ -586,9 +587,19 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: filename = os.path.join(self.local, self.split, basename) print('schema = ', schema) - self.columns = {'text': 'str'} + self.columns = infer_dataframe_schema(df, None) + column_names = [] + column_encodings = [] + for k, v in self.columns.items(): + column_names.append(k) + column_encodings.append(v) + #self.columns = {'text': 'str'} + print('inferred columns = ', self.columns) print('I am here 4', len(cloudfetch_results)) + +# raise RuntimeError("break") + if world.is_local_leader: metadata = { @@ -598,8 +609,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: for index, result in enumerate(cloudfetch_results): shard = { - "column_encodings": ["str"], - "column_names": ["tokenized_example"], + "column_encodings": column_encodings, + "column_names": column_names, "column_sizes": [None], "compression": None, "format": "mds", diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 2f2f02afd..a94d4f677 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -23,30 +23,28 @@ def get_dataset(name: str, 'refinedweb': { 'local': f'/tmp/test_refinedweb_05May1029', 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', - 'num_samples': { - 'train': 20206, - 'val': 0, - }, + 'num_samples': 20206, 'class': StreamingDataset, 'kwargs': {}, }, 'dummy_table': { 'local': f'/tmp/test_dummy_table_05May1029', 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', - 'num_samples': { - 'train': 20206, - 'val': 0, - }, + 'num_samples': 20206, 'class': StreamingDataset, 'kwargs': {}, }, 'random_cpt_table': { 'local': f'/tmp/test_random_cpt_table_05May1029', 'remote': 'SELECT text FROM main.streaming.random_cpt_table', - 'num_samples': { - 'train': 20206, - 'val': 0, - }, + 'num_samples': 100000, + 'class': StreamingDataset, + 'kwargs': {}, + }, + 'random_large_table': { + 'local': f'/tmp/test_random_large_table_05May1029', + 'remote': 'SELECT * FROM main.streaming.random_large_table', + 'num_samples': 100000, 'class': StreamingDataset, 'kwargs': {}, }, @@ -55,7 +53,7 @@ def get_dataset(name: str, # raise ValueError('Could not load dataset with name={name} and split={split}') d = dataset_map[name] - expected_samples = 1 # d['num_samples'][split] + expected_samples = d['num_samples'] local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} @@ -72,7 +70,7 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() expected_samples, dataset = get_dataset(name=name, - local=f'/tmp/test_delta_05May1029', + local=None, # f'/tmp/test_delta_05May1029', split=split, shuffle=False, batch_size=16) @@ -103,6 +101,8 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: if __name__ == "__main__": - test_streaming_remote_dataset(name = 'refinedweb', split=None) +# test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) -# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) +# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) + test_streaming_remote_dataset(name = 'random_large_table', split=None) + From cfca07b4de687bc5cb57d7e9d992df36fe44417c Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 16:50:21 -0700 Subject: [PATCH 004/145] update --- streaming/base/format/mds/encodings.py | 3 +++ streaming/base/stream.py | 20 +++++++++++++++----- tests/test_streaming_remote.py | 23 ++++++++++++++++++++--- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 71c58be46..bc5042574 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -307,6 +307,9 @@ def encode(self, obj: Any) -> bytes: return self.dtype(obj).tobytes() def decode(self, data: bytes) -> Any: + print(f"Data length: {len(data)}, Expected dtype: {self.dtype}, Element size: {np.dtype(self.dtype).itemsize}") + if len(data) % np.dtype(self.dtype).itemsize != 0: + print(f"Error: Buffer size {len(data)} is not a multiple of element size {np.dtype(self.dtype).itemsize}") return np.frombuffer(data, self.dtype)[0] diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 0b7ff48b3..d0704025c 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -566,7 +566,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() - cluster_id = "0201-234512-tcp9nfat" + #cluster_id = "0201-234512-tcp9nfat" # e2-dogfood + cluster_id = "0523-224100-tid6mais" # db-force-one print('I am here 1') sparkSession = DatabricksSession.builder.remote( @@ -588,16 +589,20 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: print('schema = ', schema) self.columns = infer_dataframe_schema(df, None) + column_names = [] column_encodings = [] + column_sizes = [] for k, v in self.columns.items(): column_names.append(k) column_encodings.append(v) + column_sizes.append(None) + #self.columns = {'text': 'str'} print('inferred columns = ', self.columns) print('I am here 4', len(cloudfetch_results)) - + # raise RuntimeError("break") if world.is_local_leader: @@ -609,9 +614,9 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: for index, result in enumerate(cloudfetch_results): shard = { - "column_encodings": column_encodings, + "column_encodings": column_encodings, "column_names": column_names, - "column_sizes": [None], + "column_sizes": column_sizes, "compression": None, "format": "mds", "hashes": ["sha1"], @@ -673,10 +678,14 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) Returns: str: Local cache filename. """ + from streaming import MDSWriter + def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): - from streaming import MDSWriter samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + print('samples = ') + print(len(samples)) + with TemporaryDirectory() as temp_dir: with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: for sample in samples: @@ -692,6 +701,7 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): retry(num_attempts=self.download_retry)( lambda: fetch_and_convert(cloud_fetch_url, local))() + print('download to local is done = ', local) return local diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index a94d4f677..f61dde93e 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -21,7 +21,7 @@ def get_dataset(name: str, other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { 'refinedweb': { - 'local': f'/tmp/test_refinedweb_05May1029', + 'local': f'/tmp/test_refinedweb_05May1029', 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', 'num_samples': 20206, 'class': StreamingDataset, @@ -48,6 +48,20 @@ def get_dataset(name: str, 'class': StreamingDataset, 'kwargs': {}, }, + 'reddit_table': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': 'SELECT text, added FROM main.reddit.data', + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': {}, + }, + 'debug_local': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': None, + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': {}, + }, } #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: # raise ValueError('Could not load dataset with name={name} and split={split}') @@ -100,9 +114,12 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: assert rcvd_samples >= expected_samples -if __name__ == "__main__": +#if __name__ == "__main__": # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) # test_streaming_remote_dataset(name = 'random_cpt_table', split=None) - test_streaming_remote_dataset(name = 'random_large_table', split=None) +# test_streaming_remote_dataset(name = 'random_large_table', split=None) +test_streaming_remote_dataset(name = 'reddit_table', split=None) +# test_streaming_remote_dataset(name = 'debug_local', split=None) + From 2d9d38e9ba078534072908fe1aa81fb0fdc3035a Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 21:57:16 -0700 Subject: [PATCH 005/145] Make cluser id a param --- streaming/base/dataset.py | 7 ++++-- streaming/base/format/mds/encodings.py | 3 --- streaming/base/stream.py | 30 +++++--------------------- tests/test_streaming_remote.py | 25 ++++++++++++++------- 4 files changed, 27 insertions(+), 38 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index cca0245ac..3605b30a3 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -331,7 +331,8 @@ def __init__(self, shuffle_block_size: Optional[int] = None, batching_method: str = 'random', allow_unsafe_types: bool = False, - replication: Optional[int] = None) -> None: + replication: Optional[int] = None, + **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload self.cache_limit = cache_limit @@ -444,7 +445,9 @@ def __init__(self, for stream in streams: stream.apply_default(default) elif remote is not None and remote.startswith('SELECT'): - default = DeltaStream(remote=remote, + cluster_id = kwargs.get('cluster_id', None) + default = DeltaStream(cluster_id, + remote=remote, local=local, split=split, download_retry=download_retry, diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index bc5042574..71c58be46 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -307,9 +307,6 @@ def encode(self, obj: Any) -> bytes: return self.dtype(obj).tobytes() def decode(self, data: bytes) -> Any: - print(f"Data length: {len(data)}, Expected dtype: {self.dtype}, Element size: {np.dtype(self.dtype).itemsize}") - if len(data) % np.dtype(self.dtype).itemsize != 0: - print(f"Error: Buffer size {len(data)} is not a multiple of element size {np.dtype(self.dtype).itemsize}") return np.frombuffer(data, self.dtype)[0] diff --git a/streaming/base/stream.py b/streaming/base/stream.py index d0704025c..5ac56c0b9 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -514,6 +514,7 @@ def get_index_size(self) -> int: class DeltaStream(Stream): def __init__(self, + cluster_id: str, remote: Optional[str] = None, local: Optional[str] = None, split: Optional[str] = None, @@ -537,6 +538,7 @@ def __init__(self, self.url_to_basename= {} self.basename_to_url={} + self.cluster_id = cluster_id def generate_unique_basename(self, url: str, index: int) -> str: """Generate a unique basename for the file path from the URL.""" @@ -566,28 +568,22 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() - #cluster_id = "0201-234512-tcp9nfat" # e2-dogfood - cluster_id = "0523-224100-tid6mais" # db-force-one + ##cluster_id = "0201-234512-tcp9nfat" # e2-dogfood + #cluster_id = "0523-224100-tid6mais" # db-force-one - print('I am here 1') sparkSession = DatabricksSession.builder.remote( host=w.config.host, token=w.config.token, - cluster_id=cluster_id).getOrCreate() + cluster_id=self.cluster_id).getOrCreate() - print('I am here 2') df = sparkSession.sql(self.remote) - print('I am here 2.0') query = df._plan.to_proto(df._session.client) # pyright: ignore - print('I am here 2.1') schema, cloudfetch_results = df._session.client.experimental_to_cloudfetch(query, "arrow", compression=False) # pyright: ignore # Local leader prepares the index file based on cloudfetch results - print('I am here 3') basename = get_index_basename() filename = os.path.join(self.local, self.split, basename) - print('schema = ', schema) self.columns = infer_dataframe_schema(df, None) column_names = [] @@ -598,13 +594,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: column_encodings.append(v) column_sizes.append(None) - #self.columns = {'text': 'str'} - print('inferred columns = ', self.columns) - - print('I am here 4', len(cloudfetch_results)) - -# raise RuntimeError("break") - if world.is_local_leader: metadata = { @@ -632,9 +621,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: } metadata["shards"].append(shard) - print('metadata = ') - print(metadata) - with open(filename, 'w') as f: json.dump(metadata, f, indent=4) @@ -664,8 +650,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: shard.validate(allow_unsafe_types) shards.append(shard) - print('I am here 4.1, shard.samples = ', shard.samples) - return shards def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: @@ -683,9 +667,6 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() - print('samples = ') - print(len(samples)) - with TemporaryDirectory() as temp_dir: with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: for sample in samples: @@ -693,7 +674,6 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') os.rename(temp_mds_filename, local_shard_path) - print('from_basename = ', from_basename) cloud_fetch_url = self.basename_to_url[from_basename] local = os.path.join(self.local, self.split, from_basename) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index f61dde93e..5d9ca701c 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -32,35 +32,43 @@ def get_dataset(name: str, 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', 'num_samples': 20206, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0201-234512-tcp9nfat" + }, }, 'random_cpt_table': { 'local': f'/tmp/test_random_cpt_table_05May1029', 'remote': 'SELECT text FROM main.streaming.random_cpt_table', 'num_samples': 100000, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0201-234512-tcp9nfat" + }, }, 'random_large_table': { 'local': f'/tmp/test_random_large_table_05May1029', 'remote': 'SELECT * FROM main.streaming.random_large_table', 'num_samples': 100000, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0201-234512-tcp9nfat" + }, }, 'reddit_table': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': 'SELECT text, added FROM main.reddit.data', 'num_samples': 378156152, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0523-224100-tid6mais" + }, }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': None, 'num_samples': 378156152, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': {} }, } #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: @@ -71,7 +79,8 @@ def get_dataset(name: str, local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} - dataset = d['class'](local=local, + dataset = d['class'](d['cluster_id'], + local=local, remote=remote, split=split, shuffle=shuffle, @@ -117,9 +126,9 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: #if __name__ == "__main__": # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) -# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) +test_streaming_remote_dataset(name = 'random_cpt_table', split=None) # test_streaming_remote_dataset(name = 'random_large_table', split=None) -test_streaming_remote_dataset(name = 'reddit_table', split=None) +# test_streaming_remote_dataset(name = 'reddit_table', split=None) # test_streaming_remote_dataset(name = 'debug_local', split=None) From ee01255382f103ad2fe368a731235254bd7be00a Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 22:03:45 -0700 Subject: [PATCH 006/145] Remove prints --- streaming/base/dataset.py | 2 -- streaming/base/stream.py | 2 -- tests/test_streaming_remote.py | 3 +-- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 3605b30a3..02a900e64 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -519,8 +519,6 @@ def __init__(self, # Build the shard index (for partitioning and mapping samples to shards). self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) - print('I am here 5.1, samples_per_shard = ') - print(self.samples_per_shard) self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard self.spanner = Spanner(self.samples_per_shard) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 5ac56c0b9..76c6efa9b 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -568,8 +568,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() - ##cluster_id = "0201-234512-tcp9nfat" # e2-dogfood - #cluster_id = "0523-224100-tid6mais" # db-force-one sparkSession = DatabricksSession.builder.remote( host=w.config.host, diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 5d9ca701c..ee265521b 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -79,8 +79,7 @@ def get_dataset(name: str, local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} - dataset = d['class'](d['cluster_id'], - local=local, + dataset = d['class'](local=local, remote=remote, split=split, shuffle=shuffle, From 7e17587603ba57fa9e3bc55e16dd532e87d0fec6 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 22:31:50 -0700 Subject: [PATCH 007/145] Remove prints --- streaming/base/dataset.py | 1 - streaming/base/stream.py | 1 - 2 files changed, 2 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 02a900e64..695ddbe85 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1237,7 +1237,6 @@ def get_item(self, sample_id: int, retry: int = 7) -> Any: raise RuntimeError('Background thread failed. Check other traceback.') # Locate the shard and sample offset within that shard where the sample lives. shard_id, shard_sample_id = self.spanner[sample_id] - #print('I am here 5.2', shard_id, shard_sample_id) shard = self.shards[shard_id] sample = None diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 76c6efa9b..c3416de03 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -544,7 +544,6 @@ def generate_unique_basename(self, url: str, index: int) -> str: """Generate a unique basename for the file path from the URL.""" hash_object = hashlib.md5(url.encode()) hex_dig = hash_object.hexdigest() - # basename = f"{hex_dig[:3]}/shard.{int(hex_dig, 16) % 100000:05d}.mds" basename = '.'.join(['shard', f'{index:05}', 'mds']) self.url_to_basename[url] = basename self.basename_to_url[basename] = url From d8227db500ae1c4d391982b0aa6e898f66fb5161 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sun, 2 Jun 2024 00:33:14 -0700 Subject: [PATCH 008/145] update --- streaming/base/stream.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index c3416de03..0e4b91b12 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -510,6 +510,30 @@ def get_index_size(self) -> int: filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size +import json +import os + +def save_dict_to_file(directory, filename, dictionary): + """Save a dictionary to a file in the specified directory.""" + if not os.path.exists(directory): + os.makedirs(directory) + + file_path = os.path.join(directory, filename) + with open(file_path, 'w') as file: + json.dump(dictionary, file, indent=4) + print(f"Dictionary saved to {file_path}") + +def load_dict_from_file(directory, filename): + """Load a dictionary from a file in the specified directory.""" + file_path = os.path.join(directory, filename) + if not os.path.exists(file_path): + raise FileNotFoundError(f"No such file: '{file_path}'") + + with open(file_path, 'r') as file: + dictionary = json.load(file) + print(f"Dictionary loaded from {file_path}") + return dictionary + class DeltaStream(Stream): @@ -547,6 +571,7 @@ def generate_unique_basename(self, url: str, index: int) -> str: basename = '.'.join(['shard', f'{index:05}', 'mds']) self.url_to_basename[url] = basename self.basename_to_url[basename] = url + return basename def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: @@ -647,6 +672,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: shard.validate(allow_unsafe_types) shards.append(shard) + save_dict_to_file('./', 'basename_to_url.json', self.basename_to_url) + return shards def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: From f9871e65563406f58c073f947bc3b04e6b9c85ad Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:35:35 +0000 Subject: [PATCH 009/145] Bump pydantic from 2.7.1 to 2.7.2 (#692) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2d7419f09..4267dd6d7 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ 'yamllint==1.35.1', 'moto>=4.0,<6', 'fastapi==0.111.0', - 'pydantic==2.7.1', + 'pydantic==2.7.2', 'uvicorn==0.29.0', 'pytest-split==0.8.2', ] From 4f338e1cdffcda415ca3937289fd4c8b317039f9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:59:20 +0000 Subject: [PATCH 010/145] Bump uvicorn from 0.29.0 to 0.30.1 (#691) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4267dd6d7..0ea49f1b4 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ 'moto>=4.0,<6', 'fastapi==0.111.0', 'pydantic==2.7.2', - 'uvicorn==0.29.0', + 'uvicorn==0.30.1', 'pytest-split==0.8.2', ] From f5c57ebee3c337ce24472981236c8cb2b7422f30 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Tue, 4 Jun 2024 23:37:14 -0700 Subject: [PATCH 011/145] Make sure epoch_size is an int (#693) * typo * epoch size int --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 7287a5c71..1fd9fc9ff 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -288,7 +288,7 @@ def apply_weights(cls, streams: Sequence[Self], samples_per_stream: NDArray[np.i stream.repeat = repeat stream.choose = choose - return choose_per_epoch + return int(choose_per_epoch) def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: """Safely download a file from remote to local cache. From 55c9f85c793be7e1932cdfcae8421dbc6784f370 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 06:51:39 +0000 Subject: [PATCH 012/145] Bump databricks-sdk from 0.27.1 to 0.28.0 (#687) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0ea49f1b4..388e8f9ff 100644 --- a/setup.py +++ b/setup.py @@ -116,7 +116,7 @@ ] extra_deps['databricks'] = [ - 'databricks-sdk==0.27.1', + 'databricks-sdk==0.28.0', ] extra_deps['alipan'] = [ From f7c7a9a563d0aec2563fa6d014bd2c088947d6ba Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 18:13:30 +0000 Subject: [PATCH 013/145] Bump pytest from 8.2.1 to 8.2.2 (#697) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 388e8f9ff..d5c0bdba6 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ 'docformatter>=1.4', 'jupyter==1.0.0', 'pre-commit>=2.18.1,<4', - 'pytest==8.2.1', + 'pytest==8.2.2', 'pytest_codeblocks==0.17.0', 'pytest-cov>=4,<6', 'toml==0.10.2', From 8904102594eec89f42734219fb38efbd635cb8e2 Mon Sep 17 00:00:00 2001 From: "Xuan (Sean) Hu" Date: Wed, 12 Jun 2024 02:33:07 +0800 Subject: [PATCH 014/145] fix: expand user path for Writer's output directory. (#694) * fix: expand user path for Writer's output directory. * Update streaming/base/format/base/writer.py --------- Co-authored-by: Saaketh Narayan Co-authored-by: Saaketh Narayan --- streaming/base/format/base/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/format/base/writer.py b/streaming/base/format/base/writer.py index 152d61ae7..ed3407b86 100644 --- a/streaming/base/format/base/writer.py +++ b/streaming/base/format/base/writer.py @@ -124,7 +124,7 @@ def __init__(self, self.shards = [] # Remove local directory if requested prior to creating writer - local = out if isinstance(out, str) else out[0] + local = os.path.expanduser(out) if isinstance(out, str) else os.path.expanduser(out[0]) if os.path.exists(local) and kwargs.get('exist_ok', False): logger.warning( f'Directory {local} exists and is not empty; exist_ok is set to True so will remove contents.' From 1af436503ff4b9a9044f2a8da5ec94753b80c950 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 19:10:47 +0000 Subject: [PATCH 015/145] Bump pydantic from 2.7.2 to 2.7.3 (#696) Bumps [pydantic](https://github.com/pydantic/pydantic) from 2.7.2 to 2.7.3. - [Release notes](https://github.com/pydantic/pydantic/releases) - [Changelog](https://github.com/pydantic/pydantic/blob/main/HISTORY.md) - [Commits](https://github.com/pydantic/pydantic/compare/v2.7.2...v2.7.3) --- updated-dependencies: - dependency-name: pydantic dependency-type: direct:development update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Saaketh Narayan --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d5c0bdba6..f2bec0176 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ 'yamllint==1.35.1', 'moto>=4.0,<6', 'fastapi==0.111.0', - 'pydantic==2.7.2', + 'pydantic==2.7.3', 'uvicorn==0.30.1', 'pytest-split==0.8.2', ] From be97343cc6d9cf40aae7772a8389a3080d82c6ca Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Fri, 14 Jun 2024 12:28:39 -0700 Subject: [PATCH 016/145] Fix edge cases with scalar or empty numpy array encoding (#702) * typo * wo --- streaming/base/format/mds/encodings.py | 8 ++++++++ tests/test_encodings.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 71c58be46..0e7c7fed6 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -201,6 +201,11 @@ def _rightsize_shape_dtype(cls, shape: npt.NDArray[np.int64]) -> str: Returns: str: The smallest acceptable uint* dtype. """ + if len(shape) == 0: + raise ValueError( + 'Attempting to encode a scalar with NDArray encoding. Please use a scalar encoding.' + ) + if shape.min() <= 0: raise ValueError('All dimensions must be greater than zero.') x = shape.max() @@ -235,6 +240,9 @@ def encode(self, obj: npt.NDArray) -> bytes: if obj.dtype != self.dtype: raise ValueError(f'Wrong dtype: expected {self.dtype}, got {obj.dtype.name}.') + if obj.size == 0: + raise ValueError('Attempting to encode a numpy array with 0 elements.') + # Encode shape, if not given in header. if self.shape is None: ndim = len(obj.shape) diff --git a/tests/test_encodings.py b/tests/test_encodings.py index bc3aac670..aac457519 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -8,6 +8,7 @@ import numpy as np import pytest +from numpy.typing import NDArray from PIL import Image import streaming.base.format.json.encodings as jsonEnc @@ -132,6 +133,19 @@ def test_ndarray_encode_decode(self, dtype_str: str, shape: Tuple[int]): assert b3_len < b2_len < b1_len assert b3_len == np.prod(shape) * dtype().nbytes + def test_error_no_elements_ndarray(self): + encoding = 'ndarray' + with pytest.raises(ValueError, + match='Attempting to encode a numpy array with 0 elements.*'): + _ = mdsEnc.mds_encode(encoding, np.array([])) + + @pytest.mark.parametrize('array', [np.array(0.5), np.empty(()), np.array(1)]) + def test_error_scalar_ndarray(self, array: NDArray): + encoding = 'ndarray' + with pytest.raises(ValueError, + match='Attempting to encode a scalar with NDArray encoding.*'): + _ = mdsEnc.mds_encode(encoding, array) + @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_pil_encode_decode(self, mode: str): pil_enc = mdsEnc.PIL() From dc8ac7ba0fa74101bba855e6b946a301766bef59 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Fri, 14 Jun 2024 12:57:40 -0700 Subject: [PATCH 017/145] Raise IndexError in `Spanner` object instead of `ValueError` (#701) * typo * woo * woooo --------- Co-authored-by: Karan Jariwala --- streaming/base/spanner.py | 2 +- tests/test_spanner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/spanner.py b/streaming/base/spanner.py index 10cd72639..af3fffa99 100644 --- a/streaming/base/spanner.py +++ b/streaming/base/spanner.py @@ -49,7 +49,7 @@ def __getitem__(self, index: int) -> Tuple[int, int]: Tuple[int, int]: Shard and relative sample index. """ if not (0 <= index < self.num_samples): - raise ValueError(f'Invalid sample index `{index}`: 0 <= {index} < {self.num_samples}') + raise IndexError(f'Invalid sample index `{index}`: 0 <= {index} < {self.num_samples}') span = index // self.span_size for shard in self.spans[span]: diff --git a/tests/test_spanner.py b/tests/test_spanner.py index f971813f3..ad8e01a74 100644 --- a/tests/test_spanner.py +++ b/tests/test_spanner.py @@ -24,6 +24,6 @@ def test_spanner_success(): def test_spanner_invalid_index(index: int): shard_sizes = np.arange(5, 100, 5) span_size = 7 - with pytest.raises(ValueError, match='Invalid sample index.*'): + with pytest.raises(IndexError, match='Invalid sample index.*'): spanner = Spanner(shard_sizes, span_size) spanner[index] From ea4f0c3416b8dcf26e3f4363bbf2135afdf77279 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Mon, 17 Jun 2024 15:30:14 -0700 Subject: [PATCH 018/145] Fix linting issues with numpy 2 (#705) * lint * lint * isin --- streaming/base/shared/prefix.py | 2 +- streaming/base/spanner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/shared/prefix.py b/streaming/base/shared/prefix.py index 7e8936086..64585381c 100644 --- a/streaming/base/shared/prefix.py +++ b/streaming/base/shared/prefix.py @@ -118,7 +118,7 @@ def _check_and_find(streams_local: List[str], streams_remote: List[Union[str, No if any(streams_remote): # Get the indices of the local directories which matches with the current # shared memory. - matching_index = np.where(np.in1d(streams_local, their_locals))[0] + matching_index = np.where(np.isin(streams_local, their_locals))[0] if matching_index.size > 0: for idx in matching_index: # If there is a conflicting local directory for a non-None remote directory, diff --git a/streaming/base/spanner.py b/streaming/base/spanner.py index af3fffa99..18426af71 100644 --- a/streaming/base/spanner.py +++ b/streaming/base/spanner.py @@ -56,6 +56,6 @@ def __getitem__(self, index: int) -> Tuple[int, int]: shard_start = self.shard_bounds[shard] shard_stop = self.shard_bounds[shard + 1] if shard_start <= index < shard_stop: - return shard, int(index - shard_start) + return shard, int(index - shard_start.item()) raise RuntimeError('Internal error: shards were indexed incorrectly') From 2b53acb7050c4018396b9817773836871e3552e6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 22:50:22 +0000 Subject: [PATCH 019/145] Bump pydantic from 2.7.3 to 2.7.4 (#704) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f2bec0176..5cac170b8 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ 'yamllint==1.35.1', 'moto>=4.0,<6', 'fastapi==0.111.0', - 'pydantic==2.7.3', + 'pydantic==2.7.4', 'uvicorn==0.30.1', 'pytest-split==0.8.2', ] From a5b9eeaf5fce88e89332a85564e1f6d8c203f7bc Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Mon, 17 Jun 2024 17:41:41 -0700 Subject: [PATCH 020/145] Enable correct resumption from the end of an epoch (#700) * typo * potensh * tests * tests * Update streaming/base/partition/relaxed.py Co-authored-by: Mihir Patel * Update streaming/base/partition/relaxed.py Co-authored-by: Mihir Patel * ready --------- Co-authored-by: Mihir Patel --- streaming/base/dataset.py | 56 ++++++++++++++++++++--------- streaming/base/partition/orig.py | 2 +- streaming/base/partition/relaxed.py | 2 +- tests/test_partition.py | 45 +++++++++++++++++++++++ 4 files changed, 86 insertions(+), 19 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index e43e212e8..292405528 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -922,14 +922,18 @@ def resample_streams( sample_ids = np.concatenate(sample_ids).astype(np.int64) return shuffle_units, sample_ids - def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, SharedMemory]: + def _share_work( + self, + sample_ids: NDArray[np.int64], + ) -> Tuple[SharedMemory, Optional[SharedMemory]]: """Put an epoch's sample ordering into shared memory. Args: sample_ids (NDArray[np.int64]): Sample IDs. Returns: - Tuple[SharedMemory, SharedMemory]: Shared memory arrays containing shape and data. + Tuple[SharedMemory, Optional[SharedMemory]]: Shared memory arrays containing shape and + data, if present. """ ndim = 5 @@ -945,19 +949,26 @@ def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, Shar shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) shape_shm.buf[:size] = np.array(sample_ids.shape, np.int64).tobytes() - # Save the generated epoch data to shared memory. - name = _get_path(self._shm_prefix_int, EPOCH_DATA) - size = sample_ids.size * np.int64().nbytes - data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) - data_shm.buf[:size] = sample_ids.tobytes() + if sample_ids.size > 0: + # Save the generated epoch data to shared memory, but only if the sample partition is + # non-empty. Otherwise, the end of the epoch has been reached. + name = _get_path(self._shm_prefix_int, EPOCH_DATA) + size = sample_ids.size * np.int64().nbytes + data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False) + data_shm.buf[:size] = sample_ids.tobytes() - return shape_shm, data_shm + return shape_shm, data_shm - def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: + else: + + return shape_shm, None + + def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: """Get an epoch's sample ordering from shared memory. Returns: - NDArray[np.int64]: Sample IDs. + Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: Sample IDs, shared + memory array for shape, and shared memory array for data, if present. """ ndim = 5 @@ -967,13 +978,22 @@ def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: shape_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) shape = tuple(np.ndarray(5, buffer=shape_shm.buf, dtype=np.int64)) - # Attach to the generated epoch data in shared memory. - name = _get_path(self._shm_prefix_int, EPOCH_DATA) - size = int(np.prod(shape)) * np.int64().nbytes - data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) - sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64) + num_elements = int(np.prod(shape)) + + if num_elements > 0: + # Attach to the generated epoch data in shared memory, but only if the sample partition + # is non-empty. Otherwise, the end of the epoch has been reached. + name = _get_path(self._shm_prefix_int, EPOCH_DATA) + size = num_elements * np.int64().nbytes + data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False) + sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64) + + return sample_ids, shape_shm, data_shm + + else: - return sample_ids, shape_shm, data_shm + sample_ids = np.empty(shape=shape, dtype=np.int64) + return sample_ids, shape_shm, None def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: """Get this worker's partition of this epoch's sample space. @@ -1025,7 +1045,9 @@ def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: # Now clean up after ourselves. shape_shm.cleanup() - data_shm.cleanup() + # Can be None if the sample partition was empty. + if data_shm is not None: + data_shm.cleanup() return worker_sample_ids diff --git a/streaming/base/partition/orig.py b/streaming/base/partition/orig.py index dff6d7878..ce8832cf5 100644 --- a/streaming/base/partition/orig.py +++ b/streaming/base/partition/orig.py @@ -46,7 +46,7 @@ def get_partitions_orig(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples <= drop_first: + if num_samples < drop_first: raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + f'({num_samples})') diff --git a/streaming/base/partition/relaxed.py b/streaming/base/partition/relaxed.py index e84bb7efc..1812b977a 100644 --- a/streaming/base/partition/relaxed.py +++ b/streaming/base/partition/relaxed.py @@ -49,7 +49,7 @@ def get_partitions_relaxed(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples <= drop_first: + if num_samples < drop_first: raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + f'({num_samples})') diff --git a/tests/test_partition.py b/tests/test_partition.py index 68d4ba8e1..aa26a63d1 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -38,6 +38,51 @@ def test_partition_walk(partition_algo: str): assert x.shape == (22, 8, 8, 1, 10) +@pytest.mark.parametrize('num_samples', [400, 1000]) +@pytest.mark.parametrize('num_canonical_nodes', [1, 4]) +@pytest.mark.parametrize('num_physical_nodes', [1, 4]) +@pytest.mark.parametrize('ranks_per_node', [1, 8]) +@pytest.mark.parametrize('workers_per_rank', [1, 8]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) +def test_partition_drop_all(num_samples: int, num_canonical_nodes: int, num_physical_nodes: int, + ranks_per_node: int, workers_per_rank: int, batch_size: int, + partition_algo: str): + initial_physical_nodes = None + if partition_algo == 'relaxed' and num_canonical_nodes == 4 and ranks_per_node == 8: + num_canonical_nodes = 3 + initial_physical_nodes = 3 + batch_size = batch_size * 3 + num_samples = 3 * num_samples + + drop_first = num_samples + + x = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, + ranks_per_node, workers_per_rank, batch_size, drop_first, + initial_physical_nodes) + # Partition should still have the appropriate shape, but without any samples in it. + assert x.shape == (num_physical_nodes, ranks_per_node, workers_per_rank, 0, batch_size) + assert x.size == 0 + + +@pytest.mark.parametrize('num_samples', [400, 1000]) +@pytest.mark.parametrize('drop_additional', [1, 400]) +@pytest.mark.parametrize('num_canonical_nodes', [4]) +@pytest.mark.parametrize('num_physical_nodes', [4]) +@pytest.mark.parametrize('ranks_per_node', [8]) +@pytest.mark.parametrize('workers_per_rank', [8]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) +def test_partition_invalid_drop_first(num_samples: int, drop_additional: int, + num_canonical_nodes: int, num_physical_nodes: int, + ranks_per_node: int, workers_per_rank: int, batch_size: int, + partition_algo: str): + drop_first = num_samples + drop_additional + with pytest.raises(ValueError, match=f'Resuming further into the dataset*'): + _ = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, + ranks_per_node, workers_per_rank, batch_size, drop_first) + + @pytest.mark.parametrize('num_samples', [1, 4]) @pytest.mark.parametrize('num_canonical_nodes', [1, 4]) @pytest.mark.parametrize('num_physical_nodes', [1, 4]) From 27d61d8d889482d537809e54a7b5382ad0fbae88 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Tue, 18 Jun 2024 12:02:26 -0700 Subject: [PATCH 021/145] Fix `drop_first` checking in partitioning to account for `world_size` divisibility (#706) * typo * potensh * tests * tests * Update streaming/base/partition/relaxed.py Co-authored-by: Mihir Patel * Update streaming/base/partition/relaxed.py Co-authored-by: Mihir Patel * ready * epoch_size_checks * epoch_size_checks * Update streaming/base/partition/__init__.py Co-authored-by: Mihir Patel --------- Co-authored-by: Mihir Patel --- streaming/base/partition/__init__.py | 14 ++++++++++++ streaming/base/partition/orig.py | 9 ++------ streaming/base/partition/relaxed.py | 4 ---- tests/test_partition.py | 33 +++++++++++++++++++++------- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/streaming/base/partition/__init__.py b/streaming/base/partition/__init__.py index 28e908cb1..65271d8e2 100644 --- a/streaming/base/partition/__init__.py +++ b/streaming/base/partition/__init__.py @@ -3,6 +3,7 @@ """Apportion shards/samples to nodes/ranks/workers for elastically deterministic sample order.""" +import logging from typing import Optional import numpy as np @@ -11,6 +12,8 @@ from streaming.base.partition.orig import get_partitions_orig from streaming.base.partition.relaxed import get_partitions_relaxed +logger = logging.getLogger(__name__) + algos = { 'orig': get_partitions_orig, 'relaxed': get_partitions_relaxed, @@ -51,6 +54,17 @@ def get_partitions(algo: str, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ + world_size = ranks_per_node * num_physical_nodes + num_repeated_samples = world_size - (num_samples % world_size) + if num_samples + num_repeated_samples < drop_first: + raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + + f'({num_samples})') + + if num_repeated_samples > 0: + logger.debug(f'Using {num_repeated_samples} repeated samples to ensure that the epoch ' + + f'size is divisible by the number of total devices. This ensures that each ' + + f'device contributes the same number of samples per global batch. ') + get = algos[algo] return get(num_samples, num_canonical_nodes, num_physical_nodes, ranks_per_node, workers_per_rank, batch_size, drop_first, initial_physical_nodes) diff --git a/streaming/base/partition/orig.py b/streaming/base/partition/orig.py index ce8832cf5..cda16ac1d 100644 --- a/streaming/base/partition/orig.py +++ b/streaming/base/partition/orig.py @@ -46,10 +46,6 @@ def get_partitions_orig(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples < drop_first: - raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + - f'({num_samples})') - if num_canonical_nodes < num_physical_nodes: if num_physical_nodes % num_canonical_nodes: raise ValueError('Either canonical or physical nodes must be evenly divisible by ' + @@ -81,7 +77,7 @@ def get_partitions_orig(num_samples: int, # For samples to be properly split across canonical nodes, there must be more samples than nodes. # The edge case is when the number of samples is equal to the number of canonical nodes, but this only works when - # there is an equal or greater number of canonical nodes than physical nodes. + # there is an equal or greater number of canonical nodes than physical nodes. # If these conditions are not met, an alternative sampling approach is used that leads to many repeats. if num_samples > num_canonical_nodes or (num_samples == num_canonical_nodes and num_canonical_nodes >= num_physical_nodes): @@ -141,8 +137,7 @@ def get_partitions_orig(num_samples: int, ids = ids.reshape(-1, num_physical_nodes) ids = ids.transpose() - # Interleave the node sample ranges over each node's ranks, padding by repeating the last - # sample. + # Interleave the node sample ranges over each node's ranks, padding with -1 for reshaping. # # ids: (physical nodes, samples per rank, ranks per node). overflow = ids.shape[1] % ranks_per_node diff --git a/streaming/base/partition/relaxed.py b/streaming/base/partition/relaxed.py index 1812b977a..6baa0a48c 100644 --- a/streaming/base/partition/relaxed.py +++ b/streaming/base/partition/relaxed.py @@ -49,10 +49,6 @@ def get_partitions_relaxed(num_samples: int, NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size). """ - if num_samples < drop_first: - raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' + - f'({num_samples})') - if initial_physical_nodes is None or (num_physical_nodes <= num_canonical_nodes and num_canonical_nodes % num_physical_nodes == 0) or \ (num_physical_nodes > num_canonical_nodes and diff --git a/tests/test_partition.py b/tests/test_partition.py index aa26a63d1..42cfaa1f6 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -38,16 +38,22 @@ def test_partition_walk(partition_algo: str): assert x.shape == (22, 8, 8, 1, 10) -@pytest.mark.parametrize('num_samples', [400, 1000]) -@pytest.mark.parametrize('num_canonical_nodes', [1, 4]) -@pytest.mark.parametrize('num_physical_nodes', [1, 4]) +@pytest.mark.parametrize('num_samples', [405, 812, 1111]) +@pytest.mark.parametrize('num_canonical_nodes', [1, 2]) +@pytest.mark.parametrize('num_physical_nodes', [2, 8]) @pytest.mark.parametrize('ranks_per_node', [1, 8]) @pytest.mark.parametrize('workers_per_rank', [1, 8]) @pytest.mark.parametrize('batch_size', [4]) @pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) -def test_partition_drop_all(num_samples: int, num_canonical_nodes: int, num_physical_nodes: int, - ranks_per_node: int, workers_per_rank: int, batch_size: int, - partition_algo: str): +def test_partition_drop_all( + num_samples: int, + num_canonical_nodes: int, + num_physical_nodes: int, + ranks_per_node: int, + workers_per_rank: int, + batch_size: int, + partition_algo: str, +): initial_physical_nodes = None if partition_algo == 'relaxed' and num_canonical_nodes == 4 and ranks_per_node == 8: num_canonical_nodes = 3 @@ -55,7 +61,11 @@ def test_partition_drop_all(num_samples: int, num_canonical_nodes: int, num_phys batch_size = batch_size * 3 num_samples = 3 * num_samples - drop_first = num_samples + # Partitioning should repeat samples so that the epoch size is divisible by the world size. + # To drop all samples, we need to drop all repeated samples as well. + world_size = num_physical_nodes * ranks_per_node + num_repeated_samples = world_size - (num_samples % world_size) + drop_first = num_samples + num_repeated_samples x = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, ranks_per_node, workers_per_rank, batch_size, drop_first, @@ -77,7 +87,14 @@ def test_partition_invalid_drop_first(num_samples: int, drop_additional: int, num_canonical_nodes: int, num_physical_nodes: int, ranks_per_node: int, workers_per_rank: int, batch_size: int, partition_algo: str): - drop_first = num_samples + drop_additional + + # Partitioning should repeat samples so that the epoch size is divisible by the world size. + # For `drop_first` to be invalid, we need to exceed the number of unique samples plus the + # number of repeated samples. + world_size = num_physical_nodes * ranks_per_node + num_repeated_samples = world_size - (num_samples % world_size) + drop_first = num_samples + num_repeated_samples + drop_additional + with pytest.raises(ValueError, match=f'Resuming further into the dataset*'): _ = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes, ranks_per_node, workers_per_rank, batch_size, drop_first) From eb61352d243e1b8188c0c3a84842a9475dc2acac Mon Sep 17 00:00:00 2001 From: Hayden Prairie <55720063+Hprairie@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:42:04 -0500 Subject: [PATCH 022/145] fix convert imagenet (#708) --- streaming/vision/convert/imagenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/vision/convert/imagenet.py b/streaming/vision/convert/imagenet.py index f98883527..cdb84e7af 100644 --- a/streaming/vision/convert/imagenet.py +++ b/streaming/vision/convert/imagenet.py @@ -159,7 +159,7 @@ def main(args: Namespace) -> None: x = open(filenames[i], 'rb').read() y = classes[i] out.write({ - 'i': i, + 'i': int(i), 'x': x, 'y': y, }) From 02db72d39e8eb19c3c74271af1b4cb0eb501af27 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:00:07 +0000 Subject: [PATCH 023/145] Bump pytest-split from 0.8.2 to 0.9.0 (#710) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5cac170b8..b32f6861b 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ 'fastapi==0.111.0', 'pydantic==2.7.4', 'uvicorn==0.30.1', - 'pytest-split==0.8.2', + 'pytest-split==0.9.0', ] extra_deps['docs'] = [ From 517dc6d7a524eca26cafd3e4d59f32cb9a97321c Mon Sep 17 00:00:00 2001 From: Vansh Singh Date: Mon, 24 Jun 2024 19:26:38 -0700 Subject: [PATCH 024/145] Remove duplicate `dbfs:` prefix from error message (#712) --- streaming/base/storage/download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index cdcf3d489..00be843a2 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -378,7 +378,7 @@ def download_from_databricks_unity_catalog(remote: str, local: str) -> None: f'operations. Increase the `download_retry` value to retry downloading ' + f'a file.',) if e.error_code == 'NOT_FOUND': - raise FileNotFoundError(f'Object dbfs:{remote} not found.') from e + raise FileNotFoundError(f'Object {remote} not found.') from e raise e os.rename(local_tmp, local) From 67ab85c8ef26529e4c7b3e4e3f87caa2220e6778 Mon Sep 17 00:00:00 2001 From: bigning Date: Fri, 28 Jun 2024 11:29:23 -0700 Subject: [PATCH 025/145] a (#713) --- streaming/base/storage/download.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index 00be843a2..9780836dc 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -53,12 +53,15 @@ def _download_file(unsigned: bool = False, extra_args (Dict[str, Any], optional): Extra arguments supported by boto3. Defaults to ``None``. """ + retries = { + 'mode': 'adaptive', + } if unsigned: # Client will be using unsigned mode in which public # resources can be accessed without credentials - config = Config(read_timeout=timeout, signature_version=UNSIGNED) + config = Config(read_timeout=timeout, signature_version=UNSIGNED, retries=retries) else: - config = Config(read_timeout=timeout) + config = Config(read_timeout=timeout, retries=retries) if extra_args is None: extra_args = {} From 7089eef6cd146015b810fcbb08ef3d64e73e93c1 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Fri, 28 Jun 2024 16:27:22 -0700 Subject: [PATCH 026/145] Upgrade ci_testing, remove codeql (#714) * linting_codeql * yo * yo --- .github/workflows/codeql-analysis.yml | 58 --------------------------- .github/workflows/linting.yaml | 2 +- 2 files changed, 1 insertion(+), 59 deletions(-) delete mode 100644 .github/workflows/codeql-analysis.yml diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml deleted file mode 100644 index 7b32c9ebf..000000000 --- a/.github/workflows/codeql-analysis.yml +++ /dev/null @@ -1,58 +0,0 @@ -# For most projects, this workflow file will not need changing; you simply need -# to commit it to your repository. -# -# You may wish to alter this file to override the set of languages analyzed, -# or to provide custom queries or build logic. -# -# ******** NOTE ******** -# We have attempted to detect the languages in your repository. Please check -# the `language` matrix defined below to confirm you have the correct set of -# supported CodeQL languages. -# -name: "CodeQL" - -on: - push: - branches: [main] - schedule: - - cron: "0 9 * * 1" # Every Monday at 09:00 (9:00 AM) - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: ["python"] - # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] - # Learn more about CodeQL language support at https://git.io/codeql-language-support - - steps: - # the following step is required to avoid running out of space - - name: Maximize build space - run: | - df -h - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - echo "Check space..." - df -h - - - name: Checkout repository - uses: actions/checkout@v3 - - name: Get composite run steps repository - uses: actions/checkout@v3 - with: - repository: mosaicml/ci-testing - ref: v0.0.2 - path: ./ci-testing - - uses: ./ci-testing/.github/actions/codeql-analysis - with: - language: ${{ matrix.language }} diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml index 4e5bd0930..69b3ea6bc 100644 --- a/.github/workflows/linting.yaml +++ b/.github/workflows/linting.yaml @@ -32,7 +32,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.0.2 + ref: v0.0.9 path: ./ci-testing - uses: ./ci-testing/.github/actions/code-quality with: From d9198adaeaae6a0980453be0c5f424247cd8ecd3 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 9 Jul 2024 12:19:18 -0700 Subject: [PATCH 027/145] Wrap with nparray (#719) --- tests/test_encodings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_encodings.py b/tests/test_encodings.py index aac457519..36374545a 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -201,7 +201,7 @@ def test_jpegfile_encode_decode(self, mode: str): # Creating the (32 x 32) NumPy Array with random values size = {'RGB': (224, 224, 3), 'L': (28, 28)}[mode] - np_data = np.random.randint(255, size=size, dtype=np.uint8) + np_data = np.array(np.random.randint(255, size=size, dtype=np.uint8)) # Default image mode of PIL Image is 'I' img = Image.fromarray(np_data).convert(mode) From 6bbcf5affe5c4a1a5046ed44a8a0755af57e50e9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:48:13 -0700 Subject: [PATCH 028/145] Bump pydantic from 2.7.4 to 2.8.2 (#718) Bumps [pydantic](https://github.com/pydantic/pydantic) from 2.7.4 to 2.8.2. - [Release notes](https://github.com/pydantic/pydantic/releases) - [Changelog](https://github.com/pydantic/pydantic/blob/main/HISTORY.md) - [Commits](https://github.com/pydantic/pydantic/compare/v2.7.4...v2.8.2) --- updated-dependencies: - dependency-name: pydantic dependency-type: direct:development update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xiaohan Zhang --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b32f6861b..f13532259 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ 'yamllint==1.35.1', 'moto>=4.0,<6', 'fastapi==0.111.0', - 'pydantic==2.7.4', + 'pydantic==2.8.2', 'uvicorn==0.30.1', 'pytest-split==0.9.0', ] From 2f7defa1e917a09dd2fb832bb1a54ec3c2469769 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Jul 2024 14:04:47 -0700 Subject: [PATCH 029/145] Bump databricks-sdk from 0.28.0 to 0.29.0 (#715) Bumps [databricks-sdk](https://github.com/databricks/databricks-sdk-py) from 0.28.0 to 0.29.0. - [Release notes](https://github.com/databricks/databricks-sdk-py/releases) - [Changelog](https://github.com/databricks/databricks-sdk-py/blob/main/CHANGELOG.md) - [Commits](https://github.com/databricks/databricks-sdk-py/compare/v0.28.0...v0.29.0) --- updated-dependencies: - dependency-name: databricks-sdk dependency-type: direct:development update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xiaohan Zhang --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f13532259..ae15b31cd 100644 --- a/setup.py +++ b/setup.py @@ -116,7 +116,7 @@ ] extra_deps['databricks'] = [ - 'databricks-sdk==0.28.0', + 'databricks-sdk==0.29.0', ] extra_deps['alipan'] = [ From 55e83ec265561aac2728daeefe9fede85d7dea40 Mon Sep 17 00:00:00 2001 From: Orion Weller <31665361+orionw@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:28:40 -0700 Subject: [PATCH 030/145] Add HF File System Support to Streaming (#711) * init commit for hf * fix reqs * fix test * change error name; throw error; fix reqs * Update docs/source/how_to_guides/configure_cloud_storage_credentials.md Co-authored-by: Saaketh Narayan * fix test credential failure * cleanup * Remove duplicate `dbfs:` prefix from error message (#712) * fix typo in tests * docs * Try to figure out what is wrong with lint; test * fix os join * try to fix precommit * a (#713) * Upgrade ci_testing, remove codeql (#714) * linting_codeql * yo * yo * Wrap with nparray (#719) * Bump pydantic from 2.7.4 to 2.8.2 (#718) Bumps [pydantic](https://github.com/pydantic/pydantic) from 2.7.4 to 2.8.2. - [Release notes](https://github.com/pydantic/pydantic/releases) - [Changelog](https://github.com/pydantic/pydantic/blob/main/HISTORY.md) - [Commits](https://github.com/pydantic/pydantic/compare/v2.7.4...v2.8.2) --- updated-dependencies: - dependency-name: pydantic dependency-type: direct:development update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xiaohan Zhang * Bump databricks-sdk from 0.28.0 to 0.29.0 (#715) Bumps [databricks-sdk](https://github.com/databricks/databricks-sdk-py) from 0.28.0 to 0.29.0. - [Release notes](https://github.com/databricks/databricks-sdk-py/releases) - [Changelog](https://github.com/databricks/databricks-sdk-py/blob/main/CHANGELOG.md) - [Commits](https://github.com/databricks/databricks-sdk-py/compare/v0.28.0...v0.29.0) --- updated-dependencies: - dependency-name: databricks-sdk dependency-type: direct:development update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Xiaohan Zhang * update docstring * add isort --------- Signed-off-by: dependabot[bot] Co-authored-by: Saaketh Narayan Co-authored-by: Vansh Singh Co-authored-by: bigning Co-authored-by: Saaketh Narayan Co-authored-by: Xiaohan Zhang Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../configure_cloud_storage_credentials.md | 18 +++++ setup.py | 4 + streaming/base/storage/__init__.py | 17 +++-- streaming/base/storage/download.py | 27 +++++++ streaming/base/storage/upload.py | 76 +++++++++++++++++++ tests/conftest.py | 6 ++ tests/test_download.py | 20 ++++- tests/test_upload.py | 29 ++++++- 8 files changed, 187 insertions(+), 10 deletions(-) diff --git a/docs/source/how_to_guides/configure_cloud_storage_credentials.md b/docs/source/how_to_guides/configure_cloud_storage_credentials.md index 8431e5a9e..6c46679c3 100644 --- a/docs/source/how_to_guides/configure_cloud_storage_credentials.md +++ b/docs/source/how_to_guides/configure_cloud_storage_credentials.md @@ -7,6 +7,7 @@ Streaming dataset supports the following cloud storage providers to stream your - [Oracle Cloud Storage](#oracle-cloud-storage) - [Azure Blob Storage](#azure-blob-storage-and-azure-datalake) - [Databricks](#databricks) +- [Huggingface Datasets](#huggingface-datasets) ## Amazon S3 @@ -251,6 +252,23 @@ export AZURE_ACCOUNT_ACCESS_KEY='NN1KHxKKkj20ZO92EMiDQjx3wp2kZG4UUvfAGlgGWRn6sPR ``` ```` +## Huggingface Datasets + +To authenticate Huggingface Hub access, users must set their HuggingFace token ([HF_TOKEN](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#hftoken)) in the run environment. See the [HF's documentation](https://huggingface.co/docs/huggingface_hub/guides/hf_file_system) on the URL format. + +Set the Huggingface token in the run environment as shown below + +````{tabs} +```{code-tab} py +import os +os.environ['HF_TOKEN'] = 'EXAMPLEFODNN7EXAMPLE' +``` + +```{code-tab} sh +export HF_TOKEN='EXAMPLEFODNN7EXAMPLE' +``` +```` + ## Databricks To authenticate Databricks access for both Unity Catalog and Databricks File System (DBFS), users must set their Databricks host (`DATABRICKS_HOST`) and access token (`DATABRICKS_TOKEN`) in the run environment. diff --git a/setup.py b/setup.py index ae15b31cd..4c3255c1c 100644 --- a/setup.py +++ b/setup.py @@ -123,6 +123,10 @@ 'AliPCS-Py>=0.8,<1', ] +extra_deps['hf'] = [ + 'huggingface_hub>=0.23.4,<0.24', +] + extra_deps['testing'] = [ 'mosaicml-cli>=0.5.25,<0.7', ] diff --git a/streaming/base/storage/__init__.py b/streaming/base/storage/__init__.py index e9653db9d..bfe9ce6f5 100644 --- a/streaming/base/storage/__init__.py +++ b/streaming/base/storage/__init__.py @@ -2,15 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 """Base module for downloading/uploading files from/to cloud storage.""" - -from streaming.base.storage.download import (download_file, download_from_alipan, - download_from_azure, download_from_azure_datalake, - download_from_databricks_unity_catalog, - download_from_dbfs, download_from_gcs, - download_from_local, download_from_oci, - download_from_s3, download_from_sftp) +# isort: off +from streaming.base.storage.download import ( + download_file, download_from_alipan, download_from_azure, download_from_azure_datalake, + download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, + download_from_hf, download_from_local, download_from_oci, download_from_s3, download_from_sftp) from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, - GCSUploader, LocalUploader, OCIUploader, S3Uploader) + GCSUploader, HFUploader, LocalUploader, OCIUploader, + S3Uploader) __all__ = [ 'download_file', @@ -21,6 +20,7 @@ 'LocalUploader', 'AzureUploader', 'AzureDataLakeUploader', + 'HFUploader', 'download_from_s3', 'download_from_sftp', 'download_from_gcs', @@ -31,4 +31,5 @@ 'download_from_dbfs', 'download_from_alipan', 'download_from_local', + 'download_from_hf', ] diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index 9780836dc..e5392c3c6 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -19,6 +19,7 @@ 'download_from_oci', 'download_from_azure', 'download_from_azure_datalake', + 'download_from_hf', 'download_from_databricks_unity_catalog', 'download_from_dbfs', 'download_from_alipan', @@ -275,6 +276,30 @@ def download_from_oci(remote: str, local: str) -> None: os.rename(local_tmp, local) +def download_from_hf(remote: str, local: str) -> None: + """Download a file from remote Hugging Face to local. + + Args: + remote (str): Remote path (Hugging Face). + local (str): Local path (local filesystem). + """ + from huggingface_hub import hf_hub_download + + obj = urllib.parse.urlparse(remote) + if obj.scheme != 'hf': + raise ValueError(f'Expected remote path to start with `hf://`, got {remote}.') + + _, _, _, repo_org, repo_name, path = remote.split('/', 5) + local_dirname = os.path.dirname(local) + hf_hub_download(repo_id=f'{repo_org}/{repo_name}', + filename=path, + repo_type='dataset', + local_dir=local_dirname) + + downloaded_name = os.path.join(local_dirname, path) + os.rename(downloaded_name, local) + + def download_from_azure(remote: str, local: str) -> None: """Download a file from remote Microsoft Azure to local. @@ -514,6 +539,8 @@ def download_file(remote: Optional[str], local: str, timeout: float): download_from_gcs(remote, local) elif remote.startswith('oci://'): download_from_oci(remote, local) + elif remote.startswith('hf://'): + download_from_hf(remote, local) elif remote.startswith('azure://'): download_from_azure(remote, local) elif remote.startswith('azure-dl://'): diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 6a8c67e3c..1c296bb89 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -24,6 +24,7 @@ 'S3Uploader', 'GCSUploader', 'OCIUploader', + 'HFUploader', 'AzureUploader', 'DatabricksUnityCatalogUploader', 'DBFSUploader', @@ -37,6 +38,7 @@ 's3': 'S3Uploader', 'gs': 'GCSUploader', 'oci': 'OCIUploader', + 'hf': 'HFUploader', 'azure': 'AzureUploader', 'azure-dl': 'AzureDataLakeUploader', 'dbfs:/Volumes': 'DatabricksUnityCatalogUploader', @@ -616,6 +618,80 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: return [] +class HFUploader(CloudUploader): + """Upload file from local machine to a Huggingface Dataset. + + Args: + out (str): Output dataset directory to save shard files. + + 1. If ``out`` is a local directory, shard files are saved locally. + 2. If ``out`` is a remote directory then the shard files are uploaded to the + remote location. + keep_local (bool): If the dataset is uploaded, whether to keep the local dataset + shard file or remove it after uploading. Defaults to ``False``. + progress_bar (bool): Display TQDM progress bars for uploading output dataset files to + a remote location. Default to ``False``. + retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. + """ + + def __init__(self, + out: str, + keep_local: bool = False, + progress_bar: bool = False, + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) + + import huggingface_hub + self.api = huggingface_hub.HfApi() + self.fs = huggingface_hub.HfFileSystem(token=os.environ.get('HF_TOKEN', None)) + + obj = urllib.parse.urlparse(out) + if obj.scheme != 'hf': + raise ValueError(f'Expected remote path to start with `hf://`, got {out}.') + + _, _, _, self.repo_org, self.repo_name, self.path = out.split('/', 5) + self.dataset_id = os.path.join(self.repo_org, self.repo_name) + self.check_dataset_exists() # pyright: ignore + + def upload_file(self, filename: str): + """Upload file from local instance to HF. + + Args: + filename (str): File to upload. + """ + + @retry(num_attempts=self.retry) + def _upload_file(): + local_filename = filename + local_filename = local_filename.replace('\\', '/') + remote_filename = os.path.join('datasets', self.dataset_id, filename) + remote_filename = remote_filename.replace('\\', '/') + logger.debug(f'Uploading to {remote_filename}') + + with self.fs.open(remote_filename, 'wb') as f: + with open(local_filename, 'rb') as data: + f.write(data.read()) + + _upload_file() + + def check_dataset_exists(self): + """Raise an exception if the dataset does not exist. + + Raises: + error: Dataset does not exist. + """ + import huggingface_hub + try: + _ = list(huggingface_hub.list_repo_tree(self.dataset_id, repo_type='dataset')) + except Exception: + raise FileNotFoundError( + f'The HF dataset {self.dataset_id} could not be found. Please make sure ' + + f'that the dataset exists and you have the correct access permissions.') + + class AzureUploader(CloudUploader): """Upload file from local machine to Microsoft Azure bucket. diff --git a/tests/conftest.py b/tests/conftest.py index ac7844539..3b8a416c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,6 +51,12 @@ def aws_credentials(): os.environ['AWS_SESSION_TOKEN'] = 'testing' +@pytest.fixture(scope='class', autouse=True) +def hf_credentials(): + """Mocked HF Credentials.""" + os.environ['HF_TOKEN'] = 'testing' + + @pytest.fixture() def s3_client(aws_credentials: Any): with mock_aws(): diff --git a/tests/test_download.py b/tests/test_download.py index 8d9a0c1b2..50bee57d1 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -14,7 +14,8 @@ download_from_azure_datalake, download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, - download_from_local, download_from_s3) + download_from_hf, download_from_local, + download_from_s3) from tests.conftest import GCS_URL, MY_BUCKET, R2_URL MY_PREFIX = 'train' @@ -47,6 +48,15 @@ def test_invalid_cloud_prefix(self, remote_local_file: Any): download_from_azure_datalake(mock_remote_filepath, mock_local_filepath) +class TestHFClient: + + @pytest.mark.usefixtures('remote_local_file') + def test_invalid_cloud_prefix(self, remote_local_file: Any): + with pytest.raises(ValueError): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='hf://') + download_from_hf(mock_remote_filepath, mock_local_filepath) + + class TestS3Client: @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') @@ -183,6 +193,14 @@ def test_download_from_gcs_gets_called(self, mocked_requests: Mock, remote_local mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) + @patch('streaming.base.storage.download.download_from_hf') + @pytest.mark.usefixtures('remote_local_file') + def test_download_from_hf_gets_called(self, mocked_requests: Mock, remote_local_file: Any): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='hf://') + download_file(mock_remote_filepath, mock_local_filepath, 60) + mocked_requests.assert_called_once() + mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) + @patch('streaming.base.storage.download.download_from_azure') @pytest.mark.usefixtures('remote_local_file') def test_download_from_azure_gets_called(self, mocked_requests: Mock, remote_local_file: Any): diff --git a/tests/test_upload.py b/tests/test_upload.py index 455b6b8c4..b280ac968 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -14,7 +14,7 @@ from streaming.base.storage.upload import (AlipanUploader, AzureDataLakeUploader, AzureUploader, CloudUploader, DatabricksUnityCatalogUploader, DBFSUploader, GCSAuthentication, GCSUploader, - LocalUploader, S3Uploader) + HFUploader, LocalUploader, S3Uploader) from tests.conftest import MY_BUCKET, R2_URL MY_PREFIX = 'train' @@ -425,6 +425,33 @@ def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): _ = AzureDataLakeUploader(out=local) +class TestHFUploader: + + @patch('streaming.base.storage.upload.HFUploader.check_dataset_exists') + @pytest.mark.usefixtures('hf_credentials') + @pytest.mark.parametrize('out', ['hf://datasets/org_name/repo_name/path']) + def test_instantiation(self, mocked_requests: Mock, out: Any): + mocked_requests.side_effect = None + _ = HFUploader(out=out) + if not isinstance(out, str): + shutil.rmtree(out[0], ignore_errors=True) + + @pytest.mark.parametrize('out', ['ss4://container/dir']) + def test_invalid_remote_str(self, out: str): + with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): + _ = HFUploader(out=out) + + def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): + with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): + local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) + local_file_path = os.path.join(local, 'file.txt') + # Creating an empty file at specified location + with open(local_file_path, 'w') as _: + pass + _ = HFUploader(out=local) + + class TestDatabricksUnityCatalogUploader: @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') From 5f939c9057b041f10342dfc5744d2d3880e3f14b Mon Sep 17 00:00:00 2001 From: bigning Date: Fri, 12 Jul 2024 16:35:50 -0700 Subject: [PATCH 031/145] Improve error message on non-0 rank when index file download failed (#723) * a * lint --- streaming/base/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 1fd9fc9ff..15e288a53 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -454,8 +454,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: wait_for_file_to_exist( filename, TICK, self.download_timeout, f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + - f'-> {filename} took too long to download. Either increase the ' + - f'`download_timeout` value or check the other traceback.') + f'-> {filename} took too long to download or failed to download. Either increase the ' + + f'`download_timeout` value or check the local rank 0 traceback.') # Load the index. try: From 083191f83ccf62330ba38153d02d0c0333a53d7e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:26:38 +0000 Subject: [PATCH 032/145] Bump pytest from 8.2.2 to 8.3.2 (#735) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4c3255c1c..38ede2adb 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ 'docformatter>=1.4', 'jupyter==1.0.0', 'pre-commit>=2.18.1,<4', - 'pytest==8.2.2', + 'pytest==8.3.2', 'pytest_codeblocks==0.17.0', 'pytest-cov>=4,<6', 'toml==0.10.2', From 551f360496bbedce90c08ca3dfdde419471432f5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:40:16 +0000 Subject: [PATCH 033/145] Bump uvicorn from 0.30.1 to 0.30.3 (#730) Bumps [uvicorn](https://github.com/encode/uvicorn) from 0.30.1 to 0.30.3. - [Release notes](https://github.com/encode/uvicorn/releases) - [Changelog](https://github.com/encode/uvicorn/blob/master/CHANGELOG.md) - [Commits](https://github.com/encode/uvicorn/compare/0.30.1...0.30.3) --- updated-dependencies: - dependency-name: uvicorn dependency-type: direct:development update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Saaketh Narayan --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 38ede2adb..d9c15d526 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ 'moto>=4.0,<6', 'fastapi==0.111.0', 'pydantic==2.8.2', - 'uvicorn==0.30.1', + 'uvicorn==0.30.3', 'pytest-split==0.9.0', ] From 4f3bc22823ea308b2ff397acf1cfcb2df919582d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:55:48 +0000 Subject: [PATCH 034/145] Bump fastapi from 0.111.0 to 0.111.1 (#724) Bumps [fastapi](https://github.com/tiangolo/fastapi) from 0.111.0 to 0.111.1. - [Release notes](https://github.com/tiangolo/fastapi/releases) - [Commits](https://github.com/tiangolo/fastapi/compare/0.111.0...0.111.1) --- updated-dependencies: - dependency-name: fastapi dependency-type: direct:development update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d9c15d526..457eb74b9 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ 'toml==0.10.2', 'yamllint==1.35.1', 'moto>=4.0,<6', - 'fastapi==0.111.0', + 'fastapi==0.111.1', 'pydantic==2.8.2', 'uvicorn==0.30.3', 'pytest-split==0.9.0', From b14cd7ad864d7fc5ed8d07845c3a58fc4f6f4467 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 30 Jul 2024 09:48:57 -0700 Subject: [PATCH 035/145] Update _version.py (#738) --- streaming/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/_version.py b/streaming/_version.py index 6999045b3..43c88d2a3 100644 --- a/streaming/_version.py +++ b/streaming/_version.py @@ -3,4 +3,4 @@ """The Streaming Version.""" -__version__ = '0.7.6' +__version__ = '0.8.0' From 54b6801caa6c16b91072f08d252db0dc6b4bc672 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:28:02 -0400 Subject: [PATCH 036/145] Make Pytest log in color in Github Action (#739) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e648880ac..9878eba0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ reportUnusedCoroutine = "error" # Pytest [tool.pytest.ini_options] # By default, do not run remote tests -addopts = "--cov=streaming --cov-fail-under=50 --codeblocks --strict-markers -m 'not daily and not remote' -ra --tb=native" +addopts = "--cov=streaming --cov-fail-under=50 --codeblocks --strict-markers -m 'not daily and not remote' -ra --tb=native --color=yes" markers = [ # For distributed testing From 3a6a5490678a2efa028ed96ba9b8813fba8687eb Mon Sep 17 00:00:00 2001 From: jaehwana2z <165435393+jaehwana2z@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:43:47 -0700 Subject: [PATCH 037/145] fix azure container name and blob name in download_from_azure (#733) Co-authored-by: Ubuntu Co-authored-by: Saaketh Narayan --- streaming/base/storage/download.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index e5392c3c6..bf4f0e33b 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -320,7 +320,10 @@ def download_from_azure(remote: str, local: str) -> None: account_url=f"https://{os.environ['AZURE_ACCOUNT_NAME']}.blob.core.windows.net", credential=os.environ['AZURE_ACCOUNT_ACCESS_KEY']) try: - blob_client = service.get_blob_client(container=obj.netloc, blob=obj.path.lstrip('/')) + file_path = obj.path.lstrip('/').split('/') + container_name = file_path[0] + blob_name = os.path.join(*file_path[1:]) + blob_client = service.get_blob_client(container=container_name, blob=blob_name) local_tmp = local + '.tmp' with open(local_tmp, 'wb') as my_blob: blob_data = blob_client.download_blob() From 2580d40ca173fe54ca5c23c5bbaa0ecf8b7ba1c9 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 22 May 2024 23:17:43 -0700 Subject: [PATCH 038/145] update --- setup.py | 1 + streaming/base/dataset.py | 14 ++- streaming/base/stream.py | 179 +++++++++++++++++++++++++++++++++ tests/test_streaming_remote.py | 90 ++++++----------- 4 files changed, 222 insertions(+), 62 deletions(-) diff --git a/setup.py b/setup.py index 457eb74b9..92c9aa012 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'databricks-connect>=14.3.0', ] extra_deps = {} diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 292405528..753ffc2a4 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -34,7 +34,7 @@ from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, get_shm_prefix) from streaming.base.spanner import Spanner -from streaming.base.stream import Stream +from streaming.base.stream import Stream, DeltaStream from streaming.base.util import bytes_to_int, number_abbrev_to_int from streaming.base.world import World @@ -443,6 +443,15 @@ def __init__(self, } for stream in streams: stream.apply_default(default) + elif remote is not None and remote.startswith('SELECT'): + default = DeltaStream(remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + streams = [default] else: default = Stream(remote=remote, local=local, @@ -507,6 +516,8 @@ def __init__(self, # Build the shard index (for partitioning and mapping samples to shards). self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) + print('I am here 5.1, samples_per_shard = ') + print(self.samples_per_shard) self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard self.spanner = Spanner(self.samples_per_shard) @@ -1247,6 +1258,7 @@ def get_item(self, sample_id: int, retry: int = 7) -> Any: raise RuntimeError('Background thread failed. Check other traceback.') # Locate the shard and sample offset within that shard where the sample lives. shard_id, shard_sample_id = self.spanner[sample_id] + #print('I am here 5.2', shard_id, shard_sample_id) shard = self.shards[shard_id] sample = None diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 15e288a53..104774aef 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -22,6 +22,10 @@ from streaming.base.util import retry, wait_for_file_to_exist from streaming.base.world import World +import pyarrow as pa +import requests +from tempfile import TemporaryDirectory + class Stream: """A dataset, or sub-dataset if mixing, from which we stream/cache samples. @@ -505,3 +509,178 @@ def get_index_size(self) -> int: """ filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size + + +class DeltaStream(Stream): + + def __init__(self, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None) -> None: + super().__init__(remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + + self.url_to_basename= {} + self.basename_to_url={} + + def generate_unique_basename(self, url: str, index: int) -> str: + """Generate a unique basename for the file path from the URL.""" + hash_object = hashlib.md5(url.encode()) + hex_dig = hash_object.hexdigest() + # basename = f"{hex_dig[:3]}/shard.{int(hex_dig, 16) % 100000:05d}.mds" + basename = '.'.join(['shard', f'{index:05}', 'mds']) + self.url_to_basename[url] = basename + self.basename_to_url[basename] = url + return basename + + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + """Load this Stream's index, retrieving its shard readers. + + Args: + world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. + + Returns: + `List[Reader]: Shard readers. + """ + # Prepare cloudfetch + from databricks.connect import DatabricksSession + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + cluster_id = "0201-234512-tcp9nfat" + + print('I am here 1') + sparkSession = DatabricksSession.builder.remote( + host=w.config.host, + token=w.config.token, + cluster_id=cluster_id).getOrCreate() + + print('I am here 2') + df = sparkSession.sql(self.remote) + print('I am here 2.0') + query = df._plan.to_proto(df._session.client) # pyright: ignore + print('I am here 2.1') + schema, cloudfetch_results = df._session.client.experimental_to_cloudfetch(query, "arrow", compression=False) # pyright: ignore + + # Local leader prepares the index file based on cloudfetch results + print('I am here 3') + basename = get_index_basename() + filename = os.path.join(self.local, self.split, basename) + + print('schema = ', schema) + self.columns = {'text': 'str'} + + print('I am here 4', len(cloudfetch_results)) + if world.is_local_leader: + + metadata = { + "version": 2, + "shards": [] + } + + for index, result in enumerate(cloudfetch_results): + shard = { + "column_encodings": ["str"], + "column_names": ["tokenized_example"], + "column_sizes": [None], + "compression": None, + "format": "mds", + "hashes": ["sha1"], + "raw_data": { + "basename": self.generate_unique_basename(result.url, index), + "bytes": result.uncompressed_size, + "hashes": {} + }, + "samples": result.row_count, + "size_limit": 67108864, + "version": 2, + "zip_data": None + } + metadata["shards"].append(shard) + + print('metadata = ') + print(metadata) + + with open(filename, 'w') as f: + json.dump(metadata, f, indent=4) + + else: + wait_for_file_to_exist( + filename, TICK, self.download_timeout, + f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + + f'-> {filename} took too long to download. Either increase the ' + + f'`download_timeout` value or check the other traceback.') + + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + + # Initialize shard readers according to the loaded info. + shards = [] + for info in obj['shards']: + shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) + shards.append(shard) + + print('I am here 4.1, shard.samples = ', shard.samples) + + return shards + + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + """Safely download a file from remote to local cache. + + Args: + from_basename (str): Source basename. + to_basename (str, optional): Destination basename, if different. + + Returns: + str: Local cache filename. + """ + def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + from streaming import MDSWriter + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + + print('from_basename = ', from_basename) + cloud_fetch_url = self.basename_to_url[from_basename] + local = os.path.join(self.local, self.split, from_basename) + + # Attempt to download, possibly repeating on failure. + retry(num_attempts=self.download_retry)( + lambda: fetch_and_convert(cloud_fetch_url, local))() + + return local + + diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 1c1f7e10c..a75fe226e 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -20,69 +20,39 @@ def get_dataset(name: str, other_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[int, StreamingDataset]: other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { - 'ade20k': { - 'remote': 's3://mosaicml-internal-dataset-ade20k/mds/2/', + 'cpt': { + 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/CPT/mds_data_11Jan24_3/', 'num_samples': { 'train': 20206, - 'val': 2000, + 'val': 0, }, - 'class': StreamingADE20K, - 'kwargs': {}, - }, - 'imagenet1k': { - 'remote': 's3://mosaicml-internal-dataset-imagenet1k/mds/2/', - 'num_samples': { - 'train': 1281167, - 'val': 50000, - }, - 'class': StreamingImageNet, - 'kwargs': {}, - }, - 'coco': { - 'remote': 's3://mosaicml-internal-dataset-coco/mds/2/', - 'num_samples': { - 'train': 117266, - 'val': 4952, - }, - 'class': StreamingCOCO, + 'class': StreamingDataset, 'kwargs': {}, }, - 'c4': { - 'remote': 's3://mosaicml-internal-dataset-c4/mds/2/', + 'dummy_table': { + 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', 'num_samples': { - 'train': 364868892, - 'val': 364608, - }, - 'class': StreamingC4, - 'kwargs': { - 'tokenizer_name': 'bert-base-uncased', - 'max_seq_len': 512, - 'group_method': 'truncate' - }, - }, - 'cifar10': { - 'remote': 's3://mosaicml-internal-dataset-cifar10/mds/2/', - 'num_samples': { - 'train': 50000, - 'val': 10000, + 'train': 20206, + 'val': 0, }, - 'class': StreamingCIFAR10, + 'class': StreamingDataset, 'kwargs': {}, }, - 'test_streaming_upload': { - 'remote': 's3://streaming-upload-test-bucket/', + 'random_cpt_table': { + 'remote': 'SELECT text FROM main.streaming.random_cpt_table', 'num_samples': { - 'all': 0, + 'train': 20206, + 'val': 0, }, 'class': StreamingDataset, 'kwargs': {}, - } + }, } - if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: - raise ValueError('Could not load dataset with name={name} and split={split}') + #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: + # raise ValueError('Could not load dataset with name={name} and split={split}') d = dataset_map[name] - expected_samples = d['num_samples'][split] + expected_samples = 1 # d['num_samples'][split] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} dataset = d['class'](local=local, @@ -94,23 +64,14 @@ def get_dataset(name: str, return (expected_samples, dataset) -@pytest.mark.remote -@pytest.mark.parametrize('name', [ - 'ade20k', - 'imagenet1k', - 'coco', - 'cifar10', - 'c4', -]) -@pytest.mark.parametrize('split', ['val']) -def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) -> None: +def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() expected_samples, dataset = get_dataset(name=name, - local=str(tmp_path), + local=f'/tmp/test_delta_05May1029', split=split, shuffle=False, - batch_size=None) + batch_size=16) build_end = time.time() build_dur = build_end - build_start print('Built dataset') @@ -121,7 +82,7 @@ def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) for _ in dataset: rcvd_samples += 1 - if (rcvd_samples % 1000 == 0): + if (rcvd_samples % 100 == 0): print(f'samples read: {rcvd_samples}') iter_end = time.time() @@ -129,8 +90,15 @@ def test_streaming_remote_dataset(tmp_path: pathlib.Path, name: str, split: str) samples_per_sec = rcvd_samples / iter_dur # Print debug info + print(f'received {rcvd_samples} samples') print(f'build_dur={build_dur:.2f}s, iter_dur={iter_dur:.2f}, ' + f'samples_per_sec={samples_per_sec:.2f}') # Test all samples arrived - assert rcvd_samples == expected_samples + assert rcvd_samples >= expected_samples + + +if __name__ == "__main__": +# test_streaming_remote_dataset(name = 'cpt', split=None) + # test_streaming_remote_dataset(name = 'dummy_table', split=None) + test_streaming_remote_dataset(name = 'random_cpt_table', split=None) From dd8f1b980c21401908bb45e2218d578edecd94c9 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 21:26:52 +0000 Subject: [PATCH 039/145] update --- tests/test_streaming_remote.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index a75fe226e..2f2f02afd 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -20,8 +20,9 @@ def get_dataset(name: str, other_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[int, StreamingDataset]: other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { - 'cpt': { - 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/CPT/mds_data_11Jan24_3/', + 'refinedweb': { + 'local': f'/tmp/test_refinedweb_05May1029', + 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', 'num_samples': { 'train': 20206, 'val': 0, @@ -30,6 +31,7 @@ def get_dataset(name: str, 'kwargs': {}, }, 'dummy_table': { + 'local': f'/tmp/test_dummy_table_05May1029', 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', 'num_samples': { 'train': 20206, @@ -39,6 +41,7 @@ def get_dataset(name: str, 'kwargs': {}, }, 'random_cpt_table': { + 'local': f'/tmp/test_random_cpt_table_05May1029', 'remote': 'SELECT text FROM main.streaming.random_cpt_table', 'num_samples': { 'train': 20206, @@ -53,6 +56,7 @@ def get_dataset(name: str, d = dataset_map[name] expected_samples = 1 # d['num_samples'][split] + local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} dataset = d['class'](local=local, @@ -99,6 +103,6 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: if __name__ == "__main__": -# test_streaming_remote_dataset(name = 'cpt', split=None) + test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) - test_streaming_remote_dataset(name = 'random_cpt_table', split=None) +# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) From bcb9429a9102275eb55520e7cbbf0f48e8170770 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 22:13:24 +0000 Subject: [PATCH 040/145] update --- streaming/base/stream.py | 17 ++++++++++++++--- tests/test_streaming_remote.py | 32 ++++++++++++++++---------------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 104774aef..e946bc462 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -563,6 +563,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: # Prepare cloudfetch from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient + from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() cluster_id = "0201-234512-tcp9nfat" @@ -586,9 +587,19 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: filename = os.path.join(self.local, self.split, basename) print('schema = ', schema) - self.columns = {'text': 'str'} + self.columns = infer_dataframe_schema(df, None) + column_names = [] + column_encodings = [] + for k, v in self.columns.items(): + column_names.append(k) + column_encodings.append(v) + #self.columns = {'text': 'str'} + print('inferred columns = ', self.columns) print('I am here 4', len(cloudfetch_results)) + +# raise RuntimeError("break") + if world.is_local_leader: metadata = { @@ -598,8 +609,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: for index, result in enumerate(cloudfetch_results): shard = { - "column_encodings": ["str"], - "column_names": ["tokenized_example"], + "column_encodings": column_encodings, + "column_names": column_names, "column_sizes": [None], "compression": None, "format": "mds", diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 2f2f02afd..a94d4f677 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -23,30 +23,28 @@ def get_dataset(name: str, 'refinedweb': { 'local': f'/tmp/test_refinedweb_05May1029', 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', - 'num_samples': { - 'train': 20206, - 'val': 0, - }, + 'num_samples': 20206, 'class': StreamingDataset, 'kwargs': {}, }, 'dummy_table': { 'local': f'/tmp/test_dummy_table_05May1029', 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', - 'num_samples': { - 'train': 20206, - 'val': 0, - }, + 'num_samples': 20206, 'class': StreamingDataset, 'kwargs': {}, }, 'random_cpt_table': { 'local': f'/tmp/test_random_cpt_table_05May1029', 'remote': 'SELECT text FROM main.streaming.random_cpt_table', - 'num_samples': { - 'train': 20206, - 'val': 0, - }, + 'num_samples': 100000, + 'class': StreamingDataset, + 'kwargs': {}, + }, + 'random_large_table': { + 'local': f'/tmp/test_random_large_table_05May1029', + 'remote': 'SELECT * FROM main.streaming.random_large_table', + 'num_samples': 100000, 'class': StreamingDataset, 'kwargs': {}, }, @@ -55,7 +53,7 @@ def get_dataset(name: str, # raise ValueError('Could not load dataset with name={name} and split={split}') d = dataset_map[name] - expected_samples = 1 # d['num_samples'][split] + expected_samples = d['num_samples'] local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} @@ -72,7 +70,7 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() expected_samples, dataset = get_dataset(name=name, - local=f'/tmp/test_delta_05May1029', + local=None, # f'/tmp/test_delta_05May1029', split=split, shuffle=False, batch_size=16) @@ -103,6 +101,8 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: if __name__ == "__main__": - test_streaming_remote_dataset(name = 'refinedweb', split=None) +# test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) -# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) +# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) + test_streaming_remote_dataset(name = 'random_large_table', split=None) + From 3c50853537393cc951aba89bfdedde4de1c92082 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 16:50:21 -0700 Subject: [PATCH 041/145] update --- streaming/base/format/mds/encodings.py | 3 +++ streaming/base/stream.py | 20 +++++++++++++++----- tests/test_streaming_remote.py | 23 ++++++++++++++++++++--- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 0e7c7fed6..7bb85c4e3 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -315,6 +315,9 @@ def encode(self, obj: Any) -> bytes: return self.dtype(obj).tobytes() def decode(self, data: bytes) -> Any: + print(f"Data length: {len(data)}, Expected dtype: {self.dtype}, Element size: {np.dtype(self.dtype).itemsize}") + if len(data) % np.dtype(self.dtype).itemsize != 0: + print(f"Error: Buffer size {len(data)} is not a multiple of element size {np.dtype(self.dtype).itemsize}") return np.frombuffer(data, self.dtype)[0] diff --git a/streaming/base/stream.py b/streaming/base/stream.py index e946bc462..400ca9f16 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -566,7 +566,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() - cluster_id = "0201-234512-tcp9nfat" + #cluster_id = "0201-234512-tcp9nfat" # e2-dogfood + cluster_id = "0523-224100-tid6mais" # db-force-one print('I am here 1') sparkSession = DatabricksSession.builder.remote( @@ -588,16 +589,20 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: print('schema = ', schema) self.columns = infer_dataframe_schema(df, None) + column_names = [] column_encodings = [] + column_sizes = [] for k, v in self.columns.items(): column_names.append(k) column_encodings.append(v) + column_sizes.append(None) + #self.columns = {'text': 'str'} print('inferred columns = ', self.columns) print('I am here 4', len(cloudfetch_results)) - + # raise RuntimeError("break") if world.is_local_leader: @@ -609,9 +614,9 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: for index, result in enumerate(cloudfetch_results): shard = { - "column_encodings": column_encodings, + "column_encodings": column_encodings, "column_names": column_names, - "column_sizes": [None], + "column_sizes": column_sizes, "compression": None, "format": "mds", "hashes": ["sha1"], @@ -673,10 +678,14 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) Returns: str: Local cache filename. """ + from streaming import MDSWriter + def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): - from streaming import MDSWriter samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + print('samples = ') + print(len(samples)) + with TemporaryDirectory() as temp_dir: with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: for sample in samples: @@ -692,6 +701,7 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): retry(num_attempts=self.download_retry)( lambda: fetch_and_convert(cloud_fetch_url, local))() + print('download to local is done = ', local) return local diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index a94d4f677..f61dde93e 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -21,7 +21,7 @@ def get_dataset(name: str, other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { 'refinedweb': { - 'local': f'/tmp/test_refinedweb_05May1029', + 'local': f'/tmp/test_refinedweb_05May1029', 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', 'num_samples': 20206, 'class': StreamingDataset, @@ -48,6 +48,20 @@ def get_dataset(name: str, 'class': StreamingDataset, 'kwargs': {}, }, + 'reddit_table': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': 'SELECT text, added FROM main.reddit.data', + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': {}, + }, + 'debug_local': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': None, + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': {}, + }, } #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: # raise ValueError('Could not load dataset with name={name} and split={split}') @@ -100,9 +114,12 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: assert rcvd_samples >= expected_samples -if __name__ == "__main__": +#if __name__ == "__main__": # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) # test_streaming_remote_dataset(name = 'random_cpt_table', split=None) - test_streaming_remote_dataset(name = 'random_large_table', split=None) +# test_streaming_remote_dataset(name = 'random_large_table', split=None) +test_streaming_remote_dataset(name = 'reddit_table', split=None) +# test_streaming_remote_dataset(name = 'debug_local', split=None) + From df799fd711690bc68c086b61e1ee1dfbc52c7e9b Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 21:57:16 -0700 Subject: [PATCH 042/145] Make cluser id a param --- streaming/base/dataset.py | 7 ++++-- streaming/base/format/mds/encodings.py | 3 --- streaming/base/stream.py | 30 +++++--------------------- tests/test_streaming_remote.py | 25 ++++++++++++++------- 4 files changed, 27 insertions(+), 38 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 753ffc2a4..1e5e2113e 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -331,7 +331,8 @@ def __init__(self, shuffle_block_size: Optional[int] = None, batching_method: str = 'random', allow_unsafe_types: bool = False, - replication: Optional[int] = None) -> None: + replication: Optional[int] = None, + **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload self.cache_limit = cache_limit @@ -444,7 +445,9 @@ def __init__(self, for stream in streams: stream.apply_default(default) elif remote is not None and remote.startswith('SELECT'): - default = DeltaStream(remote=remote, + cluster_id = kwargs.get('cluster_id', None) + default = DeltaStream(cluster_id, + remote=remote, local=local, split=split, download_retry=download_retry, diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 7bb85c4e3..0e7c7fed6 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -315,9 +315,6 @@ def encode(self, obj: Any) -> bytes: return self.dtype(obj).tobytes() def decode(self, data: bytes) -> Any: - print(f"Data length: {len(data)}, Expected dtype: {self.dtype}, Element size: {np.dtype(self.dtype).itemsize}") - if len(data) % np.dtype(self.dtype).itemsize != 0: - print(f"Error: Buffer size {len(data)} is not a multiple of element size {np.dtype(self.dtype).itemsize}") return np.frombuffer(data, self.dtype)[0] diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 400ca9f16..3f8bd1956 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -514,6 +514,7 @@ def get_index_size(self) -> int: class DeltaStream(Stream): def __init__(self, + cluster_id: str, remote: Optional[str] = None, local: Optional[str] = None, split: Optional[str] = None, @@ -537,6 +538,7 @@ def __init__(self, self.url_to_basename= {} self.basename_to_url={} + self.cluster_id = cluster_id def generate_unique_basename(self, url: str, index: int) -> str: """Generate a unique basename for the file path from the URL.""" @@ -566,28 +568,22 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() - #cluster_id = "0201-234512-tcp9nfat" # e2-dogfood - cluster_id = "0523-224100-tid6mais" # db-force-one + ##cluster_id = "0201-234512-tcp9nfat" # e2-dogfood + #cluster_id = "0523-224100-tid6mais" # db-force-one - print('I am here 1') sparkSession = DatabricksSession.builder.remote( host=w.config.host, token=w.config.token, - cluster_id=cluster_id).getOrCreate() + cluster_id=self.cluster_id).getOrCreate() - print('I am here 2') df = sparkSession.sql(self.remote) - print('I am here 2.0') query = df._plan.to_proto(df._session.client) # pyright: ignore - print('I am here 2.1') schema, cloudfetch_results = df._session.client.experimental_to_cloudfetch(query, "arrow", compression=False) # pyright: ignore # Local leader prepares the index file based on cloudfetch results - print('I am here 3') basename = get_index_basename() filename = os.path.join(self.local, self.split, basename) - print('schema = ', schema) self.columns = infer_dataframe_schema(df, None) column_names = [] @@ -598,13 +594,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: column_encodings.append(v) column_sizes.append(None) - #self.columns = {'text': 'str'} - print('inferred columns = ', self.columns) - - print('I am here 4', len(cloudfetch_results)) - -# raise RuntimeError("break") - if world.is_local_leader: metadata = { @@ -632,9 +621,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: } metadata["shards"].append(shard) - print('metadata = ') - print(metadata) - with open(filename, 'w') as f: json.dump(metadata, f, indent=4) @@ -664,8 +650,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: shard.validate(allow_unsafe_types) shards.append(shard) - print('I am here 4.1, shard.samples = ', shard.samples) - return shards def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: @@ -683,9 +667,6 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() - print('samples = ') - print(len(samples)) - with TemporaryDirectory() as temp_dir: with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: for sample in samples: @@ -693,7 +674,6 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') os.rename(temp_mds_filename, local_shard_path) - print('from_basename = ', from_basename) cloud_fetch_url = self.basename_to_url[from_basename] local = os.path.join(self.local, self.split, from_basename) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index f61dde93e..5d9ca701c 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -32,35 +32,43 @@ def get_dataset(name: str, 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', 'num_samples': 20206, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0201-234512-tcp9nfat" + }, }, 'random_cpt_table': { 'local': f'/tmp/test_random_cpt_table_05May1029', 'remote': 'SELECT text FROM main.streaming.random_cpt_table', 'num_samples': 100000, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0201-234512-tcp9nfat" + }, }, 'random_large_table': { 'local': f'/tmp/test_random_large_table_05May1029', 'remote': 'SELECT * FROM main.streaming.random_large_table', 'num_samples': 100000, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0201-234512-tcp9nfat" + }, }, 'reddit_table': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': 'SELECT text, added FROM main.reddit.data', 'num_samples': 378156152, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': { + 'cluster_id': "0523-224100-tid6mais" + }, }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': None, 'num_samples': 378156152, 'class': StreamingDataset, - 'kwargs': {}, + 'kwargs': {} }, } #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: @@ -71,7 +79,8 @@ def get_dataset(name: str, local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} - dataset = d['class'](local=local, + dataset = d['class'](d['cluster_id'], + local=local, remote=remote, split=split, shuffle=shuffle, @@ -117,9 +126,9 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: #if __name__ == "__main__": # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) -# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) +test_streaming_remote_dataset(name = 'random_cpt_table', split=None) # test_streaming_remote_dataset(name = 'random_large_table', split=None) -test_streaming_remote_dataset(name = 'reddit_table', split=None) +# test_streaming_remote_dataset(name = 'reddit_table', split=None) # test_streaming_remote_dataset(name = 'debug_local', split=None) From 872d08d9d8265b4d5c5e52940ed97cbf592794a0 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 22:03:45 -0700 Subject: [PATCH 043/145] Remove prints --- streaming/base/dataset.py | 2 -- streaming/base/stream.py | 2 -- tests/test_streaming_remote.py | 3 +-- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 1e5e2113e..c243f6f23 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -519,8 +519,6 @@ def __init__(self, # Build the shard index (for partitioning and mapping samples to shards). self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) - print('I am here 5.1, samples_per_shard = ') - print(self.samples_per_shard) self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard self.spanner = Spanner(self.samples_per_shard) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 3f8bd1956..dd2b2a315 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -568,8 +568,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: from streaming.base.converters import infer_dataframe_schema w = WorkspaceClient() - ##cluster_id = "0201-234512-tcp9nfat" # e2-dogfood - #cluster_id = "0523-224100-tid6mais" # db-force-one sparkSession = DatabricksSession.builder.remote( host=w.config.host, diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 5d9ca701c..ee265521b 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -79,8 +79,7 @@ def get_dataset(name: str, local = d['local'] remote = d['remote'] kwargs = {**d['kwargs'], **other_kwargs} - dataset = d['class'](d['cluster_id'], - local=local, + dataset = d['class'](local=local, remote=remote, split=split, shuffle=shuffle, From e975501f7695e15c662a9c27b148c69e277210d5 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 23 May 2024 22:31:50 -0700 Subject: [PATCH 044/145] Remove prints --- streaming/base/dataset.py | 1 - streaming/base/stream.py | 1 - 2 files changed, 2 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index c243f6f23..42c645a7e 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1259,7 +1259,6 @@ def get_item(self, sample_id: int, retry: int = 7) -> Any: raise RuntimeError('Background thread failed. Check other traceback.') # Locate the shard and sample offset within that shard where the sample lives. shard_id, shard_sample_id = self.spanner[sample_id] - #print('I am here 5.2', shard_id, shard_sample_id) shard = self.shards[shard_id] sample = None diff --git a/streaming/base/stream.py b/streaming/base/stream.py index dd2b2a315..d59e966e8 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -544,7 +544,6 @@ def generate_unique_basename(self, url: str, index: int) -> str: """Generate a unique basename for the file path from the URL.""" hash_object = hashlib.md5(url.encode()) hex_dig = hash_object.hexdigest() - # basename = f"{hex_dig[:3]}/shard.{int(hex_dig, 16) % 100000:05d}.mds" basename = '.'.join(['shard', f'{index:05}', 'mds']) self.url_to_basename[url] = basename self.basename_to_url[basename] = url From 03393ddc90a96c7f49556e8b2c0b8a79fc734adf Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sun, 2 Jun 2024 00:33:14 -0700 Subject: [PATCH 045/145] update --- streaming/base/stream.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index d59e966e8..eb6df254a 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -510,6 +510,30 @@ def get_index_size(self) -> int: filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size +import json +import os + +def save_dict_to_file(directory, filename, dictionary): + """Save a dictionary to a file in the specified directory.""" + if not os.path.exists(directory): + os.makedirs(directory) + + file_path = os.path.join(directory, filename) + with open(file_path, 'w') as file: + json.dump(dictionary, file, indent=4) + print(f"Dictionary saved to {file_path}") + +def load_dict_from_file(directory, filename): + """Load a dictionary from a file in the specified directory.""" + file_path = os.path.join(directory, filename) + if not os.path.exists(file_path): + raise FileNotFoundError(f"No such file: '{file_path}'") + + with open(file_path, 'r') as file: + dictionary = json.load(file) + print(f"Dictionary loaded from {file_path}") + return dictionary + class DeltaStream(Stream): @@ -547,6 +571,7 @@ def generate_unique_basename(self, url: str, index: int) -> str: basename = '.'.join(['shard', f'{index:05}', 'mds']) self.url_to_basename[url] = basename self.basename_to_url[basename] = url + return basename def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: @@ -647,6 +672,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: shard.validate(allow_unsafe_types) shards.append(shard) + save_dict_to_file('./', 'basename_to_url.json', self.basename_to_url) + return shards def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: From 44eb7f459942b21585eaef19f0b8385e97683dfb Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 14:36:01 -0700 Subject: [PATCH 046/145] Add dbsql --- streaming/base/dataset.py | 19 ++-- streaming/base/stream.py | 193 ++++++++++++++++++++++++++++++++- tests/test_streaming_remote.py | 92 +++++++++------- 3 files changed, 248 insertions(+), 56 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 42c645a7e..cc43228d3 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -34,7 +34,7 @@ from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, get_shm_prefix) from streaming.base.spanner import Spanner -from streaming.base.stream import Stream, DeltaStream +from streaming.base.stream import Stream, DeltaDBSQLStream from streaming.base.util import bytes_to_int, number_abbrev_to_int from streaming.base.world import World @@ -445,15 +445,14 @@ def __init__(self, for stream in streams: stream.apply_default(default) elif remote is not None and remote.startswith('SELECT'): - cluster_id = kwargs.get('cluster_id', None) - default = DeltaStream(cluster_id, - remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip) + default = DeltaDBSQLStream(remote=remote, + local=local, + split=split, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + **kwargs) streams = [default] else: default = Stream(remote=remote, diff --git a/streaming/base/stream.py b/streaming/base/stream.py index eb6df254a..4c0cdc523 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -22,6 +22,7 @@ from streaming.base.util import retry, wait_for_file_to_exist from streaming.base.world import World +import re import pyarrow as pa import requests from tempfile import TemporaryDirectory @@ -510,8 +511,6 @@ def get_index_size(self) -> int: filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size -import json -import os def save_dict_to_file(directory, filename, dictionary): """Save a dictionary to a file in the specified directory.""" @@ -535,7 +534,7 @@ def load_dict_from_file(directory, filename): return dictionary -class DeltaStream(Stream): +class DeltaSCStream(Stream): def __init__(self, cluster_id: str, @@ -672,8 +671,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: shard.validate(allow_unsafe_types) shards.append(shard) - save_dict_to_file('./', 'basename_to_url.json', self.basename_to_url) - return shards def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: @@ -709,3 +706,189 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): return local +class DeltaDBSQLStream(Stream): + + def __init__(self, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None, + **kwargs: Any) -> None: + super().__init__(remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + + warehouse_id = kwargs.get('warehouse_id', None) + host = kwargs.get('host', os.environ['DATABRICKS_HOST']) + token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) + catalog = kwargs.get('catalog', None) + schema = kwargs.get('schema', None) + + if any([not warehouse_id, not host, not token, not catalog, not schema]): + raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization") + + self.base_url = f"https://{self.host}/api/2.0/sql/statements/" + self.headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json" + } + self.data = { + "warehouse_id": warehouse_id, + "catalog": catalog, + "schema": schema, + "format": "ARROW_STREAM", + "disposition": "EXTERNAL_LINKS", + "statement": remote, + "wait_timeout": "2s", + } + + def refresh_statement_id(self, timeout=100): + total_time = 0 + while total_time <= timeout: + response = requests.post(self.base_url, headers=self.headers, json=self.data) + response.raise_for_status() + response_data = response.json() + query_status = response_data['status']['state'] + + if query_status == "SUCCEEDED": + self.statement_id = response_data['statement_id'] + return response_data + + print(f"Query status: {query_status}") + time.sleep(3) + total_time += 3 + raise TimeoutError(f"Query execution failed with status: {query_status}") + + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + """Load this Stream's index, retrieving its shard readers. + + Args: + world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. + + Returns: + `List[Reader]: Shard readers. + """ + metadata = self.refresh_statement_id() + + # Local leader prepares the index file based on cloudfetch results + basename = get_index_basename() + filename = os.path.join(self.local, self.split, basename) + + self.columns = metadata['manifest']['schema']['columns'] + column_names = [ c['name'] for c in self.columns ] + column_encodings = [ c['type_name'].lower() for c in self.columns ] + column_sizes = [ None for _ in self.columns ] + total_shard_count = metadata['manifest']['total_chunk_count'] + + if world.is_local_leader: + + metadata = { + "version": 2, + "shards": [] + } + + for shard_id, shard_meta in enumerate(metadata['manifest']['chunks']): + shard = { + "column_encodings": column_encodings, + "column_names": column_names, + "column_sizes": column_sizes, + "compression": None, + "format": "mds", + "hashes": ["sha1"], + "raw_data": { + "basename": f'shard.{shard_id:05}.mds', + "bytes": shard_meta['byte_count'], + "hashes": {} + }, + "samples": shard_meta['row_count'], + "size_limit": 67108864, + "version": 2, + "zip_data": None + } + metadata["shards"].append(shard) + + with open(filename, 'w') as f: + json.dump(metadata, f, indent=4) + + else: + wait_for_file_to_exist( + filename, TICK, self.download_timeout, + f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + + f'-> {filename} took too long to download. Either increase the ' + + f'`download_timeout` value or check the other traceback.') + + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + + # Initialize shard readers according to the loaded info. + shards = [] + for info in obj['shards']: + shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) + shards.append(shard) + + return shards + + @retry(num_attemps=2) + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + """Safely download a file from remote to local cache. + + Args: + from_basename (str): Source basename. + to_basename (str, optional): Destination basename, if different. + + Returns: + str: Local cache filename. + """ + from streaming import MDSWriter + + def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + + chunk_index = int(re.search(r'\d+', from_basename).group()) + cloud_fetch_url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + local = os.path.join(self.local, self.split, from_basename) + + # Attempt to download, possibly repeating on failure. + try: + retry(num_attempts=self.download_retry)( + lambda: fetch_and_convert(cloud_fetch_url, local))() + print('download to local is done = ', local) + return local + except: + self.refresh_statement_id() + + + + diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index ee265521b..59282e114 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -13,48 +13,47 @@ def get_dataset(name: str, - local: str, split: str, shuffle: bool, batch_size: Optional[int], other_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[int, StreamingDataset]: other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { - 'refinedweb': { - 'local': f'/tmp/test_refinedweb_05May1029', - 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', - 'num_samples': 20206, - 'class': StreamingDataset, - 'kwargs': {}, - }, - 'dummy_table': { - 'local': f'/tmp/test_dummy_table_05May1029', - 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', - 'num_samples': 20206, - 'class': StreamingDataset, - 'kwargs': { - 'cluster_id': "0201-234512-tcp9nfat" - }, - }, - 'random_cpt_table': { - 'local': f'/tmp/test_random_cpt_table_05May1029', - 'remote': 'SELECT text FROM main.streaming.random_cpt_table', - 'num_samples': 100000, - 'class': StreamingDataset, - 'kwargs': { - 'cluster_id': "0201-234512-tcp9nfat" - }, - }, - 'random_large_table': { - 'local': f'/tmp/test_random_large_table_05May1029', - 'remote': 'SELECT * FROM main.streaming.random_large_table', - 'num_samples': 100000, - 'class': StreamingDataset, - 'kwargs': { - 'cluster_id': "0201-234512-tcp9nfat" - }, - }, - 'reddit_table': { + # 'refinedweb': { + # 'local': f'/tmp/test_refinedweb_05May1029', + # 'remote': 'dbfs:/Volumes/main/mosaic_hackathon/managed-volume/mds/refinedweb/', + # 'num_samples': 20206, + # 'class': StreamingDataset, + # 'kwargs': {}, + # }, + # 'dummy_table': { + # 'local': f'/tmp/test_dummy_table_05May1029', + # 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', + # 'num_samples': 20206, + # 'class': StreamingDataset, + # 'kwargs': { + # 'cluster_id': "0201-234512-tcp9nfat" + # }, + # }, + # 'random_cpt_table': { + # 'local': f'/tmp/test_random_cpt_table_05May1029', + # 'remote': 'SELECT text FROM main.streaming.random_cpt_table', + # 'num_samples': 100000, + # 'class': StreamingDataset, + # 'kwargs': { + # 'cluster_id': "0201-234512-tcp9nfat" + # }, + # }, + # 'random_large_table': { + # 'local': f'/tmp/test_random_large_table_05May1029', + # 'remote': 'SELECT * FROM main.streaming.random_large_table', + # 'num_samples': 100000, + # 'class': StreamingDataset, + # 'kwargs': { + # 'cluster_id': "0201-234512-tcp9nfat" + # }, + # }, + 'reddit_table_sparkconnect': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': 'SELECT text, added FROM main.reddit.data', 'num_samples': 378156152, @@ -63,13 +62,24 @@ def get_dataset(name: str, 'cluster_id': "0523-224100-tid6mais" }, }, - 'debug_local': { + 'reddit_table_dbsql': { 'local': f'/tmp/test_random_reddit_table_05May1029', - 'remote': None, + 'remote': 'SELECT text, added FROM main.reddit.data', 'num_samples': 378156152, 'class': StreamingDataset, - 'kwargs': {} + 'kwargs': { + 'warehouse_id': "0523-224100-tid6mais", + 'catalog': 'main', + 'schema': 'reddit', + }, }, + # 'debug_local': { + # 'local': f'/tmp/test_random_reddit_table_05May1029', + # 'remote': None, + # 'num_samples': 378156152, + # 'class': StreamingDataset, + # 'kwargs': {} + # }, } #if name not in dataset_map and split not in dataset_map[name]['num_samples'][split]: # raise ValueError('Could not load dataset with name={name} and split={split}') @@ -92,7 +102,6 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() expected_samples, dataset = get_dataset(name=name, - local=None, # f'/tmp/test_delta_05May1029', split=split, shuffle=False, batch_size=16) @@ -125,9 +134,10 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: #if __name__ == "__main__": # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) -test_streaming_remote_dataset(name = 'random_cpt_table', split=None) +# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) # test_streaming_remote_dataset(name = 'random_large_table', split=None) # test_streaming_remote_dataset(name = 'reddit_table', split=None) +test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) # test_streaming_remote_dataset(name = 'debug_local', split=None) From 58a296f48841393a4031639bbc84a70a2e78cc63 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 15:08:00 -0700 Subject: [PATCH 047/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 1e051c20a..5edbaf5f5 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -7,7 +7,7 @@ import json import os import tempfile -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Any import numpy as np from numpy.typing import NDArray From 58fc267a3af1b63b4e3e4b9813703f0c898a89ad Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 15:15:48 -0700 Subject: [PATCH 048/145] update --- rust/Cargo.lock | 1633 ++++++++++++++++++++++++++++++++++++++ streaming/base/stream.py | 2 +- 2 files changed, 1634 insertions(+), 1 deletion(-) create mode 100644 rust/Cargo.lock diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 000000000..bafac98ce --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1633 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "const-random", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "arrow" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04a8801ebb147ad240b2d978d3ab9f73c9ccd4557ba6a03e7800496770ed10e0" +dependencies = [ + "ahash", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-csv", + "arrow-data", + "arrow-ipc", + "arrow-json", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "pyo3", +] + +[[package]] +name = "arrow-arith" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "895263144bd4a69751cbe6a34a53f26626e19770b313a9fa792c415cd0e78f11" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "num", +] + +[[package]] +name = "arrow-array" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "226fdc6c3a4ae154a74c24091d36a90b514f0ed7112f5b8322c1d8f354d8e20d" +dependencies = [ + "ahash", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "hashbrown", + "num", +] + +[[package]] +name = "arrow-buffer" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4843af4dd679c2f35b69c572874da8fde33be53eb549a5fb128e7a4b763510" +dependencies = [ + "bytes", + "half", + "num", +] + +[[package]] +name = "arrow-cast" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e8b9990733a9b635f656efda3c9b8308c7a19695c9ec2c7046dd154f9b144b" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "chrono", + "half", + "lexical-core", + "num", +] + +[[package]] +name = "arrow-csv" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "646fbb4e11dd0afb8083e883f53117713b8caadb4413b3c9e63e3f535da3683c" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "csv", + "csv-core", + "lazy_static", + "lexical-core", + "regex", +] + +[[package]] +name = "arrow-data" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da900f31ff01a0a84da0572209be72b2b6f980f3ea58803635de47913191c188" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", + "num", +] + +[[package]] +name = "arrow-ipc" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2707a8d7ee2d345d045283ece3ae43416175873483e5d96319c929da542a0b1f" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "flatbuffers", +] + +[[package]] +name = "arrow-json" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d1b91a63c356d14eedc778b76d66a88f35ac8498426bb0799a769a49a74a8b4" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "indexmap", + "lexical-core", + "num", + "serde", + "serde_json", +] + +[[package]] +name = "arrow-ord" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "584325c91293abbca7aaaabf8da9fe303245d641f5f4a18a6058dc68009c7ebf" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "num", +] + +[[package]] +name = "arrow-row" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e32afc1329f7b372463b21c6ca502b07cf237e1ed420d87706c1770bb0ebd38" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "half", + "hashbrown", +] + +[[package]] +name = "arrow-schema" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b104f5daa730f00fde22adc03a12aa5a2ae9ccbbf99cbd53d284119ddc90e03d" +dependencies = [ + "bitflags 2.5.0", +] + +[[package]] +name = "arrow-select" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73b3ca55356d1eae07cf48808d8c462cea674393ae6ad1e0b120f40b422eb2b4" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num", +] + +[[package]] +name = "arrow-string" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1433ce02590cae68da0a18ed3a3ed868ffac2c6f24c533ddd2067f7ee04b4a" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "num", + "regex", + "regex-syntax 0.7.5", +] + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "backtrace" +version = "0.3.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17c6a35df3749d2e8bb1b7b21a976d82b15548788d2735b9d82f329268f71a11" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + +[[package]] +name = "brotli" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" + +[[package]] +name = "cc" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "windows-targets 0.52.5", +] + +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "flatbuffers" +version = "23.5.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" +dependencies = [ + "bitflags 1.3.2", + "rustc_version", +] + +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "lz4" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e9e2dd86df36ce760a60f6ff6ad526f7ba1f14ba0356f8254fb6905e6494df1" +dependencies = [ + "libc", + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "miniz_oxide" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.48.0", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8ec7ab813848ba4522158d5517a6093db1ded27575b070f4177b8d12b41db5e" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.5", +] + +[[package]] +name = "parquet" +version = "46.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad2cba786ae07da4d73371a88b9e0f9d3ffac1a9badc83922e0e15814f5c5fa" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema", + "arrow-select", + "base64", + "brotli", + "bytes", + "chrono", + "flate2", + "futures", + "hashbrown", + "lz4", + "num", + "num-bigint", + "paste", + "seq-macro", + "snap", + "thrift", + "tokio", + "twox-hash", + "zstd", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "proc-macro2" +version = "1.0.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.19.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags 2.5.0", +] + +[[package]] +name = "regex" +version = "1.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-automata" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.3", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + +[[package]] +name = "regex-syntax" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + +[[package]] +name = "rust" +version = "0.1.0" +dependencies = [ + "arrow", + "bytes", + "futures", + "parquet", + "pyo3", + "tokio", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "serde_json" +version = "1.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" + +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding", + "ordered-float", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "tokio" +version = "1.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-macros" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "zerocopy" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "6.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 5edbaf5f5..5894e2257 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -854,7 +854,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: return shards - @retry(num_attemps=2) + @retry(num_attempts=2) def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: """Safely download a file from remote to local cache. From e591f3585cf5185249e5a87bc8719ef6f66c6a91 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 15:18:03 -0700 Subject: [PATCH 049/145] update --- streaming/base/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 5894e2257..ec04d507f 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -740,9 +740,9 @@ def __init__(self, if any([not warehouse_id, not host, not token, not catalog, not schema]): raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization") - self.base_url = f"https://{self.host}/api/2.0/sql/statements/" + self.base_url = f"https://{host}/api/2.0/sql/statements/" self.headers = { - "Authorization": f"Bearer {self.token}", + "Authorization": f"Bearer {token}", "Content-Type": "application/json" } self.data = { From 7332b0a3e789cdc693fcceedd480cb2f729f8ee0 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 15:25:51 -0700 Subject: [PATCH 050/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index b9fe04ccb..243c0c0b4 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -68,7 +68,7 @@ def get_dataset(name: str, 'num_samples': 378156152, 'class': StreamingDataset, 'kwargs': { - 'warehouse_id': "0523-224100-tid6mais", + 'warehouse_id': "89cf2c9b9f9cb3bc", 'catalog': 'main', 'schema': 'reddit', }, From d3c7e2fc8243ea7a753ddac124c6abec2be6b874 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 15:26:43 -0700 Subject: [PATCH 051/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 243c0c0b4..8ea32519c 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -5,7 +5,7 @@ import time from typing import Any, Dict, Optional, Tuple -import pytest +#import pytest from streaming.base import StreamingDataset from streaming.text import StreamingC4 From fac3b289d174fcb8510ce072ea9518c32a6fda3f Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 15:50:45 -0700 Subject: [PATCH 052/145] update --- streaming/base/stream.py | 3 +-- tests/test_streaming_remote.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index ec04d507f..76495af1c 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -747,12 +747,11 @@ def __init__(self, } self.data = { "warehouse_id": warehouse_id, - "catalog": catalog, - "schema": schema, "format": "ARROW_STREAM", "disposition": "EXTERNAL_LINKS", "statement": remote, "wait_timeout": "2s", + "parameters": [], } def refresh_statement_id(self, timeout=100): diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 8ea32519c..1ccb05b85 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -64,7 +64,7 @@ def get_dataset(name: str, }, 'reddit_table_dbsql': { 'local': f'/tmp/test_random_reddit_table_05May1029', - 'remote': 'SELECT text, added FROM main.reddit.data', + 'remote': 'SELECT * FROM main.reddit.data', 'num_samples': 378156152, 'class': StreamingDataset, 'kwargs': { From 2b632ef3e9a399287513a7579ff54971cc554254 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:02:56 -0700 Subject: [PATCH 053/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 76495af1c..9b43a1eec 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -750,7 +750,7 @@ def __init__(self, "format": "ARROW_STREAM", "disposition": "EXTERNAL_LINKS", "statement": remote, - "wait_timeout": "2s", + "wait_timeout": "5s", # cannot be less than 5 otherwise throws bad request error "parameters": [], } From 22105c3e6c573846263d06cbd4d088e81d9f8bc3 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:05:03 -0700 Subject: [PATCH 054/145] update --- streaming/base/stream.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 9b43a1eec..2d525ba59 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -783,17 +783,17 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: Returns: `List[Reader]: Shard readers. """ - metadata = self.refresh_statement_id() + sql_response = self.refresh_statement_id() # Local leader prepares the index file based on cloudfetch results basename = get_index_basename() filename = os.path.join(self.local, self.split, basename) - self.columns = metadata['manifest']['schema']['columns'] + self.columns = sql_response['manifest']['schema']['columns'] column_names = [ c['name'] for c in self.columns ] column_encodings = [ c['type_name'].lower() for c in self.columns ] column_sizes = [ None for _ in self.columns ] - total_shard_count = metadata['manifest']['total_chunk_count'] + total_shard_count = sql_response['manifest']['total_chunk_count'] if world.is_local_leader: @@ -802,7 +802,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: "shards": [] } - for shard_id, shard_meta in enumerate(metadata['manifest']['chunks']): + for shard_id, shard_meta in enumerate(sql_response['manifest']['chunks']): shard = { "column_encodings": column_encodings, "column_names": column_names, From 3de533e6a3841252d282b5c9714377c1b4b40488 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:23:29 -0700 Subject: [PATCH 055/145] update --- streaming/base/stream.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 2d525ba59..cdca33028 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -876,7 +876,11 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): os.rename(temp_mds_filename, local_shard_path) chunk_index = int(re.search(r'\d+', from_basename).group()) - cloud_fetch_url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + print('from_basename = ', from_basename) + print('chunk_index = ', chunk_index) + response = requests.get(f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}", headers = self.headers) + response.raise_for_status() + cloud_fetch_url = json.loads(response.decode('utf-8'))['external_links'][0]['external_link'] local = os.path.join(self.local, self.split, from_basename) # Attempt to download, possibly repeating on failure. @@ -886,5 +890,6 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('download to local is done = ', local) return local except: + print('Failed to download, refresh statement id and try again') self.refresh_statement_id() From 13c7d3394dabda1532e54a9d4860f48d78a28e94 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:24:38 -0700 Subject: [PATCH 056/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index cdca33028..c57c56494 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -880,7 +880,7 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('chunk_index = ', chunk_index) response = requests.get(f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}", headers = self.headers) response.raise_for_status() - cloud_fetch_url = json.loads(response.decode('utf-8'))['external_links'][0]['external_link'] + cloud_fetch_url = response.json()['external_links'][0]['external_link'] local = os.path.join(self.local, self.split, from_basename) # Attempt to download, possibly repeating on failure. From e273453a9fb217fbb79b21a67154c873261cfe46 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:26:28 -0700 Subject: [PATCH 057/145] update --- streaming/base/stream.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index c57c56494..0d053f4e8 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -891,5 +891,7 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): return local except: print('Failed to download, refresh statement id and try again') + print('cloud_fetch_url = ', cloud_fetch_url) + print('response = ', response.json()) self.refresh_statement_id() From 9a0b09bfd6e1aeab834299e9da607aea7a0c82bc Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:28:36 -0700 Subject: [PATCH 058/145] update --- streaming/base/stream.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 0d053f4e8..895014a98 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -889,9 +889,10 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): lambda: fetch_and_convert(cloud_fetch_url, local))() print('download to local is done = ', local) return local - except: + except Exception as e: print('Failed to download, refresh statement id and try again') print('cloud_fetch_url = ', cloud_fetch_url) print('response = ', response.json()) + print(e) self.refresh_statement_id() From 4bde02a9abc79cf137f50edc7a29d97662b98957 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:30:29 -0700 Subject: [PATCH 059/145] update --- streaming/base/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 895014a98..3c9321c6b 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -894,5 +894,6 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('cloud_fetch_url = ', cloud_fetch_url) print('response = ', response.json()) print(e) + raise from e self.refresh_statement_id() From 7ce0c14e3c45ab1769b792adbf20a254e247d7ea Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:31:00 -0700 Subject: [PATCH 060/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 3c9321c6b..6d3999c30 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -894,6 +894,6 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('cloud_fetch_url = ', cloud_fetch_url) print('response = ', response.json()) print(e) - raise from e + raise RuntimeError from e self.refresh_statement_id() From 0f3dbca78199d4618368554f49650d13462d382b Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:35:19 -0700 Subject: [PATCH 061/145] update --- streaming/base/stream.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 6d3999c30..a9bf3fc6a 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -789,10 +789,11 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: basename = get_index_basename() filename = os.path.join(self.local, self.split, basename) - self.columns = sql_response['manifest']['schema']['columns'] - column_names = [ c['name'] for c in self.columns ] - column_encodings = [ c['type_name'].lower() for c in self.columns ] - column_sizes = [ None for _ in self.columns ] + column_meta = sql_response['manifest']['schema']['columns'] + column_names = [ c['name'] for c in column_meta ] + column_encodings = [ c['type_name'].lower() for c in column_meta] + column_sizes = [ None for _ in column_meta ] + self.columns = { c['name'] : c['type_name'].lower() for c in column_meta } total_shard_count = sql_response['manifest']['total_chunk_count'] if world.is_local_leader: From d28f113b5aab35beb11d414b8d2fd5b1f91b5a66 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:50:02 -0700 Subject: [PATCH 062/145] update --- streaming/base/stream.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index a9bf3fc6a..5d44cc672 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -754,6 +754,25 @@ def __init__(self, "parameters": [], } + # From dbsql dtyps (lower case) to MDS encoded types + # https://docs.databricks.com/en/dev-tools/python-sql-connector.html + self.dtypes_mapping = { + 'string' : 'str', + 'bigint' : 'int64', + 'array': 'ndarray', + 'binary': 'bytes', + 'boolean': 'uint32', + 'date': 'str', + 'datetime.date': 'str', + 'decimal': 'str_decimal', + 'double' : 'float64', + 'int': 'int', + 'map': 'json', + 'smallint': 'int16', + 'struct': 'json', + 'tinyint': 'int8', + } + def refresh_statement_id(self, timeout=100): total_time = 0 while total_time <= timeout: @@ -771,6 +790,12 @@ def refresh_statement_id(self, timeout=100): total_time += 3 raise TimeoutError(f"Query execution failed with status: {query_status}") + def get_encode_format(self, sql_fmt: str): + mds_fmt = self.dtypes_mapping.get(sql_fmt, None) + if not mds_fmt: + raise TypeError(f"{sql_fmt} is not supported by MDSWrite.") + return mds_fmt + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: """Load this Stream's index, retrieving its shard readers. @@ -790,10 +815,15 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: filename = os.path.join(self.local, self.split, basename) column_meta = sql_response['manifest']['schema']['columns'] - column_names = [ c['name'] for c in column_meta ] - column_encodings = [ c['type_name'].lower() for c in column_meta] - column_sizes = [ None for _ in column_meta ] - self.columns = { c['name'] : c['type_name'].lower() for c in column_meta } + column_names, column_encodings, column_sizes = [], [], [] + self.columns = {} + for c in column_meta: + column_names.append(c['name']) + encoding = self.get_encode_format(c['type_name']) + column_encodings.append(encoding) + column_sizes.append(None) + self.columns[c['name']] = encoding + total_shard_count = sql_response['manifest']['total_chunk_count'] if world.is_local_leader: From 493f186b772721c3117503fe96e78b6aad447829 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:51:33 -0700 Subject: [PATCH 063/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 5d44cc672..53bc8b057 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -791,7 +791,7 @@ def refresh_statement_id(self, timeout=100): raise TimeoutError(f"Query execution failed with status: {query_status}") def get_encode_format(self, sql_fmt: str): - mds_fmt = self.dtypes_mapping.get(sql_fmt, None) + mds_fmt = self.dtypes_mapping.get(sql_fmt.lower(), None) if not mds_fmt: raise TypeError(f"{sql_fmt} is not supported by MDSWrite.") return mds_fmt From 0c3917af7ef8ba96dae9fceb492b1c22ee2f168f Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 16:52:52 -0700 Subject: [PATCH 064/145] update --- streaming/base/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 53bc8b057..030980f50 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -771,6 +771,7 @@ def __init__(self, 'smallint': 'int16', 'struct': 'json', 'tinyint': 'int8', + 'long': 'int8', } def refresh_statement_id(self, timeout=100): From 5ba220091568660fa7675063e5969475cfc6e430 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:06:57 -0700 Subject: [PATCH 065/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 030980f50..2e48723ca 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -784,6 +784,7 @@ def refresh_statement_id(self, timeout=100): if query_status == "SUCCEEDED": self.statement_id = response_data['statement_id'] + save_dict_to_file(self.local, f'response_{int(time.time()}', response_data) return response_data print(f"Query status: {query_status}") @@ -926,6 +927,5 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('cloud_fetch_url = ', cloud_fetch_url) print('response = ', response.json()) print(e) - raise RuntimeError from e self.refresh_statement_id() From 0bda7a2e48295dde3b334f00fc13dcdaf5d867a8 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:07:36 -0700 Subject: [PATCH 066/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 2e48723ca..d67812624 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -784,7 +784,7 @@ def refresh_statement_id(self, timeout=100): if query_status == "SUCCEEDED": self.statement_id = response_data['statement_id'] - save_dict_to_file(self.local, f'response_{int(time.time()}', response_data) + save_dict_to_file(self.local, f'response_{int(time.time())}', response_data) return response_data print(f"Query status: {query_status}") From 9d8e642e6a092d01807a16ef6c242e07a21df0e9 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:08:15 -0700 Subject: [PATCH 067/145] update --- streaming/base/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index d67812624..cc179db2b 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -5,6 +5,7 @@ import hashlib import json +import time import os import tempfile from typing import List, Optional, Sequence, Tuple, Any From e13fd7110150db155820cb926043a43ece82e24e Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:24:55 -0700 Subject: [PATCH 068/145] update --- streaming/base/stream.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index cc179db2b..746bc7239 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -859,6 +859,18 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: with open(filename, 'w') as f: json.dump(metadata, f, indent=4) + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + else: wait_for_file_to_exist( filename, TICK, self.download_timeout, @@ -866,18 +878,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: f'-> {filename} took too long to download. Either increase the ' + f'`download_timeout` value or check the other traceback.') - # Load the index. - try: - obj = json.load(open(filename)) - except json.decoder.JSONDecodeError as error: - error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) - raise error - - # Version check. - if obj['version'] != 2: - raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + - f'Expected version 2.') - # Initialize shard readers according to the loaded info. shards = [] for info in obj['shards']: From 5e95abaf42508a3cef76c5307d34f342f7634d44 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:26:10 -0700 Subject: [PATCH 069/145] update --- streaming/base/stream.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 746bc7239..c3134aaf5 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -858,19 +858,6 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: with open(filename, 'w') as f: json.dump(metadata, f, indent=4) - - # Load the index. - try: - obj = json.load(open(filename)) - except json.decoder.JSONDecodeError as error: - error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) - raise error - - # Version check. - if obj['version'] != 2: - raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + - f'Expected version 2.') - else: wait_for_file_to_exist( filename, TICK, self.download_timeout, @@ -878,6 +865,18 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: f'-> {filename} took too long to download. Either increase the ' + f'`download_timeout` value or check the other traceback.') + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + # Initialize shard readers according to the loaded info. shards = [] for info in obj['shards']: From 4f29dfacd3d6db3570dad924b0991c76e558586d Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:32:48 -0700 Subject: [PATCH 070/145] update --- streaming/base/stream.py | 4 ++-- streaming/base/util.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index c3134aaf5..5c4744034 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -20,7 +20,7 @@ from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json from streaming.base.hashing import get_hash from streaming.base.storage import download_file -from streaming.base.util import retry, wait_for_file_to_exist +from streaming.base.util import retry, wait_for_file_to_exist, wait_for_json_to_exist from streaming.base.world import World import re @@ -859,7 +859,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: with open(filename, 'w') as f: json.dump(metadata, f, indent=4) else: - wait_for_file_to_exist( + wait_for_json_to_exist( filename, TICK, self.download_timeout, f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + f'-> {filename} took too long to download. Either increase the ' + diff --git a/streaming/base/util.py b/streaming/base/util.py index 3be5b729a..02a5a2759 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -46,6 +46,37 @@ def get_list_arg(text: str) -> List[str]: """ return text.split(',') if text else [] +def wait_for_json_to_exist(filename: str, poll_interval: float, timeout: float, + err_msg: str) -> None: + """Wait for a json to exist till timeout seconds. Raise an Exception after that. + + Difference from wait_for_file_to_exist is that we load json and validate. + + Args: + filename (str): A file name of a json + poll_interval (float): Number of seconds to wait before next polling + timeout (float): Number of seconds to wait for a file to exist before raising an exception + err_msg (str): Error message description for an exception + + Raises: + RuntimeError: Raise an Exception if file does not exist after timeout + """ + def is_valid_json(filename): + try: + obj = json.load(open(filename)) + return True + except json.decoder.JSONDecodeError as error: + return False + + start_time = time() + while True: + sleep(poll_interval) + if os.path.exists(filename) and is_valid_json(filename): + sleep(poll_interval) + break + dt = time() - start_time + if dt > timeout: + raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, err_msg: str) -> None: From ee5b568db9210902c8f912831df84d4dea3aa45c Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:42:19 -0700 Subject: [PATCH 071/145] update --- streaming/base/util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/streaming/base/util.py b/streaming/base/util.py index 02a5a2759..e8b721a36 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -27,6 +27,9 @@ logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format='%(asctime)s [Process %(process)d, Thread %(thread)d] %(message)s') + + TCallable = TypeVar('TCallable', bound=Callable) __all__ = [ @@ -72,6 +75,7 @@ def is_valid_json(filename): while True: sleep(poll_interval) if os.path.exists(filename) and is_valid_json(filename): + logging.warning('json has read in') sleep(poll_interval) break dt = time() - start_time From 18ce2774d4737c4f4a0d0e22f1600ac428f8166d Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:51:29 -0700 Subject: [PATCH 072/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 1ccb05b85..8b6877c55 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -115,7 +115,7 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: for _ in dataset: rcvd_samples += 1 - if (rcvd_samples % 100 == 0): + if (rcvd_samples % 10000 == 0): print(f'samples read: {rcvd_samples}') iter_end = time.time() From 792efe94b8233d51514f656e3cc10f858cbd5fa5 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 20:55:48 -0700 Subject: [PATCH 073/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 8b6877c55..1f3018cb5 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -64,7 +64,7 @@ def get_dataset(name: str, }, 'reddit_table_dbsql': { 'local': f'/tmp/test_random_reddit_table_05May1029', - 'remote': 'SELECT * FROM main.reddit.data', + 'remote': 'SELECT text, added FROM main.reddit.data', 'num_samples': 378156152, 'class': StreamingDataset, 'kwargs': { From d0922bdb326f9ad698859a3b09cfc9ab294322ae Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 22:53:55 -0700 Subject: [PATCH 074/145] update --- streaming/base/stream.py | 50 +++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 5c4744034..c3429c636 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -886,7 +886,20 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: return shards - @retry(num_attempts=2) + def _make_request(self, url: str) -> requests.Response: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + return response + + def _fetch_and_convert(self, cloud_fetch_url: str, local_shard_path: str): + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: """Safely download a file from remote to local cache. @@ -897,35 +910,24 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) Returns: str: Local cache filename. """ - from streaming import MDSWriter - - def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): - samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() - with TemporaryDirectory() as temp_dir: - with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: - for sample in samples: - out.write(sample) - temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') - os.rename(temp_mds_filename, local_shard_path) - chunk_index = int(re.search(r'\d+', from_basename).group()) print('from_basename = ', from_basename) print('chunk_index = ', chunk_index) - response = requests.get(f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}", headers = self.headers) - response.raise_for_status() - cloud_fetch_url = response.json()['external_links'][0]['external_link'] - local = os.path.join(self.local, self.split, from_basename) - # Attempt to download, possibly repeating on failure. + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + try: - retry(num_attempts=self.download_retry)( - lambda: fetch_and_convert(cloud_fetch_url, local))() - print('download to local is done = ', local) - return local - except Exception as e: + response = self._make_request(url) + except Exception as e: # requests.exceptions.HTTPError as e: print('Failed to download, refresh statement id and try again') - print('cloud_fetch_url = ', cloud_fetch_url) - print('response = ', response.json()) + print('url = ', url) print(e) self.refresh_statement_id() + response = self._make_request(url) + cloud_fetch_url = response.json()['external_links'][0]['external_link'] + local = os.path.join(self.local, self.split, from_basename) + retry(num_attempts=self.download_retry)(lambda: self._fetch_and_convert(cloud_fetch_url, local))() + + print('Download to local is done = ', local) + return local From c3c47154500dd8b0675e1b8aad15cd5072fed5a4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 22:55:17 -0700 Subject: [PATCH 075/145] update --- streaming/base/stream.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index c3429c636..a14bb0a37 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -887,9 +887,13 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: return shards def _make_request(self, url: str) -> requests.Response: - response = requests.get(url, headers=self.headers) - response.raise_for_status() - return response +# response = requests.get(url, headers=self.headers) +# response.raise_for_status() +# return response + response = requests.Response() + response.status_code = 404 + response.url = url + raise requests.exceptions.HTTPError("Manually raised HTTPError for testing purposes", response=response) def _fetch_and_convert(self, cloud_fetch_url: str, local_shard_path: str): samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() From 75401ed536341198635c75e4962bf2effb24376b Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 22:59:31 -0700 Subject: [PATCH 076/145] update --- streaming/base/stream.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index a14bb0a37..9d3a08881 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -887,13 +887,16 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: return shards def _make_request(self, url: str) -> requests.Response: -# response = requests.get(url, headers=self.headers) -# response.raise_for_status() -# return response - response = requests.Response() - response.status_code = 404 - response.url = url - raise requests.exceptions.HTTPError("Manually raised HTTPError for testing purposes", response=response) + import random + if random.random() < 0.2: # 20% of the time + response = requests.Response() + response.status_code = 404 + response.url = url + raise requests.exceptions.HTTPError("Manually raised HTTPError for testing purposes", response=response) + else: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + return response def _fetch_and_convert(self, cloud_fetch_url: str, local_shard_path: str): samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() From 4c8545dc778e52e82bee234d86e80fbf9cd608e7 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 23:03:08 -0700 Subject: [PATCH 077/145] update --- streaming/base/stream.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 9d3a08881..f758021cb 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -898,15 +898,6 @@ def _make_request(self, url: str) -> requests.Response: response.raise_for_status() return response - def _fetch_and_convert(self, cloud_fetch_url: str, local_shard_path: str): - samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() - with TemporaryDirectory() as temp_dir: - with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: - for sample in samples: - out.write(sample) - temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') - os.rename(temp_mds_filename, local_shard_path) - def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: """Safely download a file from remote to local cache. @@ -917,6 +908,16 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) Returns: str: Local cache filename. """ + from streaming import MDSWriter + def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + chunk_index = int(re.search(r'\d+', from_basename).group()) print('from_basename = ', from_basename) print('chunk_index = ', chunk_index) @@ -934,7 +935,7 @@ def _download_file(self, from_basename: str, to_basename: Optional[str] = None) cloud_fetch_url = response.json()['external_links'][0]['external_link'] local = os.path.join(self.local, self.split, from_basename) - retry(num_attempts=self.download_retry)(lambda: self._fetch_and_convert(cloud_fetch_url, local))() + retry(num_attempts=self.download_retry)(lambda: ._fetch_and_convert(cloud_fetch_url, local))() print('Download to local is done = ', local) return local From df103e80b5d3312dd74823948b93ad1974c3b29a Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 23:03:55 -0700 Subject: [PATCH 078/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index f758021cb..56a56d237 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -935,7 +935,7 @@ def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): cloud_fetch_url = response.json()['external_links'][0]['external_link'] local = os.path.join(self.local, self.split, from_basename) - retry(num_attempts=self.download_retry)(lambda: ._fetch_and_convert(cloud_fetch_url, local))() + retry(num_attempts=self.download_retry)(lambda: _fetch_and_convert(cloud_fetch_url, local))() print('Download to local is done = ', local) return local From 0eb61aa0c00c614f4c8b321ddde87def1b5683c2 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 23:10:19 -0700 Subject: [PATCH 079/145] update --- streaming/base/stream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 56a56d237..875303e72 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -24,6 +24,7 @@ from streaming.base.world import World import re +import random import pyarrow as pa import requests from tempfile import TemporaryDirectory @@ -887,12 +888,11 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: return shards def _make_request(self, url: str) -> requests.Response: - import random - if random.random() < 0.2: # 20% of the time + if random.random() < 0.3: # 20% of the time response = requests.Response() response.status_code = 404 response.url = url - raise requests.exceptions.HTTPError("Manually raised HTTPError for testing purposes", response=response) + raise requests.exceptions.HTTPError(f"Manually raised HTTPError for testing purposes: {int(time.time())}", response=response) else: response = requests.get(url, headers=self.headers) response.raise_for_status() From 3cd0f240b0d9810379b251de2a4c9896db1b20a4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 2 Aug 2024 23:12:00 -0700 Subject: [PATCH 080/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 875303e72..3f0aec02f 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -888,7 +888,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: return shards def _make_request(self, url: str) -> requests.Response: - if random.random() < 0.3: # 20% of the time + if random.random() < 0.0: # make rhs > 0.0 for testing, so x% of the time return HTTPError response = requests.Response() response.status_code = 404 response.url = url From a6e1ec07f7616ee5743e3d6e3cb2b579cc8e60b4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 3 Aug 2024 00:06:58 -0700 Subject: [PATCH 081/145] update --- tests/test_streaming_remote.py | 69 +++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 1f3018cb5..62bddde37 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -7,7 +7,7 @@ #import pytest -from streaming.base import StreamingDataset +from streaming.base import StreamingDataset, StreamingDataLoader from streaming.text import StreamingC4 from streaming.vision import StreamingADE20K, StreamingCIFAR10, StreamingCOCO, StreamingImageNet @@ -35,7 +35,7 @@ def get_dataset(name: str, 'cluster_id': "0201-234512-tcp9nfat" }, }, - 'random_cpt_table': { + 'random_cpt_table_sparkconnect': { 'local': f'/tmp/test_random_cpt_table_05May1029', 'remote': 'SELECT text FROM main.streaming.random_cpt_table', 'num_samples': 100000, @@ -44,6 +44,17 @@ def get_dataset(name: str, 'cluster_id': "0201-234512-tcp9nfat" }, }, + 'random_cpt_table_dbsql': { + 'local': f'/tmp/test_random_cpt_table_05May1029', + 'remote': 'SELECT text FROM main.streaming.random_cpt_table', + 'num_samples': 100000, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'main', + 'schema': 'streaming', + }, + }, 'random_large_table': { 'local': f'/tmp/test_random_large_table_05May1029', 'remote': 'SELECT * FROM main.streaming.random_large_table', @@ -130,14 +141,60 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: # Test all samples arrived assert rcvd_samples >= expected_samples +def test_streaming_remote_dataloader(name: str, split: str) -> None: + # Build StreamingDataset + build_start = time.time() + batch_size = 16 + expected_samples, dataset = get_dataset(name=name, + split=split, + shuffle=False, + batch_size=batch_size) + + + data_loader = StreamingDataLoader(dataset, + batch_size=16, + num_workers=4, + prefetch_factor=None, + #persistent_workers=True, + pin_memory=True, + drop_last=True) + build_end = time.time() + build_dur = build_end - build_start + print('Built dataset') + + # Test basic iteration + rcvd_samples = 0 + iter_start = time.time() + + for batch_idx, data_dict in enumerate(data_loader): + rcvd_samples += batch_size + + if (rcvd_samples % (10*batch_size) == 0): + print(f'samples read: {rcvd_samples}') + + iter_end = time.time() + iter_dur = iter_end - iter_start + samples_per_sec = rcvd_samples / iter_dur + + # Print debug info + print(f'received {rcvd_samples} samples') + print(f'build_dur={build_dur:.2f}s, iter_dur={iter_dur:.2f}, ' + + f'samples_per_sec={samples_per_sec:.2f}') + + # Test all samples arrived + assert rcvd_samples >= expected_samples + -#if __name__ == "__main__": # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) -# test_streaming_remote_dataset(name = 'random_cpt_table', split=None) -# test_streaming_remote_dataset(name = 'random_large_table', split=None) +#test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) +# test_streaming_remote_dataset(name = 'random_large_table', split=None) # test_streaming_remote_dataset(name = 'reddit_table', split=None) -test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) +#test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) # test_streaming_remote_dataset(name = 'debug_local', split=None) +if __name__ == "__main__": + #test_streaming_remote_dataloader(name = 'refinedweb', split=None) + test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) + From 2e42bc911ac9399067091a658cce64f639a81cf4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 3 Aug 2024 00:21:28 -0700 Subject: [PATCH 082/145] update --- streaming/base/stream.py | 3 ++- tests/test_streaming_remote.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 3f0aec02f..66b8f6880 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -922,15 +922,16 @@ def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('from_basename = ', from_basename) print('chunk_index = ', chunk_index) - url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" try: + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" response = self._make_request(url) except Exception as e: # requests.exceptions.HTTPError as e: print('Failed to download, refresh statement id and try again') print('url = ', url) print(e) self.refresh_statement_id() + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" response = self._make_request(url) cloud_fetch_url = response.json()['external_links'][0]['external_link'] diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 62bddde37..4662889ca 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -185,16 +185,16 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: assert rcvd_samples >= expected_samples +if __name__ == "__main__": # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) # test_streaming_remote_dataset(name = 'random_large_table', split=None) # test_streaming_remote_dataset(name = 'reddit_table', split=None) -#test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) # test_streaming_remote_dataset(name = 'debug_local', split=None) -if __name__ == "__main__": #test_streaming_remote_dataloader(name = 'refinedweb', split=None) - test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) + # test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From 82b04d5e4c8b8af43bf82e0c1067da1c17d6883e Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 3 Aug 2024 14:29:59 -0700 Subject: [PATCH 083/145] update --- streaming/base/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 66b8f6880..93c2c21f4 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -754,6 +754,7 @@ def __init__(self, "statement": remote, "wait_timeout": "5s", # cannot be less than 5 otherwise throws bad request error "parameters": [], + "byte_limit": 9223372036854775807, } # From dbsql dtyps (lower case) to MDS encoded types From 184c44d5c67832307126d15a89e15cbd654b884c Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 3 Aug 2024 14:32:58 -0700 Subject: [PATCH 084/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 93c2c21f4..95c494032 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -754,7 +754,7 @@ def __init__(self, "statement": remote, "wait_timeout": "5s", # cannot be less than 5 otherwise throws bad request error "parameters": [], - "byte_limit": 9223372036854775807, + "byte_limit": 10000000000000, } # From dbsql dtyps (lower case) to MDS encoded types From 0e36d21461e6f5900e4dfc5573fa7db10e0155ff Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 3 Aug 2024 22:58:18 -0700 Subject: [PATCH 085/145] update --- setup.py | 1 + tests/test_streaming_remote.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 92c9aa012..517a4af82 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', 'databricks-connect>=14.3.0', + 'pyarrow>=17,<18', ] extra_deps = {} diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 4662889ca..5eed5e4f5 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -186,6 +186,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: if __name__ == "__main__": + from streaming.base.util import clean_stale_shared_memory + clean_stale_shared_memory() # test_streaming_remote_dataset(name = 'refinedweb', split=None) # test_streaming_remote_dataset(name = 'dummy_table', split=None) #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) @@ -196,5 +198,5 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataloader(name = 'refinedweb', split=None) # test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) - + # test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) From c6795b339359fdbd1fda309fda6f2817ee36d1c4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 3 Aug 2024 23:00:10 -0700 Subject: [PATCH 086/145] update --- tests/test_streaming_remote.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 5eed5e4f5..5511a2423 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -84,6 +84,18 @@ def get_dataset(name: str, 'schema': 'reddit', }, }, + 'reddit_table_dbsql_cachelimit': { + 'local': f'/tmp/test_random_reddit_table_05May1029', + 'remote': 'SELECT text, added FROM main.reddit.data', + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'reddit', + 'cache_limit': '100mb', + }, + }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': None, @@ -193,7 +205,9 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) # test_streaming_remote_dataset(name = 'random_large_table', split=None) # test_streaming_remote_dataset(name = 'reddit_table', split=None) - test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) + # test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) + # test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) From 056027197d6bb14289e6dcd4f57cf8fc7e67295b Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 3 Aug 2024 23:12:42 -0700 Subject: [PATCH 087/145] update --- streaming/base/stream.py | 2 +- tests/test_streaming_remote.py | 24 ++++++++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 95c494032..25bea31d9 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -754,7 +754,7 @@ def __init__(self, "statement": remote, "wait_timeout": "5s", # cannot be less than 5 otherwise throws bad request error "parameters": [], - "byte_limit": 10000000000000, + # "byte_limit": 10000000000000, } # From dbsql dtyps (lower case) to MDS encoded types diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 5511a2423..27c6a5b55 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -96,6 +96,19 @@ def get_dataset(name: str, 'cache_limit': '100mb', }, }, + 'wiki_table_dbsql_cachelimit': { + 'local': f'/tmp/test_wiki_table_05May1029', + 'remote': 'SELECT id, text FROM main.streaming.wiki_table', + 'num_samples': 378156152, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + 'cache_limit': '100mb', + }, + 'shuffle': True, + }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': None, @@ -111,6 +124,7 @@ def get_dataset(name: str, expected_samples = d['num_samples'] local = d['local'] remote = d['remote'] + shuffle = d['shuffle'] or shuffle kwargs = {**d['kwargs'], **other_kwargs} dataset = d['class'](local=local, remote=remote, @@ -178,11 +192,12 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: rcvd_samples = 0 iter_start = time.time() - for batch_idx, data_dict in enumerate(data_loader): - rcvd_samples += batch_size + for epcoh in range(3): + for batch_idx, data_dict in enumerate(data_loader): + rcvd_samples += batch_size - if (rcvd_samples % (10*batch_size) == 0): - print(f'samples read: {rcvd_samples}') + if (rcvd_samples % (10*batch_size) == 0): + print(f'samples read: {rcvd_samples}') iter_end = time.time() iter_dur = iter_end - iter_start @@ -213,4 +228,5 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataloader(name = 'refinedweb', split=None) # test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) # test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) + test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) From 837c7cc90532bf1bc5c5e7ff9bd89b87fb9e3aa6 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 9 Aug 2024 16:43:59 -0700 Subject: [PATCH 088/145] update --- tests/test_streaming_remote.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 27c6a5b55..a9d9e7cd5 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -109,6 +109,19 @@ def get_dataset(name: str, }, 'shuffle': True, }, + 'coco_table_dbsql': { + 'local': f'/tmp/test_coco_table_05May1029', + 'remote': 'SELECT data, captions FROM main.streaming.coco_with_meta_and_captions', + 'num_samples': 26688, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + 'cache_limit': '100mb', + }, + 'shuffle': True, + }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': None, @@ -221,12 +234,13 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: # test_streaming_remote_dataset(name = 'random_large_table', split=None) # test_streaming_remote_dataset(name = 'reddit_table', split=None) # test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) - test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) + #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) + test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) # test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) # test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) # test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) - test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) + # test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) From cdef3df03ab3f5c9585d4520a640636fc35febc0 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 9 Aug 2024 16:50:46 -0700 Subject: [PATCH 089/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index a9d9e7cd5..fb635dfef 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -137,7 +137,7 @@ def get_dataset(name: str, expected_samples = d['num_samples'] local = d['local'] remote = d['remote'] - shuffle = d['shuffle'] or shuffle + shuffle = d.get('shuffle', False) or shuffle kwargs = {**d['kwargs'], **other_kwargs} dataset = d['class'](local=local, remote=remote, From 59d19ac44987bb1d9e6007c093f40d8469b4aa6c Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 22:31:01 -0700 Subject: [PATCH 090/145] update --- streaming/base/stream.py | 1 + tests/test_streaming_remote.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 25bea31d9..aaf328237 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -775,6 +775,7 @@ def __init__(self, 'struct': 'json', 'tinyint': 'int8', 'long': 'int8', + 'list': 'json', # assume items are json serializable } def refresh_statement_id(self, timeout=100): diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index fb635dfef..7a3b0f40b 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -118,7 +118,7 @@ def get_dataset(name: str, 'warehouse_id': "89cf2c9b9f9cb3bc", 'catalog': 'main', 'schema': 'streaming', - 'cache_limit': '100mb', + # 'cache_limit': '100mb', }, 'shuffle': True, }, From 4bbc8fdd2623ad814dc3867204bd06ba16a4f6db Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 23:35:03 -0700 Subject: [PATCH 091/145] update --- streaming/base/format/mds/encodings.py | 29 ++++++++++++++++++++++++++ streaming/base/stream.py | 5 +++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 0e7c7fed6..d86fcdf95 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -524,6 +524,34 @@ def _is_valid(self, original: Any, converted: Any) -> None: e.msg = f'Invalid JSON data: {original}' raise +class StrArray(Encoding): + """Store a list of strings.""" + + def encode(self, obj: Any) -> bytes: + encoded_parts = [] + for s in obj: + encoded_str = s.encode('utf-8') + length_prefix = len(encoded_str).to_bytes(4, byteorder='big') + encoded_parts.append(length_prefix + encoded_str) + return b''.join(encode_parts) + + def decode(self, data: bytes) -> Any: + index = 0 + decoded_strings = [] + while index < len(data): + length = int.from_bytes(encdoed_bytes[index:index+4], byteorder='big') + index += 4 + decoded_str = encoded_bytes[index:index+length].decode('utf-8') + decoded_strings.append(decoded_str) + index += length + return decoded_strings + + def _is_valid(self, original: Any, converted: Any) -> None: + try: + json.loads(converted) + except json.decoder.JSONDecodeError as e: + e.msg = f'Invalid JSON data: {original}' + raise # Encodings (name -> class). _encodings = { @@ -545,6 +573,7 @@ def _is_valid(self, original: Any, converted: Any) -> None: 'str_int': StrInt, 'str_float': StrFloat, 'str_decimal': StrDecimal, + 'str_array': StrArray, 'pil': PIL, 'jpeg': JPEG, 'png': PNG, diff --git a/streaming/base/stream.py b/streaming/base/stream.py index aaf328237..3aa076475 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -763,6 +763,7 @@ def __init__(self, 'string' : 'str', 'bigint' : 'int64', 'array': 'ndarray', + 'array': 'str_array', 'binary': 'bytes', 'boolean': 'uint32', 'date': 'str', @@ -775,7 +776,6 @@ def __init__(self, 'struct': 'json', 'tinyint': 'int8', 'long': 'int8', - 'list': 'json', # assume items are json serializable } def refresh_statement_id(self, timeout=100): @@ -825,7 +825,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: self.columns = {} for c in column_meta: column_names.append(c['name']) - encoding = self.get_encode_format(c['type_name']) + encoding = self.get_encode_format(c['type_text']) + print(f'c = {c}, encoding = {encoding}') column_encodings.append(encoding) column_sizes.append(None) self.columns[c['name']] = encoding From 905065060eba3b82fa312536ea614a99a610265d Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 23:35:59 -0700 Subject: [PATCH 092/145] update --- streaming/base/format/mds/encodings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index d86fcdf95..03dee1cc2 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -533,7 +533,7 @@ def encode(self, obj: Any) -> bytes: encoded_str = s.encode('utf-8') length_prefix = len(encoded_str).to_bytes(4, byteorder='big') encoded_parts.append(length_prefix + encoded_str) - return b''.join(encode_parts) + return b''.join(encoded_parts) def decode(self, data: bytes) -> Any: index = 0 From 2079d7fb7414ffd0e10fd1151e9d349a5697e618 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 23:36:39 -0700 Subject: [PATCH 093/145] update --- streaming/base/format/mds/encodings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 03dee1cc2..b0de0ea8f 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -539,9 +539,9 @@ def decode(self, data: bytes) -> Any: index = 0 decoded_strings = [] while index < len(data): - length = int.from_bytes(encdoed_bytes[index:index+4], byteorder='big') + length = int.from_bytes(data[index:index+4], byteorder='big') index += 4 - decoded_str = encoded_bytes[index:index+length].decode('utf-8') + decoded_str = data[index:index+length].decode('utf-8') decoded_strings.append(decoded_str) index += length return decoded_strings From 6f1e84a6d39380fb3cdd7e11f5af665aef0e9fd2 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 23:41:53 -0700 Subject: [PATCH 094/145] update --- streaming/base/format/mds/encodings.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index b0de0ea8f..ee84cbf73 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -539,11 +539,20 @@ def decode(self, data: bytes) -> Any: index = 0 decoded_strings = [] while index < len(data): - length = int.from_bytes(data[index:index+4], byteorder='big') - index += 4 - decoded_str = data[index:index+length].decode('utf-8') - decoded_strings.append(decoded_str) - index += length + try: + length = int.from_bytes(data[index:index+4], byteorder='big') + index += 4 + encoded_str = data[index:index+length] + decoded_str = encoded_str.decode('utf-8') + decoded_strings.append(decoded_str) + index += length + except UnicodeDecodeError as e: + print(f"UnicodeDecodeError: {e} for bytes: {encoded_str}") + decoded_strings.append(f"") + break + except Exception as e: + print(f"Unexpected error: {e}") + break return decoded_strings def _is_valid(self, original: Any, converted: Any) -> None: From 8f17ea276ec2228b81bef5429700088c48b2dabe Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 23:45:03 -0700 Subject: [PATCH 095/145] update --- streaming/base/format/mds/encodings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index ee84cbf73..f00fb4680 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -549,6 +549,7 @@ def decode(self, data: bytes) -> Any: except UnicodeDecodeError as e: print(f"UnicodeDecodeError: {e} for bytes: {encoded_str}") decoded_strings.append(f"") + raise RuntimeError from e break except Exception as e: print(f"Unexpected error: {e}") From c7613d83bf573e7acfdf2c1a4167f89805754e01 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 23:55:58 -0700 Subject: [PATCH 096/145] update --- streaming/base/format/mds/encodings.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index f00fb4680..3c0606be9 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -533,6 +533,14 @@ def encode(self, obj: Any) -> bytes: encoded_str = s.encode('utf-8') length_prefix = len(encoded_str).to_bytes(4, byteorder='big') encoded_parts.append(length_prefix + encoded_str) + data = b''.join(encoded_parts) + + try: + self.decode(data) + except: + print(f'Failed to decode an ecoded obj: {obj}') + raise RuntimeError + return b''.join(encoded_parts) def decode(self, data: bytes) -> Any: From 44d0a6eaa1909d5a1e0971c3124a0ad978008936 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 10 Aug 2024 23:56:42 -0700 Subject: [PATCH 097/145] update --- streaming/base/format/mds/encodings.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 3c0606be9..4654b85ae 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -555,8 +555,6 @@ def decode(self, data: bytes) -> Any: decoded_strings.append(decoded_str) index += length except UnicodeDecodeError as e: - print(f"UnicodeDecodeError: {e} for bytes: {encoded_str}") - decoded_strings.append(f"") raise RuntimeError from e break except Exception as e: From 2de90e70650f96151f350c9cdef6565051041ac8 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sun, 11 Aug 2024 00:04:17 -0700 Subject: [PATCH 098/145] update --- streaming/base/format/mds/encodings.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 4654b85ae..da98fdacb 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -555,6 +555,8 @@ def decode(self, data: bytes) -> Any: decoded_strings.append(decoded_str) index += length except UnicodeDecodeError as e: + print('index =', index) + print('length = ', length) raise RuntimeError from e break except Exception as e: @@ -562,12 +564,6 @@ def decode(self, data: bytes) -> Any: break return decoded_strings - def _is_valid(self, original: Any, converted: Any) -> None: - try: - json.loads(converted) - except json.decoder.JSONDecodeError as e: - e.msg = f'Invalid JSON data: {original}' - raise # Encodings (name -> class). _encodings = { From 348f183879e88aa2f5df5a72090b31bab8148367 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sun, 11 Aug 2024 00:25:00 -0700 Subject: [PATCH 099/145] update --- streaming/base/format/mds/encodings.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index da98fdacb..9ab1cbad5 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -529,6 +529,8 @@ class StrArray(Encoding): def encode(self, obj: Any) -> bytes: encoded_parts = [] + if len(obj) == 0: + raise ValueError(f"obj cannot be empty {obj}") for s in obj: encoded_str = s.encode('utf-8') length_prefix = len(encoded_str).to_bytes(4, byteorder='big') From d188c5a7aa1e020ce4314ad181764f9c33622f19 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sun, 11 Aug 2024 09:50:53 -0700 Subject: [PATCH 100/145] update --- streaming/base/format/mds/encodings.py | 60 +++++++++++++------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 9ab1cbad5..677b51642 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -527,43 +527,43 @@ def _is_valid(self, original: Any, converted: Any) -> None: class StrArray(Encoding): """Store a list of strings.""" - def encode(self, obj: Any) -> bytes: + def encode(self, strings: Any) -> bytes: encoded_parts = [] - if len(obj) == 0: - raise ValueError(f"obj cannot be empty {obj}") - for s in obj: - encoded_str = s.encode('utf-8') - length_prefix = len(encoded_str).to_bytes(4, byteorder='big') - encoded_parts.append(length_prefix + encoded_str) - data = b''.join(encoded_parts) - try: - self.decode(data) - except: - print(f'Failed to decode an ecoded obj: {obj}') - raise RuntimeError + # Encode the length of the list of strings + list_length = len(strings) + encoded_parts.append(list_length.to_bytes(4, byteorder='big')) + for s in strings: + # Encode each string + encoded_str = s.encode('utf-8') # Encode string to UTF-8 bytes + length_prefix = len(encoded_str).to_bytes(4, byteorder='big') # Prefix with 4-byte length + encoded_parts.append(length_prefix + encoded_str) + + # Return the concatenated byte sequence return b''.join(encoded_parts) - def decode(self, data: bytes) -> Any: + + def decode(self, encoded_bytes: bytes) -> Any: index = 0 decoded_strings = [] - while index < len(data): - try: - length = int.from_bytes(data[index:index+4], byteorder='big') - index += 4 - encoded_str = data[index:index+length] - decoded_str = encoded_str.decode('utf-8') - decoded_strings.append(decoded_str) - index += length - except UnicodeDecodeError as e: - print('index =', index) - print('length = ', length) - raise RuntimeError from e - break - except Exception as e: - print(f"Unexpected error: {e}") - break + + # Decode the length of the list of strings + list_length = int.from_bytes(encoded_bytes[index:index+4], byteorder='big') + index += 4 + + for _ in range(list_length): + # Decode the length of the next string + length = int.from_bytes(encoded_bytes[index:index+4], byteorder='big') + index += 4 + + # Extract and decode the string + encoded_str = encoded_bytes[index:index+length] + decoded_str = encoded_str.decode('utf-8') + decoded_strings.append(decoded_str) + + index += length + return decoded_strings From 2b526c7bade40c41bd04320de25fb269a08db882 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 20 Aug 2024 23:01:29 -0700 Subject: [PATCH 101/145] Fix column ordering --- streaming/base/format/mds/encodings.py | 8 ++++---- streaming/base/stream.py | 15 +++++---------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 677b51642..4312d1ce7 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -532,12 +532,12 @@ def encode(self, strings: Any) -> bytes: # Encode the length of the list of strings list_length = len(strings) - encoded_parts.append(list_length.to_bytes(4, byteorder='big')) + encoded_parts.append(list_length.to_bytes(4, byteorder='little')) for s in strings: # Encode each string encoded_str = s.encode('utf-8') # Encode string to UTF-8 bytes - length_prefix = len(encoded_str).to_bytes(4, byteorder='big') # Prefix with 4-byte length + length_prefix = len(encoded_str).to_bytes(4, byteorder='little') # Prefix with 4-byte length encoded_parts.append(length_prefix + encoded_str) # Return the concatenated byte sequence @@ -549,12 +549,12 @@ def decode(self, encoded_bytes: bytes) -> Any: decoded_strings = [] # Decode the length of the list of strings - list_length = int.from_bytes(encoded_bytes[index:index+4], byteorder='big') + list_length = int.from_bytes(encoded_bytes[index:index+4], byteorder='little') index += 4 for _ in range(list_length): # Decode the length of the next string - length = int.from_bytes(encoded_bytes[index:index+4], byteorder='big') + length = int.from_bytes(encoded_bytes[index:index+4], byteorder='little') index += 4 # Extract and decode the string diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 3aa076475..be3b4a6ae 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -820,16 +820,11 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: basename = get_index_basename() filename = os.path.join(self.local, self.split, basename) - column_meta = sql_response['manifest']['schema']['columns'] - column_names, column_encodings, column_sizes = [], [], [] - self.columns = {} - for c in column_meta: - column_names.append(c['name']) - encoding = self.get_encode_format(c['type_text']) - print(f'c = {c}, encoding = {encoding}') - column_encodings.append(encoding) - column_sizes.append(None) - self.columns[c['name']] = encoding + column_meta = sorted([(c['name'], c['type_text'], None) for c in sql_response['manifest']['schema']['columns']], key=lambda x: x[0]) + column_names = [c[0] for c in column_meta] + column_encodings = [self.get_encode_format(c[1]) for c in column_meta] + column_sizes = [c[2] for c in column_meta] + self.columns = dict(zip(column_names, column_encodings)) total_shard_count = sql_response['manifest']['total_chunk_count'] From 335a78b790623731d951de7e46f9a4afc0063a67 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 21 Aug 2024 22:57:51 -0700 Subject: [PATCH 102/145] Fixed column size should appear in index.json --- streaming/base/stream.py | 25 +++++++++++++++++++------ tests/test_streaming_remote.py | 26 +++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index be3b4a6ae..fea229a02 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -778,7 +778,7 @@ def __init__(self, 'long': 'int8', } - def refresh_statement_id(self, timeout=100): + def refresh_statement_id(self, timeout=3600): total_time = 0 while total_time <= timeout: response = requests.post(self.base_url, headers=self.headers, json=self.data) @@ -814,17 +814,30 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: Returns: `List[Reader]: Shard readers. """ + from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, + is_mds_encoding, mds_encode) sql_response = self.refresh_statement_id() # Local leader prepares the index file based on cloudfetch results basename = get_index_basename() filename = os.path.join(self.local, self.split, basename) - column_meta = sorted([(c['name'], c['type_text'], None) for c in sql_response['manifest']['schema']['columns']], key=lambda x: x[0]) - column_names = [c[0] for c in column_meta] - column_encodings = [self.get_encode_format(c[1]) for c in column_meta] - column_sizes = [c[2] for c in column_meta] - self.columns = dict(zip(column_names, column_encodings)) + self.columns = { c['name']: self.get_encode_format(c['type_text']) for c in sql_response['manifest']['schema']['columns'] } + + column_names = [] + column_encodings = [] + column_sizes = [] + for name in sorted(self.columns): + encoding = self.columns[name] + if not is_mds_encoding(encoding): + raise TypeError(f'MDSWriter passed column `{name}` with encoding `{encoding}` ' + + f'is unsupported. Supported encodings are {get_mds_encodings()}') + size = get_mds_encoded_size(encoding) + column_names.append(name) + column_encodings.append(encoding) + column_sizes.append(size) + + print(f'self.columns = {self.columns}') total_shard_count = sql_response['manifest']['total_chunk_count'] diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 7a3b0f40b..7820d36e1 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -55,6 +55,17 @@ def get_dataset(name: str, 'schema': 'streaming', }, }, + 'prompt_response_table_dbsql': { + 'local': f'/tmp/test_prompt_response_table_05May1029', + 'remote': 'SELECT * FROM main.streaming.prompt_response_table_normal_1000000_20000', + 'num_samples': 1000000, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'main', + 'schema': 'streaming', + }, + }, 'random_large_table': { 'local': f'/tmp/test_random_large_table_05May1029', 'remote': 'SELECT * FROM main.streaming.random_large_table', @@ -64,6 +75,17 @@ def get_dataset(name: str, 'cluster_id': "0201-234512-tcp9nfat" }, }, + 'large_liquid_test_table_08_07': { + 'local': f'/tmp/test_liquid_test_table_05May1029', + 'remote': 'SELECT ss_sold_date_sk, ss_sold_time_sk, ss_item_sk, ss_customer_sk, ss_cdemo_sk FROM auto_maintenance_bugbash.stella.large_liquid_test_table_08_07', + 'num_samples': 100000, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'auto_maintenance_bugbash', + 'schema': 'stella', + }, + }, 'reddit_table_sparkconnect': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': 'SELECT text, added FROM main.reddit.data', @@ -235,7 +257,9 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: # test_streaming_remote_dataset(name = 'reddit_table', split=None) # test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) - test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07', split=None) + test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) # test_streaming_remote_dataset(name = 'debug_local', split=None) From caf9ce6472f9db04666d726f168280a3609b37ad Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 28 Aug 2024 12:06:50 -0700 Subject: [PATCH 103/145] update --- streaming/base/stream.py | 271 ++++++++++++++++++++++++++++++++- tests/test_streaming_remote.py | 51 +++---- 2 files changed, 292 insertions(+), 30 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index fea229a02..d7e7cc7a3 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -708,7 +708,7 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): return local -class DeltaDBSQLStream(Stream): +class DeltaDBSQLStreamOriginal(Stream): def __init__(self, remote: Optional[str] = None, @@ -951,3 +951,272 @@ def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('Download to local is done = ', local) return local + + +class DeltaDBSQLStream(Stream): + + def __init__(self, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None, + **kwargs: Any) -> None: + super().__init__(remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip) + + warehouse_id = kwargs.get('warehouse_id', None) + host = kwargs.get('host', os.environ['DATABRICKS_HOST']) + token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) + catalog = kwargs.get('catalog', None) + schema = kwargs.get('schema', None) + use_cached_result = kwargs.get('use_cached_result', False) + + if any([not warehouse_id, not host, not token, not catalog, not schema]): + raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization") + + self.base_url = f"https://{host}/api/2.0/sql/statements/" + self.session_url = f"https://{host}/api/2.0/sql/sessions/" + + self.headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + self.data = { + "warehouse_id": warehouse_id, + "format": "ARROW_STREAM", + "disposition": "EXTERNAL_LINKS", + "statement": remote, + "wait_timeout": "5s", # cannot be less than 5 otherwise throws bad request error + "parameters": [], + } + + if not use_cached_result: + # Create a session id + # Use session id in payload + # Fetch result via get status api + self.session_data = { + "warehouse_id": warehouse_id, + "catalog": catalog, + "schema": schema, + "session_confs": {"use_cached_result": "false"} + } + response = requests.post(self.session_url, headers=self.headers, json=self.session_data) + self.data['session_id'] = response.json()['session_id'] + + + # From dbsql dtyps (lower case) to MDS encoded types + # https://docs.databricks.com/en/dev-tools/python-sql-connector.html + self.dtypes_mapping = { + 'string' : 'str', + 'bigint' : 'int64', + 'array': 'ndarray', + 'array': 'str_array', + 'binary': 'bytes', + 'boolean': 'uint32', + 'date': 'str', + 'datetime.date': 'str', + 'decimal': 'str_decimal', + 'double' : 'float64', + 'int': 'int', + 'map': 'json', + 'smallint': 'int16', + 'struct': 'json', + 'tinyint': 'int8', + 'long': 'int8', + } + + def polling(self, timeout: int = 3600): + total_time = 0 + while total_time <= timeout: + response = requests.get(f"{self.base_url}/{self.statement_id}", headers=self.headers) + response.raise_for_status() + response_data = response.json() + query_status = response_data['status']['state'] + + if query_status == "SUCCEEDED": + save_dict_to_file(self.local, f'response_{int(time.time())}', response_data) + return response_data + + print(f"Query status: {query_status}") + time.sleep(3) + total_time += 3 + raise TimeoutError(f"Query execution failed with status: {query_status}") + + + def refresh_statement_id(self): + response = requests.post(self.base_url, headers=self.headers, json=self.data) + response.raise_for_status() + response_data = response.json() + self.statement_id = response_data['statement_id'] + + def get_encode_format(self, sql_fmt: str): + mds_fmt = self.dtypes_mapping.get(sql_fmt.lower(), None) + if not mds_fmt: + raise TypeError(f"{sql_fmt} is not supported by MDSWrite.") + return mds_fmt + + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + """Load this Stream's index, retrieving its shard readers. + + Args: + world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. + + Returns: + `List[Reader]: Shard readers. + """ + from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, + is_mds_encoding, mds_encode) + self.refresh_statement_id() + sql_response = self.polling() + + # Local leader prepares the index file based on cloudfetch results + basename = get_index_basename() + filename = os.path.join(self.local, self.split, basename) + + self.columns = { c['name']: self.get_encode_format(c['type_text']) for c in sql_response['manifest']['schema']['columns'] } + + column_names = [] + column_encodings = [] + column_sizes = [] + for name in sorted(self.columns): + encoding = self.columns[name] + if not is_mds_encoding(encoding): + raise TypeError(f'MDSWriter passed column `{name}` with encoding `{encoding}` ' + + f'is unsupported. Supported encodings are {get_mds_encodings()}') + size = get_mds_encoded_size(encoding) + column_names.append(name) + column_encodings.append(encoding) + column_sizes.append(size) + + print(f'self.columns = {self.columns}') + + total_shard_count = sql_response['manifest']['total_chunk_count'] + + if world.is_local_leader: + + metadata = { + "version": 2, + "shards": [] + } + + for shard_id, shard_meta in enumerate(sql_response['manifest']['chunks']): + shard = { + "column_encodings": column_encodings, + "column_names": column_names, + "column_sizes": column_sizes, + "compression": None, + "format": "mds", + "hashes": ["sha1"], + "raw_data": { + "basename": f'shard.{shard_id:05}.mds', + "bytes": shard_meta['byte_count'], + "hashes": {} + }, + "samples": shard_meta['row_count'], + "size_limit": 67108864, + "version": 2, + "zip_data": None + } + metadata["shards"].append(shard) + + with open(filename, 'w') as f: + json.dump(metadata, f, indent=4) + else: + wait_for_json_to_exist( + filename, TICK, self.download_timeout, + f'Index file {os.path.join(self.remote or "", self.split or "", basename)} ' + + f'-> {filename} took too long to download. Either increase the ' + + f'`download_timeout` value or check the other traceback.') + + # Load the index. + try: + obj = json.load(open(filename)) + except json.decoder.JSONDecodeError as error: + error.args = (f'Index file at {filename} is empty or corrupted. ' + error.args[0],) + raise error + + # Version check. + if obj['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' + + f'Expected version 2.') + + # Initialize shard readers according to the loaded info. + shards = [] + for info in obj['shards']: + shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) + shards.append(shard) + + return shards + + def _make_request(self, url: str) -> requests.Response: + if random.random() < 0.0: # make rhs > 0.0 for testing, so x% of the time return HTTPError + response = requests.Response() + response.status_code = 404 + response.url = url + raise requests.exceptions.HTTPError(f"Manually raised HTTPError for testing purposes: {int(time.time())}", response=response) + else: + response = requests.get(url, headers=self.headers) + response.raise_for_status() + return response + + def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: + """Safely download a file from remote to local cache. + + Args: + from_basename (str): Source basename. + to_basename (str, optional): Destination basename, if different. + + Returns: + str: Local cache filename. + """ + from streaming import MDSWriter + def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): + samples = pa.ipc.open_stream(requests.get(cloud_fetch_url).content).read_all().to_pylist() + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=self.columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + os.rename(temp_mds_filename, local_shard_path) + + chunk_index = int(re.search(r'\d+', from_basename).group()) + print('from_basename = ', from_basename) + print('chunk_index = ', chunk_index) + + try: + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + response = self._make_request(url) + except Exception as e: # requests.exceptions.HTTPError as e: + print('Failed to download, refresh statement id and try again') + print('url = ', url) + print(e) + self.refresh_statement_id() + url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + response = self._make_request(url) + + cloud_fetch_url = response.json()['external_links'][0]['external_link'] + local = os.path.join(self.local, self.split, from_basename) + retry(num_attempts=self.download_retry)(lambda: _fetch_and_convert(cloud_fetch_url, local))() + + print('Download to local is done = ', local) + return local + diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 7820d36e1..d2ef95827 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -26,22 +26,15 @@ def get_dataset(name: str, 'class': StreamingDataset, 'kwargs': {}, }, - 'dummy_table': { + 'dummy_table_dbsql': { 'local': f'/tmp/test_dummy_table_05May1029', 'remote': 'SELECT * FROM main.streaming.dummy_cpt_table', - 'num_samples': 20206, - 'class': StreamingDataset, - 'kwargs': { - 'cluster_id': "0201-234512-tcp9nfat" - }, - }, - 'random_cpt_table_sparkconnect': { - 'local': f'/tmp/test_random_cpt_table_05May1029', - 'remote': 'SELECT text FROM main.streaming.random_cpt_table', - 'num_samples': 100000, + 'num_samples': 5, 'class': StreamingDataset, 'kwargs': { - 'cluster_id': "0201-234512-tcp9nfat" + 'warehouse_id': "7e083095329f3ca5", + 'catalog': 'main', + 'schema': 'streaming', }, }, 'random_cpt_table_dbsql': { @@ -115,7 +108,7 @@ def get_dataset(name: str, 'warehouse_id': "89cf2c9b9f9cb3bc", 'catalog': 'main', 'schema': 'reddit', - 'cache_limit': '100mb', + 'cache_limit': '10gb', }, }, 'wiki_table_dbsql_cachelimit': { @@ -142,7 +135,7 @@ def get_dataset(name: str, 'schema': 'streaming', # 'cache_limit': '100mb', }, - 'shuffle': True, + 'shuffle': False, }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', @@ -205,7 +198,7 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: def test_streaming_remote_dataloader(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() - batch_size = 16 + batch_size = 1 expected_samples, dataset = get_dataset(name=name, split=split, shuffle=False, @@ -213,7 +206,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: data_loader = StreamingDataLoader(dataset, - batch_size=16, + batch_size=batch_size, num_workers=4, prefetch_factor=None, #persistent_workers=True, @@ -227,7 +220,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: rcvd_samples = 0 iter_start = time.time() - for epcoh in range(3): + for epcoh in range(1): for batch_idx, data_dict in enumerate(data_loader): rcvd_samples += batch_size @@ -250,21 +243,21 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: if __name__ == "__main__": from streaming.base.util import clean_stale_shared_memory clean_stale_shared_memory() -# test_streaming_remote_dataset(name = 'refinedweb', split=None) - # test_streaming_remote_dataset(name = 'dummy_table', split=None) -#test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) -# test_streaming_remote_dataset(name = 'random_large_table', split=None) -# test_streaming_remote_dataset(name = 'reddit_table', split=None) - # test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'refinedweb', split=None) + #test_streaming_remote_dataset(name = 'dummy_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'random_large_table', split=None) + #test_streaming_remote_dataset(name = 'reddit_table', split=None) + #test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) #test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07', split=None) - test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) - -# test_streaming_remote_dataset(name = 'debug_local', split=None) + #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) - # test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) - # test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) - # test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) + #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) + #test_streaming_remote_dataloader(name = 'coco_table_dbsql', split=None) From 34ab263eb1393fb8dc2cf7aaa25d04feee41554c Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 29 Aug 2024 12:04:46 -0700 Subject: [PATCH 104/145] update --- streaming/base/stream.py | 51 +++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index d7e7cc7a3..80648f1bc 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -983,7 +983,7 @@ def __init__(self, token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) catalog = kwargs.get('catalog', None) schema = kwargs.get('schema', None) - use_cached_result = kwargs.get('use_cached_result', False) + self.use_cached_result = kwargs.get('use_cached_result', False) if any([not warehouse_id, not host, not token, not catalog, not schema]): raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization") @@ -996,7 +996,14 @@ def __init__(self, "Content-Type": "application/json" } - self.data = { + self.session_payload = { + "warehouse_id": warehouse_id, + "catalog": catalog, + "schema": schema, + "session_confs": {"use_cached_result": "false"} + } + + self.payload = { "warehouse_id": warehouse_id, "format": "ARROW_STREAM", "disposition": "EXTERNAL_LINKS", @@ -1005,20 +1012,6 @@ def __init__(self, "parameters": [], } - if not use_cached_result: - # Create a session id - # Use session id in payload - # Fetch result via get status api - self.session_data = { - "warehouse_id": warehouse_id, - "catalog": catalog, - "schema": schema, - "session_confs": {"use_cached_result": "false"} - } - response = requests.post(self.session_url, headers=self.headers, json=self.session_data) - self.data['session_id'] = response.json()['session_id'] - - # From dbsql dtyps (lower case) to MDS encoded types # https://docs.databricks.com/en/dev-tools/python-sql-connector.html self.dtypes_mapping = { @@ -1058,12 +1051,28 @@ def polling(self, timeout: int = 3600): raise TimeoutError(f"Query execution failed with status: {query_status}") - def refresh_statement_id(self): - response = requests.post(self.base_url, headers=self.headers, json=self.data) + def refresh_statement_id(self, use_cached_result:bool=False): + + boolean_string = "true" if use_cached_result else "false" + self.session_payload['session_confs']['use_cached_result'] = boolean_string + + print(f"Set the session data to be {self.session_payload}") + + # Create a session id + # Use session id in payload + # Fetch result via get status api + response = requests.post(self.session_url, headers=self.headers, json=self.session_payload) + self.payload['session_id'] = response.json()['session_id'] + + print(f"Set the payload to be {self.payload}") + + response = requests.post(self.base_url, headers=self.headers, json=self.payload) response.raise_for_status() response_data = response.json() self.statement_id = response_data['statement_id'] + return self.polling() + def get_encode_format(self, sql_fmt: str): mds_fmt = self.dtypes_mapping.get(sql_fmt.lower(), None) if not mds_fmt: @@ -1084,8 +1093,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: """ from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) - self.refresh_statement_id() - sql_response = self.polling() + + sql_response = self.refresh_statement_id(self.use_cached_result) # Local leader prepares the index file based on cloudfetch results basename = get_index_basename() @@ -1209,7 +1218,7 @@ def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('Failed to download, refresh statement id and try again') print('url = ', url) print(e) - self.refresh_statement_id() + self.refresh_statement_id(use_cached_result=True) url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" response = self._make_request(url) From 0742df01e6b2ed33ab025b2e002a7167a7f3d006 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 30 Aug 2024 08:59:48 -0700 Subject: [PATCH 105/145] update --- streaming/base/stream.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 80648f1bc..5ca2fa20c 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -1033,6 +1033,8 @@ def __init__(self, 'long': 'int8', } + self.refresh_statement_id(self.use_cached_result) + def polling(self, timeout: int = 3600): total_time = 0 while total_time <= timeout: @@ -1094,7 +1096,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) - sql_response = self.refresh_statement_id(self.use_cached_result) + sql_response = self.refresh_statement_id(True) # Local leader prepares the index file based on cloudfetch results basename = get_index_basename() @@ -1218,7 +1220,7 @@ def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): print('Failed to download, refresh statement id and try again') print('url = ', url) print(e) - self.refresh_statement_id(use_cached_result=True) + self.refresh_statement_id(True) url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" response = self._make_request(url) From 45bb3575c4f1fbd27ebcf8f72b010bef6bbfa40e Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 6 Sep 2024 09:20:25 -0700 Subject: [PATCH 106/145] update --- streaming/base/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 5ca2fa20c..885ba70aa 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -708,7 +708,7 @@ def fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): return local -class DeltaDBSQLStreamOriginal(Stream): +class DeltaDBSQLStream(Stream): def __init__(self, remote: Optional[str] = None, @@ -953,7 +953,7 @@ def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): return local -class DeltaDBSQLStream(Stream): +class DeltaDBSQLStreamSession(Stream): def __init__(self, remote: Optional[str] = None, From 9b719aeaee7dfa21e65baf8d849b9a5737a12598 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 6 Sep 2024 14:27:40 -0700 Subject: [PATCH 107/145] Add print --- streaming/base/dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 6ae8fe20c..cef35ae6f 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -349,6 +349,8 @@ def __init__(self, self.allow_unsafe_types = allow_unsafe_types self.replication = replication + logger.warning('Using StreamingX:heterogeneous branch') + # Initialize the World context. # * This information is for the per-rank or per-worker process. # * DataLoader worker processes may get a different worker ID and worker count than rank. From b5fe6c56658fe0177bc2627d58ea9ff62d2f0c0e Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 10 Sep 2024 22:42:24 -0700 Subject: [PATCH 108/145] Add broadcast --- streaming/base/stream.py | 45 +++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 885ba70aa..c39ef0fcc 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -16,6 +16,7 @@ from streaming.base.compression import decompress from streaming.base.constant import TICK +import torch.distributed as dist from streaming.base.distributed import barrier, get_local_rank from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json from streaming.base.hashing import get_hash @@ -778,16 +779,48 @@ def __init__(self, 'long': 'int8', } - def refresh_statement_id(self, timeout=3600): + def generate_statement_id_and_sync(self, world: World): + if dist.is_available() and dist.is_initialized(): + if world.is_leader: # is_local_leader: + response = requests.post(self.base_url, headers=self.headers, json=self.data) + response.raise_for_status() + response_data = response.json() + self.statement_id = response_data['statement_id'] + data = self.statement_id + else: + data = None + + obj_list = [data] + dist.broadcast_object_list(obj_list, src=0) + self.statement_id = obj_list[0] + return + + world_size = world.num_ranks + if world_size > 1: + raise RuntimeError(''.join([ + f'The world_size({world_size}) > 1, but the distributed package is not available ', + 'or has not been initialized. Please check you have initialized the distributed ', + 'runtime and that PyTorch has been built with distributed support.' + ])) + + response = requests.post(self.base_url, headers=self.headers, json=self.data) + response.raise_for_status() + response_data = response.json() + self.statement_id = response_data['statement_id'] + + def wait_for_query_result(self, timeout=3600): + if not self.statement_id: + raise ValueError(f"statement id is not set yet") + total_time = 0 while total_time <= timeout: - response = requests.post(self.base_url, headers=self.headers, json=self.data) + response = requests.get(f"{self.base_url}/{self.statement_id}", headers=self.headers) response.raise_for_status() response_data = response.json() query_status = response_data['status']['state'] if query_status == "SUCCEEDED": - self.statement_id = response_data['statement_id'] + #self.statement_id = response_data['statement_id'] save_dict_to_file(self.local, f'response_{int(time.time())}', response_data) return response_data @@ -816,7 +849,9 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: """ from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) - sql_response = self.refresh_statement_id() + self.generate_statement_id_and_sync(world) + + sql_response = self.wait_for_query_result() # Local leader prepares the index file based on cloudfetch results basename = get_index_basename() @@ -841,7 +876,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: total_shard_count = sql_response['manifest']['total_chunk_count'] - if world.is_local_leader: + if world.is_leader: # is_local_leader: metadata = { "version": 2, From aa4ede7df5cc05b3a684d9cac3539e09c4fe689a Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 10 Sep 2024 23:02:41 -0700 Subject: [PATCH 109/145] update tests --- tests/test_streaming_remote.py | 38 +++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index d2ef95827..54e735ec3 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -46,6 +46,7 @@ def get_dataset(name: str, 'warehouse_id': "7e083095329f3ca5", 'catalog': 'main', 'schema': 'streaming', + 'use_cached_result': False, }, }, 'prompt_response_table_dbsql': { @@ -68,10 +69,10 @@ def get_dataset(name: str, 'cluster_id': "0201-234512-tcp9nfat" }, }, - 'large_liquid_test_table_08_07': { + 'large_liquid_test_table_08_07_dbsql': { 'local': f'/tmp/test_liquid_test_table_05May1029', - 'remote': 'SELECT ss_sold_date_sk, ss_sold_time_sk, ss_item_sk, ss_customer_sk, ss_cdemo_sk FROM auto_maintenance_bugbash.stella.large_liquid_test_table_08_07', - 'num_samples': 100000, + 'remote': 'SELECT * FROM auto_maintenance_bugbash.stella.large_liquid_test_table_08_07', + 'num_samples': 89279077339, 'class': StreamingDataset, 'kwargs': { 'warehouse_id': "7e083095329f3ca5", @@ -137,6 +138,32 @@ def get_dataset(name: str, }, 'shuffle': False, }, + 'evesize_level1_dbsql': { + 'local': f'/tmp/test_evesize_05May1029', + 'remote': "SELECT prompt, response, class FROM datasets.cody.evesize_level1_evolve_respond WHERE class = \'CODE\'", + 'num_samples': 68784, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'datasets', + 'schema': 'cody', + # 'cache_limit': '100mb', + }, + 'shuffle': False, + }, + 'evesize_level1_dbsql': { + 'local': f'/tmp/test_evesize_05May1029', + 'remote': "SELECT * FROM main.streaming.evesize_level1_evolve_response_sub VERSION AS OF 0", + 'num_samples': 273044, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + # 'cache_limit': '100mb', + }, + 'shuffle': False, + }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': None, @@ -245,15 +272,16 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: clean_stale_shared_memory() #test_streaming_remote_dataset(name = 'refinedweb', split=None) #test_streaming_remote_dataset(name = 'dummy_table_dbsql', split=None) - test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'random_large_table', split=None) #test_streaming_remote_dataset(name = 'reddit_table', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) #test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) - #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07', split=None) + #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07_dbsql', split=None) #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'debug_local', split=None) + test_streaming_remote_dataset(name = 'evesize_level1_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From 29ec01763c55c9f9e7e95e160cb8a242df45dc6d Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 10 Sep 2024 23:10:09 -0700 Subject: [PATCH 110/145] update --- streaming/base/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index c39ef0fcc..84ecbcd88 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -787,6 +787,7 @@ def generate_statement_id_and_sync(self, world: World): response_data = response.json() self.statement_id = response_data['statement_id'] data = self.statement_id + print(f'I am here 1: {data}') else: data = None From b4d369844b9b7c94653439374bf8d934984d9d32 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 11 Sep 2024 21:51:10 +0000 Subject: [PATCH 111/145] update --- streaming/base/stream.py | 8 +++++--- tests/test_streaming_remote.py | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 84ecbcd88..e796fa8cb 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -735,7 +735,7 @@ def __init__(self, keep_zip=keep_zip) warehouse_id = kwargs.get('warehouse_id', None) - host = kwargs.get('host', os.environ['DATABRICKS_HOST']) + host = kwargs.get('host', os.environ['DATABRICKS_HOST']).lstrip('https://') token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) catalog = kwargs.get('catalog', None) schema = kwargs.get('schema', None) @@ -781,16 +781,18 @@ def __init__(self, def generate_statement_id_and_sync(self, world: World): if dist.is_available() and dist.is_initialized(): + barrier() + if world.is_leader: # is_local_leader: response = requests.post(self.base_url, headers=self.headers, json=self.data) response.raise_for_status() response_data = response.json() self.statement_id = response_data['statement_id'] data = self.statement_id - print(f'I am here 1: {data}') else: data = None + obj_list = [data] dist.broadcast_object_list(obj_list, src=0) self.statement_id = obj_list[0] @@ -877,7 +879,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: total_shard_count = sql_response['manifest']['total_chunk_count'] - if world.is_leader: # is_local_leader: + if world.is_local_leader: metadata = { "version": 2, diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 54e735ec3..c31dfb7c8 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -268,8 +268,13 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: if __name__ == "__main__": + from composer.utils import dist as dist + from composer.utils import get_device + dist.initialize_dist(get_device(None)) + from streaming.base.util import clean_stale_shared_memory clean_stale_shared_memory() + #test_streaming_remote_dataset(name = 'refinedweb', split=None) #test_streaming_remote_dataset(name = 'dummy_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) From 9110fde7fa76aa93e9bde9cbb1ce2abd67ef02a8 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 11 Sep 2024 14:58:55 -0700 Subject: [PATCH 112/145] update --- tests/test_streaming_remote.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index c31dfb7c8..1c20759e8 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -138,7 +138,7 @@ def get_dataset(name: str, }, 'shuffle': False, }, - 'evesize_level1_dbsql': { + 'evesize_level1_filter_dbsql': { 'local': f'/tmp/test_evesize_05May1029', 'remote': "SELECT prompt, response, class FROM datasets.cody.evesize_level1_evolve_respond WHERE class = \'CODE\'", 'num_samples': 68784, @@ -151,7 +151,7 @@ def get_dataset(name: str, }, 'shuffle': False, }, - 'evesize_level1_dbsql': { + 'evesize_level1_version_dbsql': { 'local': f'/tmp/test_evesize_05May1029', 'remote': "SELECT * FROM main.streaming.evesize_level1_evolve_response_sub VERSION AS OF 0", 'num_samples': 273044, @@ -286,7 +286,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07_dbsql', split=None) #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'debug_local', split=None) - test_streaming_remote_dataset(name = 'evesize_level1_dbsql', split=None) + test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) + #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From 30cae093fb9a14d4ad0a6fbf6959fb099a7a737b Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 11 Sep 2024 15:02:59 -0700 Subject: [PATCH 113/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 1c20759e8..63decd9e1 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -196,7 +196,7 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: expected_samples, dataset = get_dataset(name=name, split=split, shuffle=False, - batch_size=16) + batch_size=1) build_end = time.time() build_dur = build_end - build_start print('Built dataset') From 01972e4a5624576aa767d72ab037b6243aa2a0cd Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 11 Sep 2024 20:43:28 -0700 Subject: [PATCH 114/145] update --- tests/test_streaming_remote.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 63decd9e1..9274206b4 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -11,6 +11,9 @@ from streaming.text import StreamingC4 from streaming.vision import StreamingADE20K, StreamingCIFAR10, StreamingCOCO, StreamingImageNet +from composer.utils import dist as dist +from composer.utils import get_device +import torch def get_dataset(name: str, split: str, @@ -220,6 +223,8 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: f'samples_per_sec={samples_per_sec:.2f}') # Test all samples arrived + rcvd_samples = torch.tensor(rcvd_samples, dtype=torch.int64) + dist.all_reduce(rcvd_samples, reduce_operation = 'SUM') assert rcvd_samples >= expected_samples def test_streaming_remote_dataloader(name: str, split: str) -> None: @@ -268,8 +273,6 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: if __name__ == "__main__": - from composer.utils import dist as dist - from composer.utils import get_device dist.initialize_dist(get_device(None)) from streaming.base.util import clean_stale_shared_memory From 6ba7e36ee1e5f5140fd86584449041c0a460c225 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 11 Sep 2024 21:01:36 -0700 Subject: [PATCH 115/145] update --- tests/test_streaming_remote.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 9274206b4..f2f7b77eb 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -13,6 +13,7 @@ from composer.utils import dist as dist from composer.utils import get_device +from composer.utils.dist import get_world_size import torch def get_dataset(name: str, @@ -223,8 +224,12 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: f'samples_per_sec={samples_per_sec:.2f}') # Test all samples arrived - rcvd_samples = torch.tensor(rcvd_samples, dtype=torch.int64) - dist.all_reduce(rcvd_samples, reduce_operation = 'SUM') + if dist.is_available() and dist.is_initialized() and get_world_size()>1: + rcvd_samples = torch.tensor(rcvd_samples, dtype=torch.int64).cuda() + dist.all_reduce(rcvd_samples, reduce_operation = 'SUM') + assert rcvd_samples.cpu() >= expected_samples + return + assert rcvd_samples >= expected_samples def test_streaming_remote_dataloader(name: str, split: str) -> None: From 825b58609f0c6ba750598e828e82bcde2d091f7c Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 11 Sep 2024 21:04:30 -0700 Subject: [PATCH 116/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index f2f7b77eb..aafd720ef 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -294,8 +294,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07_dbsql', split=None) #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'debug_local', split=None) - test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) - #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) + #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) + test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From 5abb61776c9a558851306cc471b82df87784d38e Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 11 Sep 2024 21:06:53 -0700 Subject: [PATCH 117/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index aafd720ef..6e16964be 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -290,12 +290,12 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'reddit_table', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) - #test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07_dbsql', split=None) #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) - test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) + #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From f1b07e1ea9b0129827aaabac9daaf08576c98c3e Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 12 Sep 2024 12:29:53 -0700 Subject: [PATCH 118/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 6e16964be..e1cb21683 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -288,9 +288,9 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'random_large_table', split=None) #test_streaming_remote_dataset(name = 'reddit_table', split=None) - #test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) - test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07_dbsql', split=None) #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'debug_local', split=None) From f1d1cd7bc4c55e94bd0183c7501e1083038858b2 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 12 Sep 2024 16:11:25 -0700 Subject: [PATCH 119/145] update --- streaming/base/stream.py | 9 +++++---- tests/test_streaming_remote.py | 15 ++++++++++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index e796fa8cb..1f6fec709 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -976,12 +976,13 @@ def _fetch_and_convert(cloud_fetch_url: str, local_shard_path: str): url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" response = self._make_request(url) except Exception as e: # requests.exceptions.HTTPError as e: - print('Failed to download, refresh statement id and try again') + print('Failed to download, I cannot refresh statement id and try again') print('url = ', url) print(e) - self.refresh_statement_id() - url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" - response = self._make_request(url) + raise TimeoutError('Check if the query results retention period of your workspace and make sure it is longer than the expected training period. For multi-node, we do not want to refresh and communicate statement id from worker processes.') from e + # self.refresh_statement_id() + #url = f"{self.base_url}/{self.statement_id}/result/chunks/{chunk_index}" + #response = self._make_request(url) cloud_fetch_url = response.json()['external_links'][0]['external_link'] local = os.path.join(self.local, self.split, from_basename) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index e1cb21683..6ed9b29ae 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -197,10 +197,11 @@ def get_dataset(name: str, def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() + batch_size = 1 expected_samples, dataset = get_dataset(name=name, split=split, shuffle=False, - batch_size=1) + batch_size=batch_size) build_end = time.time() build_dur = build_end - build_start print('Built dataset') @@ -246,7 +247,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: batch_size=batch_size, num_workers=4, prefetch_factor=None, - #persistent_workers=True, + persistent_workers=False, pin_memory=True, drop_last=True) build_end = time.time() @@ -273,6 +274,13 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: print(f'build_dur={build_dur:.2f}s, iter_dur={iter_dur:.2f}, ' + f'samples_per_sec={samples_per_sec:.2f}') + # Test all samples arrived + if dist.is_available() and dist.is_initialized() and get_world_size()>1: + rcvd_samples = torch.tensor(rcvd_samples, dtype=torch.int64).cuda() + dist.all_reduce(rcvd_samples, reduce_operation = 'SUM') + assert rcvd_samples.cpu() >= expected_samples + return + # Test all samples arrived assert rcvd_samples >= expected_samples @@ -288,7 +296,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'random_cpt_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'random_large_table', split=None) #test_streaming_remote_dataset(name = 'reddit_table', split=None) - test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'reddit_table_dbsql_cachelimit', split=None) #test_streaming_remote_dataset(name = 'coco_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'large_liquid_test_table_08_07_dbsql', split=None) @@ -302,4 +310,5 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) #test_streaming_remote_dataloader(name = 'coco_table_dbsql', split=None) + test_streaming_remote_dataloader(name = 'evesize_level1_version_dbsql', split=None) From 44b9cc3d9845882c37234aeb2fced68f0eb34a32 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 12 Sep 2024 16:21:20 -0700 Subject: [PATCH 120/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 6ed9b29ae..0e6fffa37 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -245,7 +245,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: data_loader = StreamingDataLoader(dataset, batch_size=batch_size, - num_workers=4, + num_workers=0, prefetch_factor=None, persistent_workers=False, pin_memory=True, From cb08ce8c6fca1a1dcc4072845e96f96c74041d03 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 12 Sep 2024 16:42:48 -0700 Subject: [PATCH 121/145] update --- tests/test_streaming_remote.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 0e6fffa37..21a5fb5e0 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -245,7 +245,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: data_loader = StreamingDataLoader(dataset, batch_size=batch_size, - num_workers=0, + num_workers=4, prefetch_factor=None, persistent_workers=False, pin_memory=True, @@ -310,5 +310,6 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) #test_streaming_remote_dataloader(name = 'coco_table_dbsql', split=None) - test_streaming_remote_dataloader(name = 'evesize_level1_version_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'evesize_level1_version_dbsql', split=None) + test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) From 052af998a573bc849a2a0e35dda340c7ff996caa Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 09:59:25 -0700 Subject: [PATCH 122/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 21a5fb5e0..b782d9b39 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -245,7 +245,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: data_loader = StreamingDataLoader(dataset, batch_size=batch_size, - num_workers=4, + num_workers=8, prefetch_factor=None, persistent_workers=False, pin_memory=True, From d83c7dac006c8df9b3158e7d95ba9e4ab5370114 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 13:38:57 -0700 Subject: [PATCH 123/145] update --- tests/test_streaming_remote.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index b782d9b39..7840074a4 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -129,6 +129,29 @@ def get_dataset(name: str, }, 'shuffle': True, }, + 'main_streaming_wiki_table_mds': { + 'local': f'/tmp/test_wiki_table_volume_05May1029', + 'remote': 'dbfs:/Volumes/main/streaming/xiaohan_zhang/delta-streaming-benchmarks-mds/wiki_table', + 'num_samples': 5823210, + 'class': StreamingDataset, + 'kwargs': { + 'cache_limit': '100gb', + }, + 'shuffle': True, + }, + 'main_streaming_wiki_table_dbsql': { + 'local': f'/tmp/test_wiki_table_volume_05May1029', + 'remote': 'SELECT text FROM main.streaming.wiki_table' + 'num_samples': 5823210, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'streaming', + 'cache_limit': '100gb', + }, + 'shuffle': True, + }, 'coco_table_dbsql': { 'local': f'/tmp/test_coco_table_05May1029', 'remote': 'SELECT data, captions FROM main.streaming.coco_with_meta_and_captions', @@ -304,6 +327,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) + test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) + #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From 9faed3137857a4a4118f8623e966d0bb434a70a5 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 13:54:19 -0700 Subject: [PATCH 124/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 7840074a4..02c4aa1e3 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -141,7 +141,7 @@ def get_dataset(name: str, }, 'main_streaming_wiki_table_dbsql': { 'local': f'/tmp/test_wiki_table_volume_05May1029', - 'remote': 'SELECT text FROM main.streaming.wiki_table' + 'remote': 'SELECT text FROM main.streaming.wiki_table', 'num_samples': 5823210, 'class': StreamingDataset, 'kwargs': { @@ -336,5 +336,5 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataloader(name = 'wiki_table_dbsql_cachelimit', split=None) #test_streaming_remote_dataloader(name = 'coco_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'evesize_level1_version_dbsql', split=None) - test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) From cd9bb0c000f55d26b3f3822c349e6e5da7f5e771 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 14:02:40 -0700 Subject: [PATCH 125/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 02c4aa1e3..c749a103f 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -327,8 +327,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) - test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) - #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) + test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From c88aef4afa222a403a31a51ce811e0dc5232c9ac Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 14:46:07 -0700 Subject: [PATCH 126/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index c749a103f..5d8db6094 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -220,7 +220,7 @@ def get_dataset(name: str, def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() - batch_size = 1 + batch_size = 64 expected_samples, dataset = get_dataset(name=name, split=split, shuffle=False, From c0e21b735ad1744fc5eaca5f658aae3e02867709 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 14:53:49 -0700 Subject: [PATCH 127/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 5d8db6094..63856d4b5 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -327,8 +327,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) - #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) - test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) + #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From 9ffdb2d1f5d769b21131fc40658f3f06d42b8022 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 15:16:14 -0700 Subject: [PATCH 128/145] update --- tests/test_streaming_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 63856d4b5..f51777f3d 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -220,7 +220,7 @@ def get_dataset(name: str, def test_streaming_remote_dataset(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() - batch_size = 64 + batch_size = 1024 expected_samples, dataset = get_dataset(name=name, split=split, shuffle=False, From 2cd31d90cd6d439361bc5196647b632204784dd5 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 15:21:02 -0700 Subject: [PATCH 129/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index f51777f3d..59217acd8 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -327,8 +327,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) - test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) - #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) + test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From c54cb140641a9e7a9c2ebf02d25a0ddb1f2f6d29 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 16:06:52 -0700 Subject: [PATCH 130/145] update --- tests/test_streaming_remote.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 59217acd8..68b5a3992 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -259,7 +259,7 @@ def test_streaming_remote_dataset(name: str, split: str) -> None: def test_streaming_remote_dataloader(name: str, split: str) -> None: # Build StreamingDataset build_start = time.time() - batch_size = 1 + batch_size = 1024 expected_samples, dataset = get_dataset(name=name, split=split, shuffle=False, @@ -283,6 +283,8 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: for epcoh in range(1): for batch_idx, data_dict in enumerate(data_loader): + if batch_idx < 10: + break rcvd_samples += batch_size if (rcvd_samples % (10*batch_size) == 0): From 3092e8453a72e6e4f2045f63fb3127972c255174 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 16:11:15 -0700 Subject: [PATCH 131/145] update --- tests/test_streaming_remote.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 68b5a3992..9c8c2ffcc 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -330,7 +330,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) - test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) + #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) @@ -339,4 +339,6 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataloader(name = 'coco_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'main_streaming_wiki_table_mds', split=None) + test_streaming_remote_dataloader(name = 'main_streaming_wiki_table_dbsql', split=None) From 91a56c9dc2a065959ce535f75d2586715f1ba692 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 16:17:56 -0700 Subject: [PATCH 132/145] update --- tests/test_streaming_remote.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 9c8c2ffcc..3141f575f 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -3,6 +3,7 @@ import pathlib import time +import itertools from typing import Any, Dict, Optional, Tuple #import pytest @@ -282,9 +283,9 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: iter_start = time.time() for epcoh in range(1): - for batch_idx, data_dict in enumerate(data_loader): - if batch_idx < 10: - break + skip_batches = 5 + for batch_idx, data_dict in enumerate(itertools.islice(data_loader, skip_batches, None)): + #for batch_idx, data_dict in enumerate(data_loader): rcvd_samples += batch_size if (rcvd_samples % (10*batch_size) == 0): From 03e6309fe4b0f214ae3550ca27b05510a8be0517 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 23:42:56 -0700 Subject: [PATCH 133/145] update --- streaming/base/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 1f6fec709..cb89f73c0 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -741,7 +741,7 @@ def __init__(self, schema = kwargs.get('schema', None) if any([not warehouse_id, not host, not token, not catalog, not schema]): - raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization") + raise TypeError(f"Need to specify warehouse_id, host, token catalog, schema, during initialization, but got {warehouse_id}, {host}, {token}, {catalog}, {schema}") self.base_url = f"https://{host}/api/2.0/sql/statements/" self.headers = { From 91792604dcb87a6fa003bfa10f2dcf13bf6c73a4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 23:48:23 -0700 Subject: [PATCH 134/145] update --- streaming/base/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index cef35ae6f..9565d1529 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -332,6 +332,7 @@ def __init__(self, batching_method: str = 'random', allow_unsafe_types: bool = False, replication: Optional[int] = None, + delta_kwargs: Optional[dict] = None, **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload @@ -447,7 +448,7 @@ def __init__(self, for stream in streams: stream.apply_default(default) elif remote is not None and remote.startswith('SELECT'): - cluster_id = kwargs.get('cluster_id', None) + cluster_id = delta_kwargs.get('cluster_id', None) if not cluster_id: default = DeltaDBSQLStream(remote=remote, local=local, @@ -456,7 +457,7 @@ def __init__(self, download_timeout=download_timeout, validate_hash=validate_hash, keep_zip=keep_zip, - **kwargs) + **delta_kwargs) else: default = DeltaSCStream(cluster_id, remote=remote, From 565963f547b44f9bf139be80159dada5e04a34db Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Fri, 13 Sep 2024 23:55:25 -0700 Subject: [PATCH 135/145] update --- streaming/base/dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 9565d1529..a0978221d 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -332,7 +332,9 @@ def __init__(self, batching_method: str = 'random', allow_unsafe_types: bool = False, replication: Optional[int] = None, - delta_kwargs: Optional[dict] = None, + warehouse_id: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload @@ -448,6 +450,7 @@ def __init__(self, for stream in streams: stream.apply_default(default) elif remote is not None and remote.startswith('SELECT'): + delta_kwargs = {'warehouse_id': warehouse_id, 'catalog': catalog, 'schema': schema} cluster_id = delta_kwargs.get('cluster_id', None) if not cluster_id: default = DeltaDBSQLStream(remote=remote, From ab477cfeac3ce8f090f10a04145992e004656f4d Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 14 Sep 2024 00:06:25 -0700 Subject: [PATCH 136/145] update --- streaming/base/dataset.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index a0978221d..cef35ae6f 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -332,9 +332,6 @@ def __init__(self, batching_method: str = 'random', allow_unsafe_types: bool = False, replication: Optional[int] = None, - warehouse_id: Optional[str] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload @@ -450,8 +447,7 @@ def __init__(self, for stream in streams: stream.apply_default(default) elif remote is not None and remote.startswith('SELECT'): - delta_kwargs = {'warehouse_id': warehouse_id, 'catalog': catalog, 'schema': schema} - cluster_id = delta_kwargs.get('cluster_id', None) + cluster_id = kwargs.get('cluster_id', None) if not cluster_id: default = DeltaDBSQLStream(remote=remote, local=local, @@ -460,7 +456,7 @@ def __init__(self, download_timeout=download_timeout, validate_hash=validate_hash, keep_zip=keep_zip, - **delta_kwargs) + **kwargs) else: default = DeltaSCStream(cluster_id, remote=remote, From fc2d9eb370afe51edbb1b5285db2b5dc4984714b Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 14 Sep 2024 00:27:12 -0700 Subject: [PATCH 137/145] update --- streaming/base/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index cb89f73c0..33c4978ac 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -777,6 +777,7 @@ def __init__(self, 'struct': 'json', 'tinyint': 'int8', 'long': 'int8', + 'array>': 'json', # special for messages } def generate_statement_id_and_sync(self, world: World): From d8251052bae1d86c3ca4ec0c269275a5d7289b79 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Mon, 16 Sep 2024 09:45:43 -0700 Subject: [PATCH 138/145] update --- tests/test_streaming_remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 3141f575f..bb253abff 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -329,7 +329,7 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) - #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) + test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) @@ -341,5 +341,5 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataloader(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataloader(name = 'reddit_table_dbsql', split=None) #test_streaming_remote_dataloader(name = 'main_streaming_wiki_table_mds', split=None) - test_streaming_remote_dataloader(name = 'main_streaming_wiki_table_dbsql', split=None) + #test_streaming_remote_dataloader(name = 'main_streaming_wiki_table_dbsql', split=None) From b36770d08c2205f7d3bcaa381bc8e672da260cca Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 18 Sep 2024 12:45:33 -0700 Subject: [PATCH 139/145] update --- streaming/base/stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 33c4978ac..11fab5e51 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -765,6 +765,7 @@ def __init__(self, 'bigint' : 'int64', 'array': 'ndarray', 'array': 'str_array', + 'array': 'ndarray:int32', 'binary': 'bytes', 'boolean': 'uint32', 'date': 'str', From 95bb76a70f571b27aae02218a1cbea05f62ee104 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 18 Sep 2024 13:03:50 -0700 Subject: [PATCH 140/145] update --- streaming/base/format/mds/encodings.py | 9 +++++++++ streaming/base/stream.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 4312d1ce7..06e984a52 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -566,6 +566,14 @@ def decode(self, encoded_bytes: bytes) -> Any: return decoded_strings +class IntArray(NDArray): + """Store a list of integers.""" + + def encode(self, ints: list[int]) -> bytes: + return self.encode(np.array(ints, dtype=np.int32)) + + def decode(self, encoded_bytes: bytes) -> list(int): + return self.decode(encoded_bytes).tolist() # Encodings (name -> class). _encodings = { @@ -588,6 +596,7 @@ def decode(self, encoded_bytes: bytes) -> Any: 'str_float': StrFloat, 'str_decimal': StrDecimal, 'str_array': StrArray, + 'int_array': IntArray, 'pil': PIL, 'jpeg': JPEG, 'png': PNG, diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 11fab5e51..d9644604b 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -765,7 +765,7 @@ def __init__(self, 'bigint' : 'int64', 'array': 'ndarray', 'array': 'str_array', - 'array': 'ndarray:int32', + 'array': 'int_array', 'binary': 'bytes', 'boolean': 'uint32', 'date': 'str', From 8744c82608de73cc75def3c60d3401fccd13caf4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 18 Sep 2024 13:05:19 -0700 Subject: [PATCH 141/145] update --- streaming/base/format/mds/encodings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 06e984a52..ccc40b18b 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -572,7 +572,7 @@ class IntArray(NDArray): def encode(self, ints: list[int]) -> bytes: return self.encode(np.array(ints, dtype=np.int32)) - def decode(self, encoded_bytes: bytes) -> list(int): + def decode(self, encoded_bytes: bytes) -> list[int]: return self.decode(encoded_bytes).tolist() # Encodings (name -> class). From c531aa45f02235d8b234a05e3910bcc835825820 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 18 Sep 2024 17:29:05 -0700 Subject: [PATCH 142/145] update --- streaming/base/format/mds/encodings.py | 35 +++++++++++++++++++++----- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index ccc40b18b..170544f76 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -16,6 +16,8 @@ from PIL.JpegImagePlugin import JpegImageFile from typing_extensions import Self +import struct + __all__ = [ 'get_mds_encoded_size', 'get_mds_encodings', 'is_mds_encoding', 'mds_decode', 'mds_encode', 'is_mds_encoding_safe' @@ -566,14 +568,35 @@ def decode(self, encoded_bytes: bytes) -> Any: return decoded_strings -class IntArray(NDArray): - """Store a list of integers.""" - def encode(self, ints: list[int]) -> bytes: - return self.encode(np.array(ints, dtype=np.int32)) +class IntArray(Encoding): + """Store a list of int32 integers efficiently.""" + + def encode(self, integers: Any) -> bytes: + # Pack the length of the list as an unsigned 4-byte integer + list_length = len(integers) + encoded = struct.pack(' Any: + index = 0 + + # Unpack the length of the list + list_length = struct.unpack_from(' 0: + int_bytes_length = 4 * list_length + integers = list(struct.unpack_from(f'<{list_length}i', encoded_bytes, index)) + + return integers - def decode(self, encoded_bytes: bytes) -> list[int]: - return self.decode(encoded_bytes).tolist() # Encodings (name -> class). _encodings = { From 745290ec37e951711dee16f9198723f9c13666e6 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 18 Sep 2024 19:39:31 -0700 Subject: [PATCH 143/145] update --- tests/test_streaming_remote.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index bb253abff..cbad5a9e9 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -192,6 +192,19 @@ def get_dataset(name: str, }, 'shuffle': False, }, + 'finance_more_like_dbsql': { + 'local': f'/tmp/test_finance_more_like_05May1029', + 'remote': "SELECT llama_3_1_tokens AS tokens FROM main.seanowen.finance_more_like", + 'num_samples': 210463508, + 'class': StreamingDataset, + 'kwargs': { + 'warehouse_id': "89cf2c9b9f9cb3bc", + 'catalog': 'main', + 'schema': 'seanowen', + # 'cache_limit': '100mb', + }, + 'shuffle': False, + }, 'debug_local': { 'local': f'/tmp/test_random_reddit_table_05May1029', 'remote': None, @@ -329,9 +342,10 @@ def test_streaming_remote_dataloader(name: str, split: str) -> None: #test_streaming_remote_dataset(name = 'prompt_response_table_dbsql', split=None) #test_streaming_remote_dataset(name = 'debug_local', split=None) #test_streaming_remote_dataset(name = 'evesize_level1_filter_dbsql', split=None) - test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) + #test_streaming_remote_dataset(name = 'evesize_level1_version_dbsql', split=None) #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_mds', split=None) #test_streaming_remote_dataset(name = 'main_streaming_wiki_table_dbsql', split=None) + test_streaming_remote_dataset(name = 'finance_more_like_dbsql', split=None) #test_streaming_remote_dataloader(name = 'refinedweb', split=None) #test_streaming_remote_dataloader(name = 'random_cpt_table_dbsql', split=None) From 177631d6bd3aba0ff59454b44c13039a9a7b552a Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 18 Sep 2024 20:18:55 -0700 Subject: [PATCH 144/145] update --- streaming/base/format/mds/encodings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 170544f76..3a96b9c76 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -583,7 +583,7 @@ def encode(self, integers: Any) -> bytes: return encoded - def decode(self, encoded_bytes: bytes) -> Any: + def decode(self, encoded_bytes: bytes) -> npt.NDArray: index = 0 # Unpack the length of the list @@ -595,7 +595,7 @@ def decode(self, encoded_bytes: bytes) -> Any: int_bytes_length = 4 * list_length integers = list(struct.unpack_from(f'<{list_length}i', encoded_bytes, index)) - return integers + return np.array(integers, dtype=np.int32) # Encodings (name -> class). From 2cce6eafd336a69ee6536806442a296b53b2aea0 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 28 Sep 2024 22:23:04 -0700 Subject: [PATCH 145/145] get host/token from WorkspaceClient.config --- streaming/base/stream.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index d9644604b..a56955ad6 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -734,9 +734,14 @@ def __init__(self, validate_hash=validate_hash, keep_zip=keep_zip) + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + host = w.config.host.lstrip('https://') + token = w.config.token + #host = kwargs.get('host', os.environ['DATABRICKS_HOST']).lstrip('https://') + #token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) + warehouse_id = kwargs.get('warehouse_id', None) - host = kwargs.get('host', os.environ['DATABRICKS_HOST']).lstrip('https://') - token = kwargs.get('token', os.environ['DATABRICKS_TOKEN']) catalog = kwargs.get('catalog', None) schema = kwargs.get('schema', None)