Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def create(
# we don't have to reverse it back every time.
default_inputs.update(fixed_inputs)
lp._saved_inputs = default_inputs
lp._raw_fixed_inputs = fixed_inputs

if name in cls.CACHE:
raise AssertionError(f"Launch plan named {name} was already created! Make sure your names are unique.")
Expand Down Expand Up @@ -347,6 +348,7 @@ def __init__(
self._fixed_inputs = fixed_inputs
# See create() for additional information
self._saved_inputs: Dict[str, Any] = {}
self._raw_fixed_inputs: Dict[str, Any] = {}

self._schedule = schedule
self._notifications = notifications or []
Expand Down Expand Up @@ -423,6 +425,10 @@ def fixed_inputs(self) -> _literal_models.LiteralMap:
def workflow(self) -> _annotated_workflow.WorkflowBase:
return self._workflow

@property
def raw_fixed_inputs(self) -> Dict[str, Any]:
return self._raw_fixed_inputs.copy()

@property
def saved_inputs(self) -> Dict[str, Any]:
# See note in create()
Expand Down
10 changes: 10 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.node import Node as CoreNode
from flytekit.core.promise import translate_inputs_to_literals
from flytekit.core.python_auto_container import (
PICKLE_FILE_PATH,
PickledEntity,
Expand Down Expand Up @@ -1537,6 +1538,15 @@ def register_launch_plan(
entity.workflow, serialization_settings, version, default_launch_plan=False, options=options
)

if entity.raw_fixed_inputs:
fixed_literals = translate_inputs_to_literals(
self.context,
incoming_values=entity.raw_fixed_inputs,
flyte_interface_types=entity.workflow.interface.inputs,
native_types=entity.workflow.python_interface.inputs,
)
entity._fixed_inputs = literal_models.LiteralMap(literals=fixed_literals)

# Underlying workflow, exists, only register the launch plan itself
launch_plan_model = get_serializable(
OrderedDict(), settings=serialization_settings, entity=entity, options=options
Expand Down
51 changes: 51 additions & 0 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,54 @@ def hello_world_wf() -> str:
assert remote_lp is mock_remote_lp
assert not mock_serialize_and_register.called
assert mock_raw_register.called


@mock.patch("flytekit.remote.remote.translate_inputs_to_literals")
@mock.patch("flytekit.remote.remote.get_serializable")
@mock.patch("flytekit.remote.remote.FlyteRemote.fetch_launch_plan")
@mock.patch("flytekit.remote.remote.FlyteRemote.raw_register")
@mock.patch("flytekit.remote.remote.FlyteRemote._serialize_and_register")
def test_register_launch_plan_retranslates_fixed_inputs_with_remote_context(
mock_serialize_and_register, mock_raw_register, mock_fetch_launch_plan,
mock_get_serializable, mock_translate, mock_flyte_remote_client
):
from flytekit.types.file import FlyteFile

@task
def t_with_file(f: FlyteFile) -> str:
return str(f)

@workflow
def wf_with_file(f: FlyteFile) -> str:
return t_with_file(f=f)

with tempfile.NamedTemporaryFile() as tmp:
ff = FlyteFile(path=tmp.name)
lp = LaunchPlan.get_or_create(
workflow=wf_with_file,
name="lp_with_flytefile_fixed",
fixed_inputs={"f": ff},
)

assert lp.raw_fixed_inputs == {"f": ff}

rr = FlyteRemote(
Config.for_sandbox(),
default_project="flytesnacks",
default_domain="development",
)

mock_translate.return_value = {"f": MagicMock()}
mock_get_serializable.return_value = MagicMock()
mock_flyte_remote_client.get_workflow.return_value = MagicMock()
mock_fetch_launch_plan.return_value = MagicMock()

ss = SerializationSettings(image_config=ImageConfig.auto_default_image(), version="v1")
rr.register_launch_plan(lp, version="v1", serialization_settings=ss)

mock_translate.assert_called_once_with(
rr.context,
incoming_values={"f": ff},
flyte_interface_types=wf_with_file.interface.inputs,
native_types=wf_with_file.python_interface.inputs,
)
Loading