Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,176 @@ def my_pipeline():
self.assertEqual(my_pipeline.pipeline_spec.pipeline_info.display_name,
'my display name')

def test_set_display_name_for_container_task(self):
"""Test that display_name is used as DAG task key for container tasks."""

@dsl.component
def my_component(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
task = my_component(text='hello').set_display_name('My Task')
return task

my_pipeline()
# Check that the DAG task key uses the sanitized display_name
self.assertIn('my-task', my_pipeline.pipeline_spec.root.dag.tasks)
# Verify task_info.name is set correctly
task_spec = my_pipeline.pipeline_spec.root.dag.tasks['my-task']
self.assertEqual(task_spec.task_info.name, 'My Task')

def test_set_display_name_backward_compatibility(self):
"""Test that tasks without display_name use current behavior."""

@dsl.component
def my_component(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
task = my_component(text='hello')
return task

my_pipeline()
# Check that the DAG task key uses sanitized task name
# The task variable name would be sanitized
task_keys = list(my_pipeline.pipeline_spec.root.dag.tasks.keys())
self.assertEqual(len(task_keys), 1)
# Task name should be sanitized (component name gets prefixed with 'comp-')
# but the DAG task key uses the task variable name which is sanitized
# component name without prefix
self.assertGreater(len(task_keys[0]), 0)

def test_set_display_name_uniqueness(self):
"""Test that duplicate display_names get uniqueness suffix."""

@dsl.component
def my_component(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
task1 = my_component(text='hello').set_display_name('Same Name')
task2 = my_component(text='world').set_display_name('Same Name')
return task1, task2

my_pipeline()
# Check that both tasks exist with unique keys
task_keys = list(my_pipeline.pipeline_spec.root.dag.tasks.keys())
self.assertEqual(len(task_keys), 2)
# One should be 'same-name', the other 'same-name-2'
self.assertIn('same-name', task_keys)
self.assertIn('same-name-2', task_keys)

def test_set_display_name_long_name_truncation(self):
"""Test that very long display_names are truncated."""

@dsl.component
def my_component(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
# Create a very long display name (> 15 chars)
long_name = 'a' * 50
task = my_component(text='hello').set_display_name(long_name)
return task

my_pipeline()
# Check that the task key is truncated
task_keys = list(my_pipeline.pipeline_spec.root.dag.tasks.keys())
self.assertEqual(len(task_keys), 1)
# Task key should be truncated to max_task_name_length (15)
self.assertLessEqual(len(task_keys[0]), 20) # Allow some room for suffix

def test_set_display_name_in_loop(self):
"""Test that display_name works for tasks inside ParallelFor loops."""

@dsl.component
def my_component(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
with dsl.ParallelFor(items=['a', 'b', 'c']) as item:
task = my_component(text=item).set_display_name('Loop Task')
return task

my_pipeline()
# Find the loop component
loop_component = None
for comp_name, comp_spec in my_pipeline.pipeline_spec.components.items():
if comp_spec.dag and 'loop-task' in comp_spec.dag.tasks:
loop_component = comp_spec
break
self.assertIsNotNone(loop_component)
self.assertIn('loop-task', loop_component.dag.tasks)

def test_set_display_name_with_dependencies(self):
"""Test that task dependencies work correctly with display_name."""

@dsl.component
def producer(text: str) -> str:
return text

@dsl.component
def consumer(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
prod = producer(text='hello').set_display_name('Producer')
cons = consumer(text=prod.output).set_display_name('Consumer')
return cons

my_pipeline()
# Check that both tasks exist
self.assertIn('producer', my_pipeline.pipeline_spec.root.dag.tasks)
self.assertIn('consumer', my_pipeline.pipeline_spec.root.dag.tasks)
# Check that consumer depends on producer
consumer_spec = my_pipeline.pipeline_spec.root.dag.tasks['consumer']
self.assertIn('producer', consumer_spec.dependent_tasks)

def test_set_display_name_special_characters(self):
"""Test that display_name with special characters is sanitized."""

@dsl.component
def my_component(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
task = my_component(text='hello').set_display_name('My Task @#$%')
return task

my_pipeline()
# Check that special characters are sanitized
task_keys = list(my_pipeline.pipeline_spec.root.dag.tasks.keys())
self.assertEqual(len(task_keys), 1)
# Should only contain lowercase letters, numbers, and hyphens
self.assertTrue(all(c.isalnum() or c == '-' for c in task_keys[0]))

def test_set_display_name_empty_after_sanitization(self):
"""Test that empty display_name after sanitization falls back to task name."""

@dsl.component
def my_component(text: str) -> str:
return text

@dsl.pipeline(name='my-pipeline')
def my_pipeline():
# Display name that becomes empty after sanitization
task = my_component(text='hello').set_display_name('!!!')
return task

my_pipeline()
# Should fall back to task name
task_keys = list(my_pipeline.pipeline_spec.root.dag.tasks.keys())
self.assertEqual(len(task_keys), 1)
# Should not be empty
self.assertGreater(len(task_keys[0]), 0)

def test_set_description_through_pipeline_decorator(self):

@dsl.pipeline(description='Prefer me.')
Expand Down
Loading