Skip to content

Commit 766405f

Browse files
Merge pull request #214 from fastlabel/feature/download-dataset-objects
download と get の引数を同じにする
2 parents 723e2a2 + 2acff17 commit 766405f

File tree

4 files changed

+79
-31
lines changed

4 files changed

+79
-31
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2639,7 +2639,10 @@ dataset_objects = client.get_dataset_objects(
26392639
dataset="YOUR_DATASET_NAME",
26402640
version="latest", # default is "latest"
26412641
tags=["cat"],
2642-
licenses=["MIT"]
2642+
licenses=["fastlabel"],
2643+
types=["train", "valid"], # choices are "train", "valid", "test" and "none" (Optional)
2644+
offset=0, # default is 0 (Optional)
2645+
limit=1000, # default is 1000, and must be less than 1000 (Optional)
26432646
)
26442647
```
26452648

@@ -2663,6 +2666,9 @@ client.download_dataset_objects(
26632666
version="latest", # default is "latest"
26642667
tags=["cat"],
26652668
types=["train", "valid"], # choices are "train", "valid", "test" and "none" (Optional)
2669+
licenses=["fastlabel"],
2670+
offset=0, # default is 0 (Optional)
2671+
limit=1000, # default is 1000, and must be less than 1000 (Optional)
26662672
)
26672673
```
26682674

fastlabel/__init__.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from .api import Api
3030
from .exceptions import FastLabelInvalidException
31+
from .query import DatasetObjectGetQuery
3132

3233
logger = logging.getLogger(__name__)
3334
logging.basicConfig(
@@ -3961,6 +3962,7 @@ def get_dataset_objects(
39613962
tags: Optional[List[str]] = None,
39623963
licenses: Optional[List[str]] = None,
39633964
revision_id: str = None,
3965+
types: Optional[List[Union[str, DatasetObjectType]]] = None,
39643966
offset: int = 0,
39653967
limit: int = 1000,
39663968
) -> list:
@@ -3973,6 +3975,31 @@ def get_dataset_objects(
39733975
revision_id is dataset rebision (Optional).
39743976
Only use specify one of revision_id or version.
39753977
"""
3978+
endpoint = "dataset-objects-v2"
3979+
types = [DatasetObjectType.create(type_) for type_ in types or []]
3980+
params = self._prepare_params(
3981+
dataset=dataset,
3982+
version=version,
3983+
tags=tags,
3984+
licenses=licenses,
3985+
revision_id=revision_id,
3986+
types=types,
3987+
offset=offset,
3988+
limit=limit,
3989+
)
3990+
return self.api.get_request(endpoint, params=params)
3991+
3992+
def _prepare_params(
3993+
self,
3994+
dataset: str,
3995+
offset: int,
3996+
limit: int,
3997+
version: str,
3998+
revision_id: str,
3999+
tags: Optional[List[str]],
4000+
licenses: Optional[List[str]],
4001+
types: Optional[List[DatasetObjectType]],
4002+
) -> DatasetObjectGetQuery:
39764003
if version and revision_id:
39774004
raise FastLabelInvalidException(
39784005
"only use specify one of revisionId or version.", 400
@@ -3981,56 +4008,47 @@ def get_dataset_objects(
39814008
raise FastLabelInvalidException(
39824009
"Limit must be less than or equal to 1000.", 422
39834010
)
3984-
endpoint = "dataset-objects-v2"
3985-
params = {"dataset": dataset, "offset": offset, "limit": limit}
4011+
params: DatasetObjectGetQuery = {
4012+
"dataset": dataset,
4013+
"offset": offset,
4014+
"limit": limit,
4015+
}
39864016
if revision_id:
39874017
params["revisionId"] = revision_id
39884018
if version:
39894019
params["version"] = version
3990-
3991-
tags = tags or []
39924020
if tags:
39934021
params["tags"] = tags
39944022
if licenses:
39954023
params["licenses"] = licenses
3996-
return self.api.get_request(endpoint, params=params)
4024+
if types:
4025+
params["types"] = [t.value for t in types]
4026+
return params
39974027

39984028
def download_dataset_objects(
39994029
self,
40004030
dataset: str,
40014031
path: str,
40024032
version: str = "",
4033+
revision_id: str = "",
40034034
tags: Optional[List[str]] = None,
4035+
licenses: Optional[List[str]] = None,
40044036
types: Optional[List[Union[str, DatasetObjectType]]] = None,
40054037
offset: int = 0,
40064038
limit: int = 1000,
40074039
):
40084040
endpoint = "dataset-objects-v2/signed-urls"
4009-
if limit > 1000:
4010-
raise FastLabelInvalidException(
4011-
"Limit must be less than or equal to 1000.", 422
4012-
)
4013-
params = {"dataset": dataset, "offset": offset, "limit": limit}
4014-
if version:
4015-
params["version"] = version
4016-
if tags:
4017-
params["tags"] = tags
4018-
if types:
4019-
try:
4020-
types = list(
4021-
map(
4022-
lambda t: t
4023-
if isinstance(t, DatasetObjectType)
4024-
else DatasetObjectType(t),
4025-
types,
4026-
)
4027-
)
4028-
except ValueError:
4029-
raise FastLabelInvalidException(
4030-
f"types must be {[k for k in DatasetObjectType.__members__.keys()]}.",
4031-
422,
4032-
)
4033-
params["types"] = [t.value for t in types]
4041+
types = [DatasetObjectType.create(type_) for type_ in types or []]
4042+
params = self._prepare_params(
4043+
dataset=dataset,
4044+
offset=offset,
4045+
limit=limit,
4046+
version=version,
4047+
revision_id=revision_id,
4048+
tags=tags,
4049+
types=types,
4050+
licenses=licenses,
4051+
)
40344052

40354053
response = self.api.get_request(endpoint, params=params)
40364054

fastlabel/const.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,15 @@ class DatasetObjectType(Enum):
254254
train = "train"
255255
valid = "valid"
256256
test = "test"
257+
258+
@classmethod
259+
def create(cls, value: "str | DatasetObjectType") -> "DatasetObjectType":
260+
if isinstance(value, cls):
261+
return value
262+
try:
263+
return cls(value)
264+
except ValueError:
265+
raise ValueError(
266+
f"Invalid DatasetObjectType: {value}. "
267+
f"types must be {[k for k in DatasetObjectType.__members__.keys()]}"
268+
)

fastlabel/query.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import List, Optional, TypedDict
2+
3+
4+
class DatasetObjectGetQuery(TypedDict, total=False):
5+
dataset: str
6+
version: str
7+
revisionId: str
8+
tags: Optional[List[str]]
9+
licenses: Optional[List[str]]
10+
types: Optional[List[str]]
11+
offset: int
12+
limit: int

0 commit comments

Comments
 (0)