Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: minor
changes:
added:
- Subdirectory support for hf:// URLs (e.g., hf://owner/repo/path/to/file.h5).
- Google Cloud Storage support with gs:// URL scheme (e.g., gs://bucket/path/to/file.h5).
35 changes: 32 additions & 3 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import os
import tempfile
from policyengine_core.tools.hugging_face import *
from policyengine_core.tools.google_cloud import (
parse_gs_url,
download_gcs_file,
upload_gcs_file,
)
import sys
from policyengine_core.tools.win_file_manager import WindowsAtomicFileManager

Expand Down Expand Up @@ -353,9 +358,22 @@ def download(self, url: str = None, version: str = None) -> None:
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
)
elif url.startswith("hf://"):
owner_name, model_name, file_name = url.split("/")[2:]
owner_name, model_name, file_name, hf_version = parse_hf_url(url)
self.download_from_huggingface(
owner_name, model_name, file_name, version
owner_name, model_name, file_name, hf_version or version
)
return
elif url.startswith("gs://"):
bucket, file_path, gs_version = parse_gs_url(url)
print(
f"Downloading from GCS gs://{bucket}/{file_path}",
file=sys.stderr,
)
downloaded_path = download_gcs_file(
bucket=bucket,
file_path=file_path,
version=gs_version or version,
local_path=str(self.file_path),
)
return
else:
Expand Down Expand Up @@ -386,8 +404,19 @@ def upload(self, url: str = None):
url = self.url

if url.startswith("hf://"):
owner_name, model_name, file_name = url.split("/")[2:]
owner_name, model_name, file_name, _ = parse_hf_url(url)
self.upload_to_huggingface(owner_name, model_name, file_name)
elif url.startswith("gs://"):
bucket, file_path, _ = parse_gs_url(url)
print(
f"Uploading to GCS gs://{bucket}/{file_path}",
file=sys.stderr,
)
upload_gcs_file(
bucket=bucket,
file_path=file_path,
local_path=str(self.file_path),
)

def remove(self):
"""Removes the dataset from disk."""
Expand Down
18 changes: 12 additions & 6 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
)
import random
from policyengine_core.tools.hugging_face import *
from policyengine_core.tools.google_cloud import (
parse_gs_url,
download_gcs_file,
)

import json

Expand Down Expand Up @@ -160,17 +164,19 @@ def __init__(
if dataset is not None:
if isinstance(dataset, str):
if "hf://" in dataset:
owner, repo, filename = dataset.split("/")[-3:]
if "@" in filename:
version = filename.split("@")[-1]
filename = filename.split("@")[0]
else:
version = None
owner, repo, filename, version = parse_hf_url(dataset)
dataset = download_huggingface_dataset(
repo=f"{owner}/{repo}",
repo_filename=filename,
version=version,
)
elif "gs://" in dataset:
bucket, file_path, version = parse_gs_url(dataset)
dataset = download_gcs_file(
bucket=bucket,
file_path=file_path,
version=version,
)
datasets_by_name = {
dataset.name: dataset for dataset in self.datasets
}
Expand Down
168 changes: 168 additions & 0 deletions policyengine_core/tools/google_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os
import tempfile
from getpass import getpass
from pathlib import Path


def parse_gs_url(url: str) -> tuple[str, str, str | None]:
"""
Parse a Google Cloud Storage URL into components.

Args:
url: URL in format gs://bucket/path/to/file[@version]

Returns:
Tuple of (bucket, file_path, version)
version is None if not specified
"""
if not url.startswith("gs://"):
raise ValueError(
f"Invalid gs:// URL format: {url}. "
"Expected format: gs://bucket/path/to/file[@version]"
)

# Remove the "gs://" prefix
path = url[5:]
parts = path.split("/", 1)

if len(parts) < 2 or not parts[1]:
raise ValueError(
f"Invalid gs:// URL format: {url}. "
"Expected format: gs://bucket/path/to/file[@version]"
)

bucket = parts[0]
file_path = parts[1]

version = None
if "@" in file_path:
file_path, version = file_path.rsplit("@", 1)

return bucket, file_path, version


def download_gcs_file(
bucket: str,
file_path: str,
version: str = None,
local_path: str = None,
):
"""
Download a file from Google Cloud Storage.

Args:
bucket: The GCS bucket name.
file_path: The path to the file within the bucket.
version: The generation/version of the file (optional).
local_path: The local path to save the file to. If None, downloads to a temp directory.

Returns:
The local path where the file was saved.
"""
try:
from google.cloud import storage
import google.auth
except ImportError:
raise ImportError(
"google-cloud-storage is required for gs:// URLs. "
"Install it with: pip install google-cloud-storage"
)

credentials, project_id = _get_gcs_credentials()

storage_client = storage.Client(
credentials=credentials, project=project_id
)

bucket_obj = storage_client.bucket(bucket)
blob = bucket_obj.blob(file_path)

if version:
blob = bucket_obj.blob(file_path, generation=int(version))

if local_path is None:
# Download to a temp directory, preserving the filename
filename = Path(file_path).name
local_path = os.path.join(tempfile.gettempdir(), filename)

blob.download_to_filename(local_path)
return local_path


def upload_gcs_file(
bucket: str,
file_path: str,
local_path: str,
version_metadata: str = None,
):
"""
Upload a file to Google Cloud Storage.

Args:
bucket: The GCS bucket name.
file_path: The path to upload to within the bucket.
local_path: The local path of the file to upload.
version_metadata: Optional version string to store in blob metadata.
"""
try:
from google.cloud import storage
import google.auth
except ImportError:
raise ImportError(
"google-cloud-storage is required for gs:// URLs. "
"Install it with: pip install google-cloud-storage"
)

credentials, project_id = _get_gcs_credentials()

storage_client = storage.Client(
credentials=credentials, project=project_id
)

bucket_obj = storage_client.bucket(bucket)
blob = bucket_obj.blob(file_path)
blob.upload_from_filename(local_path)

if version_metadata:
blob.metadata = {"version": version_metadata}
blob.patch()

return f"gs://{bucket}/{file_path}"


def _get_gcs_credentials():
"""
Get GCS credentials, prompting for service account key path if needed.

Returns:
Tuple of (credentials, project_id)
"""
try:
import google.auth
from google.auth import exceptions as auth_exceptions
except ImportError:
raise ImportError(
"google-cloud-storage is required for gs:// URLs. "
"Install it with: pip install google-cloud-storage"
)

# First try default credentials (e.g., from gcloud auth, service account, etc.)
try:
credentials, project_id = google.auth.default()
return credentials, project_id
except auth_exceptions.DefaultCredentialsError:
pass

# If no default credentials, check for service account key in environment
key_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")

if key_path is None:
key_path = getpass(
"Enter path to GCS service account key JSON "
"(or set GOOGLE_APPLICATION_CREDENTIALS): "
)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = key_path

# Try again with the provided credentials
credentials, project_id = google.auth.default()
return credentials, project_id
30 changes: 30 additions & 0 deletions policyengine_core/tools/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,36 @@
warnings.simplefilter("ignore")


def parse_hf_url(url: str) -> tuple[str, str, str, str | None]:
"""
Parse a Hugging Face URL into components.

Args:
url: URL in format hf://owner/repo/path/to/file[@version]

Returns:
Tuple of (owner, repo, file_path, version)
version is None if not specified
"""
parts = url.split("/")[2:]

if len(parts) < 3:
raise ValueError(
f"Invalid hf:// URL format: {url}. "
"Expected format: hf://owner/repo/path/to/file[@version]"
)

owner = parts[0]
repo = parts[1]
file_path = "/".join(parts[2:])

version = None
if "@" in file_path:
file_path, version = file_path.rsplit("@", 1)

return owner, repo, file_path, version


def download_huggingface_dataset(
repo: str,
repo_filename: str,
Expand Down
63 changes: 63 additions & 0 deletions tests/core/tools/test_google_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
from policyengine_core.tools.google_cloud import parse_gs_url


class TestParseGsUrl:
def test_basic_url(self):
bucket, file_path, version = parse_gs_url("gs://my-bucket/file.h5")
assert (bucket, file_path, version) == ("my-bucket", "file.h5", None)

def test_subdirectory_url(self):
bucket, file_path, version = parse_gs_url(
"gs://my-bucket/data/2024/file.h5"
)
assert bucket == "my-bucket"
assert file_path == "data/2024/file.h5"
assert version is None

def test_url_with_version(self):
bucket, file_path, version = parse_gs_url(
"gs://my-bucket/file.h5@12345"
)
assert (file_path, version) == ("file.h5", "12345")

def test_subdirectory_with_version(self):
bucket, file_path, version = parse_gs_url(
"gs://my-bucket/path/to/file.h5@67890"
)
assert bucket == "my-bucket"
assert (file_path, version) == ("path/to/file.h5", "67890")

def test_deep_subdirectory(self):
bucket, file_path, version = parse_gs_url(
"gs://my-bucket/a/b/c/d/e/file.h5"
)
assert file_path == "a/b/c/d/e/file.h5"

def test_invalid_url_no_gs_prefix(self):
with pytest.raises(ValueError, match="Invalid gs:// URL format"):
parse_gs_url("s3://my-bucket/file.h5")

def test_invalid_url_no_file(self):
with pytest.raises(ValueError, match="Invalid gs:// URL format"):
parse_gs_url("gs://my-bucket")

def test_invalid_url_no_file_with_slash(self):
with pytest.raises(ValueError, match="Invalid gs:// URL format"):
parse_gs_url("gs://my-bucket/")

def test_bucket_with_dashes_and_dots(self):
bucket, file_path, version = parse_gs_url(
"gs://my-project.appspot.com/data/file.h5"
)
assert bucket == "my-project.appspot.com"
assert file_path == "data/file.h5"

def test_version_in_middle_of_path(self):
# @ in subdirectory name should NOT be treated as version separator
# Only the last @ should be used for version
bucket, file_path, version = parse_gs_url(
"gs://my-bucket/path@weird/file.h5@v1.0"
)
assert file_path == "path@weird/file.h5"
assert version == "v1.0"
Loading