Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,48 @@ def get_batch_size(t: Type) -> Optional[int]:
return None


class FileDownloadConfig:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follows same pattern as BatchSize:

class BatchSize:

"""
This is used to annotate a FlyteFile when we want to download the file with a specific extension. For example,

```python
# ContainerTask
def t1(file: Annotated[FlyteFile, FileDownloadConfig(file_extension="csv")]):
... # copilot downloads the file to e.g. /inputs/file.csv

versus...

def t1(file: FlyteFile["csv"]):
... # copilot downloads the file to e.g. /inputs/file
```

file_extension: (Default is "") The file extension (e.g. "csv", "parquet") to use during copilot download.
enable_legacy_filename: (Default is False) When true and file_extension is non-empty, the copilot download phase
writes the blob to both the full path (with extension) and the old path (without extension), preserving backward compatibility for
workflows with tasks that may read from both.
"""

def __init__(self, file_extension: str = "", enable_legacy_filename: bool = False):
self._file_extension = file_extension
self._enable_legacy_filename = enable_legacy_filename

@property
def file_extension(self) -> str:
return self._file_extension

@property
def enable_legacy_filename(self) -> bool:
return self._enable_legacy_filename


def get_file_download_config(t: Type) -> Optional[FileDownloadConfig]:
if is_annotated(t):
for arg in get_args(t):
if isinstance(arg, FileDownloadConfig):
return arg
return None


def modify_literal_uris(lit: Literal):
"""
Modifies the literal object recursively to replace the URIs with the native paths in case they are of
Expand Down
40 changes: 37 additions & 3 deletions flytekit/models/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,19 @@ class BlobDimensionality(object):
SINGLE = _types_pb2.BlobType.SINGLE
MULTIPART = _types_pb2.BlobType.MULTIPART

def __init__(self, format, dimensionality):
def __init__(self, format, dimensionality, file_extension="", enable_legacy_filename=False):
"""
:param Text format: A string describing the format of the underlying blob data.
:param int dimensionality: An integer from BlobType.BlobDimensionality enum
:param Text file_extension: The file extension (e.g. "csv", "parquet") to use
during copilot download, e.g. "csv", "parquet". Empty by default.
:param bool enable_legacy_filename: When True and file_extension is set, the copilot
download phase writes the blob to both the extended path and the base path.
"""
self._format = format
self._dimensionality = dimensionality
self._file_extension = file_extension
self._enable_legacy_filename = enable_legacy_filename

@property
def format(self):
Expand All @@ -62,16 +68,44 @@ def dimensionality(self):
"""
return self._dimensionality

@property
def file_extension(self):
"""
The file extension (e.g. "csv", "parquet") to use during copilot download.
Default is "", which means no extension is appended.
:rtype: Text
"""
return self._file_extension

@property
def enable_legacy_filename(self):
"""
When True and file_extension is set, the copilot download writes the blob to
both the full path (with extension) and the old path (without extension).
:rtype: bool
"""
return self._enable_legacy_filename

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.types_pb2.BlobType
"""
return _types_pb2.BlobType(format=self.format, dimensionality=self.dimensionality)
return _types_pb2.BlobType(
format=self.format,
dimensionality=self.dimensionality,
file_extension=self._file_extension,
enable_legacy_filename=self._enable_legacy_filename,
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.core.types_pb2.BlobType proto:
:rtype: BlobType
"""
return cls(format=proto.format, dimensionality=proto.dimensionality)
return cls(
format=proto.format,
dimensionality=proto.dimensionality,
file_extension=proto.file_extension,
enable_legacy_filename=proto.enable_legacy_filename,
)
35 changes: 31 additions & 4 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AsyncTypeTransformer,
TypeEngine,
TypeTransformerFailedError,
get_file_download_config,
get_underlying_type,
)
from flytekit.exceptions.user import FlyteAssertion
Expand Down Expand Up @@ -477,8 +478,26 @@ def get_format(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
return ""
return cast(FlyteFile, t).extension()

def _blob_type(self, format: str) -> BlobType:
return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)
@staticmethod
def get_file_extension(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
if t is os.PathLike:
return ""
file_download_config = get_file_download_config(t)
if file_download_config is None:
return ""
return file_download_config.file_extension or ""

@staticmethod
def get_enable_legacy_filename(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
if t is os.PathLike:
return False
file_download_config = get_file_download_config(t)
if file_download_config is None:
return False
return file_download_config.enable_legacy_filename or False

def _blob_type(self, format: str, file_extension: str = "", enable_legacy_filename: bool = False) -> BlobType:
return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE, file_extension=file_extension, enable_legacy_filename=enable_legacy_filename)

def assert_type(
self, t: typing.Union[typing.Type[FlyteFile], os.PathLike], v: typing.Union[FlyteFile, os.PathLike, str]
Expand All @@ -491,7 +510,11 @@ def assert_type(
)

def get_literal_type(self, t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> LiteralType:
return LiteralType(blob=self._blob_type(format=FlyteFilePathTransformer.get_format(t)))
return LiteralType(blob=self._blob_type(
format=FlyteFilePathTransformer.get_format(t),
file_extension=FlyteFilePathTransformer.get_file_extension(t),
enable_legacy_filename=FlyteFilePathTransformer.get_enable_legacy_filename(t),
))

def get_mime_type_from_extension(self, extension: str) -> typing.Union[str, typing.Sequence[str]]:
extension_to_mime_type = {
Expand Down Expand Up @@ -565,7 +588,11 @@ async def async_to_literal(
raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike")

# information used by all cases
meta = BlobMetadata(type=self._blob_type(format=FlyteFilePathTransformer.get_format(python_type)))
meta = BlobMetadata(type=self._blob_type(
format=FlyteFilePathTransformer.get_format(python_type),
file_extension=FlyteFilePathTransformer.get_file_extension(python_type),
enable_legacy_filename=FlyteFilePathTransformer.get_enable_legacy_filename(python_type),
))

if isinstance(python_val, FlyteFile):
# Cast the source path to str type to avoid error raised when the source path is used as the blob uri,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.16.1,<2.0.0a0",
"flyteidl @ git+https://github.com/dominodatalab/flyteidl.git@af517f6",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump this after flyteorg/flyte#7009 merges

"fsspec>=2023.3.0",
# Bug in 2025.5.0, 2025.5.0post1 https://github.com/fsspec/gcsfs/issues/687
# Bug in 2024.2.0 https://github.com/fsspec/gcsfs/pull/643
Expand Down
14 changes: 13 additions & 1 deletion tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.type_engine import FileDownloadConfig, TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
Expand Down Expand Up @@ -764,6 +764,18 @@ def test_headers():
assert len(FlyteFilePathTransformer.get_additional_headers(".gz")) == 1


def test_transform_flytefile_with_file_download_config():
csv_file_no_config = FlyteFile["csv"]
lt = FlyteFilePathTransformer().get_literal_type(csv_file_no_config)
assert lt.blob.file_extension == ""
assert lt.blob.enable_legacy_filename == False

legacy_file = Annotated[FlyteFile["csv"], FileDownloadConfig(file_extension="csv", enable_legacy_filename=True)]
lt = FlyteFilePathTransformer().get_literal_type(legacy_file)
assert lt.blob.file_extension == "csv"
assert lt.blob.enable_legacy_filename == True


def test_new_remote_file():
nf = FlyteFile.new_remote_file(name="foo.txt")
assert isinstance(nf, FlyteFile)
Expand Down
22 changes: 22 additions & 0 deletions tests/flytekit/unit/models/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,33 @@ def test_blob_type():
)
assert o.format == "csv"
assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o.file_extension == ""
assert o.enable_legacy_filename == False

o2 = _types.BlobType.from_flyte_idl(o.to_flyte_idl())
assert o == o2
assert o2.format == "csv"
assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o2.file_extension == ""
assert o2.enable_legacy_filename == False

o = _types.BlobType(
format="csv",
dimensionality=_types.BlobType.BlobDimensionality.SINGLE,
file_extension="csv",
enable_legacy_filename=True,
)
assert o.format == "csv"
assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o.file_extension == "csv"
assert o.enable_legacy_filename == True

o2 = _types.BlobType.from_flyte_idl(o.to_flyte_idl())
assert o == o2
assert o2.format == "csv"
assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o2.file_extension == "csv"
assert o2.enable_legacy_filename == True


def test_enum_type():
Expand Down
Loading