From bf95770e2c83f8852075a66aeb7630a29184fda8 Mon Sep 17 00:00:00 2001 From: "baogorek@gmail.com" Date: Wed, 26 Nov 2025 11:00:23 -0500 Subject: [PATCH] Add subdirectory support for hf:// URLs and gs:// scheme for GCS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add parse_hf_url() helper function to centralize hf:// URL parsing - Support subdirectory paths: hf://owner/repo/path/to/file.h5[@version] - Add google_cloud.py module with parse_gs_url(), download_gcs_file(), upload_gcs_file() - Support gs:// URLs: gs://bucket/path/to/file.h5[@version] - Update dataset.py download/upload methods for both schemes - Update simulation.py dataset initialization for both schemes - google-cloud-storage is optional - raises helpful ImportError if not installed Closes #405 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- changelog_entry.yaml | 5 + policyengine_core/data/dataset.py | 35 +++- policyengine_core/simulations/simulation.py | 18 ++- policyengine_core/tools/google_cloud.py | 168 ++++++++++++++++++++ policyengine_core/tools/hugging_face.py | 30 ++++ tests/core/tools/test_google_cloud.py | 63 ++++++++ tests/core/tools/test_hugging_face.py | 57 ++++++- 7 files changed, 364 insertions(+), 12 deletions(-) create mode 100644 policyengine_core/tools/google_cloud.py create mode 100644 tests/core/tools/test_google_cloud.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..e8f941786 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - Subdirectory support for hf:// URLs (e.g., hf://owner/repo/path/to/file.h5). + - Google Cloud Storage support with gs:// URL scheme (e.g., gs://bucket/path/to/file.h5). diff --git a/policyengine_core/data/dataset.py b/policyengine_core/data/dataset.py index d97852be1..2cd715b6f 100644 --- a/policyengine_core/data/dataset.py +++ b/policyengine_core/data/dataset.py @@ -8,6 +8,11 @@ import os import tempfile from policyengine_core.tools.hugging_face import * +from policyengine_core.tools.google_cloud import ( + parse_gs_url, + download_gcs_file, + upload_gcs_file, +) import sys from policyengine_core.tools.win_file_manager import WindowsAtomicFileManager @@ -353,9 +358,22 @@ def download(self, url: str = None, version: str = None) -> None: f"File {file_path} not found in release {release_tag} of {org}/{repo}." ) elif url.startswith("hf://"): - owner_name, model_name, file_name = url.split("/")[2:] + owner_name, model_name, file_name, hf_version = parse_hf_url(url) self.download_from_huggingface( - owner_name, model_name, file_name, version + owner_name, model_name, file_name, hf_version or version + ) + return + elif url.startswith("gs://"): + bucket, file_path, gs_version = parse_gs_url(url) + print( + f"Downloading from GCS gs://{bucket}/{file_path}", + file=sys.stderr, + ) + downloaded_path = download_gcs_file( + bucket=bucket, + file_path=file_path, + version=gs_version or version, + local_path=str(self.file_path), ) return else: @@ -386,8 +404,19 @@ def upload(self, url: str = None): url = self.url if url.startswith("hf://"): - owner_name, model_name, file_name = url.split("/")[2:] + owner_name, model_name, file_name, _ = parse_hf_url(url) self.upload_to_huggingface(owner_name, model_name, file_name) + elif url.startswith("gs://"): + bucket, file_path, _ = parse_gs_url(url) + print( + f"Uploading to GCS gs://{bucket}/{file_path}", + file=sys.stderr, + ) + upload_gcs_file( + bucket=bucket, + file_path=file_path, + local_path=str(self.file_path), + ) def remove(self): """Removes the dataset from disk.""" diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index a0b1ec851..fea8216ae 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -23,6 +23,10 @@ ) import random from policyengine_core.tools.hugging_face import * +from policyengine_core.tools.google_cloud import ( + parse_gs_url, + download_gcs_file, +) import json @@ -160,17 +164,19 @@ def __init__( if dataset is not None: if isinstance(dataset, str): if "hf://" in dataset: - owner, repo, filename = dataset.split("/")[-3:] - if "@" in filename: - version = filename.split("@")[-1] - filename = filename.split("@")[0] - else: - version = None + owner, repo, filename, version = parse_hf_url(dataset) dataset = download_huggingface_dataset( repo=f"{owner}/{repo}", repo_filename=filename, version=version, ) + elif "gs://" in dataset: + bucket, file_path, version = parse_gs_url(dataset) + dataset = download_gcs_file( + bucket=bucket, + file_path=file_path, + version=version, + ) datasets_by_name = { dataset.name: dataset for dataset in self.datasets } diff --git a/policyengine_core/tools/google_cloud.py b/policyengine_core/tools/google_cloud.py new file mode 100644 index 000000000..825063c61 --- /dev/null +++ b/policyengine_core/tools/google_cloud.py @@ -0,0 +1,168 @@ +import os +import tempfile +from getpass import getpass +from pathlib import Path + + +def parse_gs_url(url: str) -> tuple[str, str, str | None]: + """ + Parse a Google Cloud Storage URL into components. + + Args: + url: URL in format gs://bucket/path/to/file[@version] + + Returns: + Tuple of (bucket, file_path, version) + version is None if not specified + """ + if not url.startswith("gs://"): + raise ValueError( + f"Invalid gs:// URL format: {url}. " + "Expected format: gs://bucket/path/to/file[@version]" + ) + + # Remove the "gs://" prefix + path = url[5:] + parts = path.split("/", 1) + + if len(parts) < 2 or not parts[1]: + raise ValueError( + f"Invalid gs:// URL format: {url}. " + "Expected format: gs://bucket/path/to/file[@version]" + ) + + bucket = parts[0] + file_path = parts[1] + + version = None + if "@" in file_path: + file_path, version = file_path.rsplit("@", 1) + + return bucket, file_path, version + + +def download_gcs_file( + bucket: str, + file_path: str, + version: str = None, + local_path: str = None, +): + """ + Download a file from Google Cloud Storage. + + Args: + bucket: The GCS bucket name. + file_path: The path to the file within the bucket. + version: The generation/version of the file (optional). + local_path: The local path to save the file to. If None, downloads to a temp directory. + + Returns: + The local path where the file was saved. + """ + try: + from google.cloud import storage + import google.auth + except ImportError: + raise ImportError( + "google-cloud-storage is required for gs:// URLs. " + "Install it with: pip install google-cloud-storage" + ) + + credentials, project_id = _get_gcs_credentials() + + storage_client = storage.Client( + credentials=credentials, project=project_id + ) + + bucket_obj = storage_client.bucket(bucket) + blob = bucket_obj.blob(file_path) + + if version: + blob = bucket_obj.blob(file_path, generation=int(version)) + + if local_path is None: + # Download to a temp directory, preserving the filename + filename = Path(file_path).name + local_path = os.path.join(tempfile.gettempdir(), filename) + + blob.download_to_filename(local_path) + return local_path + + +def upload_gcs_file( + bucket: str, + file_path: str, + local_path: str, + version_metadata: str = None, +): + """ + Upload a file to Google Cloud Storage. + + Args: + bucket: The GCS bucket name. + file_path: The path to upload to within the bucket. + local_path: The local path of the file to upload. + version_metadata: Optional version string to store in blob metadata. + """ + try: + from google.cloud import storage + import google.auth + except ImportError: + raise ImportError( + "google-cloud-storage is required for gs:// URLs. " + "Install it with: pip install google-cloud-storage" + ) + + credentials, project_id = _get_gcs_credentials() + + storage_client = storage.Client( + credentials=credentials, project=project_id + ) + + bucket_obj = storage_client.bucket(bucket) + blob = bucket_obj.blob(file_path) + blob.upload_from_filename(local_path) + + if version_metadata: + blob.metadata = {"version": version_metadata} + blob.patch() + + return f"gs://{bucket}/{file_path}" + + +def _get_gcs_credentials(): + """ + Get GCS credentials, prompting for service account key path if needed. + + Returns: + Tuple of (credentials, project_id) + """ + try: + import google.auth + from google.auth import exceptions as auth_exceptions + except ImportError: + raise ImportError( + "google-cloud-storage is required for gs:// URLs. " + "Install it with: pip install google-cloud-storage" + ) + + # First try default credentials (e.g., from gcloud auth, service account, etc.) + try: + credentials, project_id = google.auth.default() + return credentials, project_id + except auth_exceptions.DefaultCredentialsError: + pass + + # If no default credentials, check for service account key in environment + key_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") + + if key_path is None: + key_path = getpass( + "Enter path to GCS service account key JSON " + "(or set GOOGLE_APPLICATION_CREDENTIALS): " + ) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = key_path + + # Try again with the provided credentials + credentials, project_id = google.auth.default() + return credentials, project_id diff --git a/policyengine_core/tools/hugging_face.py b/policyengine_core/tools/hugging_face.py index 65891e61e..887d44479 100644 --- a/policyengine_core/tools/hugging_face.py +++ b/policyengine_core/tools/hugging_face.py @@ -13,6 +13,36 @@ warnings.simplefilter("ignore") +def parse_hf_url(url: str) -> tuple[str, str, str, str | None]: + """ + Parse a Hugging Face URL into components. + + Args: + url: URL in format hf://owner/repo/path/to/file[@version] + + Returns: + Tuple of (owner, repo, file_path, version) + version is None if not specified + """ + parts = url.split("/")[2:] + + if len(parts) < 3: + raise ValueError( + f"Invalid hf:// URL format: {url}. " + "Expected format: hf://owner/repo/path/to/file[@version]" + ) + + owner = parts[0] + repo = parts[1] + file_path = "/".join(parts[2:]) + + version = None + if "@" in file_path: + file_path, version = file_path.rsplit("@", 1) + + return owner, repo, file_path, version + + def download_huggingface_dataset( repo: str, repo_filename: str, diff --git a/tests/core/tools/test_google_cloud.py b/tests/core/tools/test_google_cloud.py new file mode 100644 index 000000000..e9ba42ea1 --- /dev/null +++ b/tests/core/tools/test_google_cloud.py @@ -0,0 +1,63 @@ +import pytest +from policyengine_core.tools.google_cloud import parse_gs_url + + +class TestParseGsUrl: + def test_basic_url(self): + bucket, file_path, version = parse_gs_url("gs://my-bucket/file.h5") + assert (bucket, file_path, version) == ("my-bucket", "file.h5", None) + + def test_subdirectory_url(self): + bucket, file_path, version = parse_gs_url( + "gs://my-bucket/data/2024/file.h5" + ) + assert bucket == "my-bucket" + assert file_path == "data/2024/file.h5" + assert version is None + + def test_url_with_version(self): + bucket, file_path, version = parse_gs_url( + "gs://my-bucket/file.h5@12345" + ) + assert (file_path, version) == ("file.h5", "12345") + + def test_subdirectory_with_version(self): + bucket, file_path, version = parse_gs_url( + "gs://my-bucket/path/to/file.h5@67890" + ) + assert bucket == "my-bucket" + assert (file_path, version) == ("path/to/file.h5", "67890") + + def test_deep_subdirectory(self): + bucket, file_path, version = parse_gs_url( + "gs://my-bucket/a/b/c/d/e/file.h5" + ) + assert file_path == "a/b/c/d/e/file.h5" + + def test_invalid_url_no_gs_prefix(self): + with pytest.raises(ValueError, match="Invalid gs:// URL format"): + parse_gs_url("s3://my-bucket/file.h5") + + def test_invalid_url_no_file(self): + with pytest.raises(ValueError, match="Invalid gs:// URL format"): + parse_gs_url("gs://my-bucket") + + def test_invalid_url_no_file_with_slash(self): + with pytest.raises(ValueError, match="Invalid gs:// URL format"): + parse_gs_url("gs://my-bucket/") + + def test_bucket_with_dashes_and_dots(self): + bucket, file_path, version = parse_gs_url( + "gs://my-project.appspot.com/data/file.h5" + ) + assert bucket == "my-project.appspot.com" + assert file_path == "data/file.h5" + + def test_version_in_middle_of_path(self): + # @ in subdirectory name should NOT be treated as version separator + # Only the last @ should be used for version + bucket, file_path, version = parse_gs_url( + "gs://my-bucket/path@weird/file.h5@v1.0" + ) + assert file_path == "path@weird/file.h5" + assert version == "v1.0" diff --git a/tests/core/tools/test_hugging_face.py b/tests/core/tools/test_hugging_face.py index 54f04bd88..3dcda6134 100644 --- a/tests/core/tools/test_hugging_face.py +++ b/tests/core/tools/test_hugging_face.py @@ -1,11 +1,12 @@ import os import pytest -from unittest.mock import patch +from unittest.mock import patch, MagicMock from huggingface_hub import ModelInfo from huggingface_hub.errors import RepositoryNotFoundError from policyengine_core.tools.hugging_face import ( get_or_prompt_hf_token, download_huggingface_dataset, + parse_hf_url, ) @@ -55,8 +56,11 @@ def test_download_private_repo(self): with patch( "policyengine_core.tools.hugging_face.model_info" ) as mock_model_info: + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.headers = {} mock_model_info.side_effect = RepositoryNotFoundError( - "Test error" + "Test error", response=mock_response ) with patch( "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" @@ -88,8 +92,11 @@ def test_download_private_repo_no_token(self): with patch( "policyengine_core.tools.hugging_face.model_info" ) as mock_model_info: + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.headers = {} mock_model_info.side_effect = RepositoryNotFoundError( - "Test error" + "Test error", response=mock_response ) with patch( "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" @@ -156,3 +163,47 @@ def test_environment_variable_persistence(self): assert first_result == second_result == test_token assert os.environ.get("HUGGING_FACE_TOKEN") == test_token + + +class TestParseHfUrl: + def test_basic_url(self): + owner, repo, file_path, version = parse_hf_url( + "hf://owner/repo/file.h5" + ) + assert (owner, repo, file_path, version) == ( + "owner", + "repo", + "file.h5", + None, + ) + + def test_subdirectory_url(self): + owner, repo, file_path, version = parse_hf_url( + "hf://owner/repo/data/2024/file.h5" + ) + assert owner == "owner" + assert repo == "repo" + assert file_path == "data/2024/file.h5" + assert version is None + + def test_url_with_version(self): + owner, repo, file_path, version = parse_hf_url( + "hf://owner/repo/file.h5@v1.0" + ) + assert (file_path, version) == ("file.h5", "v1.0") + + def test_subdirectory_with_version(self): + owner, repo, file_path, version = parse_hf_url( + "hf://owner/repo/path/to/file.h5@v2.0" + ) + assert (file_path, version) == ("path/to/file.h5", "v2.0") + + def test_deep_subdirectory(self): + owner, repo, file_path, version = parse_hf_url( + "hf://owner/repo/a/b/c/d/e/file.h5" + ) + assert file_path == "a/b/c/d/e/file.h5" + + def test_invalid_url_too_short(self): + with pytest.raises(ValueError, match="Invalid hf:// URL format"): + parse_hf_url("hf://owner/repo")