Skip to content

Commit 8698c06

Browse files
authored
Merge pull request #406 from PolicyEngine/feature/hf-subdirectory-support
Add subdirectory support for hf:// URLs and gs:// scheme for GCS
2 parents f4ba960 + bf95770 commit 8698c06

7 files changed

Lines changed: 364 additions & 12 deletions

File tree

changelog_entry.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Subdirectory support for hf:// URLs (e.g., hf://owner/repo/path/to/file.h5).
5+
- Google Cloud Storage support with gs:// URL scheme (e.g., gs://bucket/path/to/file.h5).

policyengine_core/data/dataset.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import os
99
import tempfile
1010
from policyengine_core.tools.hugging_face import *
11+
from policyengine_core.tools.google_cloud import (
12+
parse_gs_url,
13+
download_gcs_file,
14+
upload_gcs_file,
15+
)
1116
import sys
1217
from policyengine_core.tools.win_file_manager import WindowsAtomicFileManager
1318

@@ -353,9 +358,22 @@ def download(self, url: str = None, version: str = None) -> None:
353358
f"File {file_path} not found in release {release_tag} of {org}/{repo}."
354359
)
355360
elif url.startswith("hf://"):
356-
owner_name, model_name, file_name = url.split("/")[2:]
361+
owner_name, model_name, file_name, hf_version = parse_hf_url(url)
357362
self.download_from_huggingface(
358-
owner_name, model_name, file_name, version
363+
owner_name, model_name, file_name, hf_version or version
364+
)
365+
return
366+
elif url.startswith("gs://"):
367+
bucket, file_path, gs_version = parse_gs_url(url)
368+
print(
369+
f"Downloading from GCS gs://{bucket}/{file_path}",
370+
file=sys.stderr,
371+
)
372+
downloaded_path = download_gcs_file(
373+
bucket=bucket,
374+
file_path=file_path,
375+
version=gs_version or version,
376+
local_path=str(self.file_path),
359377
)
360378
return
361379
else:
@@ -386,8 +404,19 @@ def upload(self, url: str = None):
386404
url = self.url
387405

388406
if url.startswith("hf://"):
389-
owner_name, model_name, file_name = url.split("/")[2:]
407+
owner_name, model_name, file_name, _ = parse_hf_url(url)
390408
self.upload_to_huggingface(owner_name, model_name, file_name)
409+
elif url.startswith("gs://"):
410+
bucket, file_path, _ = parse_gs_url(url)
411+
print(
412+
f"Uploading to GCS gs://{bucket}/{file_path}",
413+
file=sys.stderr,
414+
)
415+
upload_gcs_file(
416+
bucket=bucket,
417+
file_path=file_path,
418+
local_path=str(self.file_path),
419+
)
391420

392421
def remove(self):
393422
"""Removes the dataset from disk."""

policyengine_core/simulations/simulation.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
)
2424
import random
2525
from policyengine_core.tools.hugging_face import *
26+
from policyengine_core.tools.google_cloud import (
27+
parse_gs_url,
28+
download_gcs_file,
29+
)
2630

2731
import json
2832

@@ -160,17 +164,19 @@ def __init__(
160164
if dataset is not None:
161165
if isinstance(dataset, str):
162166
if "hf://" in dataset:
163-
owner, repo, filename = dataset.split("/")[-3:]
164-
if "@" in filename:
165-
version = filename.split("@")[-1]
166-
filename = filename.split("@")[0]
167-
else:
168-
version = None
167+
owner, repo, filename, version = parse_hf_url(dataset)
169168
dataset = download_huggingface_dataset(
170169
repo=f"{owner}/{repo}",
171170
repo_filename=filename,
172171
version=version,
173172
)
173+
elif "gs://" in dataset:
174+
bucket, file_path, version = parse_gs_url(dataset)
175+
dataset = download_gcs_file(
176+
bucket=bucket,
177+
file_path=file_path,
178+
version=version,
179+
)
174180
datasets_by_name = {
175181
dataset.name: dataset for dataset in self.datasets
176182
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import os
2+
import tempfile
3+
from getpass import getpass
4+
from pathlib import Path
5+
6+
7+
def parse_gs_url(url: str) -> tuple[str, str, str | None]:
8+
"""
9+
Parse a Google Cloud Storage URL into components.
10+
11+
Args:
12+
url: URL in format gs://bucket/path/to/file[@version]
13+
14+
Returns:
15+
Tuple of (bucket, file_path, version)
16+
version is None if not specified
17+
"""
18+
if not url.startswith("gs://"):
19+
raise ValueError(
20+
f"Invalid gs:// URL format: {url}. "
21+
"Expected format: gs://bucket/path/to/file[@version]"
22+
)
23+
24+
# Remove the "gs://" prefix
25+
path = url[5:]
26+
parts = path.split("/", 1)
27+
28+
if len(parts) < 2 or not parts[1]:
29+
raise ValueError(
30+
f"Invalid gs:// URL format: {url}. "
31+
"Expected format: gs://bucket/path/to/file[@version]"
32+
)
33+
34+
bucket = parts[0]
35+
file_path = parts[1]
36+
37+
version = None
38+
if "@" in file_path:
39+
file_path, version = file_path.rsplit("@", 1)
40+
41+
return bucket, file_path, version
42+
43+
44+
def download_gcs_file(
45+
bucket: str,
46+
file_path: str,
47+
version: str = None,
48+
local_path: str = None,
49+
):
50+
"""
51+
Download a file from Google Cloud Storage.
52+
53+
Args:
54+
bucket: The GCS bucket name.
55+
file_path: The path to the file within the bucket.
56+
version: The generation/version of the file (optional).
57+
local_path: The local path to save the file to. If None, downloads to a temp directory.
58+
59+
Returns:
60+
The local path where the file was saved.
61+
"""
62+
try:
63+
from google.cloud import storage
64+
import google.auth
65+
except ImportError:
66+
raise ImportError(
67+
"google-cloud-storage is required for gs:// URLs. "
68+
"Install it with: pip install google-cloud-storage"
69+
)
70+
71+
credentials, project_id = _get_gcs_credentials()
72+
73+
storage_client = storage.Client(
74+
credentials=credentials, project=project_id
75+
)
76+
77+
bucket_obj = storage_client.bucket(bucket)
78+
blob = bucket_obj.blob(file_path)
79+
80+
if version:
81+
blob = bucket_obj.blob(file_path, generation=int(version))
82+
83+
if local_path is None:
84+
# Download to a temp directory, preserving the filename
85+
filename = Path(file_path).name
86+
local_path = os.path.join(tempfile.gettempdir(), filename)
87+
88+
blob.download_to_filename(local_path)
89+
return local_path
90+
91+
92+
def upload_gcs_file(
93+
bucket: str,
94+
file_path: str,
95+
local_path: str,
96+
version_metadata: str = None,
97+
):
98+
"""
99+
Upload a file to Google Cloud Storage.
100+
101+
Args:
102+
bucket: The GCS bucket name.
103+
file_path: The path to upload to within the bucket.
104+
local_path: The local path of the file to upload.
105+
version_metadata: Optional version string to store in blob metadata.
106+
"""
107+
try:
108+
from google.cloud import storage
109+
import google.auth
110+
except ImportError:
111+
raise ImportError(
112+
"google-cloud-storage is required for gs:// URLs. "
113+
"Install it with: pip install google-cloud-storage"
114+
)
115+
116+
credentials, project_id = _get_gcs_credentials()
117+
118+
storage_client = storage.Client(
119+
credentials=credentials, project=project_id
120+
)
121+
122+
bucket_obj = storage_client.bucket(bucket)
123+
blob = bucket_obj.blob(file_path)
124+
blob.upload_from_filename(local_path)
125+
126+
if version_metadata:
127+
blob.metadata = {"version": version_metadata}
128+
blob.patch()
129+
130+
return f"gs://{bucket}/{file_path}"
131+
132+
133+
def _get_gcs_credentials():
134+
"""
135+
Get GCS credentials, prompting for service account key path if needed.
136+
137+
Returns:
138+
Tuple of (credentials, project_id)
139+
"""
140+
try:
141+
import google.auth
142+
from google.auth import exceptions as auth_exceptions
143+
except ImportError:
144+
raise ImportError(
145+
"google-cloud-storage is required for gs:// URLs. "
146+
"Install it with: pip install google-cloud-storage"
147+
)
148+
149+
# First try default credentials (e.g., from gcloud auth, service account, etc.)
150+
try:
151+
credentials, project_id = google.auth.default()
152+
return credentials, project_id
153+
except auth_exceptions.DefaultCredentialsError:
154+
pass
155+
156+
# If no default credentials, check for service account key in environment
157+
key_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
158+
159+
if key_path is None:
160+
key_path = getpass(
161+
"Enter path to GCS service account key JSON "
162+
"(or set GOOGLE_APPLICATION_CREDENTIALS): "
163+
)
164+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = key_path
165+
166+
# Try again with the provided credentials
167+
credentials, project_id = google.auth.default()
168+
return credentials, project_id

policyengine_core/tools/hugging_face.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,36 @@
1313
warnings.simplefilter("ignore")
1414

1515

16+
def parse_hf_url(url: str) -> tuple[str, str, str, str | None]:
17+
"""
18+
Parse a Hugging Face URL into components.
19+
20+
Args:
21+
url: URL in format hf://owner/repo/path/to/file[@version]
22+
23+
Returns:
24+
Tuple of (owner, repo, file_path, version)
25+
version is None if not specified
26+
"""
27+
parts = url.split("/")[2:]
28+
29+
if len(parts) < 3:
30+
raise ValueError(
31+
f"Invalid hf:// URL format: {url}. "
32+
"Expected format: hf://owner/repo/path/to/file[@version]"
33+
)
34+
35+
owner = parts[0]
36+
repo = parts[1]
37+
file_path = "/".join(parts[2:])
38+
39+
version = None
40+
if "@" in file_path:
41+
file_path, version = file_path.rsplit("@", 1)
42+
43+
return owner, repo, file_path, version
44+
45+
1646
def download_huggingface_dataset(
1747
repo: str,
1848
repo_filename: str,
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pytest
2+
from policyengine_core.tools.google_cloud import parse_gs_url
3+
4+
5+
class TestParseGsUrl:
6+
def test_basic_url(self):
7+
bucket, file_path, version = parse_gs_url("gs://my-bucket/file.h5")
8+
assert (bucket, file_path, version) == ("my-bucket", "file.h5", None)
9+
10+
def test_subdirectory_url(self):
11+
bucket, file_path, version = parse_gs_url(
12+
"gs://my-bucket/data/2024/file.h5"
13+
)
14+
assert bucket == "my-bucket"
15+
assert file_path == "data/2024/file.h5"
16+
assert version is None
17+
18+
def test_url_with_version(self):
19+
bucket, file_path, version = parse_gs_url(
20+
"gs://my-bucket/file.h5@12345"
21+
)
22+
assert (file_path, version) == ("file.h5", "12345")
23+
24+
def test_subdirectory_with_version(self):
25+
bucket, file_path, version = parse_gs_url(
26+
"gs://my-bucket/path/to/file.h5@67890"
27+
)
28+
assert bucket == "my-bucket"
29+
assert (file_path, version) == ("path/to/file.h5", "67890")
30+
31+
def test_deep_subdirectory(self):
32+
bucket, file_path, version = parse_gs_url(
33+
"gs://my-bucket/a/b/c/d/e/file.h5"
34+
)
35+
assert file_path == "a/b/c/d/e/file.h5"
36+
37+
def test_invalid_url_no_gs_prefix(self):
38+
with pytest.raises(ValueError, match="Invalid gs:// URL format"):
39+
parse_gs_url("s3://my-bucket/file.h5")
40+
41+
def test_invalid_url_no_file(self):
42+
with pytest.raises(ValueError, match="Invalid gs:// URL format"):
43+
parse_gs_url("gs://my-bucket")
44+
45+
def test_invalid_url_no_file_with_slash(self):
46+
with pytest.raises(ValueError, match="Invalid gs:// URL format"):
47+
parse_gs_url("gs://my-bucket/")
48+
49+
def test_bucket_with_dashes_and_dots(self):
50+
bucket, file_path, version = parse_gs_url(
51+
"gs://my-project.appspot.com/data/file.h5"
52+
)
53+
assert bucket == "my-project.appspot.com"
54+
assert file_path == "data/file.h5"
55+
56+
def test_version_in_middle_of_path(self):
57+
# @ in subdirectory name should NOT be treated as version separator
58+
# Only the last @ should be used for version
59+
bucket, file_path, version = parse_gs_url(
60+
"gs://my-bucket/path@weird/file.h5@v1.0"
61+
)
62+
assert file_path == "path@weird/file.h5"
63+
assert version == "v1.0"

0 commit comments

Comments
 (0)