From 8ec4ae44e1a3c9278691143b9cd18337f69d19e3 Mon Sep 17 00:00:00 2001 From: ddl-rliu Date: Thu, 5 Mar 2026 14:37:04 -0800 Subject: [PATCH] Add FileDownloadConfig annotation for FlyteFile inputs Port new BlobType fields file_extension and enable_legacy_filename to flytekit. FlyteFile inputs can be annotated with the FileDownloadConfig annotation to configure the file extension to use during the copilot download phase. e.g. ```python def t1(file: Annotated[FlyteFile, FileDownloadConfig(file_extension="csv")]): ... # copilot downloads the file to e.g. /inputs/file.csv versus... def t1(file: FlyteFile["csv"]): ... # copilot downloads the file to e.g. /inputs/file ``` --- flytekit/core/type_engine.py | 42 +++++++++++++++++++ flytekit/models/core/types.py | 40 ++++++++++++++++-- flytekit/types/file/file.py | 35 ++++++++++++++-- pyproject.toml | 2 +- tests/flytekit/unit/core/test_flyte_file.py | 14 ++++++- tests/flytekit/unit/models/core/test_types.py | 22 ++++++++++ 6 files changed, 146 insertions(+), 9 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9993c98479..d2f9b0a94a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -105,6 +105,48 @@ def get_batch_size(t: Type) -> Optional[int]: return None +class FileDownloadConfig: + """ + This is used to annotate a FlyteFile when we want to download the file with a specific extension. For example, + + ```python + # ContainerTask + def t1(file: Annotated[FlyteFile, FileDownloadConfig(file_extension="csv")]): + ... # copilot downloads the file to e.g. /inputs/file.csv + + versus... + + def t1(file: FlyteFile["csv"]): + ... # copilot downloads the file to e.g. /inputs/file + ``` + + file_extension: (Default is "") The file extension (e.g. "csv", "parquet") to use during copilot download. + enable_legacy_filename: (Default is False) When true and file_extension is non-empty, the copilot download phase + writes the blob to both the full path (with extension) and the old path (without extension), preserving backward compatibility for + workflows with tasks that may read from both. + """ + + def __init__(self, file_extension: str = "", enable_legacy_filename: bool = False): + self._file_extension = file_extension + self._enable_legacy_filename = enable_legacy_filename + + @property + def file_extension(self) -> str: + return self._file_extension + + @property + def enable_legacy_filename(self) -> bool: + return self._enable_legacy_filename + + +def get_file_download_config(t: Type) -> Optional[FileDownloadConfig]: + if is_annotated(t): + for arg in get_args(t): + if isinstance(arg, FileDownloadConfig): + return arg + return None + + def modify_literal_uris(lit: Literal): """ Modifies the literal object recursively to replace the URIs with the native paths in case they are of diff --git a/flytekit/models/core/types.py b/flytekit/models/core/types.py index 4508961bbc..e01068f95e 100644 --- a/flytekit/models/core/types.py +++ b/flytekit/models/core/types.py @@ -38,13 +38,19 @@ class BlobDimensionality(object): SINGLE = _types_pb2.BlobType.SINGLE MULTIPART = _types_pb2.BlobType.MULTIPART - def __init__(self, format, dimensionality): + def __init__(self, format, dimensionality, file_extension="", enable_legacy_filename=False): """ :param Text format: A string describing the format of the underlying blob data. :param int dimensionality: An integer from BlobType.BlobDimensionality enum + :param Text file_extension: The file extension (e.g. "csv", "parquet") to use + during copilot download, e.g. "csv", "parquet". Empty by default. + :param bool enable_legacy_filename: When True and file_extension is set, the copilot + download phase writes the blob to both the extended path and the base path. """ self._format = format self._dimensionality = dimensionality + self._file_extension = file_extension + self._enable_legacy_filename = enable_legacy_filename @property def format(self): @@ -62,11 +68,34 @@ def dimensionality(self): """ return self._dimensionality + @property + def file_extension(self): + """ + The file extension (e.g. "csv", "parquet") to use during copilot download. + Default is "", which means no extension is appended. + :rtype: Text + """ + return self._file_extension + + @property + def enable_legacy_filename(self): + """ + When True and file_extension is set, the copilot download writes the blob to + both the full path (with extension) and the old path (without extension). + :rtype: bool + """ + return self._enable_legacy_filename + def to_flyte_idl(self): """ :rtype: flyteidl.core.types_pb2.BlobType """ - return _types_pb2.BlobType(format=self.format, dimensionality=self.dimensionality) + return _types_pb2.BlobType( + format=self.format, + dimensionality=self.dimensionality, + file_extension=self._file_extension, + enable_legacy_filename=self._enable_legacy_filename, + ) @classmethod def from_flyte_idl(cls, proto): @@ -74,4 +103,9 @@ def from_flyte_idl(cls, proto): :param flyteidl.core.types_pb2.BlobType proto: :rtype: BlobType """ - return cls(format=proto.format, dimensionality=proto.dimensionality) + return cls( + format=proto.format, + dimensionality=proto.dimensionality, + file_extension=proto.file_extension, + enable_legacy_filename=proto.enable_legacy_filename, + ) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 47915add8e..bdf772b9fd 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -24,6 +24,7 @@ AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, + get_file_download_config, get_underlying_type, ) from flytekit.exceptions.user import FlyteAssertion @@ -477,8 +478,26 @@ def get_format(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str: return "" return cast(FlyteFile, t).extension() - def _blob_type(self, format: str) -> BlobType: - return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE) + @staticmethod + def get_file_extension(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str: + if t is os.PathLike: + return "" + file_download_config = get_file_download_config(t) + if file_download_config is None: + return "" + return file_download_config.file_extension or "" + + @staticmethod + def get_enable_legacy_filename(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str: + if t is os.PathLike: + return False + file_download_config = get_file_download_config(t) + if file_download_config is None: + return False + return file_download_config.enable_legacy_filename or False + + def _blob_type(self, format: str, file_extension: str = "", enable_legacy_filename: bool = False) -> BlobType: + return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE, file_extension=file_extension, enable_legacy_filename=enable_legacy_filename) def assert_type( self, t: typing.Union[typing.Type[FlyteFile], os.PathLike], v: typing.Union[FlyteFile, os.PathLike, str] @@ -491,7 +510,11 @@ def assert_type( ) def get_literal_type(self, t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> LiteralType: - return LiteralType(blob=self._blob_type(format=FlyteFilePathTransformer.get_format(t))) + return LiteralType(blob=self._blob_type( + format=FlyteFilePathTransformer.get_format(t), + file_extension=FlyteFilePathTransformer.get_file_extension(t), + enable_legacy_filename=FlyteFilePathTransformer.get_enable_legacy_filename(t), + )) def get_mime_type_from_extension(self, extension: str) -> typing.Union[str, typing.Sequence[str]]: extension_to_mime_type = { @@ -565,7 +588,11 @@ async def async_to_literal( raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") # information used by all cases - meta = BlobMetadata(type=self._blob_type(format=FlyteFilePathTransformer.get_format(python_type))) + meta = BlobMetadata(type=self._blob_type( + format=FlyteFilePathTransformer.get_format(python_type), + file_extension=FlyteFilePathTransformer.get_file_extension(python_type), + enable_legacy_filename=FlyteFilePathTransformer.get_enable_legacy_filename(python_type), + )) if isinstance(python_val, FlyteFile): # Cast the source path to str type to avoid error raised when the source path is used as the blob uri, diff --git a/pyproject.toml b/pyproject.toml index 82b8c6c054..7b02665795 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.16.1,<2.0.0a0", + "flyteidl @ git+https://github.com/dominodatalab/flyteidl.git@af517f6", "fsspec>=2023.3.0", # Bug in 2025.5.0, 2025.5.0post1 https://github.com/fsspec/gcsfs/issues/687 # Bug in 2024.2.0 https://github.com/fsspec/gcsfs/pull/643 diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index fb0903c567..9681a990fb 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -17,7 +17,7 @@ from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import FileDownloadConfig, TypeEngine from flytekit.core.workflow import workflow from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap, Blob, BlobMetadata @@ -764,6 +764,18 @@ def test_headers(): assert len(FlyteFilePathTransformer.get_additional_headers(".gz")) == 1 +def test_transform_flytefile_with_file_download_config(): + csv_file_no_config = FlyteFile["csv"] + lt = FlyteFilePathTransformer().get_literal_type(csv_file_no_config) + assert lt.blob.file_extension == "" + assert lt.blob.enable_legacy_filename == False + + legacy_file = Annotated[FlyteFile["csv"], FileDownloadConfig(file_extension="csv", enable_legacy_filename=True)] + lt = FlyteFilePathTransformer().get_literal_type(legacy_file) + assert lt.blob.file_extension == "csv" + assert lt.blob.enable_legacy_filename == True + + def test_new_remote_file(): nf = FlyteFile.new_remote_file(name="foo.txt") assert isinstance(nf, FlyteFile) diff --git a/tests/flytekit/unit/models/core/test_types.py b/tests/flytekit/unit/models/core/test_types.py index 21d6cea396..bf4124eb67 100644 --- a/tests/flytekit/unit/models/core/test_types.py +++ b/tests/flytekit/unit/models/core/test_types.py @@ -15,11 +15,33 @@ def test_blob_type(): ) assert o.format == "csv" assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE + assert o.file_extension == "" + assert o.enable_legacy_filename == False o2 = _types.BlobType.from_flyte_idl(o.to_flyte_idl()) assert o == o2 assert o2.format == "csv" assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE + assert o2.file_extension == "" + assert o2.enable_legacy_filename == False + + o = _types.BlobType( + format="csv", + dimensionality=_types.BlobType.BlobDimensionality.SINGLE, + file_extension="csv", + enable_legacy_filename=True, + ) + assert o.format == "csv" + assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE + assert o.file_extension == "csv" + assert o.enable_legacy_filename == True + + o2 = _types.BlobType.from_flyte_idl(o.to_flyte_idl()) + assert o == o2 + assert o2.format == "csv" + assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE + assert o2.file_extension == "csv" + assert o2.enable_legacy_filename == True def test_enum_type():