From 56eef2d68a0b39d5f6be33d3f2ac6389deb1af4c Mon Sep 17 00:00:00 2001 From: Aditi Sharma <165942273+Aditi2424@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:58:41 -0800 Subject: [PATCH 01/21] Feature store v3 (#5490) * feat: Add Feature Store Support to V3 * Add feature store tests --------- Co-authored-by: adishaa --- .../src/sagemaker/mlops/feature_store/feature_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py index 3e3e7813df..5ee04780be 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -731,4 +731,4 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: """ for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") - return data_frame \ No newline at end of file + return data_frame From d5fd78330c175f3145d27ca32c2341601f51e238 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Thu, 19 Feb 2026 11:26:08 -0800 Subject: [PATCH 02/21] feat: feature_processor v3 --- .../sagemaker/core/helper/session_helper.py | 220 ++ .../tests/unit/session/test_session_helper.py | 545 +++ .../feature_processor/__init__.py | 45 + .../feature_processor/_config_uploader.py | 209 ++ .../feature_processor/_constants.py | 54 + .../feature_processor/_data_source.py | 154 + .../feature_store/feature_processor/_enums.py | 33 + .../feature_store/feature_processor/_env.py | 78 + .../_event_bridge_rule_helper.py | 305 ++ .../_event_bridge_scheduler_helper.py | 118 + .../feature_processor/_exceptions.py | 18 + .../feature_processor/_factory.py | 167 + .../_feature_processor_config.py | 72 + .../_feature_processor_pipeline_events.py | 29 + .../feature_processor/_input_loader.py | 366 ++ .../feature_processor/_input_offset_parser.py | 129 + .../feature_processor/_params_loader.py | 83 + .../feature_processor/_spark_factory.py | 202 ++ .../feature_processor/_udf_arg_provider.py | 239 ++ .../feature_processor/_udf_output_receiver.py | 98 + .../feature_processor/_udf_wrapper.py | 88 + .../feature_processor/_validation.py | 210 ++ .../feature_processor/feature_processor.py | 129 + .../feature_processor/feature_scheduler.py | 1100 ++++++ .../feature_processor/lineage/__init__.py | 0 .../lineage/_feature_group_contexts.py | 31 + .../_feature_group_lineage_entity_handler.py | 182 + .../lineage/_feature_processor_lineage.py | 759 +++++ .../_feature_processor_lineage_name_helper.py | 101 + .../lineage/_lineage_association_handler.py | 300 ++ .../_pipeline_lineage_entity_handler.py | 105 + .../lineage/_pipeline_schedule.py | 44 + .../lineage/_pipeline_trigger.py | 36 + ...pipeline_version_lineage_entity_handler.py | 92 + .../lineage/_s3_lineage_entity_handler.py | 316 ++ .../lineage/_transformation_code.py | 31 + .../feature_processor/lineage/constants.py | 43 + .../lineage/test_constants.py | 401 +++ ...st_feature_group_lineage_entity_handler.py | 62 + .../lineage/test_feature_processor_lineage.py | 2966 +++++++++++++++++ .../test_lineage_association_handler.py | 224 ++ .../test_pipeline_lineage_entity_handler.py | 74 + .../lineage/test_pipeline_trigger.py | 33 + ...pipeline_version_lineage_entity_handler.py | 67 + .../lineage/test_s3_lineage_entity_handler.py | 434 +++ .../feature_processor/test_config_uploader.py | 317 ++ .../feature_processor/test_data_helpers.py | 166 + .../feature_processor/test_data_source.py | 34 + .../feature_processor/test_env.py | 122 + .../test_event_bridge_rule_helper.py | 301 ++ .../test_event_bridge_scheduler_helper.py | 96 + .../feature_processor/test_factory.py | 75 + .../test_feature_processor.py | 122 + .../test_feature_processor_config.py | 46 + .../test_feature_processor_pipeline_events.py | 30 + .../test_feature_scheduler.py | 1057 ++++++ .../feature_processor/test_input_loader.py | 320 ++ .../test_input_offset_parser.py | 143 + .../feature_processor/test_params_loader.py | 86 + .../test_spark_session_factory.py | 174 + .../test_udf_arg_provider.py | 280 ++ .../test_udf_output_receiver.py | 106 + .../feature_processor/test_udf_wrapper.py | 85 + .../feature_processor/test_validation.py | 192 ++ 64 files changed, 14744 insertions(+) create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 41957e30a2..b4c327ec09 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -85,6 +85,10 @@ TAGS, SESSION_DEFAULT_S3_BUCKET_PATH, SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, + FEATURE_GROUP, + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, ) # Setting LOGGER for backward compatibility, in case users import it... @@ -1607,6 +1611,222 @@ def delete_endpoint_config(self, endpoint_config_name): logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name) self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + def delete_feature_group(self, feature_group_name): + """Delete an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to delete. + """ + logger.info("Deleting feature group with name: %s", feature_group_name) + self.sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name) + + def create_feature_group( + self, + feature_group_name, + record_identifier_name, + event_time_feature_name, + feature_definitions, + role_arn=None, + online_store_config=None, + offline_store_config=None, + throughput_config=None, + description=None, + tags=None, + ): + """Create an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + record_identifier_name (str): Name of the record identifier feature. + event_time_feature_name (str): Name of the event time feature. + feature_definitions (list): List of feature definitions. + role_arn (str): ARN of the role used to execute the API (default: None). + Resolved from SageMaker Config if not provided. + online_store_config (dict): Online store configuration (default: None). + offline_store_config (dict): Offline store configuration (default: None). + throughput_config (dict): Throughput configuration (default: None). + description (str): Description of the Feature Group (default: None). + tags (Optional[Tags]): Tags for labeling the Feature Group (default: None). + + Returns: + dict: Response from the CreateFeatureGroup API. + """ + tags = format_tags(tags) + tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) + ) + + role_arn = resolve_value_from_config( + role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self + ) + + inferred_online_store_config = update_nested_dictionary_with_values_from_config( + online_store_config, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + sagemaker_session=self, + ) + if inferred_online_store_config is not None: + # OnlineStore should be handled differently because if you set KmsKeyId, then you + # need to set EnableOnlineStore key as well + inferred_online_store_config["EnableOnlineStore"] = True + + inferred_offline_store_config = update_nested_dictionary_with_values_from_config( + offline_store_config, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, + sagemaker_session=self, + ) + + kwargs = dict( + FeatureGroupName=feature_group_name, + RecordIdentifierFeatureName=record_identifier_name, + EventTimeFeatureName=event_time_feature_name, + FeatureDefinitions=feature_definitions, + RoleArn=role_arn, + ) + update_args( + kwargs, + OnlineStoreConfig=inferred_online_store_config, + OfflineStoreConfig=inferred_offline_store_config, + ThroughputConfig=throughput_config, + Description=description, + Tags=tags, + ) + + logger.info("Creating feature group with name: %s", feature_group_name) + return self.sagemaker_client.create_feature_group(**kwargs) + + def describe_feature_group(self, feature_group_name, next_token=None): + """Describe an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to describe. + next_token (str): A token for paginated results (default: None). + + Returns: + dict: Response from the DescribeFeatureGroup API. + """ + args = {"FeatureGroupName": feature_group_name} + update_args(args, NextToken=next_token) + return self.sagemaker_client.describe_feature_group(**args) + + def update_feature_group( + self, + feature_group_name, + feature_additions=None, + online_store_config=None, + throughput_config=None, + ): + """Update an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to update. + feature_additions (list): List of feature definitions to add (default: None). + online_store_config (dict): Online store configuration updates (default: None). + throughput_config (dict): Throughput configuration updates (default: None). + + Returns: + dict: Response from the UpdateFeatureGroup API. + """ + args = {"FeatureGroupName": feature_group_name} + update_args( + args, + FeatureAdditions=feature_additions, + OnlineStoreConfig=online_store_config, + ThroughputConfig=throughput_config, + ) + return self.sagemaker_client.update_feature_group(**args) + + def list_feature_groups( + self, + name_contains=None, + feature_group_status_equals=None, + offline_store_status_equals=None, + creation_time_after=None, + creation_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ): + """List Amazon SageMaker Feature Groups. + + Args: + name_contains (str): Filter by name substring (default: None). + feature_group_status_equals (str): Filter by status (default: None). + offline_store_status_equals (str): Filter by offline store status (default: None). + creation_time_after (datetime): Filter by creation time lower bound (default: None). + creation_time_before (datetime): Filter by creation time upper bound (default: None). + sort_order (str): Sort order, 'Ascending' or 'Descending' (default: None). + sort_by (str): Sort by field (default: None). + max_results (int): Maximum number of results (default: None). + next_token (str): Pagination token (default: None). + + Returns: + dict: Response from the ListFeatureGroups API. + """ + args = {} + update_args( + args, + NameContains=name_contains, + FeatureGroupStatusEquals=feature_group_status_equals, + OfflineStoreStatusEquals=offline_store_status_equals, + CreationTimeAfter=creation_time_after, + CreationTimeBefore=creation_time_before, + SortOrder=sort_order, + SortBy=sort_by, + MaxResults=max_results, + NextToken=next_token, + ) + return self.sagemaker_client.list_feature_groups(**args) + + def update_feature_metadata( + self, + feature_group_name, + feature_name, + description=None, + parameter_additions=None, + parameter_removals=None, + ): + """Update metadata for a feature in an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + feature_name (str): Name of the feature to update metadata for. + description (str): Updated description for the feature (default: None). + parameter_additions (list): Parameters to add (default: None). + parameter_removals (list): Parameters to remove (default: None). + + Returns: + dict: Response from the UpdateFeatureMetadata API. + """ + args = { + "FeatureGroupName": feature_group_name, + "FeatureName": feature_name, + } + update_args( + args, + Description=description, + ParameterAdditions=parameter_additions, + ParameterRemovals=parameter_removals, + ) + return self.sagemaker_client.update_feature_metadata(**args) + + def describe_feature_metadata(self, feature_group_name, feature_name): + """Describe metadata for a feature in an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + feature_name (str): Name of the feature to describe metadata for. + + Returns: + dict: Response from the DescribeFeatureMetadata API. + """ + return self.sagemaker_client.describe_feature_metadata( + FeatureGroupName=feature_group_name, + FeatureName=feature_name, + ) + def wait_for_optimization_job(self, job, poll=5): """Wait for an Amazon SageMaker Optimization job to complete. diff --git a/sagemaker-core/tests/unit/session/test_session_helper.py b/sagemaker-core/tests/unit/session/test_session_helper.py index ca4fd81aa8..7e2004c1d0 100644 --- a/sagemaker-core/tests/unit/session/test_session_helper.py +++ b/sagemaker-core/tests/unit/session/test_session_helper.py @@ -29,6 +29,11 @@ update_args, NOTEBOOK_METADATA_FILE, ) +from sagemaker.core.config.config_schema import ( + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, +) class TestSession: @@ -451,3 +456,543 @@ def test_update_args_with_none_values(self): assert args["existing"] == "value" assert "new_key" not in args assert args["another_key"] == "another_value" + +class TestFeatureGroupSessionMethods: + """Test cases for Feature Group session methods""" + + @pytest.fixture + def session_with_mock_client(self): + """Create a Session with a mocked sagemaker_client.""" + mock_boto_session = Mock() + mock_boto_session.region_name = "us-west-2" + mock_boto_session.client.return_value = Mock() + mock_boto_session.resource.return_value = Mock() + session = Session(boto_session=mock_boto_session) + session.sagemaker_client = Mock() + return session + + # --- delete_feature_group --- + + def test_delete_feature_group(self, session_with_mock_client): + """Test delete_feature_group delegates to sagemaker_client.""" + session = session_with_mock_client + session.delete_feature_group("my-feature-group") + + session.sagemaker_client.delete_feature_group.assert_called_once_with( + FeatureGroupName="my-feature-group" + ) + + # --- describe_feature_group --- + + def test_describe_feature_group(self, session_with_mock_client): + """Test describe_feature_group delegates and returns response.""" + session = session_with_mock_client + expected = {"FeatureGroupName": "my-fg", "CreationTime": "2024-01-01"} + session.sagemaker_client.describe_feature_group.return_value = expected + + result = session.describe_feature_group("my-fg") + + session.sagemaker_client.describe_feature_group.assert_called_once_with( + FeatureGroupName="my-fg" + ) + assert result == expected + + def test_describe_feature_group_with_next_token(self, session_with_mock_client): + """Test describe_feature_group includes NextToken when provided.""" + session = session_with_mock_client + session.sagemaker_client.describe_feature_group.return_value = {} + + session.describe_feature_group("my-fg", next_token="abc123") + + session.sagemaker_client.describe_feature_group.assert_called_once_with( + FeatureGroupName="my-fg", NextToken="abc123" + ) + + def test_describe_feature_group_omits_none_next_token(self, session_with_mock_client): + """Test describe_feature_group omits NextToken when None.""" + session = session_with_mock_client + session.sagemaker_client.describe_feature_group.return_value = {} + + session.describe_feature_group("my-fg", next_token=None) + + call_kwargs = session.sagemaker_client.describe_feature_group.call_args[1] + assert "NextToken" not in call_kwargs + + # --- update_feature_group --- + + def test_update_feature_group_all_params(self, session_with_mock_client): + """Test update_feature_group with all optional params provided.""" + session = session_with_mock_client + expected = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123:feature-group/my-fg"} + session.sagemaker_client.update_feature_group.return_value = expected + + additions = [{"FeatureName": "new_feat", "FeatureType": "String"}] + online_cfg = {"EnableOnlineStore": True} + throughput_cfg = {"ThroughputMode": "OnDemand"} + + result = session.update_feature_group( + "my-fg", + feature_additions=additions, + online_store_config=online_cfg, + throughput_config=throughput_cfg, + ) + + session.sagemaker_client.update_feature_group.assert_called_once_with( + FeatureGroupName="my-fg", + FeatureAdditions=additions, + OnlineStoreConfig=online_cfg, + ThroughputConfig=throughput_cfg, + ) + assert result == expected + + def test_update_feature_group_omits_none_params(self, session_with_mock_client): + """Test update_feature_group omits None optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_group.return_value = {} + + session.update_feature_group("my-fg") + + call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] + assert call_kwargs == {"FeatureGroupName": "my-fg"} + + def test_update_feature_group_partial_params(self, session_with_mock_client): + """Test update_feature_group with only some optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_group.return_value = {} + + throughput_cfg = {"ThroughputMode": "Provisioned"} + session.update_feature_group("my-fg", throughput_config=throughput_cfg) + + call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "ThroughputConfig": throughput_cfg, + } + + # --- list_feature_groups --- + + def test_list_feature_groups_no_params(self, session_with_mock_client): + """Test list_feature_groups with no filters delegates with empty args.""" + session = session_with_mock_client + expected = {"FeatureGroupSummaries": []} + session.sagemaker_client.list_feature_groups.return_value = expected + + result = session.list_feature_groups() + + session.sagemaker_client.list_feature_groups.assert_called_once_with() + assert result == expected + + def test_list_feature_groups_all_params(self, session_with_mock_client): + """Test list_feature_groups with all params provided.""" + session = session_with_mock_client + session.sagemaker_client.list_feature_groups.return_value = {} + + session.list_feature_groups( + name_contains="test", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after="2024-01-01", + creation_time_before="2024-12-31", + sort_order="Ascending", + sort_by="Name", + max_results=10, + next_token="token123", + ) + + session.sagemaker_client.list_feature_groups.assert_called_once_with( + NameContains="test", + FeatureGroupStatusEquals="Created", + OfflineStoreStatusEquals="Active", + CreationTimeAfter="2024-01-01", + CreationTimeBefore="2024-12-31", + SortOrder="Ascending", + SortBy="Name", + MaxResults=10, + NextToken="token123", + ) + + def test_list_feature_groups_omits_none_params(self, session_with_mock_client): + """Test list_feature_groups omits None params.""" + session = session_with_mock_client + session.sagemaker_client.list_feature_groups.return_value = {} + + session.list_feature_groups(name_contains="test", max_results=5) + + call_kwargs = session.sagemaker_client.list_feature_groups.call_args[1] + assert call_kwargs == {"NameContains": "test", "MaxResults": 5} + + # --- update_feature_metadata --- + + def test_update_feature_metadata_all_params(self, session_with_mock_client): + """Test update_feature_metadata with all optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + additions = [{"Key": "team", "Value": "ml"}] + removals = [{"Key": "deprecated"}] + + result = session.update_feature_metadata( + "my-fg", + "my-feature", + description="Updated desc", + parameter_additions=additions, + parameter_removals=removals, + ) + + session.sagemaker_client.update_feature_metadata.assert_called_once_with( + FeatureGroupName="my-fg", + FeatureName="my-feature", + Description="Updated desc", + ParameterAdditions=additions, + ParameterRemovals=removals, + ) + assert result == {} + + def test_update_feature_metadata_omits_none_params(self, session_with_mock_client): + """Test update_feature_metadata omits None optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + session.update_feature_metadata("my-fg", "my-feature") + + call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "FeatureName": "my-feature", + } + + def test_update_feature_metadata_partial_params(self, session_with_mock_client): + """Test update_feature_metadata with only description.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + session.update_feature_metadata("my-fg", "my-feature", description="New desc") + + call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "FeatureName": "my-feature", + "Description": "New desc", + } + + # --- describe_feature_metadata --- + + def test_describe_feature_metadata(self, session_with_mock_client): + """Test describe_feature_metadata delegates and returns response.""" + session = session_with_mock_client + expected = {"FeatureGroupName": "my-fg", "FeatureName": "my-feature"} + session.sagemaker_client.describe_feature_metadata.return_value = expected + + result = session.describe_feature_metadata("my-fg", "my-feature") + + session.sagemaker_client.describe_feature_metadata.assert_called_once_with( + FeatureGroupName="my-fg", FeatureName="my-feature" + ) + assert result == expected + +MODULE = "sagemaker.core.helper.session_helper" + + +class TestCreateFeatureGroup: + """Test cases for create_feature_group session method.""" + + @pytest.fixture + def session(self): + """Create a Session with a mocked sagemaker_client.""" + mock_boto_session = Mock() + mock_boto_session.region_name = "us-west-2" + mock_boto_session.client.return_value = Mock() + mock_boto_session.resource.return_value = Mock() + session = Session(boto_session=mock_boto_session) + session.sagemaker_client = Mock() + return session + + @pytest.fixture + def base_args(self): + """Minimal required arguments for create_feature_group.""" + return dict( + feature_group_name="my-fg", + record_identifier_name="record_id", + event_time_feature_name="event_time", + feature_definitions=[{"FeatureName": "f1", "FeatureType": "String"}], + ) + + # --- Full parameter pass-through --- + + def test_create_feature_group_all_params(self, session, base_args): + """Test that all parameters are passed through to sagemaker_client.""" + role = "arn:aws:iam::123456789012:role/Role" + online_cfg = {"SecurityConfig": {"KmsKeyId": "key-123"}} + offline_cfg = {"S3StorageConfig": {"S3Uri": "s3://bucket"}} + throughput_cfg = {"ThroughputMode": "ON_DEMAND"} + description = "My feature group" + tags = [{"Key": "team", "Value": "ml"}] + + expected_response = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-fg"} + session.sagemaker_client.create_feature_group.return_value = expected_response + + with patch(f"{MODULE}.format_tags", return_value=tags) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=tags) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=tags), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=role), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", side_effect=[online_cfg, offline_cfg]): + + result = session.create_feature_group( + **base_args, + role_arn=role, + online_store_config=online_cfg, + offline_store_config=offline_cfg, + throughput_config=throughput_cfg, + description=description, + tags=tags, + ) + + assert result == expected_response + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["FeatureGroupName"] == "my-fg" + assert call_kwargs["RecordIdentifierFeatureName"] == "record_id" + assert call_kwargs["EventTimeFeatureName"] == "event_time" + assert call_kwargs["FeatureDefinitions"] == base_args["feature_definitions"] + assert call_kwargs["RoleArn"] == role + # EnableOnlineStore is set to True when online config is inferred + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + assert call_kwargs["OfflineStoreConfig"] == offline_cfg + assert call_kwargs["ThroughputConfig"] == throughput_cfg + assert call_kwargs["Description"] == description + assert call_kwargs["Tags"] == tags + + # --- Tag processing pipeline --- + + def test_tag_processing_pipeline_order(self, session, base_args): + """Test that tags go through format_tags -> _append_project_tags -> _append_sagemaker_config_tags.""" + raw_tags = {"team": "ml"} + formatted = [{"Key": "team", "Value": "ml"}] + with_project = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}] + with_config = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}, {"Key": "cfg", "Value": "v"}] + + with patch(f"{MODULE}.format_tags", return_value=formatted) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=with_project) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=with_config) as mock_cfg, \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, tags=raw_tags) + + # format_tags is called with the raw input + mock_format.assert_called_once_with(raw_tags) + # _append_project_tags receives the formatted tags + mock_proj.assert_called_once_with(formatted) + # _append_sagemaker_config_tags receives the project-appended tags + mock_cfg.assert_called_once_with(with_project, "SageMaker.FeatureGroup.Tags") + + # Final tags in the API call should be the config-appended tags + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["Tags"] == with_config + + def test_tags_none_still_processed(self, session, base_args): + """Test that None tags still go through the pipeline (format_tags handles None).""" + with patch(f"{MODULE}.format_tags", return_value=None) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=None) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, tags=None) + + mock_format.assert_called_once_with(None) + mock_proj.assert_called_once_with(None) + # Tags=None should be omitted from the API call via update_args + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "Tags" not in call_kwargs + + # --- role_arn resolution from config --- + + def test_role_arn_resolved_from_config_when_none(self, session, base_args): + """Test that role_arn is resolved from SageMaker Config when not provided.""" + config_role = "arn:aws:iam::123456789012:role/ConfigRole" + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=config_role) as mock_resolve, \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, role_arn=None) + + mock_resolve.assert_called_once_with( + None, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["RoleArn"] == config_role + + def test_role_arn_passed_through_when_provided(self, session, base_args): + """Test that an explicit role_arn is passed to resolve_value_from_config (which returns it).""" + explicit_role = "arn:aws:iam::123456789012:role/ExplicitRole" + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=explicit_role) as mock_resolve, \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, role_arn=explicit_role) + + mock_resolve.assert_called_once_with( + explicit_role, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["RoleArn"] == explicit_role + + # --- online_store_config merging and EnableOnlineStore --- + + def test_online_store_config_merged_and_enable_set(self, session, base_args): + """Test that online_store_config is merged from config and EnableOnlineStore=True is set.""" + inferred_online = {"SecurityConfig": {"KmsKeyId": "config-key"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[inferred_online, None]) as mock_update: + + session.create_feature_group(**base_args, online_store_config=None) + + # First call is for online store config + mock_update.assert_any_call( + None, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + assert call_kwargs["OnlineStoreConfig"]["SecurityConfig"]["KmsKeyId"] == "config-key" + + def test_online_store_config_none_when_no_config(self, session, base_args): + """Test that OnlineStoreConfig is omitted when config returns None.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "OnlineStoreConfig" not in call_kwargs + + def test_online_store_config_explicit_gets_enable_set(self, session, base_args): + """Test that explicitly provided online_store_config also gets EnableOnlineStore=True.""" + explicit_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} + # update_nested_dictionary returns the merged result + merged_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[merged_online, None]): + + session.create_feature_group(**base_args, online_store_config=explicit_online) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + + # --- offline_store_config merging --- + + def test_offline_store_config_merged_from_config(self, session, base_args): + """Test that offline_store_config is merged from SageMaker Config.""" + inferred_offline = {"S3StorageConfig": {"S3Uri": "s3://config-bucket"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[None, inferred_offline]) as mock_update: + + session.create_feature_group(**base_args, offline_store_config=None) + + # Second call is for offline store config + mock_update.assert_any_call( + None, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OfflineStoreConfig"] == inferred_offline + + def test_offline_store_config_none_when_no_config(self, session, base_args): + """Test that OfflineStoreConfig is omitted when config returns None.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "OfflineStoreConfig" not in call_kwargs + + # --- None optional parameters omitted --- + + def test_none_optional_params_omitted(self, session, base_args): + """Test that None optional params (throughput, description, tags) are omitted from API call.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "ThroughputConfig" not in call_kwargs + assert "Description" not in call_kwargs + assert "Tags" not in call_kwargs + assert "OnlineStoreConfig" not in call_kwargs + assert "OfflineStoreConfig" not in call_kwargs + # Required params should still be present + assert "FeatureGroupName" in call_kwargs + assert "RecordIdentifierFeatureName" in call_kwargs + assert "EventTimeFeatureName" in call_kwargs + assert "FeatureDefinitions" in call_kwargs + assert "RoleArn" in call_kwargs + + def test_partial_optional_params(self, session, base_args): + """Test that only provided optional params appear in the API call.""" + throughput = {"ThroughputMode": "ON_DEMAND"} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group( + **base_args, + throughput_config=throughput, + description="test desc", + ) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["ThroughputConfig"] == throughput + assert call_kwargs["Description"] == "test desc" + assert "Tags" not in call_kwargs + assert "OnlineStoreConfig" not in call_kwargs + assert "OfflineStoreConfig" not in call_kwargs + + # --- Return value --- + + def test_returns_api_response(self, session, base_args): + """Test that the method returns the sagemaker_client response.""" + expected = {"FeatureGroupArn": "arn:fg"} + session.sagemaker_client.create_feature_group.return_value = expected + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + result = session.create_feature_group(**base_args) + + assert result == expected diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py new file mode 100644 index 0000000000..1051096d0e --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Exported classes for the sagemaker.mlops.feature_store.feature_processor module.""" +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( # noqa: F401 + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, + PySparkDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._exceptions import ( # noqa: F401 + IngestionError, +) +from sagemaker.mlops.feature_store.feature_processor.feature_processor import ( # noqa: F401 + feature_processor, +) +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( # noqa: F401 + to_pipeline, + schedule, + describe, + put_trigger, + delete_trigger, + enable_trigger, + disable_trigger, + delete_schedule, + list_pipelines, + execute, + TransformationCode, + FeatureProcessorPipelineEvents, +) +from sagemaker.mlops.feature_store.feature_processor._enums import ( # noqa: F401 + FeatureProcessorPipelineExecutionStatus, +) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py new file mode 100644 index 0000000000..d181218fb5 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py @@ -0,0 +1,209 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for preparing and uploading configs for a scheduled feature processor.""" +from __future__ import absolute_import +from typing import Callable, Dict, Optional, Tuple, List, Union + +import attr + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import ( + SPARK_JAR_FILES_PATH, + SPARK_PY_FILES_PATH, + SPARK_FILES_PATH, + S3_DATA_DISTRIBUTION_TYPE, +) +from sagemaker.core.inputs import TrainingInput +from sagemaker.core.shapes import Channel, DataSource, S3DataSource +from sagemaker.core.remote_function.core.stored_function import StoredFunction +from sagemaker.core.remote_function.job import ( + _prepare_and_upload_workspace, + _prepare_and_upload_runtime_scripts, + _JobSettings, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, + _prepare_and_upload_spark_dependent_files, +) +from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, +) +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.core.s3 import s3_path_join + + +@attr.s +class ConfigUploader: + """Prepares and uploads customer provided configs to S3""" + + remote_decorator_config: _JobSettings = attr.ib() + runtime_env_manager: RuntimeEnvironmentManager = attr.ib() + + def prepare_step_input_channel_for_spark_mode( + self, func: Callable, s3_base_uri: str, sagemaker_session: Session + ) -> Tuple[List[Channel], Dict]: + """Prepares input channels for SageMaker Pipeline Step. + + Returns: + Tuple of (List[Channel], spark_dependency_paths dict) + """ + self._prepare_and_upload_callable(func, s3_base_uri, sagemaker_session) + bootstrap_scripts_s3uri = self._prepare_and_upload_runtime_scripts( + self.remote_decorator_config.spark_config, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + ) + dependencies_list_path = self.runtime_env_manager.snapshot( + self.remote_decorator_config.dependencies + ) + user_workspace_s3uri = self._prepare_and_upload_workspace( + dependencies_list_path, + self.remote_decorator_config.include_local_workdir, + self.remote_decorator_config.pre_execution_commands, + self.remote_decorator_config.pre_execution_script, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + self.remote_decorator_config.custom_file_filter, + ) + + ( + submit_jars_s3_paths, + submit_py_files_s3_paths, + submit_files_s3_path, + config_file_s3_uri, + ) = self._prepare_and_upload_spark_dependent_files( + self.remote_decorator_config.spark_config, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + ) + + channels = [ + Channel( + channel_name=RUNTIME_SCRIPTS_CHANNEL_NAME, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=bootstrap_scripts_s3uri, + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ] + + if user_workspace_s3uri: + channels.append( + Channel( + channel_name=REMOTE_FUNCTION_WORKSPACE, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE), + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ) + + if config_file_s3_uri: + channels.append( + Channel( + channel_name=SPARK_CONF_CHANNEL_NAME, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=config_file_s3_uri, + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ) + + return channels, { + SPARK_JAR_FILES_PATH: submit_jars_s3_paths, + SPARK_PY_FILES_PATH: submit_py_files_s3_paths, + SPARK_FILES_PATH: submit_files_s3_path, + } + + def _prepare_and_upload_callable( + self, func: Callable, s3_base_uri: str, sagemaker_session: Session + ) -> None: + """Prepares and uploads callable to S3""" + stored_function = StoredFunction( + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=self.remote_decorator_config.s3_kms_key, + ) + stored_function.save(func) + + def _prepare_and_upload_workspace( + self, + local_dependencies_path: str, + include_local_workdir: bool, + pre_execution_commands: List[str], + pre_execution_script_local_path: str, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, + ) -> str: + """Upload the training step dependencies to S3 if present""" + return _prepare_and_upload_workspace( + local_dependencies_path=local_dependencies_path, + include_local_workdir=include_local_workdir, + pre_execution_commands=pre_execution_commands, + pre_execution_script_local_path=pre_execution_script_local_path, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=custom_file_filter, + ) + + def _prepare_and_upload_runtime_scripts( + self, + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + ) -> str: + """Copy runtime scripts to a folder and upload to S3""" + return _prepare_and_upload_runtime_scripts( + spark_config=spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + def _prepare_and_upload_spark_dependent_files( + self, + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + ) -> Tuple: + """Upload the spark dependencies to S3 if present""" + if not spark_config: + return None, None, None, None + + return _prepare_and_upload_spark_dependent_files( + spark_config=spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py new file mode 100644 index 0000000000..e010446904 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing constants for feature_processor and feature_scheduler module.""" +from __future__ import absolute_import + +from sagemaker.core.workflow.parameters import Parameter, ParameterTypeEnum + +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" +DEFAULT_SCHEDULE_STATE = "ENABLED" +DEFAULT_TRIGGER_STATE = "ENABLED" +UNDERSCORE = "_" +RESOURCE_NOT_FOUND_EXCEPTION = "ResourceNotFoundException" +RESOURCE_NOT_FOUND = "ResourceNotFound" +EXECUTION_TIME_PIPELINE_PARAMETER = "scheduled_time" +VALIDATION_EXCEPTION = "ValidationException" +EVENT_BRIDGE_INVOCATION_TIME = "" +SCHEDULED_TIME_PIPELINE_PARAMETER = Parameter( + name=EXECUTION_TIME_PIPELINE_PARAMETER, parameter_type=ParameterTypeEnum.STRING +) +EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT = "%Y-%m-%dT%H:%M:%SZ" # 2023-01-01T07:00:00Z +NO_FLEXIBLE_TIME_WINDOW = dict(Mode="OFF") +PIPELINE_NAME_MAXIMUM_LENGTH = 80 +PIPELINE_CONTEXT_TYPE = "FeatureEngineeringPipeline" +SPARK_JAR_FILES_PATH = "submit_jars_s3_paths" +SPARK_PY_FILES_PATH = "submit_py_files_s3_paths" +SPARK_FILES_PATH = "submit_files_s3_path" +FEATURE_PROCESSOR_TAG_KEY = "sm-fs-fe:created-from" +FEATURE_PROCESSOR_TAG_VALUE = "fp-to-pipeline" +FEATURE_GROUP_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):feature-group/(.*?)$" +PIPELINE_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):pipeline/(.*?)$" +EVENTBRIDGE_RULE_ARN_REGEX_PATTERN = r"arn:(.*?):events:(.*?):(.*?):rule/(.*?)$" +SAGEMAKER_WHL_FILE_S3_PATH = "s3://ada-private-beta/sagemaker-2.151.1.dev0-py2.py3-none-any.whl" +S3_DATA_DISTRIBUTION_TYPE = "FullyReplicated" +PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name" +PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name" +TO_PIPELINE_RESERVED_TAG_KEYS = [ + FEATURE_PROCESSOR_TAG_KEY, + PIPELINE_CONTEXT_NAME_TAG_KEY, + PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, +] +BASE_EVENT_PATTERN = { + "source": ["aws.sagemaker"], + "detail": {"currentPipelineExecutionStatus": [], "pipelineArn": []}, +} diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py new file mode 100644 index 0000000000..a6c452267c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes to define input data sources.""" +from __future__ import absolute_import + +from typing import Optional, Dict, Union, TypeVar, Generic +from abc import ABC, abstractmethod +from pyspark.sql import DataFrame, SparkSession + + +import attr + +T = TypeVar("T") + + +@attr.s +class BaseDataSource(Generic[T], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the specified type. + """ + + @abstractmethod + def read_data(self, *args, **kwargs) -> T: + """Read data from data source and return the specified type. + + Args: + args: Arguments for reading the data. + kwargs: Keyword argument for reading the data. + Returns: + T: The specified abstraction of data source. + """ + + @property + @abstractmethod + def data_source_unique_id(self) -> str: + """The identifier for the customized feature processor data source. + + Returns: + str: The data source unique id. + """ + + @property + @abstractmethod + def data_source_name(self) -> str: + """The name for the customized feature processor data source. + + Returns: + str: The data source name. + """ + + +@attr.s +class PySparkDataSource(BaseDataSource[DataFrame], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the Spark DataFrame. + """ + + @abstractmethod + def read_data( + self, spark: SparkSession, params: Optional[Dict[str, Union[str, Dict]]] = None + ) -> DataFrame: + """Read data from data source and convert the data to Spark DataFrame. + + Args: + spark (SparkSession): The Spark session to read the data. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. + Returns: + DataFrame: The Spark DataFrame as an abstraction on the data source. + """ + + +@attr.s +class FeatureGroupDataSource: + """A Feature Group data source definition for a FeatureProcessor. + + Attributes: + name (str): The name or ARN of the Feature Group. + input_start_offset (Optional[str], optional): A duration specified as a string in the + format ' ' where 'no' is a number and 'unit' is a unit of time in ['hours', + 'days', 'weeks', 'months', 'years'] (plural and singular forms). Inputs contain data + with event times no earlier than input_start_offset in the past. Offsets are relative + to the function execution time. If the function is executed by a Schedule, then the + offset is relative to the scheduled start time. Defaults to None. + input_end_offset (Optional[str], optional): The 'end' (as opposed to start) counterpart for + the 'input_start_offset'. Inputs will contain records with event times no later than + 'input_end_offset' in the past. Defaults to None. + """ + + name: str = attr.ib() + input_start_offset: Optional[str] = attr.ib(default=None) + input_end_offset: Optional[str] = attr.ib(default=None) + + +@attr.s +class CSVDataSource: + """An CSV data source definition for a FeatureProcessor. + + Attributes: + s3_uri (str): S3 URI of the data source. + csv_header (bool): Whether to read the first line of the CSV file as column names. This + option is only valid when file_format is set to csv. By default the value of this + option is true, and all column types are assumed to be a string. + infer_schema (bool): Whether to infer the schema of the CSV data source. This option is only + valid when file_format is set to csv. If set to true, two passes of the data is required + to load and infer the schema. + """ + + s3_uri: str = attr.ib() + csv_header: bool = attr.ib(default=True) + csv_infer_schema: bool = attr.ib(default=False) + + +@attr.s +class ParquetDataSource: + """An parquet data source definition for a FeatureProcessor. + + Attributes: + s3_uri (str): S3 URI of the data source. + """ + + s3_uri: str = attr.ib() + + +@attr.s +class IcebergTableDataSource: + """An iceberg table data source definition for FeatureProcessor + + Attributes: + warehouse_s3_uri (str): S3 URI of data warehouse. The value is usually + the URI where data is stored. + catalog (str): Name of the catalog. + database (str): Name of the database. + table (str): Name of the table. + """ + + warehouse_s3_uri: str = attr.ib() + catalog: str = attr.ib() + database: str = attr.ib() + table: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py new file mode 100644 index 0000000000..b63ed3a65a --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing Enums for the feature_processor module.""" +from __future__ import absolute_import + +from enum import Enum + + +class FeatureProcessorMode(Enum): + """Enum of feature_processor modes.""" + + PYSPARK = "pyspark" # Execute a pyspark job. + PYTHON = "python" # Execute a regular python script. + + +class FeatureProcessorPipelineExecutionStatus(Enum): + """Enum of feature_processor pipeline execution status.""" + + EXECUTING = "Executing" + STOPPING = "Stopping" + STOPPED = "Stopped" + FAILED = "Failed" + SUCCEEDED = "Succeeded" diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py new file mode 100644 index 0000000000..d4ccfb1197 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class that determines the current execution environment.""" +from __future__ import absolute_import + + +from typing import Dict, Optional +from datetime import datetime, timezone +import json +import logging +import os +import attr +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) + + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EnvironmentHelper: + """Helper class to retrieve info from environment. + + Attributes: + current_time (datetime): The current datetime. + """ + + current_time = attr.ib(default=datetime.now(timezone.utc)) + + def is_training_job(self) -> bool: + """Determine if the current execution environment is inside a SageMaker Training Job""" + return self.load_training_resource_config() is not None + + def get_instance_count(self) -> int: + """Determine the number of instances for the current execution environment.""" + resource_config = self.load_training_resource_config() + return len(resource_config["hosts"]) if resource_config else 1 + + def load_training_resource_config(self) -> Optional[Dict]: + """Load the contents of resourceconfig.json contents. + + Returns: + Optional[Dict]: None if not found. + """ + SM_TRAINING_CONFIG_FILE_PATH = "/opt/ml/input/config/resourceconfig.json" + try: + with open(SM_TRAINING_CONFIG_FILE_PATH, "r") as cfgfile: + resource_config = json.load(cfgfile) + logger.debug("Contents of %s: %s", SM_TRAINING_CONFIG_FILE_PATH, resource_config) + return resource_config + except FileNotFoundError: + return None + + def get_job_scheduled_time(self) -> str: + """Get the job scheduled time. + + Returns: + str: Timestamp when the job is scheduled. + """ + + scheduled_time = self.current_time.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + if self.is_training_job(): + envs = dict(os.environ) + return envs.get(EXECUTION_TIME_PIPELINE_PARAMETER, scheduled_time) + + return scheduled_time diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py new file mode 100644 index 0000000000..250e7d456f --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py @@ -0,0 +1,305 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for EventBridge Schedule management for a feature processor.""" +from __future__ import absolute_import + +import json +import logging +import re +from typing import Dict, List, Tuple, Optional, Any +import attr +from botocore.exceptions import ClientError +from botocore.paginate import PageIterator +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + RESOURCE_NOT_FOUND_EXCEPTION, + PIPELINE_ARN_REGEX_PATTERN, + BASE_EVENT_PATTERN, +) +from sagemaker.mlops.feature_store.feature_processor._enums import ( + FeatureProcessorPipelineExecutionStatus, +) +from sagemaker.core.common_utils import TagsDict + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EventBridgeRuleHelper: + """Contains helper methods for managing EventBridge rules for a feature processor.""" + + sagemaker_session: Session = attr.ib() + event_bridge_rule_client = attr.ib() + + def put_rule( + self, + source_pipeline_events: List[FeatureProcessorPipelineEvents], + target_pipeline: str, + event_pattern: str, + state: str, + ) -> str: + """Creates an EventBridge Rule for a given target pipeline. + + Args: + source_pipeline_events: The list of pipeline events that trigger the EventBridge Rule. + target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule. + event_pattern: The EventBridge EventPattern that triggers the EventBridge Rule. + If specified, will override source_pipeline_events. + state: Indicates whether the rule is enabled or disabled. + + Returns: + The Amazon Resource Name (ARN) of the rule. + """ + self._validate_feature_processor_pipeline_events(source_pipeline_events) + rule_name = target_pipeline + _event_patterns = ( + event_pattern + or self._generate_event_pattern_from_feature_processor_pipeline_events( + source_pipeline_events + ) + ) + rule_arn = self.event_bridge_rule_client.put_rule( + Name=rule_name, EventPattern=_event_patterns, State=state + )["RuleArn"] + return rule_arn + + def put_target( + self, + rule_name: str, + target_pipeline: str, + target_pipeline_parameters: Dict[str, str], + role_arn: str, + ) -> None: + """Attach target pipeline to an event based trigger. + + Args: + rule_name: The name of the EventBridge Rule. + target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule. + target_pipeline_parameters: The list of parameters to start execution of a pipeline. + role_arn: The Amazon Resource Name (ARN) of the IAM role associated with the rule. + """ + target_pipeline_arn_and_name = self._generate_pipeline_arn_and_name(target_pipeline) + target_pipeline_name = target_pipeline_arn_and_name["pipeline_name"] + target_pipeline_arn = target_pipeline_arn_and_name["pipeline_arn"] + target_request_dict = { + "Id": target_pipeline_name, + "Arn": target_pipeline_arn, + "RoleArn": role_arn, + } + if target_pipeline_parameters: + target_request_dict["SageMakerPipelineParameters"] = { + "PipelineParameterList": target_pipeline_parameters + } + put_targets_response = self.event_bridge_rule_client.put_targets( + Rule=rule_name, + Targets=[target_request_dict], + ) + if put_targets_response["FailedEntryCount"] != 0: + error_msg = put_targets_response["FailedEntries"][0]["ErrorMessage"] + raise Exception(f"Failed to add target pipeline to rule. Failure reason: {error_msg}") + + def delete_rule(self, rule_name: str) -> None: + """Deletes an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.delete_rule(Name=rule_name) + + def remove_targets(self, rule_name: str, ids: List[str]) -> None: + """Deletes an EventBridge Targets of a given rule if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + ids: The ids of the EventBridge Target. + """ + self.event_bridge_rule_client.remove_targets(Rule=rule_name, Ids=ids) + + def list_targets_by_rule(self, rule_name: str) -> PageIterator: + """List EventBridge Targets of a given rule. + + Args: + rule_name: The name of the EventBridge Rule. + + Returns: + The page iterator of list_targets_by_rule call. + """ + return self.event_bridge_rule_client.get_paginator("list_targets_by_rule").paginate( + Rule=rule_name + ) + + def describe_rule(self, rule_name: str) -> Optional[Dict[str, Any]]: + """Describe the EventBridge Rule ARN corresponding to a sagemaker pipeline + + Args: + rule_name: The name of the EventBridge Rule. + Returns: + Optional[Dict[str, str]] : Describe EventBridge Rule response if exists. + """ + try: + event_bridge_rule_response = self.event_bridge_rule_client.describe_rule(Name=rule_name) + return event_bridge_rule_response + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("No EventBridge Rule found for pipeline %s.", rule_name) + return None + raise e + + def enable_rule(self, rule_name: str) -> None: + """Enables an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.enable_rule(Name=rule_name) + logger.info("Enabled EventBridge Rule for pipeline %s.", rule_name) + + def disable_rule(self, rule_name: str) -> None: + """Disables an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.disable_rule(Name=rule_name) + logger.info("Disabled EventBridge Rule for pipeline %s.", rule_name) + + def add_tags(self, rule_arn: str, tags: List[TagsDict]) -> None: + """Adds tags to the EventBridge Rule. + + Args: + rule_arn: The ARN of the EventBridge Rule. + tags: List of tags to be added. + """ + self.event_bridge_rule_client.tag_resource(ResourceARN=rule_arn, Tags=tags) + + def _generate_event_pattern_from_feature_processor_pipeline_events( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> str: + """Generates the event pattern json string from the pipeline events. + + Args: + pipeline_events: List of pipeline events. + Returns: + str: The event pattern json string. + + Raises: + ValueError: If pipeline events contain duplicate pipeline names. + """ + + result_event_pattern = { + "detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], + } + filters = [] + desired_status_to_pipeline_names_map = ( + self._aggregate_pipeline_events_with_same_desired_status(pipeline_events) + ) + for desired_status in desired_status_to_pipeline_names_map: + pipeline_arns = [ + self._generate_pipeline_arn_and_name(pipeline_name)["pipeline_arn"] + for pipeline_name in desired_status_to_pipeline_names_map[desired_status] + ] + curr_filter = BASE_EVENT_PATTERN.copy() + curr_filter["detail"]["pipelineArn"] = pipeline_arns + curr_filter["detail"]["currentPipelineExecutionStatus"] = [ + status_enum.value for status_enum in desired_status + ] + filters.append(curr_filter) + if len(filters) > 1: + result_event_pattern["$or"] = filters + else: + result_event_pattern.update(filters[0]) + return json.dumps(result_event_pattern) + + def _validate_feature_processor_pipeline_events( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> None: + """Validates the pipeline events. + + Args: + pipeline_events: List of pipeline events. + Raises: + ValueError: If pipeline events contain duplicate pipeline names. + """ + + unique_pipelines = {event.pipeline_name for event in pipeline_events} + potential_infinite_loop = [] + if len(unique_pipelines) != len(pipeline_events): + raise ValueError("Pipeline names in pipeline_events must be unique.") + + for event in pipeline_events: + if FeatureProcessorPipelineExecutionStatus.EXECUTING in event.pipeline_execution_status: + potential_infinite_loop.append(event.pipeline_name) + if potential_infinite_loop: + logger.warning( + "Potential infinite loop detected for pipelines %s. " + "Setting pipeline_execution_status to EXECUTING might cause infinite loop. " + "Please consider a terminal status instead.", + potential_infinite_loop, + ) + + def _aggregate_pipeline_events_with_same_desired_status( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> Dict[Tuple, List[str]]: + """Aggregate pipeline events with same desired status. + + e.g. + { + (FeatureProcessorPipelineExecutionStatus.FAILED, + FeatureProcessorPipelineExecutionStatus.STOPPED): + ["pipeline_name_1", "pipeline_name_2"], + (FeatureProcessorPipelineExecutionStatus.STOPPED, + FeatureProcessorPipelineExecutionStatus.STOPPED): + ["pipeline_name_3"], + } + Args: + pipeline_events: List of pipeline events. + Returns: + Dict[Tuple, List[str]]: A dictionary of desired status keys and corresponding pipeline + names. + """ + events_by_desired_status = {} + + for event in pipeline_events: + sorted_execution_status = sorted(event.pipeline_execution_status, key=lambda x: x.value) + desired_status_keys = tuple(sorted_execution_status) + + if desired_status_keys not in events_by_desired_status: + events_by_desired_status[desired_status_keys] = [] + events_by_desired_status[desired_status_keys].append(event.pipeline_name) + + return events_by_desired_status + + def _generate_pipeline_arn_and_name(self, pipeline_uri: str) -> Dict[str, str]: + """Generate pipeline arn and pipeline name from pipeline uri. + + Args: + pipeline_uri: The name or arn of the pipeline. + Returns: + Dict[str, str]: The arn and name of the pipeline. + """ + match = re.match(PIPELINE_ARN_REGEX_PATTERN, pipeline_uri) + pipeline_arn = "" + pipeline_name = "" + if not match: + pipeline_name = pipeline_uri + describe_pipeline_response = self.sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_arn = describe_pipeline_response["PipelineArn"] + else: + pipeline_arn = pipeline_uri + pipeline_name = match.group(4) + return dict(pipeline_arn=pipeline_arn, pipeline_name=pipeline_name) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py new file mode 100644 index 0000000000..f454a217e2 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py @@ -0,0 +1,118 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for EventBridge Schedule management for a feature processor.""" +from __future__ import absolute_import +import logging +from datetime import datetime +from typing import Dict, Optional, Any +import attr +from botocore.exceptions import ClientError +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EVENT_BRIDGE_INVOCATION_TIME, + NO_FLEXIBLE_TIME_WINDOW, + RESOURCE_NOT_FOUND_EXCEPTION, +) + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EventBridgeSchedulerHelper: + """Contains helper methods for scheduling events to EventBridge""" + + sagemaker_session: Session = attr.ib() + event_bridge_scheduler_client = attr.ib() + + def upsert_schedule( + self, + schedule_name: str, + pipeline_arn: str, + schedule_expression: str, + state: str, + start_date: datetime, + role: str, + ) -> Dict: + """Creates or updates a Schedule for the given pipeline_arn and schedule_expression. + + Args: + schedule_name: The name of the schedule. + pipeline_arn: The ARN of the sagemaker pipeline that needs to scheduled. + schedule_expression: The schedule expression. + state: Specifies whether the schedule is enabled or disabled. Can only + be ENABLED or DISABLED. + start_date: The date, in UTC, after which the schedule can begin invoking its target. + role: The RoleArn used to execute the scheduled events. + + Returns: + schedule_arn: The arn of the schedule. + """ + pipeline_parameter = dict( + PipelineParameterList=[ + dict( + Name=EXECUTION_TIME_PIPELINE_PARAMETER, + Value=EVENT_BRIDGE_INVOCATION_TIME, + ) + ] + ) + create_or_update_schedule_request_dict = dict( + Name=schedule_name, + ScheduleExpression=schedule_expression, + FlexibleTimeWindow=NO_FLEXIBLE_TIME_WINDOW, + Target=dict( + Arn=pipeline_arn, + SageMakerPipelineParameters=pipeline_parameter, + RoleArn=role, + ), + State=state, + StartDate=start_date, + ) + try: + return self.event_bridge_scheduler_client.update_schedule( + **create_or_update_schedule_request_dict + ) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + return self.event_bridge_scheduler_client.create_schedule( + **create_or_update_schedule_request_dict + ) + raise e + + def delete_schedule(self, schedule_name: str) -> None: + """Deletes an EventBridge Schedule of a given pipeline if there is one. + + Args: + schedule_name: The name of the EventBridge Schedule. + """ + logger.info("Deleting EventBridge Schedule for pipeline %s.", schedule_name) + self.event_bridge_scheduler_client.delete_schedule(Name=schedule_name) + + def describe_schedule(self, schedule_name) -> Optional[Dict[str, Any]]: + """Describe the EventBridge Schedule ARN corresponding to a sagemaker pipeline + + Args: + schedule_name: The name of the EventBridge Schedule. + Returns: + Optional[Dict[str, str]] : Describe EventBridge Schedule response if exists. + """ + try: + event_bridge_scheduler_response = self.event_bridge_scheduler_client.get_schedule( + Name=schedule_name + ) + return event_bridge_scheduler_response + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("No EventBridge Schedule found for pipeline %s.", schedule_name) + return None + raise e diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py new file mode 100644 index 0000000000..0b21d10ab9 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores exceptions related to the feature_store.feature_processor module.""" +from __future__ import absolute_import + + +class IngestionError(Exception): + """Exception raised to indicate that ingestion did not complete successfully.""" diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py new file mode 100644 index 0000000000..f205c32665 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py @@ -0,0 +1,167 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains static factory classes to instantiate complex objects for the FeatureProcessor.""" +from __future__ import absolute_import + +from typing import Dict +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._params_loader import ( + ParamsLoader, + SystemParamsLoader, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, + SparkSessionFactory, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import SparkArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + SparkOutputReceiver, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ( + FeatureProcessorArgValidator, + InputValidator, + SparkUDFSignatureValidator, + InputOffsetValidator, + BaseDataSourceValidator, + ValidatorChain, +) + + +class ValidatorFactory: + """Static factory to handle ValidationChain instantiation.""" + + @staticmethod + def get_validation_chain(fp_config: FeatureProcessorConfig) -> ValidatorChain: + """Instantiate a ValidationChain""" + base_validators = [ + InputValidator(), + FeatureProcessorArgValidator(), + InputOffsetValidator(), + BaseDataSourceValidator(), + ] + + mode = fp_config.mode + if FeatureProcessorMode.PYSPARK == mode: + base_validators.append(SparkUDFSignatureValidator()) + return ValidatorChain(validators=base_validators) + + raise ValueError(f"FeatureProcessorMode {mode} is not supported.") + + +class UDFWrapperFactory: + """Static factory to handle UDFWrapper instantiation at runtime.""" + + @staticmethod + def get_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper: + """Instantiate a UDFWrapper based on the FeatureProcessingMode. + + Args: + fp_config (FeatureProcessorConfig): the configuration values for the + feature_processor decorator. + + Raises: + ValueError: if an unsupported FeatureProcessorMode is provided in fp_config. + + Returns: + UDFWrapper: An instance of UDFWrapper to decorate the UDF. + """ + mode = fp_config.mode + + if FeatureProcessorMode.PYSPARK == mode: + return UDFWrapperFactory._get_spark_udf_wrapper(fp_config) + + raise ValueError(f"FeatureProcessorMode {mode} is not supported.") + + @staticmethod + def _get_spark_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper[DataFrame]: + """Instantiate a new UDFWrapper for PySpark functions. + + Args: + fp_config (FeatureProcessorConfig): the configuration values for the feature_processor + decorator. + """ + spark_session_factory = UDFWrapperFactory._get_spark_session_factory(fp_config.spark_config) + feature_store_manager_factory = UDFWrapperFactory._get_feature_store_manager_factory() + + output_manager = UDFWrapperFactory._get_spark_output_receiver(feature_store_manager_factory) + arg_provider = UDFWrapperFactory._get_spark_arg_provider(spark_session_factory) + + return UDFWrapper[DataFrame](arg_provider, output_manager) + + @staticmethod + def _get_spark_arg_provider( + spark_session_factory: SparkSessionFactory, + ) -> SparkArgProvider: + """Instantiate a new SparkArgProvider for PySpark functions. + + Args: + spark_session_factory (SparkSessionFactory): A factory to provide a reference to the + SparkSession initialized for the feature_processor wrapped function. The factory + lazily loads the SparkSession, i.e. defers to function execution time. + + Returns: + SparkArgProvider: An instance that generates arguments to provide to the + feature_processor wrapped function. + """ + environment_helper = EnvironmentHelper() + + system_parameters_arg_provider = SystemParamsLoader(environment_helper) + params_arg_provider = ParamsLoader(system_parameters_arg_provider) + input_loader = SparkDataFrameInputLoader(spark_session_factory, environment_helper) + + return SparkArgProvider(params_arg_provider, input_loader, spark_session_factory) + + @staticmethod + def _get_spark_output_receiver( + feature_store_manager_factory: FeatureStoreManagerFactory, + ) -> SparkOutputReceiver: + """Instantiate a new SparkOutputManager for PySpark functions. + + Args: + feature_store_manager_factory (FeatureStoreManagerFactory): A factory to provide + that provides a FeatureStoreManager that handles data ingestion to a Feature Group. + The factory lazily loads the FeatureStoreManager. + + Returns: + SparkOutputReceiver: An instance that handles outputs of the wrapped function. + """ + return SparkOutputReceiver(feature_store_manager_factory) + + @staticmethod + def _get_spark_session_factory(spark_config: Dict[str, str]) -> SparkSessionFactory: + """Instantiate a new SparkSessionFactory + + Args: + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. + + Returns: + SparkSessionFactory: A Spark session factory instance. + """ + environment_helper = EnvironmentHelper() + return SparkSessionFactory(environment_helper, spark_config) + + @staticmethod + def _get_feature_store_manager_factory() -> FeatureStoreManagerFactory: + """Instantiate a new FeatureStoreManagerFactory""" + return FeatureStoreManagerFactory() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py new file mode 100644 index 0000000000..f5d4dd91f3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains data classes for the FeatureProcessor.""" +from __future__ import absolute_import + +from typing import Dict, List, Optional, Sequence, Union + +import attr + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode + + +@attr.s(frozen=True) +class FeatureProcessorConfig: + """Immutable data class containing the arguments for a FeatureProcessor. + + This class is used throughout sagemaker.mlops.feature_store.feature_processor module. Documentation + for each field can be be found in the feature_processor decorator. + + Defaults are defined as literals in the feature_processor decorator's parameters for usability + (i.e. literals in docs). Defaults, or any business logic, should not be added to this class. + It only serves as an immutable data class. + """ + + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib() + output: str = attr.ib() + mode: FeatureProcessorMode = attr.ib() + target_stores: Optional[List[str]] = attr.ib() + parameters: Optional[Dict[str, Union[str, Dict]]] = attr.ib() + enable_ingestion: bool = attr.ib() + spark_config: Dict[str, str] = attr.ib() + + @staticmethod + def create( + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], + output: str, + mode: FeatureProcessorMode, + target_stores: Optional[List[str]], + parameters: Optional[Dict[str, Union[str, Dict]]], + enable_ingestion: bool, + spark_config: Dict[str, str], + ) -> "FeatureProcessorConfig": + """Static initializer.""" + return FeatureProcessorConfig( + inputs=inputs, + output=output, + mode=mode, + target_stores=target_stores, + parameters=parameters, + enable_ingestion=enable_ingestion, + spark_config=spark_config, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py new file mode 100644 index 0000000000..4ce9fb1b76 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains data classes for the Feature Processor Pipeline Events.""" +from __future__ import absolute_import + +from typing import List +import attr +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorPipelineExecutionStatus + + +@attr.s(frozen=True) +class FeatureProcessorPipelineEvents: + """Immutable data class containing the execution events for a FeatureProcessor pipeline. + + This class is used for creating event based triggers for feature processor pipelines. + """ + + pipeline_name: str = attr.ib() + pipeline_execution_status: List[FeatureProcessorPipelineExecutionStatus] = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py new file mode 100644 index 0000000000..627de943c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py @@ -0,0 +1,366 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes that loads user specified input sources (e.g. Feature Groups, S3 URIs, etc).""" +from __future__ import absolute_import + +import logging +import re +from abc import ABC, abstractmethod +from typing import Generic, Optional, TypeVar, Union + +import attr +from pyspark.sql import DataFrame + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import FEATURE_GROUP_ARN_REGEX_PATTERN +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + IcebergTableDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +T = TypeVar("T") + +logger = logging.getLogger("sagemaker") + + +class InputLoader(Generic[T], ABC): + """Loads the contents of a Feature Group's offline store or contents at an S3 URI.""" + + @abstractmethod + def load_from_feature_group(self, feature_group_data_source: FeatureGroupDataSource) -> T: + """Load the data from a Feature Group's offline store. + + Args: + feature_group_data_source (FeatureGroupDataSource): the feature group source. + + Returns: + T: The contents of the offline store as an instance of type T. + """ + + @abstractmethod + def load_from_s3(self, s3_data_source: Union[CSVDataSource, ParquetDataSource]) -> T: + """Load the contents from an S3 based data source. + + Args: + s3_data_source (Union[CSVDataSource, ParquetDataSource]): a data source that is based + in S3. + + Returns: + T: The contents stored at the data source as an instance of type T. + """ + + +@attr.s +class SparkDataFrameInputLoader(InputLoader[DataFrame]): + """InputLoader that reads data in as a Spark DataFrame.""" + + spark_session_factory: SparkSessionFactory = attr.ib() + environment_helper: EnvironmentHelper = attr.ib() + sagemaker_session: Optional[Session] = attr.ib(default=None) + + _supported_table_format = ["Iceberg", "Glue", None] + + def load_from_feature_group( + self, feature_group_data_source: FeatureGroupDataSource + ) -> DataFrame: + """Load the contents of a Feature Group's offline store as a DataFrame. + + Args: + feature_group_data_source (FeatureGroupDataSource): the Feature Group source. + + Raises: + ValueError: If the Feature Group does not have an Offline Store. + ValueError: If the Feature Group's Table Type is not supported by the feature_processor. + + Returns: + DataFrame: A Spark DataFrame containing the contents of the Feature Group's + offline store. + """ + sagemaker_session: Session = self.sagemaker_session or Session() + + feature_group_name = feature_group_data_source.name + feature_group = sagemaker_session.describe_feature_group( + self._parse_name_from_arn(feature_group_name) + ) + logger.debug( + "Called describe_feature_group with %s and received: %s", + feature_group_name, + feature_group, + ) + + if "OfflineStoreConfig" not in feature_group: + raise ValueError( + f"Input Feature Groups must have an enabled Offline Store." + f" Feature Group: {feature_group_name} does not have an Offline Store enabled." + ) + + offline_store_uri = feature_group["OfflineStoreConfig"]["S3StorageConfig"][ + "ResolvedOutputS3Uri" + ] + + table_format = feature_group["OfflineStoreConfig"].get("TableFormat", None) + + if table_format not in self._supported_table_format: + raise ValueError( + f"Feature group with table format {table_format} is not supported. " + f"The table format should be one of {self._supported_table_format}." + ) + + start_offset = feature_group_data_source.input_start_offset + end_offset = feature_group_data_source.input_end_offset + + if table_format == "Iceberg": + data_catalog_config = feature_group["OfflineStoreConfig"]["DataCatalogConfig"] + return self.load_from_iceberg_table( + IcebergTableDataSource( + offline_store_uri, + data_catalog_config["Catalog"], + data_catalog_config["Database"], + data_catalog_config["TableName"], + ), + feature_group["EventTimeFeatureName"], + start_offset, + end_offset, + ) + + return self.load_from_date_partitioned_s3( + ParquetDataSource(offline_store_uri), start_offset, end_offset + ) + + def load_from_date_partitioned_s3( + self, + s3_data_source: ParquetDataSource, + input_start_offset: str, + input_end_offset: str, + ) -> DataFrame: + """Load the contents from a Feature Group's partitioned offline S3 as a DataFrame. + + Args: + s3_data_source (ParquetDataSource): + A data source that is based in S3. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the data loaded from S3. + """ + + spark_session = self.spark_session_factory.spark_session + s3a_uri = s3_data_source.s3_uri.replace("s3://", "s3a://") + filter_condition = self._get_s3_partitions_offset_filter_condition( + input_start_offset, input_end_offset + ) + + logger.info( + "Loading data from %s with filtering condition %s.", + s3a_uri, + filter_condition, + ) + input_df = spark_session.read.parquet(s3a_uri) + if filter_condition: + input_df = input_df.filter(filter_condition) + + return input_df + + def load_from_s3(self, s3_data_source: Union[CSVDataSource, ParquetDataSource]) -> DataFrame: + """Load the contents from an S3 based data source as a DataFrame. + + Args: + s3_data_source (Union[CSVDataSource, ParquetDataSource]): + A data source that is based in S3. + + Raises: + ValueError: If an invalid DataSource is provided. + + Returns: + DataFrame: Contents of the data loaded from S3. + """ + spark_session = self.spark_session_factory.spark_session + s3a_uri = s3_data_source.s3_uri.replace("s3://", "s3a://") + + if isinstance(s3_data_source, CSVDataSource): + # TODO: Accept `schema` parameter. (Inferring schema requires a pass through every row) + logger.info("Loading data from %s.", s3a_uri) + return spark_session.read.csv( + s3a_uri, + header=s3_data_source.csv_header, + inferSchema=s3_data_source.csv_infer_schema, + ) + + if isinstance(s3_data_source, ParquetDataSource): + logger.info("Loading data from %s.", s3a_uri) + return spark_session.read.parquet(s3a_uri) + + raise ValueError("An invalid data source was provided.") + + def load_from_iceberg_table( + self, + iceberg_table_data_source: IcebergTableDataSource, + event_time_feature_name: str, + input_start_offset: str, + input_end_offset: str, + ) -> DataFrame: + """Load the contents from an Iceberg table as a DataFrame. + + Args: + iceberg_table_data_source (IcebergTableDataSource): An Iceberg Table source. + event_time_feature_name (str): Event time feature's name of feature group. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the Iceberg Table as a Spark DataFrame. + """ + catalog = iceberg_table_data_source.catalog.lower() + database = iceberg_table_data_source.database.lower() + table = iceberg_table_data_source.table.lower() + iceberg_table = f"{catalog}.{database}.{table}" + + spark_session = self.spark_session_factory.get_spark_session_with_iceberg_config( + iceberg_table_data_source.warehouse_s3_uri, catalog + ) + + filter_condition = self._get_iceberg_offset_filter_condition( + event_time_feature_name, + input_start_offset, + input_end_offset, + ) + + iceberg_df = spark_session.table(iceberg_table) + + if filter_condition: + logger.info( + "The filter condition for iceberg feature group is %s.", + filter_condition, + ) + iceberg_df = iceberg_df.filter(filter_condition) + + return iceberg_df + + def _get_iceberg_offset_filter_condition( + self, + event_time_feature_name: str, + input_start_offset: str, + input_end_offset: str, + ): + """Load the contents from an Iceberg table as a DataFrame. + + Args: + iceberg_table_data_source (IcebergTableDataSource): An Iceberg Table source. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the Iceberg Table as a Spark DataFrame. + """ + if input_start_offset is None and input_end_offset is None: + return None + + offset_parser = InputOffsetParser(self.environment_helper.get_job_scheduled_time()) + start_offset_time = offset_parser.get_iso_format_offset_date(input_start_offset) + end_offset_time = offset_parser.get_iso_format_offset_date(input_end_offset) + + start_condition = ( + f"{event_time_feature_name} >= '{start_offset_time}'" if input_start_offset else None + ) + end_condition = ( + f"{event_time_feature_name} < '{end_offset_time}'" if input_end_offset else None + ) + + conditions = filter(None, [start_condition, end_condition]) + return " AND ".join(conditions) + + def _get_s3_partitions_offset_filter_condition( + self, input_start_offset: str, input_end_offset: str + ) -> str: + """Get s3 partitions filter condition based on input offsets. + + Args: + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + str: A SQL string that defines the condition of time range filter. + """ + if input_start_offset is None and input_end_offset is None: + return None + + offset_parser = InputOffsetParser(self.environment_helper.get_job_scheduled_time()) + ( + start_year, + start_month, + start_day, + start_hour, + ) = offset_parser.get_offset_date_year_month_day_hour(input_start_offset) + ( + end_year, + end_month, + end_day, + end_hour, + ) = offset_parser.get_offset_date_year_month_day_hour(input_end_offset) + + # Include all records that event time is between start_year and end_year + start_year_include_condition = f"year >= '{start_year}'" if input_start_offset else None + end_year_include_condition = f"year <= '{end_year}'" if input_end_offset else None + year_include_condition = " AND ".join( + filter(None, [start_year_include_condition, end_year_include_condition]) + ) + + # Exclude all records that the event time is earlier than the start or later than the end + start_offset_exclude_condition = ( + f"(year = '{start_year}' AND month < '{start_month}') " + f"OR (year = '{start_year}' AND month = '{start_month}' AND day < '{start_day}') " + f"OR (year = '{start_year}' AND month = '{start_month}' AND day = '{start_day}' " + f"AND hour < '{start_hour}')" + if input_start_offset + else None + ) + end_offset_exclude_condition = ( + f"(year = '{end_year}' AND month > '{end_month}') " + f"OR (year = '{end_year}' AND month = '{end_month}' AND day > '{end_day}') " + f"OR (year = '{end_year}' AND month = '{end_month}' AND day = '{end_day}' " + f"AND hour >= '{end_hour}')" + if input_end_offset + else None + ) + offset_exclude_condition = " OR ".join( + filter(None, [start_offset_exclude_condition, end_offset_exclude_condition]) + ) + + filter_condition = f"({year_include_condition}) AND NOT ({offset_exclude_condition})" + + logger.info("The filter condition for hive feature group is %s.", filter_condition) + + return filter_condition + + def _parse_name_from_arn(self, fg_uri: str) -> str: + """Parse a Feature Group's name from an arn. + + Args: + fg_uri (str): a string identifier of the Feature Group. + + Returns: + str: the name of the feature group. + """ + match = re.match(FEATURE_GROUP_ARN_REGEX_PATTERN, fg_uri) + if match: + feature_group_name = match.group(4) + return feature_group_name + return fg_uri diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py new file mode 100644 index 0000000000..89d816af49 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class that parse the input data start and end offset""" +from __future__ import absolute_import + +import re +from typing import Optional, Tuple, Union +from datetime import datetime, timezone +from dateutil.relativedelta import relativedelta +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) + +UNIT_RE = r"(\d+?)\s+([a-z]+?)s?" +VALID_UNITS = ["hour", "day", "week", "month", "year"] + + +class InputOffsetParser: + """Contains methods to parse the input offset to different formats. + + Args: + now (datetime): + The point of time that the parser should calculate offset against. + """ + + def __init__(self, now: Union[datetime, str] = None) -> None: + if now is None: + self.now = datetime.now(timezone.utc) + elif isinstance(now, datetime): + self.now = now + else: + self.now = datetime.strptime(now, EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + + def get_iso_format_offset_date(self, offset: Optional[str]) -> str: + """Get the iso format of target date based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + str: ISO-8061 formatted string of the offset date. + """ + if offset is None: + return None + + offset_datetime = self.get_offset_datetime(offset) + return offset_datetime.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + + def get_offset_datetime(self, offset: Optional[str]) -> datetime: + """Get the datetime format of target date based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + datetime: datetime instance of the offset date. + """ + if offset is None: + return None + + offset_td = InputOffsetParser.parse_offset_to_timedelta(offset) + + return self.now + offset_td + + def get_offset_date_year_month_day_hour( + self, offset: Optional[str] + ) -> Tuple[str, str, str, str]: + """Get the year, month, day and hour based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + Tuple[str, str, str, str]: A tuple that consists of extracted year, month, day, hour from offset date. + """ + if offset is None: + return (None, None, None, None) + + offset_dt = self.get_offset_datetime(offset) + return ( + offset_dt.strftime("%Y"), + offset_dt.strftime("%m"), + offset_dt.strftime("%d"), + offset_dt.strftime("%H"), + ) + + @staticmethod + def parse_offset_to_timedelta(offset: Optional[str]) -> relativedelta: + """Parse the offset to time delta. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Raises: + ValueError: If an offset is provided in a unrecognizable format. + ValueError: If an invalid offset unit is provided. + + Returns: + reletivedelta: Time delta representation of the time offset. + """ + if offset is None: + return None + + unit_match = re.fullmatch(UNIT_RE, offset) + + if not unit_match: + raise ValueError( + f"[{offset}] is not in a valid offset format. " + "Please pass a valid offset e.g '1 day'." + ) + + multiple, unit = unit_match.groups() + + if unit not in VALID_UNITS: + raise ValueError(f"[{unit}] is not a valid offset unit. Supported units: {VALID_UNITS}") + + shift_args = {f"{unit}s": -int(multiple)} + + return relativedelta(**shift_args) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py new file mode 100644 index 0000000000..f5be546e86 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py @@ -0,0 +1,83 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for loading the 'params' argument for the UDF.""" +from __future__ import absolute_import + +from typing import Dict, Union + +import attr + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +@attr.s +class SystemParamsLoader: + """Provides the fields for the params['system'] namespace. + + These are the parameters that the feature_processor automatically loads from various SageMaker + resources. + """ + + _SYSTEM_PARAMS_KEY = "system" + + environment_helper: EnvironmentHelper = attr.ib() + + def get_system_args(self) -> Dict[str, Union[str, Dict]]: + """Generates the system generated parameters for the feature_processor wrapped function. + + Args: + fp_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + + Returns: + Dict[str, Union[str, Dict]]: The system parameters. + """ + + return { + self._SYSTEM_PARAMS_KEY: { + "scheduled_time": self.environment_helper.get_job_scheduled_time(), + } + } + + +@attr.s +class ParamsLoader: + """Provides 'params' argument for the FeatureProcessor.""" + + _PARAMS_KEY = "params" + + system_parameters_arg_provider: SystemParamsLoader = attr.ib() + + def get_parameter_args( + self, + fp_config: FeatureProcessorConfig, + ) -> Dict[str, Union[str, Dict]]: + """Loads the 'params' argument for the FeatureProcessor. + + Args: + fp_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + + Returns: + Dict[str, Union[str, Dict]]: A dictionary that contains both user provided + parameters (feature_processor argument) and system parameters. + """ + return { + self._PARAMS_KEY: { + **(fp_config.parameters or {}), + **self.system_parameters_arg_provider.get_system_args(), + } + } diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py new file mode 100644 index 0000000000..0d6d41506e --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py @@ -0,0 +1,202 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains factory classes for instantiating Spark objects.""" +from __future__ import absolute_import + +from functools import lru_cache +from typing import List, Tuple, Dict + +import feature_store_pyspark +import feature_store_pyspark.FeatureStoreManager as fsm +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +from pyspark.sql import SparkSession + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +SPARK_APP_NAME = "FeatureProcessor" + + +class SparkSessionFactory: + """Lazy loading, memoizing, instantiation of SparkSessions. + + Useful when you want to defer SparkSession instantiation and provide access to the same + instance throughout the application. + """ + + def __init__( + self, environment_helper: EnvironmentHelper, spark_config: Dict[str, str] = None + ) -> None: + """Initialize the SparkSessionFactory. + + Args: + environment_helper (EnvironmentHelper): A helper class to determine the current + execution. + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. + """ + self.environment_helper = environment_helper + self.spark_config = spark_config + + @property + @lru_cache() + def spark_session(self) -> SparkSession: + """Instantiate a new SparkSession or return the existing one.""" + is_training_job = self.environment_helper.is_training_job() + instance_count = self.environment_helper.get_instance_count() + + spark_configs = self._get_spark_configs(is_training_job) + spark_conf = SparkConf().setAll(spark_configs).setAppName(SPARK_APP_NAME) + + if instance_count == 1: + spark_conf.setMaster("local[*]") + + sc = SparkContext.getOrCreate(conf=spark_conf) + + jsc = sc._jsc # Java Spark Context (JVM SparkContext) + for cfg in self._get_jsc_hadoop_configs(): + jsc.hadoopConfiguration().set(cfg[0], cfg[1]) + + return SparkSession(sparkContext=sc) + + def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]: + """Generate Spark Configurations optimized for feature_processing functionality. + + Args: + is_training_job (bool): a boolean indicating whether the current execution environment + is a training job or not. + + Returns: + List[Tuple[str, str]]: Spark configurations. + """ + spark_configs = [ + ( + "spark.hadoop.fs.s3a.aws.credentials.provider", + ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ), + ), + # spark-3.3.1#recommended-settings-for-writing-to-object-stores - https://tinyurl.com/54rkhef6 + ("spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version", "2"), + ( + "spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored", + "true", + ), + ("spark.hadoop.parquet.enable.summary-metadata", "false"), + # spark-3.3.1#parquet-io-settings https://tinyurl.com/59a7uhwu + ("spark.sql.parquet.mergeSchema", "false"), + ("spark.sql.parquet.filterPushdown", "true"), + ("spark.sql.hive.metastorePartitionPruning", "true"), + # hadoop-aws#performance - https://tinyurl.com/mutxj96f + ("spark.hadoop.fs.s3a.threads.max", "500"), + ("spark.hadoop.fs.s3a.connection.maximum", "500"), + ("spark.hadoop.fs.s3a.experimental.input.fadvise", "normal"), + ("spark.hadoop.fs.s3a.block.size", "128M"), + ("spark.hadoop.fs.s3a.fast.upload.buffer", "disk"), + ("spark.hadoop.fs.trash.interval", "0"), + ("spark.port.maxRetries", "50"), + ] + + if self.spark_config: + spark_configs.extend(self.spark_config.items()) + + if not is_training_job: + fp_spark_jars = feature_store_pyspark.classpath_jars() + fp_spark_packages = [ + "org.apache.hadoop:hadoop-aws:3.3.1", + "org.apache.hadoop:hadoop-common:3.3.1", + ] + + if self.spark_config and "spark.jars" in self.spark_config: + fp_spark_jars.append(self.spark_config.get("spark.jars")) + + if self.spark_config and "spark.jars.packages" in self.spark_config: + fp_spark_packages.append(self.spark_config.get("spark.jars.packages")) + + spark_configs.extend( + ( + ("spark.jars", ",".join(fp_spark_jars)), + ( + "spark.jars.packages", + ",".join(fp_spark_packages), + ), + ) + ) + + return spark_configs + + def _get_jsc_hadoop_configs(self) -> List[Tuple[str, str]]: + """JVM SparkContext Hadoop configurations.""" + # Skip generation of _SUCCESS files to speed up writes. + return [("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")] + + def _get_iceberg_configs(self, warehouse_s3_uri: str, catalog: str) -> List[Tuple[str, str]]: + """Spark configurations for reading and writing data from Iceberg Table sources. + + Args: + warehouse_s3_uri (str): The S3 URI of the warehouse. + catalog (str): The catalog. + + Returns: + List[Tuple[str, str]]: the Spark configurations. + """ + catalog = catalog.lower() + return [ + (f"spark.sql.catalog.{catalog}", "smfs.shaded.org.apache.iceberg.spark.SparkCatalog"), + (f"spark.sql.catalog.{catalog}.warehouse", warehouse_s3_uri), + ( + f"spark.sql.catalog.{catalog}.catalog-impl", + "smfs.shaded.org.apache.iceberg.aws.glue.GlueCatalog", + ), + ( + f"spark.sql.catalog.{catalog}.io-impl", + "smfs.shaded.org.apache.iceberg.aws.s3.S3FileIO", + ), + (f"spark.sql.catalog.{catalog}.glue.skip-name-validation", "true"), + ] + + def get_spark_session_with_iceberg_config(self, warehouse_s3_uri, catalog) -> SparkSession: + """Get an instance of the SparkSession with Iceberg settings configured. + + Args: + warehouse_s3_uri (str): The S3 URI of the warehouse. + catalog (str): The catalog. + + Returns: + SparkSession: A SparkSession ready to support reading and writing data from an Iceberg + Table. + """ + conf = self.spark_session._jvm.SparkSession().conf() + + for cfg in self._get_iceberg_configs(warehouse_s3_uri, catalog): + conf.set(cfg[0], cfg[1]) + + return self.spark_session + + +class FeatureStoreManagerFactory: + """Lazy loading, memoizing, instantiation of FeatureStoreManagers. + + Useful when you want to defer FeatureStoreManagers instantiation and provide access to the same + instance throughout the application. + """ + + @property + @lru_cache() + def feature_store_manager(self) -> fsm.FeatureStoreManager: + """Instansiate a new FeatureStoreManager.""" + return fsm.FeatureStoreManager() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py new file mode 100644 index 0000000000..bd21e804eb --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py @@ -0,0 +1,239 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for loading arguments for the parameters defined in the UDF.""" +from __future__ import absolute_import + +from abc import ABC, abstractmethod +from inspect import signature +from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union, Optional + +import attr +from pyspark.sql import DataFrame, SparkSession + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, + PySparkDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._params_loader import ParamsLoader +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory + +T = TypeVar("T") + + +@attr.s +class UDFArgProvider(Generic[T], ABC): + """Base class for arguments providers for the UDF. + + Args: + Generic (T): The type of the auto-loaded data values. + """ + + @abstractmethod + def provide_input_args( + self, udf: Callable[..., T], fp_config: FeatureProcessorConfig + ) -> OrderedDict[str, T]: + """Provides a dict of (input name, auto-loaded data) using the feature_processor parameters. + + The input name is the udfs parameter name, and the data source is the one defined at the + same index (as the input name) in fp_config.inputs. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + OrderedDict[str, T]: The loaded data sources, in the same order as fp_config.inputs. + """ + + @abstractmethod + def provide_params_arg( + self, udf: Callable[..., T], fp_config: FeatureProcessorConfig + ) -> Dict[str, Dict]: + """Provides the 'params' argument that is provided to the UDF. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Dict[str, Dict]: A combination of user defined parameters (in fp_config) and system + provided parameters. + """ + + @abstractmethod + def provide_additional_kwargs(self, udf: Callable[..., T]) -> Dict[str, Any]: + """Provides any additional arguments to be provided to the UDF, dependent on the mode. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + + Returns: + Dict[str, Any]: additional kwargs for the user function. + """ + + +@attr.s +class SparkArgProvider(UDFArgProvider[DataFrame]): + """Provides arguments to Spark UDFs.""" + + PARAMS_ARG_NAME = "params" + SPARK_SESSION_ARG_NAME = "spark" + + params_loader: ParamsLoader = attr.ib() + input_loader: SparkDataFrameInputLoader = attr.ib() + spark_session_factory: SparkSessionFactory = attr.ib() + + def provide_input_args( + self, udf: Callable[..., DataFrame], fp_config: FeatureProcessorConfig + ) -> OrderedDict[str, DataFrame]: + """Provide a DataFrame for each requested input. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If the signature of the UDF does not match the fp_config.inputs. + ValueError: If there are no inputs provided to the user defined function. + + Returns: + OrderedDict[str, DataFrame]: The loaded data sources, in the same order as + fp_config.inputs. + """ + udf_parameter_names = list(signature(udf).parameters.keys()) + udf_input_names = self._get_input_parameters(udf_parameter_names) + udf_params = self.params_loader.get_parameter_args(fp_config).get( + self.PARAMS_ARG_NAME, None + ) + + if len(udf_input_names) == 0: + raise ValueError("Expected at least one input to the user defined function.") + + if len(udf_input_names) != len(fp_config.inputs): + raise ValueError( + f"The signature of the user defined function does not match the list of inputs" + f" requested. Expected {len(fp_config.inputs)} parameter(s)." + ) + + return OrderedDict( + (input_name, self._load_data_frame(data_source=input_uri, params=udf_params)) + for (input_name, input_uri) in zip(udf_input_names, fp_config.inputs) + ) + + def provide_params_arg( + self, udf: Callable[..., DataFrame], fp_config: FeatureProcessorConfig + ) -> Dict[str, Union[str, Dict]]: + """Provide params for the UDF. If the udf has a parameter named 'params'. + + Args: + udf (Callable[..., T]): the feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + return ( + self.params_loader.get_parameter_args(fp_config) + if self._has_param(udf, self.PARAMS_ARG_NAME) + else {} + ) + + def provide_additional_kwargs(self, udf: Callable[..., DataFrame]) -> Dict[str, SparkSession]: + """Provide the Spark session. If the udf has a parameter named 'spark'. + + Args: + udf (Callable[..., T]): the feature_processor wrapped user function. + """ + return ( + {self.SPARK_SESSION_ARG_NAME: self.spark_session_factory.spark_session} + if self._has_param(udf, self.SPARK_SESSION_ARG_NAME) + else {} + ) + + def _get_input_parameters(self, udf_parameter_names: List[str]) -> List[str]: + """Parses the parameter names from the UDF that correspond to the input data sources. + + This function assumes that the udf signature's `params` and `spark` parameters are at the + end, in any order, if provided. + + Args: + udf_parameter_names (List[str]): The full list of parameters names in the UDF. + + Returns: + List[str]: A subset of parameter names corresponding to the input data sources. + """ + inputs_end_index = len(udf_parameter_names) - 1 + + # Reduce range based on the position of optional kwargs of the UDF. + if self.PARAMS_ARG_NAME in udf_parameter_names: + inputs_end_index = udf_parameter_names.index(self.PARAMS_ARG_NAME) - 1 + + if self.SPARK_SESSION_ARG_NAME in udf_parameter_names: + inputs_end_index = min( + inputs_end_index, + udf_parameter_names.index(self.SPARK_SESSION_ARG_NAME) - 1, + ) + + return udf_parameter_names[: inputs_end_index + 1] + + def _load_data_frame( + self, + data_source: Union[ + FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource + ], + params: Optional[Dict[str, Union[str, Dict]]] = None, + ) -> DataFrame: + """Given a data source definition, load the data as a Spark DataFrame. + + Args: + data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]): A user specified data source from the feature_processor + decorator's parameters. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. + + Returns: + DataFrame: The contents of the data source as a Spark DataFrame. + """ + if isinstance(data_source, (CSVDataSource, ParquetDataSource)): + return self.input_loader.load_from_s3(data_source) + + if isinstance(data_source, FeatureGroupDataSource): + return self.input_loader.load_from_feature_group(data_source) + + if isinstance(data_source, PySparkDataSource): + spark_session = self.spark_session_factory.spark_session + return data_source.read_data(spark=spark_session, params=params) + + if isinstance(data_source, BaseDataSource): + return data_source.read_data(params=params) + + raise ValueError(f"Unknown data source type: {type(data_source)}") + + def _has_param(self, udf: Callable, name: str) -> bool: + """Determine if a function has a parameter with a given name. + + Args: + udf (Callable): the user defined function. + name (str): the name of the parameter. + + Returns: + bool: True if the udf contains a parameter with the name. + """ + return name in list(signature(udf).parameters.keys()) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py new file mode 100644 index 0000000000..a037e837c2 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py @@ -0,0 +1,98 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for handling UDF outputs""" +from __future__ import absolute_import + +import logging +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +import attr +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor import IngestionError +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, +) + +T = TypeVar("T") + +logger = logging.getLogger("sagemaker") + + +class UDFOutputReceiver(Generic[T], ABC): + """Base class for handling outputs of the UDF.""" + + @abstractmethod + def ingest_udf_output(self, output: T, fp_config: FeatureProcessorConfig) -> None: + """Ingests data to the output feature group. + + Args: + output (T): The output of the feature_processor wrapped function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + + +@attr.s +class SparkOutputReceiver(UDFOutputReceiver[DataFrame]): + """Handles the Spark DataFrame the output from the UDF""" + + feature_store_manager_factory: FeatureStoreManagerFactory = attr.ib() + + def ingest_udf_output(self, output: DataFrame, fp_config: FeatureProcessorConfig) -> None: + """Ingests UDF to the output Feature Group. + + Args: + output (T): The output of the feature_processor wrapped function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + Py4JError: If there is a problem with Py4J, including client code errors. + IngestionError: If any rows are not ingested successfully then a sample of the records, + with failure reasons, is logged. + """ + if fp_config.enable_ingestion is False: + logging.info("Ingestion is disabled. Skipping ingestion.") + return + + logger.info( + "Ingesting transformed data to %s with target_stores: %s", + fp_config.output, + fp_config.target_stores, + ) + + feature_store_manager = self.feature_store_manager_factory.feature_store_manager + try: + feature_store_manager.ingest_data( + input_data_frame=output, + feature_group_arn=fp_config.output, + target_stores=fp_config.target_stores, + ) + except Py4JJavaError as e: + if e.java_exception.getClass().getSimpleName() == "StreamIngestionFailureException": + logger.warning( + "Ingestion did not complete successfully. Failed records and error messages" + " have been printed to the console." + ) + feature_store_manager.get_failed_stream_ingestion_data_frame().show( + n=20, truncate=False + ) + raise IngestionError(e.java_exception) + + raise e + + logger.info("Ingestion to %s complete.", fp_config.output) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py new file mode 100644 index 0000000000..95b07de7c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides a wrapper for user provided functions.""" +from __future__ import absolute_import + +import functools +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar + +import attr + +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import UDFArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + UDFOutputReceiver, +) + +T = TypeVar("T") + + +@attr.s +class UDFWrapper(Generic[T]): + """Class that wraps a user provided function.""" + + udf_arg_provider: UDFArgProvider[T] = attr.ib() + udf_output_receiver: UDFOutputReceiver[T] = attr.ib() + + def wrap(self, udf: Callable[..., T], fp_config: FeatureProcessorConfig) -> Callable[..., None]: + """Wrap the provided UDF with the logic defined by the FeatureProcessorConfig. + + General functionality of the wrapper function includes but is not limited to loading data + sources and ingesting output data to a Feature Group. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Callable[..., None]: the user provided function wrapped with feature_processor logic. + """ + + @functools.wraps(udf) + def wrapper() -> None: + udf_args, udf_kwargs = self._prepare_udf_args( + udf=udf, + fp_config=fp_config, + ) + + output = udf(*udf_args, **udf_kwargs) + + self.udf_output_receiver.ingest_udf_output(output, fp_config) + + return wrapper + + def _prepare_udf_args( + self, + udf: Callable[..., T], + fp_config: FeatureProcessorConfig, + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """Generate the arguments for the user defined function, provided by the wrapper function. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Tuple[Tuple[Any, ...], Dict[str, Any]]: A tuple positional arguments and keyword + arguments for the UDF. + """ + args = () + kwargs = { + **self.udf_arg_provider.provide_input_args(udf, fp_config), + **self.udf_arg_provider.provide_params_arg(udf, fp_config), + **self.udf_arg_provider.provide_additional_kwargs(udf), + } + + return (args, kwargs) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py new file mode 100644 index 0000000000..307838be0c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py @@ -0,0 +1,210 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module that contains validators and a validation chain""" +from __future__ import absolute_import + +import inspect +import re +from abc import ABC, abstractmethod +from typing import Any, Callable, List + +import attr + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + FeatureGroupDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) + + +@attr.s +class Validator(ABC): + """Base class for all validators. Errors are raised if validation fails.""" + + @abstractmethod + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validates FeatureProcessorConfig and a UDF.""" + + +@attr.s +class ValidatorChain: + """Executes a series of validators.""" + + validators: List[Validator] = attr.ib() + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validates a value using the list of validators. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If there are any validation errors raised by the validators in this chain. + """ + for validator in self.validators: + validator.validate(udf, fp_config) + + +class FeatureProcessorArgValidator(Validator): + """A validator for arguments provided to FeatureProcessor.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Temporary validator for unsupported feature_processor parameters. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + # TODO: Validate target_stores values. + + +class InputValidator(Validator): + """A validator for the 'input' parameter.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the arguments provided to the decorator's input parameter. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If no inputs are provided. + """ + + inputs = fp_config.inputs + if inputs is None or len(inputs) == 0: + raise ValueError("At least one input is required.") + + +class SparkUDFSignatureValidator(Validator): + """A validator for PySpark UDF signatures.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the signature of the UDF based on the configurations provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. No input provided to feature_processor. + 2. Number of provided parameters does not match with that of provided inputs. + 3. Required parameters are not provided in the right order. + """ + parameters = list(inspect.signature(udf).parameters.keys()) + input_parameters = self._get_input_params(udf) + if len(input_parameters) < 1: + raise ValueError("feature_processor expects at least 1 input parameter.") + + # Validate count of input parameters against requested inputs. + num_data_sources = len(fp_config.inputs) + if len(input_parameters) != num_data_sources: + raise ValueError( + f"feature_processor expected a function with ({num_data_sources}) parameter(s)" + f" before any optional 'params' or 'spark' parameters for the ({num_data_sources})" + f" requested data source(s)." + ) + + # Validate position of non-input parameters. + if "params" in parameters and parameters[-1] != "params" and parameters[-2] != "params": + raise ValueError( + "feature_processor expected the 'params' parameter to be the last or second last" + " parameter after input parameters." + ) + + if "spark" in parameters and parameters[-1] != "spark" and parameters[-2] != "spark": + raise ValueError( + "feature_processor expected the 'spark' parameter to be the last or second last" + " parameter after input parameters." + ) + + def _get_input_params(self, udf: Callable[..., Any]) -> List[str]: + """Get the parameters that correspond to the inputs for a UDF. + + Args: + udf (Callable[..., Any]): the user provided function. + """ + parameters = list(inspect.signature(udf).parameters.keys()) + + # Remove non-input parameter names. + if "params" in parameters: + parameters.remove("params") + if "spark" in parameters: + parameters.remove("spark") + + return parameters + + +class InputOffsetValidator(Validator): + """An Validator for input offset.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the start and end input offset provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when input_start_offset is later than + input_end_offset. + """ + + for config_input in fp_config.inputs: + if isinstance(input, FeatureGroupDataSource): + input_start_offset = config_input.input_start_offset + input_end_offset = config_input.input_end_offset + start_td = InputOffsetParser.parse_offset_to_timedelta(input_start_offset) + end_td = InputOffsetParser.parse_offset_to_timedelta(input_end_offset) + if start_td and end_td and start_td > end_td: + raise ValueError("input_start_offset should be always before input_end_offset.") + + +class BaseDataSourceValidator(Validator): + """An Validator for BaseDataSource.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the BaseDataSource provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when data_source_unique_id or data_source_name + of the input data source is not valid. + """ + + for config_input in fp_config.inputs: + if isinstance(config_input, BaseDataSource): + source_name = config_input.data_source_name + source_id = config_input.data_source_unique_id + + source_name_pattern = r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$" + source_id_pattern = r"^.{1,2048}$" + + if not re.match(source_name_pattern, source_name): + raise ValueError( + f"data_source_name of input does not match pattern '{source_name_pattern}'." + ) + + if not re.match(source_id_pattern, source_id): + raise ValueError( + f"data_source_unique_id of input does not match " + f"pattern '{source_id_pattern}'." + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py new file mode 100644 index 0000000000..31593a3f1c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Processor decorator for feature transformation functions.""" +from __future__ import absolute_import + +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from sagemaker.mlops.feature_store.feature_processor import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +def feature_processor( + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], + output: str, + target_stores: Optional[List[str]] = None, + parameters: Optional[Dict[str, Union[str, Dict]]] = None, + enable_ingestion: bool = True, + spark_config: Dict[str, str] = None, +) -> Callable: + """Decorator to facilitate feature engineering for Feature Groups. + + If the decorated function is executed without arguments then the decorated function's arguments + are automatically loaded from the input data sources. Outputs are ingested to the output Feature + Group. If arguments are provided to this function, then arguments are not automatically loaded + (for testing). + + Decorated functions must conform to the expected signature. Parameters: one parameter of type + pyspark.sql.DataFrame for each DataSource in 'inputs'; followed by the optional parameters with + names and types in [params: Dict[str, Any], spark: SparkSession]. Outputs: a single return + value of type pyspark.sql.DataFrame. The function can have any name. + + **Example:** + + .. code-block:: python + + @feature_processor( + inputs=[FeatureGroupDataSource("input-fg"), CSVDataSource("s3://bucket/prefix)], + output='arn:aws:sagemaker:us-west-2:123456789012:feature-group/output-fg' + ) + def transform( + input_feature_group: DataFrame, input_csv: DataFrame, params: Dict[str, Any], + spark: SparkSession + ) -> DataFrame: + return ... + + **More concisely:** + + .. code-block:: python + + @feature_processor( + inputs=[FeatureGroupDataSource("input-fg"), CSVDataSource("s3://bucket/prefix)], + output='arn:aws:sagemaker:us-west-2:123456789012:feature-group/output-fg' + ) + def transform(input_feature_group, input_csv): + return ... + + Args: + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource,\ + BaseDataSource]]): A list of data sources. + output (str): A Feature Group ARN to write results of this function to. + target_stores (Optional[list[str]], optional): A list containing at least one of + 'OnlineStore' or 'OfflineStore'. If unspecified, data will be ingested to the enabled + stores of the output feature group. Defaults to None. + parameters (Optional[Dict[str, Union[str, Dict]]], optional): Parameters to be provided to + the decorated function, available as the 'params' argument. Useful for parameterized + functions. The params argument also contains the set of system provided parameters + under the key 'system'. E.g. 'scheduled_time': a timestamp representing the time that + the execution was scheduled to execute at, if triggered by a Scheduler, otherwise, the + current time. + enable_ingestion (bool, optional): A boolean indicating whether the decorated function's + return value is ingested to the 'output' Feature Group. This flag is useful during the + development phase to ensure that data is not used until the function is ready. It also + useful for users that want to manage their own data ingestion. Defaults to True. + spark_config (Dict[str, str]): A dict contains the key-value paris for Spark configurations. + + Raises: + IngestionError: If any rows are not ingested successfully then a sample of the records, + with failure reasons, is logged. + + Returns: + Callable: The decorated function. + """ + + def decorator(udf: Callable[..., Any]) -> Callable: + fp_config = FeatureProcessorConfig.create( + inputs=inputs, + output=output, + mode=FeatureProcessorMode.PYSPARK, + target_stores=target_stores, + parameters=parameters, + enable_ingestion=enable_ingestion, + spark_config=spark_config, + ) + + validator_chain = ValidatorFactory.get_validation_chain(fp_config) + udf_wrapper = UDFWrapperFactory.get_udf_wrapper(fp_config) + + validator_chain.validate(udf=udf, fp_config=fp_config) + wrapped_function = udf_wrapper.wrap(udf=udf, fp_config=fp_config) + + wrapped_function.feature_processor_config = fp_config + + return wrapped_function + + return decorator diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py new file mode 100644 index 0000000000..c9039d982c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py @@ -0,0 +1,1100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Processor schedule APIs.""" +from __future__ import absolute_import +import logging +import json +import re +from datetime import datetime +from typing import Callable, List, Optional, Dict, Sequence, Union, Any, Tuple + +import pytz +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._config_uploader import ConfigUploader +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_group_lineage_context_name, + _get_feature_group_pipeline_lineage_context_name, + _get_feature_group_pipeline_version_lineage_context_name, + _get_feature_processor_pipeline_lineage_context_name, + _get_feature_processor_pipeline_version_lineage_context_name, +) +from sagemaker.core.lineage import context +from sagemaker.core.lineage._utils import get_resource_name_from_arn +from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, +) + +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.network import SUBNETS_KEY, SECURITY_GROUP_IDS_KEY + +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, + RESOURCE_NOT_FOUND_EXCEPTION, + SPARK_JAR_FILES_PATH, + SPARK_PY_FILES_PATH, + SPARK_FILES_PATH, + FEATURE_PROCESSOR_TAG_KEY, + FEATURE_PROCESSOR_TAG_VALUE, + PIPELINE_CONTEXT_TYPE, + DEFAULT_SCHEDULE_STATE, + SCHEDULED_TIME_PIPELINE_PARAMETER, + PIPELINE_CONTEXT_NAME_TAG_KEY, + PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + PIPELINE_NAME_MAXIMUM_LENGTH, + RESOURCE_NOT_FOUND, + FEATURE_GROUP_ARN_REGEX_PATTERN, + TO_PIPELINE_RESERVED_TAG_KEYS, + DEFAULT_TRIGGER_STATE, + EVENTBRIDGE_RULE_ARN_REGEX_PATTERN, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + +from sagemaker.core.s3 import s3_path_join + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.workflow.pipeline import Pipeline +from sagemaker.mlops.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) + +from sagemaker.mlops.workflow.steps import TrainingStep + +from sagemaker.train.model_trainer import ModelTrainer +from sagemaker.train.configs import Compute, Networking, StoppingCondition, SourceCode, Tag +from sagemaker.core.shapes import OutputDataConfig +from sagemaker.core.workflow.pipeline_context import PipelineSession + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage import ( + FeatureProcessorLineageHandler, + TransformationCode, +) + +from sagemaker.core.remote_function.job import ( + _JobSettings, + JOBS_CONTAINER_ENTRYPOINT, + SPARK_APP_SCRIPT_PATH, +) + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, +) + +logger = logging.getLogger("sagemaker") + + +def to_pipeline( + pipeline_name: str, + step: Callable, + role: Optional[str] = None, + transformation_code: Optional[TransformationCode] = None, + max_retries: Optional[int] = None, + tags: Optional[List[Tuple[str, str]]] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates a sagemaker pipeline that takes in a callable as a training step. + + To configure training step used in sagemaker pipeline, input argument step needs to be wrapped + by remote decorator in module sagemaker.remote_function. If not wrapped by remote decorator, + default configurations in sagemaker.remote_function.job._JobSettings will be used to create + training step. + + Args: + pipeline_name (str): The name of the pipeline. + step (Callable): A user provided function wrapped by feature_processor and optionally + wrapped by remote_decorator. + role (Optional[str]): The Amazon Resource Name (ARN) of the role used by the pipeline to + access and create resources. If not specified, it will default to the credentials + provided by the AWS configuration chain. + transformation_code (Optional[str]): The data source for a reference to the transformation + code for Lineage tracking. This code is not used for actual transformation. + max_retries (Optional[int]): The number of times to retry sagemaker pipeline step. + If not specified, sagemaker pipline step will not retry. + tags (List[Tuple[str, str]): A list of tags attached to the pipeline and all corresponding + lineage resources that support tags. If not specified, no custom tags will be attached. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: SageMaker Pipeline ARN. + """ + + _validate_input_for_to_pipeline_api(pipeline_name, step) + if tags: + _validate_tags_for_to_pipeline_api(tags) + + _sagemaker_session = sagemaker_session or Session() + + _validate_lineage_resources_for_to_pipeline_api( + step.feature_processor_config, _sagemaker_session + ) + + remote_decorator_config = _get_remote_decorator_config_from_input( + wrapped_func=step, sagemaker_session=_sagemaker_session + ) + _role = role or get_execution_role(_sagemaker_session) + + runtime_env_manager = RuntimeEnvironmentManager() + client_python_version = runtime_env_manager._current_python_version() + config_uploader = ConfigUploader(remote_decorator_config, runtime_env_manager) + + s3_base_uri = s3_path_join(remote_decorator_config.s3_root_uri, pipeline_name) + + ( + input_data_config, + spark_dependency_paths, + ) = config_uploader.prepare_step_input_channel_for_spark_mode( + func=getattr(step, "wrapped_func", step), + s3_base_uri=s3_base_uri, + sagemaker_session=_sagemaker_session, + ) + + pipeline_session = PipelineSession( + boto_session=_sagemaker_session.boto_session, + default_bucket=_sagemaker_session.default_bucket(), + default_bucket_prefix=_sagemaker_session.default_bucket_prefix, + ) + logger.info("Created PipelineSession for pipeline %s", pipeline_name) + + model_trainer = _prepare_model_trainer_from_remote_decorator_config( + remote_decorator_config=remote_decorator_config, + s3_base_uri=s3_base_uri, + client_python_version=client_python_version, + spark_dependency_paths=spark_dependency_paths, + pipeline_session=pipeline_session, + role=_role, + ) + + step_args = model_trainer.train(input_data_config=input_data_config) + logger.info("Obtained step_args from ModelTrainer.train() for pipeline %s", pipeline_name) + + step_name = "-".join([pipeline_name, "feature-processor"]) + training_step_request_dict = dict( + name=step_name, + step_args=step_args, + ) + logger.info("Created TrainingStep '%s' with step_args", step_name) + + if max_retries: + training_step_request_dict["retry_policies"] = [ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + max_attempts=max_retries, + ), + SageMakerJobStepRetryPolicy( + exception_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT, + ], + max_attempts=max_retries, + ), + ] + + pipeline_request_dict = dict( + name=pipeline_name, + steps=[TrainingStep(**training_step_request_dict)], + sagemaker_session=_sagemaker_session, + parameters=[SCHEDULED_TIME_PIPELINE_PARAMETER], + ) + pipeline_tags = [dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)] + if tags: + pipeline_tags.extend([dict(Key=k, Value=v) for k, v in tags]) + + pipeline = Pipeline(**pipeline_request_dict) + logger.info("Creating/Updating sagemaker pipeline %s", pipeline_name) + pipeline.upsert( + role_arn=_role, + tags=pipeline_tags, + ) + logger.info("Created sagemaker pipeline %s", pipeline_name) + + describe_pipeline_response = pipeline.describe() + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=pipeline_name, + pipeline_arn=pipeline_arn, + pipeline=describe_pipeline_response, + inputs=_get_feature_processor_inputs(wrapped_func=step), + output=_get_feature_processor_outputs(wrapped_func=step), + transformation_code=transformation_code, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_lineage(tags_propagate_to_lineage_resources) + lineage_handler.upsert_tags_for_lineage_resources(tags_propagate_to_lineage_resources) + + pipeline_lineage_names: Dict[str, str] = lineage_handler.get_pipeline_lineage_names() + + if pipeline_lineage_names is None: + raise RuntimeError("Failed to retrieve pipeline lineage. Pipeline Lineage does not exist") + + pipeline.upsert( + role_arn=_role, + tags=[ + { + "Key": PIPELINE_CONTEXT_NAME_TAG_KEY, + "Value": pipeline_lineage_names["pipeline_context_name"], + }, + { + "Key": PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + "Value": pipeline_lineage_names["pipeline_version_context_name"], + }, + ], + ) + return pipeline_arn + + +def schedule( + pipeline_name: str, + schedule_expression: str, + role_arn: Optional[str] = None, + state: Optional[str] = DEFAULT_SCHEDULE_STATE, + start_date: Optional[datetime] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates an EventBridge Schedule that schedules executions of a sagemaker pipeline. + + The pipeline created will also have a pipeline parameter `scheduled-time` indicating when the + pipeline is scheduled to run. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be scheduled. + schedule_expression (str): The expression that defines when the schedule runs. It supports + at expression, rate expression and cron expression. See the + `CreateSchedule API + `_ + for more details. + state (str): Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See the `State request parameter + `_ + for more details. If not specified, it will default to ENABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + role_arn (Optional[str]): The Amazon Resource Name (ARN) of the IAM role that EventBridge + Scheduler will assume for this target when the schedule is invoked. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The EventBridge Schedule ARN. + """ + + _sagemaker_session = sagemaker_session or Session() + _validate_pipeline_lineage_resources(pipeline_name, _sagemaker_session) + _start_date = start_date or datetime.now(tz=pytz.utc) + _role_arn = role_arn or get_execution_role(_sagemaker_session) + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("scheduler"), + ) + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + logger.info("Creating/Updating EventBridge Schedule for pipeline %s.", pipeline_name) + event_bridge_schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=pipeline_name, + pipeline_arn=pipeline_arn, + schedule_expression=schedule_expression, + state=state, + start_date=_start_date, + role=_role_arn, + ) + logger.info("Created/Updated EventBridge Schedule for pipeline %s.", pipeline_name) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=pipeline_name, + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline=describe_pipeline_response, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_schedule_lineage( + pipeline_name=pipeline_name, + schedule_arn=event_bridge_schedule_arn["ScheduleArn"], + schedule_expression=schedule_expression, + state=state, + start_date=_start_date, + tags=tags_propagate_to_lineage_resources, + ) + return event_bridge_schedule_arn["ScheduleArn"] + + +def put_trigger( + source_pipeline_events: List[FeatureProcessorPipelineEvents], + target_pipeline: str, + target_pipeline_parameters: Optional[Dict[str, str]] = None, + state: Optional[str] = DEFAULT_TRIGGER_STATE, + event_pattern: Optional[str] = None, + role_arn: Optional[str] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates an event based trigger that triggers executions of a sagemaker pipeline. + + Args: + source_pipeline_events (List[FeatureProcessorPipelineEvents]): The list of + FeatureProcessorPipelineEvents that will trigger the target_pipeline. + target_pipeline (str): The name of the SageMaker Pipeline that will be triggered. + target_pipeline_parameters (Optional[Dict[str, str]]): The list of parameters to start + execution of a pipeline. + state (Optional[str]): Indicates whether the rule is enabled or disabled. + If not specified, it will default to ENABLED. + event_pattern (Optional[str]): The EventBridge EventPattern that triggers the + target_pipeline. If specified, will override source_pipeline_events. For more + information, see + https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-event-patterns.html + in the Amazon EventBridge User Guide. + role_arn (Optional[str]): The Amazon Resource Name (ARN) of the IAM role that EventBridge + Scheduler will assume for this target when the schedule is invoked. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The EventBridge Rule ARN. + """ + _sagemaker_session = sagemaker_session or Session() + _role_arn = role_arn or get_execution_role(_sagemaker_session) + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + logger.info("Creating/Updating EventBridge Rule for pipeline %s.", target_pipeline) + rule_arn = event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline=target_pipeline, + event_pattern=event_pattern, + state=state, + ) + rule_name = _parse_name_from_arn(rule_arn, EVENTBRIDGE_RULE_ARN_REGEX_PATTERN) + logger.info("Created/Updated EventBridge Rule for pipeline %s.", target_pipeline) + + logger.info("Attaching pipeline %s to EventBridge Rule %s as target", target_pipeline, rule_arn) + event_bridge_rule_helper.put_target( + rule_name=rule_name, + target_pipeline=target_pipeline, + target_pipeline_parameters=target_pipeline_parameters, + role_arn=_role_arn, + ) + logger.info("Attached pipeline %s to EventBridge Rule %s as target", target_pipeline, rule_arn) + + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=target_pipeline + ) + describe_rule_response = event_bridge_rule_helper.describe_rule(rule_name=rule_name) + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + event_bridge_rule_helper.add_tags(rule_arn=rule_arn, tags=tags_propagate_to_lineage_resources) + + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=target_pipeline, + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline=describe_pipeline_response, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_trigger_lineage( + pipeline_name=target_pipeline, + trigger_arn=rule_arn, + state=state, + tags=tags_propagate_to_lineage_resources, + event_pattern=describe_rule_response["EventPattern"], + ) + return rule_arn + + +def enable_trigger( + pipeline_name: str, + sagemaker_session: Optional[Session] = None, +) -> None: + """Enable the EventBridge Rule that is associated with the pipeline. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule_helper.enable_rule(rule_name=pipeline_name) + logger.info("Enabled EventBridge Rule for pipeline %s.", pipeline_name) + + +def disable_trigger(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Disable the EventBridge Rule that is associated with the pipeline. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule_helper.disable_rule(rule_name=pipeline_name) + logger.info("Disabled EventBridge Rule for pipeline %s.", pipeline_name) + + +def execute( + pipeline_name: str, + execution_time: Optional[datetime] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Starts an execution of a SageMaker Pipeline created by feature_processor + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + execution_time (datetime): The date, in UTC, will be used as a sagemaker pipeline parameter + indicating the time which at which the execution is scheduled to execute. If not + specified, it will default to the current timestamp. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The pipeline execution ARN. + """ + _sagemaker_session = sagemaker_session or Session() + _validate_pipeline_lineage_resources(pipeline_name, _sagemaker_session) + _execution_time = execution_time or datetime.now() + start_pipeline_execution_request = dict( + PipelineName=pipeline_name, + PipelineParameters=[ + dict( + Name=EXECUTION_TIME_PIPELINE_PARAMETER, + Value=_execution_time.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ) + ], + ) + logger.info("Starting an execution for pipline %s", pipeline_name) + execution_response = _sagemaker_session.sagemaker_client.start_pipeline_execution( + **start_pipeline_execution_request + ) + execution_arn = execution_response["PipelineExecutionArn"] + logger.info( + "Execution %s for pipeline %s is successfully started.", + execution_arn, + pipeline_name, + ) + return execution_arn + + +def delete_schedule(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Delete EventBridge Schedule corresponding to a SageMaker Pipeline if there is one. + + Args: + pipeline_name (str): The name of the SageMaker Pipeline that needs to be deleted + sagemaker_session: (Optional[Session], optional): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, _sagemaker_session.boto_session.client("scheduler") + ) + try: + event_bridge_scheduler_helper.delete_schedule(pipeline_name) + logger.info("Deleted EventBridge Schedule for pipeline %s.", pipeline_name) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION != e.response["Error"]["Code"]: + raise e + + +def delete_trigger(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Delete EventBridge Rule corresponding to a SageMaker Pipeline if there is one. + + Args: + pipeline_name (str): The name of the SageMaker Pipeline that needs to be deleted + sagemaker_session: (Optional[Session], optional): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + try: + target_ids = [] + for page in event_bridge_rule_helper.list_targets_by_rule(pipeline_name): + target_ids.extend([target["Id"] for target in page["Targets"]]) + event_bridge_rule_helper.remove_targets(rule_name=pipeline_name, ids=target_ids) + event_bridge_rule_helper.delete_rule(pipeline_name) + logger.info("Deleted EventBridge Rule for pipeline %s.", pipeline_name) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION != e.response["Error"]["Code"]: + raise e + + +def describe( + pipeline_name: str, sagemaker_session: Optional[Session] = None +) -> Dict[str, Union[int, str]]: + """Describe feature processor and other related resources. + + This API will include details related to the feature processor including SageMaker Pipeline and + EventBridge Schedule. + + Args: + pipeline_name (str): Name of the pipeline. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + Dict[str, Union[int, str]]: Return information for resources related to feature processor. + """ + + _sagemaker_session = sagemaker_session or Session() + describe_response_dict = {} + + try: + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_definition = json.loads(describe_pipeline_response["PipelineDefinition"]) + pipeline_step = pipeline_definition["Steps"][0] + describe_response_dict = dict( + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline_execution_role_arn=describe_pipeline_response["RoleArn"], + ) + + if "RetryPolicies" in pipeline_step: + describe_response_dict["max_retries"] = pipeline_step["RetryPolicies"][0]["MaxAttempts"] + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("Pipeline %s does not exist.", pipeline_name) + + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("scheduler"), + ) + + event_bridge_schedule = event_bridge_scheduler_helper.describe_schedule(pipeline_name) + if event_bridge_schedule: + describe_response_dict.update( + dict( + schedule_arn=event_bridge_schedule["Arn"], + schedule_expression=event_bridge_schedule["ScheduleExpression"], + schedule_state=event_bridge_schedule["State"], + schedule_start_date=event_bridge_schedule["StartDate"].strftime( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT + ), + schedule_role=event_bridge_schedule["Target"]["RoleArn"], + ) + ) + + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_based_trigger = event_bridge_rule_helper.describe_rule(pipeline_name) + if event_based_trigger: + describe_response_dict.update( + dict( + trigger=event_based_trigger["Arn"], + event_pattern=event_based_trigger["EventPattern"], + trigger_state=event_based_trigger["State"], + ) + ) + + return describe_response_dict + + +def list_pipelines(sagemaker_session: Optional[Session] = None) -> List[Dict[str, Any]]: + """Lists all SageMaker Pipelines created by Feature Processor SDK. + + Args: + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + List[Dict[str, Any]]: Return list of SageMaker Pipeline metadata created for + feature_processor. + """ + + _sagemaker_session = sagemaker_session or Session() + next_token = None + list_response = [] + pipeline_names_so_far = set([]) + while True: + list_contexts_request = dict(ContextType=PIPELINE_CONTEXT_TYPE) + if next_token: + list_contexts_request["NextToken"] = next_token + list_contexts_response = _sagemaker_session.sagemaker_client.list_contexts( + **list_contexts_request + ) + for _context in list_contexts_response["ContextSummaries"]: + pipeline_name = get_resource_name_from_arn(_context["Source"]["SourceUri"]) + if pipeline_name not in pipeline_names_so_far: + list_response.append(dict(pipeline_name=pipeline_name)) + pipeline_names_so_far.add(pipeline_name) + next_token = list_contexts_response.get("NextToken") + if not next_token: + break + + return list_response + + +def _validate_input_for_to_pipeline_api(pipeline_name: str, step: Callable) -> None: + """Validate input to to_pipeline API. + + The provided callable is considered valid if it's wrapped by feature_processor decorator + and uses pyspark mode. + + Args: + pipeline_name (str): The name of the pipeline. + step (Callable): A user provided function wrapped by feature_processor and optionally + wrapped by remote_decorator. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. pipeline name is longer than 80 characters. + 2. function is not annotated with either feature_processor or remote decorator. + 3. provides a mode other than pyspark. + """ + if len(pipeline_name) > PIPELINE_NAME_MAXIMUM_LENGTH: + raise ValueError( + "Pipeline name used by feature processor should be less than 80 " + "characters. Please choose another pipeline name." + ) + + if not hasattr(step, "feature_processor_config") or not step.feature_processor_config: + raise ValueError( + "Please wrap step parameter with feature_processor decorator" + " in order to use to_pipeline API." + ) + + if not hasattr(step, "job_settings") or not step.job_settings: + raise ValueError( + "Please wrap step parameter with remote decorator in order to use to_pipeline API." + ) + + if FeatureProcessorMode.PYSPARK != step.feature_processor_config.mode: + raise ValueError( + f"Mode {step.feature_processor_config.mode} is not supported by to_pipeline API." + ) + + +def _validate_tags_for_to_pipeline_api(tags: List[Tuple[str, str]]) -> None: + """Validate tags provided to to_pipeline API. + + Args: + tags (List[Tuple[str, str]]): A list of tags attached to the pipeline. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. reserved tag keys are provided to API. + """ + provided_tag_keys = [tag_key_value_pair[0] for tag_key_value_pair in tags] + for reserved_tag_key in TO_PIPELINE_RESERVED_TAG_KEYS: + if reserved_tag_key in provided_tag_keys: + raise ValueError( + f"{reserved_tag_key} is a reserved tag key for to_pipeline API. Please choose another tag." + ) + + +def _validate_lineage_resources_for_to_pipeline_api( + feature_processor_config: FeatureProcessorConfig, sagemaker_session: Session +) -> None: + """Validate existence of feature group lineage resources for to_pipeline API. + + Args: + feature_processor_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + """ + inputs = feature_processor_config.inputs + output = feature_processor_config.output + for ds in inputs: + if isinstance(ds, FeatureGroupDataSource): + fg_name = _parse_name_from_arn(ds.name) + _validate_fg_lineage_resources(fg_name, sagemaker_session) + output_fg_name = _parse_name_from_arn(output) + _validate_fg_lineage_resources(output_fg_name, sagemaker_session) + + +def _validate_fg_lineage_resources(feature_group_name: str, sagemaker_session: Session) -> None: + """Validate existence of feature group lineage resources. + + Args: + feature_group_name (str): The name or arn of the feature group. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + + Raises (ValueError): raises ValueError when lineage resources are not created for feature + groups. + """ + + feature_group = sagemaker_session.describe_feature_group(feature_group_name=feature_group_name) + feature_group_creation_time = feature_group["CreationTime"].strftime("%s") + feature_group_context = _get_feature_group_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + feature_group_pipeline_context = _get_feature_group_pipeline_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + feature_group_pipeline_version_context = ( + _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + ) + for context_name in [ + feature_group_context, + feature_group_pipeline_context, + feature_group_pipeline_version_context, + ]: + try: + logger.info("Verifying existence of context %s.", context_name) + context.Context.load(context_name=context_name, sagemaker_session=sagemaker_session) + except ClientError as e: + if RESOURCE_NOT_FOUND == e.response["Error"]["Code"]: + raise ValueError( + f"Lineage resource {context_name} has not yet been created for feature group" + f" {feature_group_name} or has already been deleted. Please try again later." + ) + raise e + + +def _validate_pipeline_lineage_resources(pipeline_name: str, sagemaker_session: Session) -> None: + """Validate existence of pipeline lineage resources. + + Args: + pipeline_name (str): The name of the pipeline. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + """ + pipeline = sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=pipeline_name) + pipeline_creation_time = pipeline["CreationTime"].strftime("%s") + pipeline_context_name = _get_feature_processor_pipeline_lineage_context_name( + pipeline_name=pipeline_name, pipeline_creation_time=pipeline_creation_time + ) + try: + pipeline_context = context.Context.load( + context_name=pipeline_context_name, sagemaker_session=sagemaker_session + ) + last_update_time = pipeline_context.properties["LastUpdateTime"] + pipeline_version_context_name = ( + _get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name=pipeline_name, pipeline_last_update_time=last_update_time + ) + ) + context.Context.load( + context_name=pipeline_version_context_name, sagemaker_session=sagemaker_session + ) + except ClientError as e: + if RESOURCE_NOT_FOUND == e.response["Error"]["Code"]: + raise ValueError( + "Pipeline lineage resources have not been created yet or have already been deleted" + ". Please try again later." + ) + raise e + + +def _prepare_model_trainer_from_remote_decorator_config( + remote_decorator_config: _JobSettings, + s3_base_uri: str, + client_python_version: str, + spark_dependency_paths: Dict[str, Optional[str]], + pipeline_session: PipelineSession, + role: str, +) -> ModelTrainer: + """Prepares a ModelTrainer instance from remote decorator configuration. + + Args: + remote_decorator_config (_JobSettings): Configurations used for setting up + SageMaker Pipeline Step. + s3_base_uri (str): S3 URI used as destination for dependencies upload. + client_python_version (str): Python version used on client side. + spark_dependency_paths (Dict[str, Optional[str]]): A dictionary contains S3 paths spark + dependency files get uploaded to if present. + pipeline_session (PipelineSession): Pipeline-aware session that causes + ModelTrainer.train() to return step arguments instead of launching a job. + role (str): The IAM role ARN for the training job. + Returns: + ModelTrainer: A configured ModelTrainer instance. + """ + logger.info("Mapping remote decorator config to ModelTrainer params") + + # Build environment dict from remote_decorator_config (strings only for Pydantic validation) + environment = dict(remote_decorator_config.environment_variables or {}) + + # Build command from container entry point and arguments + entry_point_and_args = _get_container_entry_point_and_arguments( + remote_decorator_config=remote_decorator_config, + s3_base_uri=s3_base_uri, + client_python_version=client_python_version, + spark_dependency_paths=spark_dependency_paths, + ) + joined_command = " ".join( + entry_point_and_args["container_entry_point"] + + entry_point_and_args["container_arguments"] + ) + source_code = SourceCode(command=joined_command) + logger.info("SourceCode command: %s", joined_command) + + # Create Compute config + compute = Compute( + instance_type=remote_decorator_config.instance_type, + instance_count=remote_decorator_config.instance_count, + volume_size_in_gb=remote_decorator_config.volume_size, + volume_kms_key_id=remote_decorator_config.volume_kms_key, + ) + logger.info( + "Compute: instance_type=%s, instance_count=%s, volume_size=%s", + remote_decorator_config.instance_type, + remote_decorator_config.instance_count, + remote_decorator_config.volume_size, + ) + + # Create Networking config if VPC config is present + networking = None + if remote_decorator_config.vpc_config: + networking = Networking( + subnets=remote_decorator_config.vpc_config[SUBNETS_KEY], + security_group_ids=remote_decorator_config.vpc_config[SECURITY_GROUP_IDS_KEY], + enable_inter_container_traffic_encryption=( + remote_decorator_config.encrypt_inter_container_traffic + ), + ) + logger.info( + "Networking: subnets=%s, security_groups=%s, encrypt=%s", + remote_decorator_config.vpc_config[SUBNETS_KEY], + remote_decorator_config.vpc_config[SECURITY_GROUP_IDS_KEY], + remote_decorator_config.encrypt_inter_container_traffic, + ) + + # Create StoppingCondition if max_runtime_in_seconds is configured + stopping_condition = None + if remote_decorator_config.max_runtime_in_seconds: + stopping_condition = StoppingCondition( + max_runtime_in_seconds=remote_decorator_config.max_runtime_in_seconds, + ) + + # Create OutputDataConfig + output_data_config = OutputDataConfig( + s3_output_path=s3_base_uri, + kms_key_id=remote_decorator_config.s3_kms_key, + ) + + # Convert tags from List[Tuple[str, str]] to List[Tag] + tags = None + if remote_decorator_config.tags: + tags = [Tag(key=k, value=v) for k, v in remote_decorator_config.tags] + logger.info("Tags count: %d", len(tags) if tags else 0) + + logger.info("Environment keys: %s", list(environment.keys())) + + model_trainer = ModelTrainer( + training_image=remote_decorator_config.image_uri, + role=role, + sagemaker_session=pipeline_session, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + output_data_config=output_data_config, + source_code=source_code, + training_input_mode="File", + environment=environment, + tags=tags, + ) + + # Inject SCHEDULED_TIME_PIPELINE_PARAMETER after construction to bypass Pydantic + # validation (Parameter is not a string). The @runnable_by_pipeline decorator resolves + # Parameter objects to strings during pipeline definition serialization. + model_trainer.environment[EXECUTION_TIME_PIPELINE_PARAMETER] = SCHEDULED_TIME_PIPELINE_PARAMETER + + logger.info( + "Created ModelTrainer with image=%s, instance_type=%s, instance_count=%s", + remote_decorator_config.image_uri, + remote_decorator_config.instance_type, + remote_decorator_config.instance_count, + ) + + return model_trainer + + +def _get_container_entry_point_and_arguments( + remote_decorator_config: _JobSettings, + s3_base_uri: str, + client_python_version: str, + spark_dependency_paths: Dict[str, Optional[str]], +) -> Dict[str, List[str]]: + """Extracts the container entry point and container arguments from remote decorator configs + + Args: + remote_decorator_config (_JobSettings): Configurations used for setting up + SageMaker Pipeline Step. + s3_base_uri (str): S3 URI used as destination for dependencies upload. + client_python_version (str): Python version used on client side. + spark_dependency_paths (Dict[str, Optional[str]]): A dictionary contains S3 paths spark + dependency files get uploaded to if present. + Returns: + Dict[str, List[str]]: Request dictionary containing container entry point and + arguments setup. + """ + + spark_config = remote_decorator_config.spark_config + jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT.copy() + + if spark_dependency_paths[SPARK_JAR_FILES_PATH]: + jobs_container_entrypoint.extend(["--jars", spark_dependency_paths[SPARK_JAR_FILES_PATH]]) + + if spark_dependency_paths[SPARK_PY_FILES_PATH]: + jobs_container_entrypoint.extend( + ["--py-files", spark_dependency_paths[SPARK_PY_FILES_PATH]] + ) + + if spark_dependency_paths[SPARK_FILES_PATH]: + jobs_container_entrypoint.extend(["--files", spark_dependency_paths[SPARK_FILES_PATH]]) + + if spark_config and spark_config.spark_event_logs_uri: + jobs_container_entrypoint.extend( + ["--spark-event-logs-s3-uri", spark_config.spark_event_logs_uri] + ) + + if spark_config: + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + + container_args = ["--s3_base_uri", s3_base_uri] + container_args.extend(["--region", remote_decorator_config.sagemaker_session.boto_region_name]) + container_args.extend(["--client_python_version", client_python_version]) + + if remote_decorator_config.s3_kms_key: + container_args.extend(["--s3_kms_key", remote_decorator_config.s3_kms_key]) + + return dict( + container_entry_point=jobs_container_entrypoint, + container_arguments=container_args, + ) + + +def _get_remote_decorator_config_from_input( + wrapped_func: Callable, sagemaker_session: Session +) -> _JobSettings: + """Extracts the remote decorator configuration from the wrapped function and other inputs. + + Args: + wrapped_func (Callable): Wrapped user defined function. If it contains remote decorator + job settings, configs will be used to construct remote_decorator_config, otherwise + default job settings will be used. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + _JobSettings: Configurations used for creating sagemaker pipeline step. + """ + remote_decorator_config = getattr( + wrapped_func, + "job_settings", + ) + # TODO: Remove this after GA + remote_decorator_config.sagemaker_session = sagemaker_session + + # TODO: This needs to be removed when new mode is introduced. + if remote_decorator_config.spark_config is None: + remote_decorator_config.spark_config = SparkConfig() + remote_decorator_config.image_uri = _JobSettings._get_default_spark_image(sagemaker_session) + + return remote_decorator_config + + +def _get_feature_processor_inputs( + wrapped_func: Callable, +) -> Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]]: + """Retrieve Feature Processor Config Inputs""" + feature_processor_config: FeatureProcessorConfig = wrapped_func.feature_processor_config + return feature_processor_config.inputs + + +def _get_feature_processor_outputs( + wrapped_func: Callable, +) -> str: + """Retrieve Feature Processor Config Output""" + feature_processor_config: FeatureProcessorConfig = wrapped_func.feature_processor_config + return feature_processor_config.output + + +def _parse_name_from_arn( + name_or_arn: str, regex_pattern: str = FEATURE_GROUP_ARN_REGEX_PATTERN +) -> str: + """Parse the name from a string, if it's an ARN. Otherwise, return the string. + + Args: + fg_uri (str): The Feature Group Name or ARN. + + Returns: + str: The Feature Group Name. + """ + match = re.match(regex_pattern, name_or_arn) + if match: + name = match.group(4) + return name + return name_or_arn + + +def _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn: str, sagemaker_session: Session +) -> List[Dict[str, str]]: + """Retrieve custom tags attached to sagemakre pipeline + + Args: + pipeline_arn (str): SageMaker Pipeline Arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + List[Dict[str, str]]: List of custom tags to be propagated to lineage resources. + """ + tags_in_pipeline = sagemaker_session.sagemaker_client.list_tags(ResourceArn=pipeline_arn)[ + "Tags" + ] + return [d for d in tags_in_pipeline if d["Key"] not in TO_PIPELINE_RESERVED_TAG_KEYS] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py new file mode 100644 index 0000000000..2b4f134f0a --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store Feature Group Contexts""" +from __future__ import absolute_import +import attr + + +@attr.s +class FeatureGroupContexts: + """A Feature Group Context data source. + + Attributes: + feature_group_name (str): The name of the Feature Group. + feature_group_pipeline_context_arn (str): The ARN of the Feature Group Pipeline Context. + feature_group_pipeline_version_context_arn (str): + The ARN of the Feature Group Versions Context + """ + + name: str = attr.ib() + pipeline_context_arn: str = attr.ib() + pipeline_version_context_arn: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py new file mode 100644 index 0000000000..55230d7c1c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py @@ -0,0 +1,182 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Feature Processor Lineage""" +from __future__ import absolute_import + +import re +from typing import Dict, Any +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import FEATURE_GROUP_ARN_REGEX_PATTERN +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + FEATURE_GROUP, + CREATION_TIME, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_group_pipeline_lineage_context_name, + _get_feature_group_pipeline_version_lineage_context_name, +) + +logger = logging.getLogger(SAGEMAKER) + + +class FeatureGroupLineageEntityHandler: + """Class for handling Feature Group Lineage""" + + @staticmethod + def retrieve_feature_group_context_arns( + feature_group_name: str, sagemaker_session: Session + ) -> FeatureGroupContexts: + """Retrieve Feature Group Contexts. + + Arguments: + feature_group_name (str): The Feature Group Name. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + FeatureGroupContexts: The Feature Group Pipeline and Version Context. + """ + feature_group = FeatureGroupLineageEntityHandler._describe_feature_group( + feature_group_name=FeatureGroupLineageEntityHandler.parse_name_from_arn( + feature_group_name + ), + sagemaker_session=sagemaker_session, + ) + feature_group_name = feature_group[FEATURE_GROUP] + feature_group_creation_time = feature_group[CREATION_TIME].strftime("%s") + feature_group_pipeline_context = ( + FeatureGroupLineageEntityHandler._load_feature_group_pipeline_context( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + sagemaker_session=sagemaker_session, + ) + ) + feature_group_pipeline_version_context = ( + FeatureGroupLineageEntityHandler._load_feature_group_pipeline_version_context( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + sagemaker_session=sagemaker_session, + ) + ) + return FeatureGroupContexts( + name=feature_group_name, + pipeline_context_arn=feature_group_pipeline_context.context_arn, + pipeline_version_context_arn=feature_group_pipeline_version_context.context_arn, + ) + + @staticmethod + def _describe_feature_group( + feature_group_name: str, sagemaker_session: Session + ) -> Dict[str, Any]: + """Retrieve the Feature Group. + + Arguments: + feature_group_name (str): The Feature Group Name. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Dict[str, Any]: The Feature Group details. + """ + feature_group = sagemaker_session.describe_feature_group(feature_group_name) + logger.debug( + "Called describe_feature_group with %s and received: %s", + feature_group_name, + feature_group, + ) + return feature_group + + @staticmethod + def _load_feature_group_pipeline_context( + feature_group_name: str, + feature_group_creation_time: str, + sagemaker_session: Session, + ) -> Context: + """Retrieve Feature Group Pipeline Context + + Arguments: + feature_group_name (str): The Feature Group Name. + feature_group_creation_time (str): The Feature Group Creation Time, + in long epoch seconds. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The Feature Group Pipeline Context. + """ + feature_group_pipeline_context = _get_feature_group_pipeline_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + return Context.load( + context_name=feature_group_pipeline_context, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def _load_feature_group_pipeline_version_context( + feature_group_name: str, + feature_group_creation_time: str, + sagemaker_session: Session, + ) -> Context: + """Retrieve Feature Group Pipeline Version Context + + Arguments: + feature_group_name (str): The Feature Group Name. + feature_group_creation_time (str): The Feature Group Creation Time, + in long epoch seconds. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The Feature Group Pipeline Version Context. + """ + feature_group_pipeline_version_context = ( + _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + ) + return Context.load( + context_name=feature_group_pipeline_version_context, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def parse_name_from_arn(fg_uri: str) -> str: + """Parse the name from a string, if it's an ARN. Otherwise, return the string. + + Arguments: + fg_uri (str): The Feature Group Name or ARN. + + Returns: + str: The Feature Group Name. + """ + match = re.match(FEATURE_GROUP_ARN_REGEX_PATTERN, fg_uri) + if match: + feature_group_name = match.group(4) + return feature_group_name + return fg_uri diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py new file mode 100644 index 0000000000..d706b3b441 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py @@ -0,0 +1,759 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Lineage Associations""" +from __future__ import absolute_import +import logging +from datetime import datetime +from typing import Optional, Iterator, List, Dict, Set, Sequence, Union +import attr +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import ( + PipelineTrigger, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + LAST_UPDATE_TIME, + PIPELINE_CONTEXT_NAME_KEY, + PIPELINE_CONTEXT_VERSION_NAME_KEY, + FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + DATA_SET, + TRANSFORMATION_CODE, + CREATION_TIME, + RESOURCE_NOT_FOUND, + ERROR, + CODE, + LAST_MODIFIED_TIME, + TRANSFORMATION_CODE_STATUS_INACTIVE, + TRANSFORMATION_CODE_STATUS_ACTIVE, + CONTRIBUTED_TO, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.lineage.artifact import Artifact +from sagemaker.core.lineage.association import AssociationSummary +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) + +logger = logging.getLogger(SAGEMAKER) + + +@attr.s +class FeatureProcessorLineageHandler: + """Class to Create and Update FeatureProcessor Lineage Entities. + + Attributes: + pipeline_name (str): Pipeline Name. + pipeline_arn (str): The ARN of the Pipeline. + pipeline (str): The details of the Pipeline. + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]]): The inputs to the Feature processor. + output (str): The output Feature Group. + transformation_code (TransformationCode): The Transformation Code for Feature Processor. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + + pipeline_name: str = attr.ib() + pipeline_arn: str = attr.ib() + pipeline: Dict = attr.ib() + sagemaker_session: Session = attr.ib() + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib(default=None) + output: str = attr.ib(default=None) + transformation_code: TransformationCode = attr.ib(default=None) + + def create_lineage(self, tags: Optional[List[Dict[str, str]]] = None) -> None: + """Create and Update Feature Processor Lineage""" + input_feature_group_contexts: List[FeatureGroupContexts] = ( + self._retrieve_input_feature_group_contexts() + ) + output_feature_group_contexts: FeatureGroupContexts = ( + self._retrieve_output_feature_group_contexts() + ) + input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts() + transformation_code_artifact: Optional[Artifact] = ( + S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=self.transformation_code, + pipeline_last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + ) + if transformation_code_artifact is not None: + logger.info("Created Transformation Code Artifact: %s", transformation_code_artifact) + if tags: + transformation_code_artifact.set_tags(tags) # pylint: disable=E1101 + # Create the Pipeline Lineage for the first time + if not self._check_if_pipeline_lineage_exists(): + self._create_new_pipeline_lineage( + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + else: + self._update_pipeline_lineage( + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + + def get_pipeline_lineage_names(self) -> Optional[Dict[str, str]]: + """Retrieve Pipeline Lineage Names. + + Returns: + Optional[Dict[str, str]]: Pipeline and Pipeline version lineage names. + """ + if not self._check_if_pipeline_lineage_exists(): + return None + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + return { + PIPELINE_CONTEXT_NAME_KEY: pipeline_context.context_name, + PIPELINE_CONTEXT_VERSION_NAME_KEY: current_pipeline_version_context.context_name, + } + + def create_schedule_lineage( + self, + pipeline_name: str, + schedule_arn, + schedule_expression, + state, + start_date: datetime, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + """Class to Create and Update FeatureProcessor Lineage Entities. + + Arguments: + pipeline_name (str): Pipeline Name. + schedule_arn (str): The ARN of the Schedule. + schedule_expression (str): The expression that defines when the schedule runs. + It supports at expression, rate expression and cron expression. + state (str):Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/ + API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details. + If not specified, it will default to ENABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + tags (Optional[List[Dict[str, str]]]): Custom tags to be attached to schedule + lineage resource. + """ + pipeline_context: Context = self._get_pipeline_context() + pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + pipeline_schedule: PipelineSchedule = PipelineSchedule( + schedule_name=pipeline_name, + schedule_arn=schedule_arn, + schedule_expression=schedule_expression, + pipeline_name=pipeline_name, + state=state, + start_date=start_date.strftime("%s"), + ) + schedule_artifact: Artifact = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=pipeline_schedule, + sagemaker_session=self.sagemaker_session, + ) + if tags: + schedule_artifact.set_tags(tags) + + LineageAssociationHandler.add_upstream_schedule_associations( + schedule_artifact=schedule_artifact, + pipeline_version_context_arn=pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def create_trigger_lineage( + self, + pipeline_name: str, + trigger_arn: str, + event_pattern: str, + state: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + """Class to Create and Update FeatureProcessor Pipeline Trigger Lineage Entities. + + Arguments: + pipeline_name (str): Pipeline Name. + trigger_arn (str): The ARN of the EventBridge Rule. + event_pattern (str): The event pattern for the rule. + state (str): Specifies whether the trigger is enabled or disabled. Valid values are + ENABLED and DISABLED. If not specified, it will default to ENABLED. + tags (Optional[List[Dict[str, str]]]): Custom tags to be attached to trigger + lineage resource. + """ + pipeline_context: Context = self._get_pipeline_context() + pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + pipeline_trigger: PipelineTrigger = PipelineTrigger( + trigger_name=pipeline_name, + trigger_arn=trigger_arn, + event_pattern=event_pattern, + pipeline_name=pipeline_name, + state=state, + ) + trigger_artifact: Artifact = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=pipeline_trigger, + sagemaker_session=self.sagemaker_session, + ) + if tags: + trigger_artifact.set_tags(tags) + + LineageAssociationHandler._add_association( + source_arn=trigger_artifact.artifact_arn, + destination_arn=pipeline_version_context.context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=self.sagemaker_session, + ) + + def upsert_tags_for_lineage_resources(self, tags: List[Dict[str, str]]) -> None: + """Add or update tags for lineage resources using tags attached to sagemaker pipeline as + + source of truth. + + Args: + tags (List[Dict[str, str]]): Custom tags to be attached to lineage resources. + """ + if not tags: + return + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts() + pipeline_context.set_tags(tags) + current_pipeline_version_context.set_tags(tags) + for input_raw_data_artifact in input_raw_data_artifacts: + input_raw_data_artifact.set_tags(tags) + + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + self.sagemaker_session, + self.sagemaker_session.boto_session.client("scheduler"), + ) + event_bridge_schedule = event_bridge_scheduler_helper.describe_schedule(self.pipeline_name) + + event_bridge_rule_helper = EventBridgeRuleHelper( + self.sagemaker_session, + self.sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule = event_bridge_rule_helper.describe_rule(self.pipeline_name) + + if event_bridge_schedule: + schedule_artifact_summary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=event_bridge_schedule["Arn"], + sagemaker_session=self.sagemaker_session, + ) + if schedule_artifact_summary is not None: + pipeline_schedule_artifact: Artifact = ( + S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=schedule_artifact_summary.artifact_arn, + sagemaker_session=self.sagemaker_session, + ) + ) + pipeline_schedule_artifact.set_tags(tags) + + if event_bridge_rule: + rule_artifact_summary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=event_bridge_rule["Arn"], + sagemaker_session=self.sagemaker_session, + ) + if rule_artifact_summary: + pipeline_trigger_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=rule_artifact_summary.artifact_arn, + sagemaker_session=self.sagemaker_session, + ) + pipeline_trigger_artifact.set_tags(tags) + + def _create_new_pipeline_lineage( + self, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact], + ) -> None: + """Create pipeline lineage resources.""" + + pipeline_context = self._create_pipeline_lineage_for_new_pipeline() + pipeline_version_context = self._create_pipeline_version_lineage() + self._add_associations_for_pipeline( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_versions_context_arn=pipeline_version_context.context_arn, + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_version_context_arn=pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def _update_pipeline_lineage( + self, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact], + ) -> None: + """Update pipeline lineage resources.""" + + # If pipeline lineage exists then determine whether to create a new version. + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + upstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) + ) + + upstream_raw_data_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=DATA_SET, + sagemaker_session=self.sagemaker_session, + ) + ) + + upstream_transformation_code: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=TRANSFORMATION_CODE, + sagemaker_session=self.sagemaker_session, + ) + ) + + downstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_downstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + destination_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) + ) + + is_upstream_feature_group_equal: bool = self._compare_upstream_feature_groups( + upstream_feature_group_associations=upstream_feature_group_associations, + input_feature_group_contexts=input_feature_group_contexts, + ) + is_downstream_feature_group_equal: bool = self._compare_downstream_feature_groups( + downstream_feature_group_associations=downstream_feature_group_associations, + output_feature_group_contexts=output_feature_group_contexts, + ) + is_upstream_raw_data_equal: bool = self._compare_upstream_raw_data( + upstream_raw_data_associations=upstream_raw_data_associations, + input_raw_data_artifacts=input_raw_data_artifacts, + ) + + self._update_last_transformation_code( + upstream_transformation_code_associations=upstream_transformation_code + ) + if ( + not is_upstream_feature_group_equal + or not is_downstream_feature_group_equal + or not is_upstream_raw_data_equal + ): + if not is_upstream_raw_data_equal: + logger.info("Raw data inputs have changed from the last pipeline configuration.") + if not is_upstream_feature_group_equal: + logger.info( + "Feature group inputs have changed from the last pipeline configuration." + ) + if not is_downstream_feature_group_equal: + logger.info( + "Feature Group output has changed from the last pipeline configuration." + ) + pipeline_context.properties["LastUpdateTime"] = self.pipeline[ + "LastModifiedTime" + ].strftime("%s") + PipelineLineageEntityHandler.update_pipeline_context(pipeline_context=pipeline_context) + new_pipeline_version_context: Context = self._create_pipeline_version_lineage() + self._add_associations_for_pipeline( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_versions_context_arn=new_pipeline_version_context.context_arn, + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_version_context_arn=new_pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + elif transformation_code_artifact is not None: + # We will append the new transformation code artifact + # to the existing pipeline version. + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=transformation_code_artifact, + # pylint: disable=no-member + pipeline_version_context_arn=current_pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def _retrieve_input_raw_data_artifacts(self) -> List[Artifact]: + """Retrieve input Raw Data Artifacts. + + Returns: + List[Artifact]: List of Raw Data Artifacts. + """ + raw_data_artifacts: List[Artifact] = list() + raw_data_uri_set: Set[str] = set() + + for data_source in self.inputs: + if isinstance(data_source, (CSVDataSource, ParquetDataSource, BaseDataSource)): + data_source_uri = ( + data_source.s3_uri + if isinstance(data_source, (CSVDataSource, ParquetDataSource)) + else data_source.data_source_unique_id + ) + if data_source_uri not in raw_data_uri_set: + raw_data_uri_set.add(data_source_uri) + raw_data_artifacts.append( + S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, + sagemaker_session=self.sagemaker_session, + ) + ) + + return raw_data_artifacts + + def _compare_upstream_raw_data( + self, + upstream_raw_data_associations: Iterator[AssociationSummary], + input_raw_data_artifacts: List[Artifact], + ) -> bool: + """Compare the existing and the new upstream Raw Data. + + Arguments: + upstream_raw_data_associations (Iterator[AssociationSummary]): + Upstream existing raw data associations for the pipeline. + input_raw_data_artifacts (List[Artifact]): + New Upstream raw data for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + raw_data_association_set = { + raw_data_association.source_arn + for raw_data_association in upstream_raw_data_associations + } + if len(raw_data_association_set) != len(input_raw_data_artifacts): + return False + for raw_data in input_raw_data_artifacts: + if raw_data.artifact_arn not in raw_data_association_set: + return False + return True + + def _compare_downstream_feature_groups( + self, + downstream_feature_group_associations: Iterator[AssociationSummary], + output_feature_group_contexts: FeatureGroupContexts, + ) -> bool: + """Compare the existing and the new downstream Feature Groups. + + Arguments: + downstream_feature_group_associations (Iterator[AssociationSummary]): + Downstream existing Feature Group association for the pipeline. + output_feature_group_contexts (List[Artifact]): + New Downstream Feature group for the pipeline. + + Returns: + bool: Boolean if old and new Downstream is same. + """ + feature_group_association_set = set() + for feature_group_association in downstream_feature_group_associations: + feature_group_association_set.add(feature_group_association.destination_arn) + if len(feature_group_association_set) != 1: + ValueError( + f"There should only be one Feature Group as output, " + f"instead we got {len(feature_group_association_set)}. " + f"With Feature Group Versions Contexts: {feature_group_association_set}" + ) + return ( + output_feature_group_contexts.pipeline_version_context_arn + in feature_group_association_set + ) + + def _compare_upstream_feature_groups( + self, + upstream_feature_group_associations: Iterator[AssociationSummary], + input_feature_group_contexts: List[FeatureGroupContexts], + ) -> bool: + """Compare the existing and the new upstream Feature Group. + + Arguments: + upstream_feature_group_associations (Iterator[AssociationSummary]): + Upstream existing Feature Group association for the pipeline. + input_feature_group_contexts (List[Artifact]): + New Upstream Feature group for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + feature_group_association_set = set() + for feature_group_association in upstream_feature_group_associations: + feature_group_association_set.add(feature_group_association.source_arn) + if len(feature_group_association_set) != len(input_feature_group_contexts): + return False + for feature_group in input_feature_group_contexts: + if feature_group.pipeline_version_context_arn not in feature_group_association_set: + return False + return True + + def _update_last_transformation_code( + self, upstream_transformation_code_associations: Iterator[AssociationSummary] + ) -> None: + """Compare the existing and the new upstream Transformation Code. + + Arguments: + upstream_transformation_code_associations (Iterator[AssociationSummary]): + Upstream existing transformation code associations for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + upstream_transformation_code = next(upstream_transformation_code_associations, None) + if upstream_transformation_code is None: + return + + last_transformation_code_artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=upstream_transformation_code.source_arn, + sagemaker_session=self.sagemaker_session, + ) + logger.info( + "Retrieved previous transformation code artifact: %s", last_transformation_code_artifact + ) + if ( + last_transformation_code_artifact.properties["state"] + == TRANSFORMATION_CODE_STATUS_ACTIVE + ): + last_transformation_code_artifact.properties["state"] = ( + TRANSFORMATION_CODE_STATUS_INACTIVE + ) + last_transformation_code_artifact.properties["exclusive_end_date"] = self.pipeline[ + LAST_MODIFIED_TIME + ].strftime("%s") + S3LineageEntityHandler.update_transformation_code_artifact( + transformation_code_artifact=last_transformation_code_artifact + ) + logger.info("Updated the last transformation artifact") + + def _get_pipeline_context(self) -> Context: + """Retrieve Pipeline Context. + + Returns: + Context: The Pipeline Context. + """ + return PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=self.pipeline_name, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _get_pipeline_version_context(self, last_update_time: str) -> Context: + """Retrieve Pipeline Version Context. + + Returns: + Context: The Pipeline Version Context. + """ + return PipelineVersionLineageEntityHandler.load_pipeline_version_context( + pipeline_name=self.pipeline_name, + last_update_time=last_update_time, + sagemaker_session=self.sagemaker_session, + ) + + def _check_if_pipeline_lineage_exists(self) -> bool: + """Check if Pipeline Lineage exists. + + Returns: + bool: Check if pipeline lineage exists. + """ + try: + PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=self.pipeline_name, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + return True + except ClientError as e: + if e.response[ERROR][CODE] == RESOURCE_NOT_FOUND: + return False + raise e + + def _retrieve_input_feature_group_contexts(self) -> List[FeatureGroupContexts]: + """Retrieve input Feature Groups' Context ARNs. + + Returns: + List[FeatureGroupContexts]: List of Input Feature Groups for the pipeline. + """ + feature_group_contexts: List[FeatureGroupContexts] = list() + feature_group_input_set: Set[str] = set() + for data_source in self.inputs: + if isinstance(data_source, FeatureGroupDataSource): + feature_group_name: str = FeatureGroupLineageEntityHandler.parse_name_from_arn( + data_source.name + ) + if feature_group_name not in feature_group_input_set: + feature_group_input_set.add(feature_group_name) + feature_group_contexts.append( + FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=data_source.name, + sagemaker_session=self.sagemaker_session, + ) + ) + return feature_group_contexts + + def _retrieve_output_feature_group_contexts(self) -> FeatureGroupContexts: + """Retrieve output Feature Group's Context ARNs. + + Returns: + FeatureGroupContexts: The output Feature Group for the pipeline. + """ + return FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=self.output, sagemaker_session=self.sagemaker_session + ) + + def _create_pipeline_lineage_for_new_pipeline(self) -> Context: + """Create Pipeline Context for a new pipeline. + + Returns: + Context: The Pipeline Context. + """ + return PipelineLineageEntityHandler.create_pipeline_context( + pipeline_name=self.pipeline_name, + pipeline_arn=self.pipeline_arn, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _create_pipeline_version_lineage(self) -> Context: + """Create a new Pipeline Version Context. + + Returns: + Context: The Pipeline Versions Context. + """ + return PipelineVersionLineageEntityHandler.create_pipeline_version_context( + pipeline_name=self.pipeline_name, + pipeline_arn=self.pipeline_arn, + last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _add_associations_for_pipeline( + self, + pipeline_context_arn: str, + pipeline_versions_context_arn: str, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact] = None, + ) -> None: + """Add Feature Processor Lineage Associations for the Pipeline + + Arguments: + pipeline_context_arn (str): The pipeline Context ARN. + pipeline_versions_context_arn (str): The pipeline Version Context ARN. + input_feature_group_contexts (List[FeatureGroupContexts]): List of input FeatureGroups. + input_raw_data_artifacts (List[Artifact]): List of input raw data. + output_feature_group_contexts (FeatureGroupContexts): Output Feature Group + transformation_code_artifact (Optional[Artifact]): The transformation Code. + """ + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=input_feature_group_contexts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + LineageAssociationHandler.add_downstream_feature_group_data_associations( + feature_group_output=output_feature_group_contexts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + LineageAssociationHandler.add_upstream_raw_data_associations( + raw_data_inputs=input_raw_data_artifacts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + if transformation_code_artifact is not None: + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=transformation_code_artifact, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py new file mode 100644 index 0000000000..1a4e9ed04f --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle lineage resource name generation.""" +from __future__ import absolute_import + +FEATURE_PROCESSOR_CREATED_PREFIX = "sm-fs-fe" +FEATURE_PROCESSOR_CREATED_TRIGGER_PREFIX = "sm-fs-fe-trigger" +FEATURE_GROUP_PIPELINE_CONTEXT_SUFFIX = "feature-group-pipeline" +FEATURE_GROUP_PIPELINE_CONTEXT_VERSION_SUFFIX = "feature-group-pipeline-version" +FEATURE_PROCESSOR_PIPELINE_CONTEXT_SUFFIX = "fep" +FEATURE_PROCESSOR_PIPELINE_VERSION_CONTEXT_SUFFIX = "fep-ver" + + +def _get_feature_processor_lineage_context_name( + resource_name: str, + resource_creation_time: str, + lineage_context_prefix: str = None, + lineage_context_suffix: str = None, +) -> str: + """Generic naming generation function for lineage resources used by feature_processor.""" + context_name_base = [f"{resource_name}-{resource_creation_time}"] + if lineage_context_prefix: + context_name_base.insert(0, lineage_context_prefix) + if lineage_context_suffix: + context_name_base.append(lineage_context_suffix) + return "-".join(context_name_base) + + +def _get_feature_group_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group contexts.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, resource_creation_time=feature_group_creation_time + ) + + +def _get_feature_group_pipeline_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group pipeline.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, + resource_creation_time=feature_group_creation_time, + lineage_context_suffix=FEATURE_GROUP_PIPELINE_CONTEXT_SUFFIX, + ) + + +def _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group pipeline version.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, + resource_creation_time=feature_group_creation_time, + lineage_context_suffix=FEATURE_GROUP_PIPELINE_CONTEXT_VERSION_SUFFIX, + ) + + +def _get_feature_processor_pipeline_lineage_context_name( + pipeline_name: str, pipeline_creation_time: str +) -> str: + """Generate context name for feature processor pipeline.""" + return _get_feature_processor_lineage_context_name( + resource_name=pipeline_name, + resource_creation_time=pipeline_creation_time, + lineage_context_prefix=FEATURE_PROCESSOR_CREATED_PREFIX, + lineage_context_suffix=FEATURE_PROCESSOR_PIPELINE_CONTEXT_SUFFIX, + ) + + +def _get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name: str, pipeline_last_update_time: str +) -> str: + """Generate context name for feature processor pipeline version.""" + return _get_feature_processor_lineage_context_name( + resource_name=pipeline_name, + resource_creation_time=pipeline_last_update_time, + lineage_context_prefix=FEATURE_PROCESSOR_CREATED_PREFIX, + lineage_context_suffix=FEATURE_PROCESSOR_PIPELINE_VERSION_CONTEXT_SUFFIX, + ) + + +def _get_feature_processor_schedule_lineage_artifact_name(schedule_name: str) -> str: + """Generate artifact name for feature processor pipeline schedule.""" + return "-".join([FEATURE_PROCESSOR_CREATED_PREFIX, schedule_name]) + + +def _get_feature_processor_trigger_lineage_artifact_name(trigger_name: str) -> str: + """Generate artifact name for feature processor pipeline trigger.""" + return "-".join([FEATURE_PROCESSOR_CREATED_TRIGGER_PREFIX, trigger_name]) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py new file mode 100644 index 0000000000..0413b5d7c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py @@ -0,0 +1,300 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Lineage Associations""" +from __future__ import absolute_import +import logging +from typing import List, Optional, Iterator +from botocore.exceptions import ClientError + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor._constants import VALIDATION_EXCEPTION +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + CONTRIBUTED_TO, + ERROR, + CODE, + SAGEMAKER, + ASSOCIATED_WITH, +) +from sagemaker.core.lineage.artifact import Artifact +from sagemaker.core.lineage.association import Association, AssociationSummary + +logger = logging.getLogger(SAGEMAKER) + + +class LineageAssociationHandler: + """Class to handler the FeatureProcessor Lineage Associations""" + + @staticmethod + def add_upstream_feature_group_data_associations( + feature_group_inputs: List[FeatureGroupContexts], + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Feature Group Lineage Associations. + + Arguments: + feature_group_inputs (List[FeatureGroupContexts]): The input Feature Group List. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + for feature_group in feature_group_inputs: + LineageAssociationHandler._add_association( + source_arn=feature_group.pipeline_context_arn, + destination_arn=pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_raw_data_associations( + raw_data_inputs: List[Artifact], + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Raw Data Lineage Associations. + + Arguments: + raw_data_inputs (List[Artifact]): The input raw data List. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + for raw_data_artifact in raw_data_inputs: + LineageAssociationHandler._add_association( + source_arn=raw_data_artifact.artifact_arn, + destination_arn=pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=raw_data_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_transformation_code_associations( + transformation_code_artifact: Artifact, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Transformation Code Lineage Associations. + + Arguments: + transformation_code_artifact (Artifact): The transformation Code Artifact. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=transformation_code_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_schedule_associations( + schedule_artifact: Artifact, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Schedule Lineage Associations. + + Arguments: + schedule_artifact (Artifact): The schedule Artifact. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=schedule_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_downstream_feature_group_data_associations( + feature_group_output: FeatureGroupContexts, + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Downstream Feature Group Lineage Associations. + + Arguments: + feature_group_output (FeatureGroupContexts): The output Feature Group. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=pipeline_context_arn, + destination_arn=feature_group_output.pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=pipeline_version_context_arn, + destination_arn=feature_group_output.pipeline_version_context_arn, + association_type="ContributedTo", + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_pipeline_and_pipeline_version_association( + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Lineage Association + + between the Pipeline and the Pipeline Versions. + + Arguments: + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=pipeline_context_arn, + destination_arn=pipeline_version_context_arn, + association_type=ASSOCIATED_WITH, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def list_upstream_associations( + entity_arn: str, source_type: str, sagemaker_session: Session + ) -> Iterator[AssociationSummary]: + """List Upstream Lineage Associations. + + Arguments: + entity_arn (str): The Lineage Entity ARN. + source_type (str): The Source Type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return LineageAssociationHandler._list_association( + destination_arn=entity_arn, + source_type=source_type, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def list_downstream_associations( + entity_arn: str, destination_type: str, sagemaker_session: Session + ) -> Iterator[AssociationSummary]: + """List Downstream Lineage Associations. + + Arguments: + entity_arn (str): The Lineage Entity ARN. + destination_type (str): The Destination Type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return LineageAssociationHandler._list_association( + source_arn=entity_arn, + destination_type=destination_type, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def _add_association( + source_arn: str, + destination_arn: str, + association_type: str, + sagemaker_session: Session, + ) -> None: + """Add Lineage Association. + + Arguments: + source_arn (str): The source ARN. + destination_arn (str): The destination ARN. + association_type (str): The association type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + try: + logger.info( + "Adding association with source_arn: " + "%s, destination_arn: %s and association_type: %s.", + source_arn, + destination_arn, + association_type, + ) + Association.create( + source_arn=source_arn, + destination_arn=destination_arn, + association_type=association_type, + sagemaker_session=sagemaker_session, + ) + except ClientError as e: + if e.response[ERROR][CODE] == VALIDATION_EXCEPTION: + logger.info("Association already exists") + else: + raise e + + @staticmethod + def _list_association( + sagemaker_session: Session, + source_arn: Optional[str] = None, + source_type: Optional[str] = None, + destination_arn: Optional[str] = None, + destination_type: Optional[str] = None, + ) -> Iterator[AssociationSummary]: + """List Lineage Associations. + + Arguments: + source_arn (str): The source ARN. + source_type (str): The source type. + destination_arn (str): The destination ARN. + destination_type (str): The destination type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return Association.list( + source_arn=source_arn, + source_type=source_type, + destination_arn=destination_arn, + destination_type=destination_type, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py new file mode 100644 index 0000000000..3bf80e9d95 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Pipeline Lineage""" +from __future__ import absolute_import +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + PIPELINE_NAME_KEY, + PIPELINE_CREATION_TIME_KEY, + LAST_UPDATE_TIME_KEY, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_pipeline_lineage_context_name, +) +from sagemaker.core.lineage import context + +logger = logging.getLogger(SAGEMAKER) + + +class PipelineLineageEntityHandler: + """Class for handling FeatureProcessor Pipeline Lineage""" + + @staticmethod + def create_pipeline_context( + pipeline_name: str, + pipeline_arn: str, + creation_time: str, + last_update_time: str, + sagemaker_session: Session, + ) -> Context: + """Create the FeatureProcessor Pipeline context. + + Arguments: + pipeline_name (str): The pipeline name. + pipeline_arn (str): The pipeline ARN. + creation_time (str): The pipeline creation time. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline context. + """ + return context.Context.create( + context_name=_get_feature_processor_pipeline_lineage_context_name( + pipeline_name, creation_time + ), + context_type="FeatureEngineeringPipeline", + source_uri=pipeline_arn, + source_type=creation_time, + properties={ + PIPELINE_NAME_KEY: pipeline_name, + PIPELINE_CREATION_TIME_KEY: creation_time, + LAST_UPDATE_TIME_KEY: last_update_time, + }, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_pipeline_context( + pipeline_name: str, creation_time: str, sagemaker_session: Session + ) -> Context: + """Load the FeatureProcessor Pipeline context. + + Arguments: + pipeline_name (str): The pipeline name. + creation_time (str): The pipeline creation time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline context. + """ + return Context.load( + context_name=_get_feature_processor_pipeline_lineage_context_name( + pipeline_name, creation_time + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def update_pipeline_context(pipeline_context: Context) -> None: + """Update the FeatureProcessor Pipeline context + + Arguments: + pipeline_context (Context): The pipeline context. + """ + pipeline_context.save() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py new file mode 100644 index 0000000000..08f10fb8fb --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store the Pipeline Schedule""" +from __future__ import absolute_import +import attr + + +@attr.s +class PipelineSchedule: + """A Schedule definition for FeatureProcessor Lineage. + + Attributes: + schedule_name (str): Schedule Name. + schedule_arn (str): The ARN of the Schedule. + schedule_expression (str): The expression that defines when the schedule runs. It supports + at expression, rate expression and cron expression. See https://docs.aws.amazon.com/ + scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-request + -ScheduleExpression for more details. + pipeline_name (str): The SageMaker Pipeline name that will be scheduled. + state (str): Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/ + API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details. + If not specified, it will default to DISABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + """ + + schedule_name: str = attr.ib() + schedule_arn: str = attr.ib() + schedule_expression: str = attr.ib() + pipeline_name: str = attr.ib() + state: str = attr.ib() + start_date: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py new file mode 100644 index 0000000000..e58003f396 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store the Pipeline Schedule""" +from __future__ import absolute_import +import attr + + +@attr.s +class PipelineTrigger: + """An evnet based trigger definition for FeatureProcessor Lineage. + + Attributes: + trigger_name (str): Trigger Name. + trigger_arn (str): The ARN of the Trigger. + event_pattern (str): The event pattern. For more information, see Amazon EventBridge + event patterns in the Amazon EventBridge User Guide. + pipeline_name (str): The SageMaker Pipeline name that will be triggered. + state (str): Specifies whether the trigger is enabled or disabled. Valid values are + ENABLED and DISABLED. + """ + + trigger_name: str = attr.ib() + trigger_arn: str = attr.ib() + event_pattern: str = attr.ib() + pipeline_name: str = attr.ib() + state: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py new file mode 100644 index 0000000000..5d0b4c979b --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py @@ -0,0 +1,92 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Pipeline Version Lineage""" +from __future__ import absolute_import +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + PIPELINE_VERSION_CONTEXT_TYPE, + PIPELINE_NAME_KEY, + LAST_UPDATE_TIME_KEY, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_pipeline_version_lineage_context_name, +) + +logger = logging.getLogger(SAGEMAKER) + + +class PipelineVersionLineageEntityHandler: + """Class for handling FeatureProcessor Pipeline Version Lineage""" + + @staticmethod + def create_pipeline_version_context( + pipeline_name: str, + pipeline_arn: str, + last_update_time: str, + sagemaker_session: Session, + ) -> Context: + """Create the FeatureProcessor Pipeline Version context. + + Arguments: + pipeline_name (str): The pipeline name. + pipeline_arn (str): The pipeline ARN. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline version context. + """ + return Context.create( + context_name=_get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name, last_update_time + ), + context_type=f"{PIPELINE_VERSION_CONTEXT_TYPE}-{pipeline_name}", + source_uri=pipeline_arn, + source_type=last_update_time, + properties={ + PIPELINE_NAME_KEY: pipeline_name, + LAST_UPDATE_TIME_KEY: last_update_time, + }, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_pipeline_version_context( + pipeline_name: str, last_update_time: str, sagemaker_session: Session + ) -> Context: + """Load the FeatureProcessor Pipeline Version context. + + Arguments: + pipeline_name (str): The pipeline name. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline version context. + """ + return Context.load( + context_name=_get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name, last_update_time + ), + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py new file mode 100644 index 0000000000..78a0f18c7c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py @@ -0,0 +1,316 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle S3 Lineage""" +from __future__ import absolute_import +import logging +from typing import Union, Optional, List + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor import ( + CSVDataSource, + ParquetDataSource, + BaseDataSource, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_schedule_lineage_artifact_name, + _get_feature_processor_trigger_lineage_artifact_name, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + TRANSFORMATION_CODE_STATUS_ACTIVE, + FEP_LINEAGE_PREFIX, + TRANSFORMATION_CODE_ARTIFACT_NAME, +) +from sagemaker.core.lineage.artifact import Artifact, ArtifactSummary + +logger = logging.getLogger("sagemaker") + + +class S3LineageEntityHandler: + """Class for handling FeatureProcessor S3 Artifact Lineage""" + + @staticmethod + def retrieve_raw_data_artifact( + raw_data: Union[CSVDataSource, ParquetDataSource, BaseDataSource], + sagemaker_session: Session, + ) -> Artifact: + """Load or create the FeatureProcessor Pipeline's raw data Artifact. + + Arguments: + raw_data (Union[CSVDataSource, ParquetDataSource, BaseDataSource]): The raw data to be + retrieved. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The raw data artifact. + """ + raw_data_uri = ( + raw_data.s3_uri + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_unique_id + ) + raw_data_artifact_name = ( + "sm-fs-fe-raw-data" + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_name + ) + + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=raw_data_uri, sagemaker_session=sagemaker_session + ) + if load_artifact is not None: + return S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + + return S3LineageEntityHandler._create_artifact( + s3_uri=raw_data_uri, + artifact_type="DataSet", + artifact_name=raw_data_artifact_name, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def update_transformation_code_artifact( + transformation_code_artifact: Artifact, + ) -> None: + """Update Pipeline's transformation code Artifact. + + Arguments: + transformation_code_artifact (TransformationCode): The transformation code Artifact to be updated. + """ + transformation_code_artifact.save() + + @staticmethod + def create_transformation_code_artifact( + transformation_code: TransformationCode, + pipeline_last_update_time: str, + sagemaker_session: Session, + ) -> Optional[Artifact]: + """Create the FeatureProcessor Pipeline's transformation code Artifact. + + Arguments: + transformation_code (TransformationCode): The transformation code to be retrieved. + pipeline_last_update_time (str): The last update time of the pipeline. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The transformation code artifact. + """ + if transformation_code is None: + return None + + properties = dict( + state=TRANSFORMATION_CODE_STATUS_ACTIVE, + inclusive_start_date=pipeline_last_update_time, + ) + if transformation_code.name is not None: + properties["name"] = transformation_code.name + if transformation_code.author is not None: + properties["author"] = transformation_code.author + + return S3LineageEntityHandler._create_artifact( + s3_uri=transformation_code.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=pipeline_last_update_time)], + properties=properties, + artifact_type="TransformationCode", + artifact_name=f"{FEP_LINEAGE_PREFIX}-" + f"{TRANSFORMATION_CODE_ARTIFACT_NAME}-" + f"{pipeline_last_update_time}", + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def retrieve_pipeline_schedule_artifact( + pipeline_schedule: PipelineSchedule, + sagemaker_session: Session, + _get_feature_processor_schedule_lineage_artifact_namef=None, + ) -> Optional[Artifact]: + """Load or create the FeatureProcessor Pipeline's schedule Artifact + + Arguments: + pipeline_schedule (PipelineSchedule): Class to hold the Pipeline Schedule details + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Schedule Artifact. + """ + if pipeline_schedule is None: + return None + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=pipeline_schedule.schedule_arn, + sagemaker_session=sagemaker_session, + ) + if load_artifact is not None: + pipeline_schedule_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + pipeline_schedule_artifact.properties["pipeline_name"] = pipeline_schedule.pipeline_name + pipeline_schedule_artifact.properties["schedule_expression"] = ( + pipeline_schedule.schedule_expression + ) + pipeline_schedule_artifact.properties["state"] = pipeline_schedule.state + pipeline_schedule_artifact.properties["start_date"] = pipeline_schedule.start_date + pipeline_schedule_artifact.save() + return pipeline_schedule_artifact + + return S3LineageEntityHandler._create_artifact( + s3_uri=pipeline_schedule.schedule_arn, + artifact_type="PipelineSchedule", + artifact_name=_get_feature_processor_schedule_lineage_artifact_name( + schedule_name=pipeline_schedule.schedule_name + ), + properties=dict( + pipeline_name=pipeline_schedule.pipeline_name, + schedule_expression=pipeline_schedule.schedule_expression, + state=pipeline_schedule.state, + start_date=pipeline_schedule.start_date, + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def retrieve_pipeline_trigger_artifact( + pipeline_trigger: PipelineTrigger, + sagemaker_session: Session, + ) -> Optional[Artifact]: + """Load or create the FeatureProcessor Pipeline's trigger Artifact + + Arguments: + pipeline_trigger (PipelineTrigger): Class to hold the Pipeline Trigger details + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Trigger Artifact. + """ + if pipeline_trigger is None: + return None + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=pipeline_trigger.trigger_arn, + sagemaker_session=sagemaker_session, + ) + if load_artifact is not None: + pipeline_trigger_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + pipeline_trigger_artifact.properties["pipeline_name"] = pipeline_trigger.pipeline_name + pipeline_trigger_artifact.properties["event_pattern"] = pipeline_trigger.event_pattern + pipeline_trigger_artifact.properties["state"] = pipeline_trigger.state + pipeline_trigger_artifact.save() + return pipeline_trigger_artifact + + return S3LineageEntityHandler._create_artifact( + s3_uri=pipeline_trigger.trigger_arn, + artifact_type="PipelineTrigger", + artifact_name=_get_feature_processor_trigger_lineage_artifact_name( + trigger_name=pipeline_trigger.trigger_name + ), + properties=dict( + pipeline_name=pipeline_trigger.pipeline_name, + event_pattern=pipeline_trigger.event_pattern, + state=pipeline_trigger.state, + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_artifact_from_arn(artifact_arn: str, sagemaker_session: Session) -> Artifact: + """Load Lineage Artifacts from ARN. + + Arguments: + artifact_arn (str): The Artifact ARN. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Artifact for the provided ARN. + """ + return Artifact.load(artifact_arn=artifact_arn, sagemaker_session=sagemaker_session) + + @staticmethod + def _load_artifact_from_s3_uri( + s3_uri: str, sagemaker_session: Session + ) -> Optional[ArtifactSummary]: + """Load FeatureProcessor S3 Lineage Artifacts. + + Arguments: + s3_uri (str): The s3 uri of the Artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + ArtifactSummary: The Artifact Summary for the provided S3 URI. + """ + artifacts = Artifact.list(source_uri=s3_uri, sagemaker_session=sagemaker_session) + for artifact_summary in artifacts: + # We want to make sure that source_type is empty. + # Since SDK will not set it while creating artifacts. + if ( + artifact_summary.source.source_types is None + or len(artifact_summary.source.source_types) == 0 + ): + return artifact_summary + return None + + @staticmethod + def _create_artifact( + s3_uri: str, + artifact_type: str, + sagemaker_session: Session, + properties: Optional[dict] = None, + artifact_name: Optional[str] = None, + source_types: Optional[List[dict]] = None, + ) -> Artifact: + """Create Lineage Artifacts. + + Arguments: + s3_uri (str): The s3 uri of the Artifact. + artifact_type (str): The Artifact type. + properties (Optional[dict]): The properties of the Artifact. + artifact_name (Optional[str]): The name of the Artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The new Artifact. + """ + return Artifact.create( + source_uri=s3_uri, + source_types=source_types, + artifact_type=artifact_type, + artifact_name=artifact_name, + properties=properties, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py new file mode 100644 index 0000000000..70ce48d910 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store Transformation Code""" +from __future__ import absolute_import +from typing import Optional +import attr + + +@attr.s +class TransformationCode: + """A Transformation Code definition for FeatureProcessor Lineage. + + Attributes: + s3_uri (str): The S3 URI of the code. + name (Optional[str]): The name of the code Artifact object. + author (Optional[str]): The author of the code. + """ + + s3_uri: str = attr.ib() + name: Optional[str] = attr.ib(default=None) + author: Optional[str] = attr.ib(default=None) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py new file mode 100644 index 0000000000..25f4b04716 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing constants for feature_processor and feature_scheduler module.""" +from __future__ import absolute_import + +FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE = "FeatureGroupPipelineVersion" +PIPELINE_CONTEXT_TYPE = "FeatureEngineeringPipeline" +PIPELINE_VERSION_CONTEXT_TYPE = "FeatureEngineeringPipelineVersion" +PIPELINE_CONTEXT_NAME_SUFFIX = "fep" +PIPELINE_VERSION_CONTEXT_NAME_SUFFIX = "fep-ver" +FEP_LINEAGE_PREFIX = "sm-fs-fe" +DATA_SET = "DataSet" +TRANSFORMATION_CODE = "TransformationCode" +LAST_UPDATE_TIME = "LastUpdateTime" +LAST_MODIFIED_TIME = "LastModifiedTime" +CREATION_TIME = "CreationTime" +RESOURCE_NOT_FOUND = "ResourceNotFound" +ERROR = "Error" +CODE = "Code" +SAGEMAKER = "sagemaker" +CONTRIBUTED_TO = "ContributedTo" +ASSOCIATED_WITH = "AssociatedWith" +FEATURE_GROUP = "FeatureGroupName" +FEATURE_GROUP_PIPELINE_SUFFIX = "feature-group-pipeline" +FEATURE_GROUP_PIPELINE_VERSION_SUFFIX = "feature-group-pipeline-version" +PIPELINE_CONTEXT_NAME_KEY = "pipeline_context_name" +PIPELINE_CONTEXT_VERSION_NAME_KEY = "pipeline_version_context_name" +PIPELINE_NAME_KEY = "PipelineName" +PIPELINE_CREATION_TIME_KEY = "PipelineCreationTime" +LAST_UPDATE_TIME_KEY = "LastUpdateTime" +TRANSFORMATION_CODE_STATUS_ACTIVE = "Active" +TRANSFORMATION_CODE_STATUS_INACTIVE = "Inactive" +TRANSFORMATION_CODE_ARTIFACT_NAME = "transformation-code" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py new file mode 100644 index 0000000000..12a323f871 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py @@ -0,0 +1,401 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains constants of feature processor to be used for unit tests.""" +from __future__ import absolute_import + +import datetime +from typing import List, Sequence, Union + +from botocore.exceptions import ClientError +from mock import Mock +from pyspark.sql import DataFrame + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.core.lineage._api_types import ContextSource +from sagemaker.core.lineage.artifact import Artifact, ArtifactSource, ArtifactSummary +from sagemaker.core.lineage.context import Context + +PIPELINE_NAME = "test-pipeline-01" +PIPELINE_ARN = "arn:aws:sagemaker:us-west-2:12345789012:pipeline/test-pipeline-01" +CREATION_TIME = "123123123" +LAST_UPDATE_TIME = "234234234" +SAGEMAKER_SESSION_MOCK = Mock(Session) +CONTEXT_MOCK_01 = Mock(Context) +CONTEXT_MOCK_02 = Mock(Context) + + +class MockDataSource(BaseDataSource): + + data_source_unique_id = "test_source_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + +FEATURE_GROUP_DATA_SOURCE: List[FeatureGroupDataSource] = [ + FeatureGroupDataSource( + name="feature-group-01", + ), + FeatureGroupDataSource( + name="feature-group-02", + ), +] + +FEATURE_GROUP_INPUT: List[FeatureGroupContexts] = [ + FeatureGroupContexts( + name="feature-group-01", + pipeline_context_arn="feature-group-01-pipeline-context-arn", + pipeline_version_context_arn="feature-group-01-pipeline-version-context-arn", + ), + FeatureGroupContexts( + name="feature-group-02", + pipeline_context_arn="feature-group-02-pipeline-context-arn", + pipeline_version_context_arn="feature-group-02-pipeline-version-context-arn", + ), +] + +RAW_DATA_INPUT: Sequence[Union[CSVDataSource, ParquetDataSource, BaseDataSource]] = [ + CSVDataSource(s3_uri="raw-data-uri-01"), + CSVDataSource(s3_uri="raw-data-uri-02"), + ParquetDataSource(s3_uri="raw-data-uri-03"), + MockDataSource(), +] + +RAW_DATA_INPUT_ARTIFACTS: List[Artifact] = [ + Artifact(artifact_arn="artifact-01-arn"), + Artifact(artifact_arn="artifact-02-arn"), + Artifact(artifact_arn="artifact-03-arn"), + Artifact(artifact_arn="artifact-04-arn"), +] + +PIPELINE_SCHEDULE = PipelineSchedule( + schedule_name="schedule-name", + schedule_arn="schedule-arn", + schedule_expression="schedule-expression", + pipeline_name="pipeline-name", + state="state", + start_date="123123123", +) + +PIPELINE_SCHEDULE_2 = PipelineSchedule( + schedule_name="schedule-name-2", + schedule_arn="schedule-arn", + schedule_expression="schedule-expression-2", + pipeline_name="pipeline-name", + state="state-2", + start_date="234234234", +) + +PIPELINE_TRIGGER = PipelineTrigger( + trigger_name="trigger-name", + trigger_arn="trigger-arn", + pipeline_name="pipeline-name", + event_pattern="event-pattern", + state="Enabled", +) + +PIPELINE_TRIGGER_2 = PipelineTrigger( + trigger_name="trigger-name-2", + trigger_arn="trigger-arn", + pipeline_name="pipeline-name", + event_pattern="event-pattern-2", + state="Enabled", +) + +PIPELINE_TRIGGER_ARTIFACT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-trigger-trigger-name", + artifact_type="PipelineTrigger", + source={"source_uri": "trigger-arn"}, + properties=dict( + pipeline_name=PIPELINE_TRIGGER.pipeline_name, + event_pattern=PIPELINE_TRIGGER.event_pattern, + state=PIPELINE_TRIGGER.state, + ), +) + +PIPELINE_TRIGGER_ARTIFACT_SUMMARY: ArtifactSummary = ArtifactSummary( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-trigger-trigger-name", + source=ArtifactSource( + source_uri="trigger-arn", + ), + artifact_type="PipelineTrigger", + creation_time=datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), +) + +ARTIFACT_RESULT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" + }, + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 28, 21, 53, 47, 912000), +) + +SCHEDULE_ARTIFACT_RESULT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" + }, + properties=dict( + pipeline_name=PIPELINE_SCHEDULE.pipeline_name, + schedule_expression=PIPELINE_SCHEDULE.schedule_expression, + state=PIPELINE_SCHEDULE.state, + start_date=PIPELINE_SCHEDULE.start_date, + ), + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 28, 21, 53, 47, 912000), +) + +ARTIFACT_SUMMARY: ArtifactSummary = ArtifactSummary( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source=ArtifactSource( + source_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + source_types=[], + ), + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), +) + +TRANSFORMATION_CODE_ARTIFACT_1 = Artifact( + artifact_arn="ts-artifact-01-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + "source_types": [{"source_id_type": "Custom", "value": "1684369626"}], + }, + properties={ + "name": "test-name", + "author": "test-author", + "inclusive_start_date": "1684369626", + "state": "Active", + }, +) + +TRANSFORMATION_CODE_ARTIFACT_2 = Artifact( + artifact_arn="ts-artifact-02-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + "source_types": [{"source_id_type": "Custom", "value": "1684369626"}], + }, + properties={ + "name": "test-name", + "author": "test-author", + "inclusive_start_date": "1684369626", + "state": "Active", + }, +) + +INACTIVE_TRANSFORMATION_CODE_ARTIFACT_1 = Artifact( + artifact_arn="ts-artifact-02-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + "source_types": [{"source_id_type": "Custom", "value": "1684369307"}], + }, + Properties={ + "name": "test-name", + "author": "test-author", + "exclusive_end_date": "1684369626", + "inclusive_start_date": "1684369307", + "state": "Inactive", + }, +) + +VALIDATION_EXCEPTION = ClientError( + {"Error": {"Code": "ValidationException", "Message": "AssociationAlreadyExists"}}, + "Operation", +) + +RESOURCE_NOT_FOUND_EXCEPTION = ClientError( + {"Error": {"Code": "ResourceNotFound", "Message": "ResourceDoesNotExists"}}, + "Operation", +) + +NON_VALIDATION_EXCEPTION = ClientError( + {"Error": {"Code": "NonValidationException", "Message": "NonValidationError"}}, + "Operation", +) + +FEATURE_GROUP_NAME = "feature-group-name-01" +FEATURE_GROUP = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:789975069016:feature-group/feature-group-name-01", + "FeatureGroupName": "feature-group-name-01", + "RecordIdentifierFeatureName": "model_year_status", + "EventTimeFeatureName": "ingest_time", + "FeatureDefinitions": [ + {"FeatureName": "model_year_status", "FeatureType": "String"}, + {"FeatureName": "avg_mileage", "FeatureType": "String"}, + {"FeatureName": "max_mileage", "FeatureType": "String"}, + {"FeatureName": "avg_price", "FeatureType": "String"}, + {"FeatureName": "max_price", "FeatureType": "String"}, + {"FeatureName": "avg_msrp", "FeatureType": "String"}, + {"FeatureName": "max_msrp", "FeatureType": "String"}, + {"FeatureName": "ingest_time", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://sagemaker-us-west-2-789975069016/" + "feature-store/feature-processor/" + "suryans-v2/offline-store", + "ResolvedOutputS3Uri": "s3://sagemaker-us-west-2-" + "789975069016/feature-store/" + "feature-processor/suryans-v2/" + "offline-store/789975069016/" + "sagemaker/us-west-2/" + "offline-store/" + "feature-group-name-01-" + "1682629457/data", + }, + "DisableGlueTableCreation": False, + "DataCatalogConfig": { + "TableName": "feature-group-name-01_1682629457", + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + }, + }, + "RoleArn": "arn:aws:iam::789975069016:role/service-role/AmazonSageMaker-ExecutionRole-20230421T100744", + "FeatureGroupStatus": "Created", + "OnlineStoreTotalSizeBytes": 0, + "ResponseMetadata": { + "RequestId": "8f139791-345d-4388-8d6d-40420495a3c4", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "8f139791-345d-4388-8d6d-40420495a3c4", + "content-type": "application/x-amz-json-1.1", + "content-length": "1608", + "date": "Mon, 01 May 2023 21:42:59 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE = { + "PipelineArn": "arn:aws:sagemaker:us-west-2:597217924798:pipeline/test-pipeline-26", + "PipelineName": "test-pipeline-26", + "PipelineDisplayName": "test-pipeline-26", + "PipelineDefinition": '{"Version": "2020-12-01", "Metadata": {}, ' + '"Parameters": [{"Name": "scheduled-time", "Type": "String"}], ' + '"PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, ' + '"TrialName": {"Get": "Execution.PipelineExecutionId"}}, ' + '"Steps": [{"Name": "test-pipeline-26-training-step", "Type": ' + '"Training", "Arguments": {"AlgorithmSpecification": {"TrainingInputMode": ' + '"File", "TrainingImage": "153931337802.dkr.ecr.us-west-2.amazonaws.com/' + 'sagemaker-spark-processing:3.2-cpu-py39-v1.1", "ContainerEntrypoint": ' + '["/bin/bash", "/opt/ml/input/data/sagemaker_remote_function_bootstrap/' + 'job_driver.sh", "--files", "s3://bugbash-schema-update/temp.sh", ' + '"/opt/ml/input/data/sagemaker_remote_function_bootstrap/spark_app.py"], ' + '"ContainerArguments": ["--s3_base_uri", ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"--region", "us-west-2", "--client_python_version", "3.9"]}, ' + '"OutputDataConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26"}, ' + '"StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": ' + '{"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.m5.xlarge"}, ' + '"RoleArn": "arn:aws:iam::597217924798:role/Admin", "InputDataConfig": ' + '[{"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26/' + 'sagemaker_remote_function_bootstrap", "S3DataDistributionType": ' + '"FullyReplicated"}}, "ChannelName": "sagemaker_remote_function_bootstrap"}, ' + '{"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": ' + '"s3://bugbash-schema-update/sagemaker-2.142.1.dev0-py2.py3-none-any.whl", ' + '"S3DataDistributionType": "FullyReplicated"}}, "ChannelName": ' + '"sagemaker_wheel_file"}], "Environment": {"AWS_DEFAULT_REGION": "us-west-2"}, ' + '"DebugHookConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"CollectionConfigurations": []},' + ' "ProfilerConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"DisableProfiler": false}, "RetryStrategy": {"MaximumRetryAttempts": 1}}}]}', + "RoleArn": "arn:aws:iam::597217924798:role/Admin", + "PipelineStatus": "Active", + "CreationTime": datetime.datetime(2023, 4, 27, 9, 46, 35, 686000), + "LastModifiedTime": datetime.datetime(2023, 4, 27, 20, 27, 36, 648000), + "CreatedBy": {}, + "LastModifiedBy": {}, + "ResponseMetadata": { + "RequestId": "2075bc1c-1b34-4fe5-b7d8-7cfdf784a7d9", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "2075bc1c-1b34-4fe5-b7d8-7cfdf784a7d9", + "content-type": "application/x-amz-json-1.1", + "content-length": "2555", + "date": "Thu, 04 May 2023 00:28:35 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE_CONTEXT: Context = Context( + context_arn=f"{PIPELINE_NAME}-context-arn", + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + context_type="FeatureEngineeringPipeline", + source=ContextSource(source_uri=PIPELINE_ARN, source_types=[]), + properties={ + "PipelineName": PIPELINE_NAME, + "PipelineCreationTime": CREATION_TIME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, +) + +PIPELINE_VERSION_CONTEXT: Context = Context( + context_arn=f"{PIPELINE_NAME}-version-context-arn", + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + context_type=f"FeatureEngineeringPipelineVersion-{PIPELINE_NAME}", + source=ContextSource(source_uri=PIPELINE_ARN, source_types=LAST_UPDATE_TIME), + properties={"PipelineName": PIPELINE_NAME, "LastUpdateTime": LAST_UPDATE_TIME}, +) + +TRANSFORMATION_CODE_INPUT_1: TransformationCode = TransformationCode( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + author="test-author", + name="test-name", +) + +TRANSFORMATION_CODE_INPUT_2: TransformationCode = TransformationCode( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + author="test-author", + name="test-name", +) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py new file mode 100644 index 0000000000..bc725570b9 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from mock import patch, call + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.core.lineage.context import Context + +from test_constants import ( + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, + CONTEXT_MOCK_02, + FEATURE_GROUP, + FEATURE_GROUP_NAME, +) + + +def test_retrieve_feature_group_context_arns(): + with patch.object( + SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP + ) as fg_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn-fep" + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep-ver" + result = FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=FEATURE_GROUP_NAME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result.name == FEATURE_GROUP_NAME + assert result.pipeline_context_arn == "context-arn-fep" + assert result.pipeline_version_context_arn == "context-arn-fep-ver" + fg_describe_method.assert_called_once_with(FEATURE_GROUP_NAME) + context_load.assert_has_calls( + [ + call( + context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline-version", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == context_load.call_count diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py new file mode 100644 index 0000000000..fe5783210a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py @@ -0,0 +1,2966 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy +import datetime +from typing import Iterator, List + +import pytest +from mock import call, patch, Mock + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + TRANSFORMATION_CODE_STATUS_INACTIVE, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.lineage.artifact import Artifact +from test_constants import ( + FEATURE_GROUP_DATA_SOURCE, + FEATURE_GROUP_INPUT, + LAST_UPDATE_TIME, + PIPELINE, + PIPELINE_ARN, + PIPELINE_CONTEXT, + PIPELINE_NAME, + PIPELINE_VERSION_CONTEXT, + RAW_DATA_INPUT, + RAW_DATA_INPUT_ARTIFACTS, + RESOURCE_NOT_FOUND_EXCEPTION, + SAGEMAKER_SESSION_MOCK, + SCHEDULE_ARTIFACT_RESULT, + PIPELINE_TRIGGER_ARTIFACT, + TRANSFORMATION_CODE_ARTIFACT_1, + TRANSFORMATION_CODE_ARTIFACT_2, + TRANSFORMATION_CODE_INPUT_1, + TRANSFORMATION_CODE_INPUT_2, + ARTIFACT_SUMMARY, + ARTIFACT_RESULT, +) + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage import ( + FeatureProcessorLineageHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import ( + PipelineTrigger, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.core.lineage._api_types import AssociationSummary + +SCHEDULE_ARN = "" +SCHEDULE_EXPRESSION = "" +STATE = "" +TRIGGER_ARN = "" +EVENT_PATTERN = "" +START_DATE = datetime.datetime(2023, 4, 28, 21, 53, 47, 912000) +TAGS = [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("scheduler").return_value = Mock() + return Mock(Session, boto_session=boto_session) + + +@pytest.fixture +def event_bridge_scheduler_helper(sagemaker_session): + return EventBridgeSchedulerHelper( + sagemaker_session, sagemaker_session.boto_session.client("scheduler") + ) + + +def test_create_lineage_when_no_lineage_exists_with_fg_only(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + ): + lineage_handler.create_lineage() + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_not_called() + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=[], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_lineage_when_no_lineage_exists_with_raw_data_only(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_called_once_with( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=[], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_no_lineage_exists_with_fg_and_raw_data_with_tags(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_no_lineage_exists_with_no_transformation_code(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=None, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_not_called() + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_not_called() + + +def test_create_lineage_when_already_exist_with_no_version_change(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as create_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_pipeline_version_context_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_raw_data(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=[RAW_DATA_INPUT[0], RAW_DATA_INPUT[1]] + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 2 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_input_fg(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + [FEATURE_GROUP_DATA_SOURCE[0]], + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[FEATURE_GROUP_INPUT[0], FEATURE_GROUP_INPUT[0]], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=[FEATURE_GROUP_INPUT[0]], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_output_fg(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[1].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[1], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[1], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_transformation_code(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_last_transformation_code_as_none(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_1.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_1.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + update_transformation_code_artifact_method.assert_not_called() + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_all_previous_transformation_code_as_none(): + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_not_called() + update_transformation_code_artifact_method.assert_not_called() + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_removed_transformation_code(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=None, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_upstream_transformation_code_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + artifact_set_tags.assert_not_called() + + +def test_get_pipeline_lineage_names_when_no_lineage_exists(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method: + return_value = lineage_handler.get_pipeline_lineage_names() + + assert return_value is None + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_get_pipeline_lineage_names_when_lineage_exists(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + ): + return_value = lineage_handler.get_pipeline_lineage_names() + + assert return_value == dict( + pipeline_context_name=PIPELINE_CONTEXT.context_name, + pipeline_version_context_name=PIPELINE_VERSION_CONTEXT.context_name, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_schedule_lineage(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_schedule_artifact", + return_value=SCHEDULE_ARTIFACT_RESULT, + ) as retrieve_pipeline_schedule_artifact_method, + patch.object( + LineageAssociationHandler, + "add_upstream_schedule_associations", + ) as add_upstream_schedule_associations_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_schedule_lineage( + pipeline_name=PIPELINE_NAME, + schedule_arn=SCHEDULE_ARN, + schedule_expression=SCHEDULE_EXPRESSION, + state=STATE, + start_date=START_DATE, + tags=TAGS, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_pipeline_schedule_artifact_method.assert_called_once_with( + pipeline_schedule=PipelineSchedule( + schedule_name=PIPELINE_NAME, + schedule_arn=SCHEDULE_ARN, + schedule_expression=SCHEDULE_EXPRESSION, + pipeline_name=PIPELINE_NAME, + state=STATE, + start_date=START_DATE.strftime("%s"), + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_schedule_associations_method.assert_called_once_with( + schedule_artifact=SCHEDULE_ARTIFACT_RESULT, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_trigger_lineage(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_trigger_artifact", + return_value=PIPELINE_TRIGGER_ARTIFACT, + ) as retrieve_pipeline_trigger_artifact_method, + patch.object( + LineageAssociationHandler, + "_add_association", + ) as add_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_trigger_lineage( + pipeline_name=PIPELINE_NAME, + trigger_arn=TRIGGER_ARN, + event_pattern=EVENT_PATTERN, + state=STATE, + tags=TAGS, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_pipeline_trigger_artifact_method.assert_called_once_with( + pipeline_trigger=PipelineTrigger( + trigger_name=PIPELINE_NAME, + trigger_arn=TRIGGER_ARN, + pipeline_name=PIPELINE_NAME, + event_pattern=EVENT_PATTERN, + state=STATE, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_association_method.assert_called_once_with( + source_arn=PIPELINE_TRIGGER_ARTIFACT.artifact_arn, + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_upsert_tags_for_lineage_resources(): + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + mock_session = Mock(Session) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=mock_session, + ) + lineage_handler.sagemaker_session.boto_session = Mock() + lineage_handler.sagemaker_session.sagemaker_client = Mock() + with ( + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, "load_artifact_from_arn", return_value=ARTIFACT_RESULT + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, "_load_artifact_from_s3_uri", return_value=ARTIFACT_SUMMARY + ) as load_artifact_from_s3_uri_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + patch.object( + Context, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as context_set_tags, + patch.object( + EventBridgeSchedulerHelper, "describe_schedule", return_value=dict(Arn="schedule_arn") + ) as get_event_bridge_schedule, + patch.object( + EventBridgeRuleHelper, "describe_rule", return_value=dict(Arn="rule_arn") + ) as get_event_bridge_rule, + ): + lineage_handler.upsert_tags_for_lineage_resources(TAGS) + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=mock_session), + ] + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=mock_session, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=mock_session, + ) + + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + load_artifact_from_s3_uri_method.assert_has_calls( + [ + call(s3_uri="schedule_arn", sagemaker_session=mock_session), + call(s3_uri="rule_arn", sagemaker_session=mock_session), + ] + ) + get_event_bridge_schedule.assert_called_once_with(PIPELINE_NAME) + get_event_bridge_rule.assert_called_once_with(PIPELINE_NAME) + load_artifact_from_arn_method.assert_called_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, sagemaker_session=mock_session + ) + + # three raw data artifact, one schedule artifact and one trigger artifact + artifact_set_tags.assert_has_calls( + [ + call(TAGS), + call(TAGS), + call(TAGS), + call(TAGS), + call(TAGS), + ] + ) + # pipeline context and current pipeline version context + context_set_tags.assert_has_calls( + [ + call(TAGS), + call(TAGS), + ] + ) + + +def generate_pipeline_version_upstream_feature_group_list() -> Iterator[AssociationSummary]: + pipeline_version_upstream_fg: List[AssociationSummary] = list() + for feature_group in FEATURE_GROUP_INPUT: + pipeline_version_upstream_fg.append( + AssociationSummary( + source_arn=feature_group.pipeline_version_context_arn, + source_name=f"{feature_group.name}-pipeline-version", + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ) + return iter(pipeline_version_upstream_fg) + + +def generate_pipeline_version_upstream_raw_data_list() -> Iterator[AssociationSummary]: + pipeline_version_upstream_fg: List[AssociationSummary] = list() + for raw_data in RAW_DATA_INPUT_ARTIFACTS: + pipeline_version_upstream_fg.append( + AssociationSummary( + source_arn=raw_data.artifact_arn, + source_name="sm-fs-fe-raw-data", + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ) + return iter(pipeline_version_upstream_fg) + + +def generate_pipeline_version_upstream_transformation_code() -> Iterator[AssociationSummary]: + return iter( + [ + AssociationSummary( + source_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + source_name=TRANSFORMATION_CODE_ARTIFACT_1.artifact_name, + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ] + ) + + +def generate_pipeline_version_downstream_feature_group() -> Iterator[AssociationSummary]: + return iter( + [ + AssociationSummary( + source_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_name=PIPELINE_VERSION_CONTEXT.context_name, + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_version_context_arn, + destination_name=f"{FEATURE_GROUP_INPUT[0].name}-pipeline-version", + association_type="ContributedTo", + ) + ] + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py new file mode 100644 index 0000000000..a3d24dd0b5 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from mock import patch, call +import pytest + +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) +from sagemaker.core.lineage.association import Association +from botocore.exceptions import ClientError + +from test_constants import ( + FEATURE_GROUP_INPUT, + RAW_DATA_INPUT_ARTIFACTS, + VALIDATION_EXCEPTION, + NON_VALIDATION_EXCEPTION, + SAGEMAKER_SESSION_MOCK, + TRANSFORMATION_CODE_ARTIFACT_1, +) + + +def test_add_upstream_feature_group_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for feature_group in FEATURE_GROUP_INPUT: + create_association_method.assert_has_calls( + [ + call( + source_arn=feature_group.pipeline_context_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(FEATURE_GROUP_INPUT) * 2 == create_association_method.call_count + + +def test_add_upstream_feature_group_data_associations_when_association_already_exists(): + with patch.object( + Association, "create", side_effect=VALIDATION_EXCEPTION + ) as create_association_method: + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for feature_group in FEATURE_GROUP_INPUT: + create_association_method.assert_has_calls( + [ + call( + source_arn=feature_group.pipeline_context_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(FEATURE_GROUP_INPUT) * 2 == create_association_method.call_count + + +def test_add_upstream_feature_group_data_associations_when_non_validation_exception(): + with patch.object(Association, "create", side_effect=NON_VALIDATION_EXCEPTION): + with pytest.raises(ClientError): + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_add_upstream_raw_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_raw_data_associations( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for raw_data in RAW_DATA_INPUT_ARTIFACTS: + create_association_method.assert_has_calls( + [ + call( + source_arn=raw_data.artifact_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=raw_data.artifact_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(RAW_DATA_INPUT_ARTIFACTS) * 2 == create_association_method.call_count + + +def test_add_upstream_transformation_code_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_called_once_with( + source_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_add_downstream_feature_group_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_downstream_feature_group_data_associations( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_has_calls( + [ + call( + source_arn="pipeline-context-arn", + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn="pipeline-version-context-arn", + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_version_context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == create_association_method.call_count + + +def test_add_pipeline_and_pipeline_version_association(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_called_once_with( + source_arn="pipeline-context-arn", + destination_arn="pipeline-version-context-arn", + association_type="AssociatedWith", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_list_upstream_associations(): + with patch.object(Association, "list") as list_association_method: + LineageAssociationHandler.list_upstream_associations( + entity_arn="pipeline-context-arn", + source_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_association_method.assert_called_once_with( + source_arn=None, + source_type="FeatureEngineeringPipelineVersion", + destination_arn="pipeline-context-arn", + destination_type=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_list_downstream_associations(): + with patch.object(Association, "list") as list_association_method: + LineageAssociationHandler.list_downstream_associations( + entity_arn="pipeline-context-arn", + destination_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_association_method.assert_called_once_with( + source_arn="pipeline-context-arn", + source_type=None, + destination_arn=None, + destination_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py new file mode 100644 index 0000000000..deb76a7748 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py @@ -0,0 +1,74 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.core.lineage.context import Context +from test_constants import ( + PIPELINE_NAME, + PIPELINE_ARN, + CREATION_TIME, + LAST_UPDATE_TIME, + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, +) + + +def test_create_pipeline_context(): + with patch.object(Context, "create", return_value=CONTEXT_MOCK_01) as create_method: + result = PipelineLineageEntityHandler.create_pipeline_context( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + creation_time=CREATION_TIME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + create_method.assert_called_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + context_type="FeatureEngineeringPipeline", + source_uri=PIPELINE_ARN, + source_type=CREATION_TIME, + properties={ + "PipelineName": PIPELINE_NAME, + "PipelineCreationTime": CREATION_TIME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_load_pipeline_context(): + with patch.object(Context, "load", return_value=CONTEXT_MOCK_01) as load_method: + result = PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=PIPELINE_NAME, + creation_time=CREATION_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + load_method.assert_called_once_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_update_pipeline_context(): + with patch.object(Context, "save", return_value=CONTEXT_MOCK_01): + PipelineLineageEntityHandler.update_pipeline_context(pipeline_context=CONTEXT_MOCK_01) + CONTEXT_MOCK_01.save.assert_called_once() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py new file mode 100644 index 0000000000..c936c3c164 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger + + +def test_pipeline_trigger(): + + trigger = PipelineTrigger( + trigger_name="test_trigger", + trigger_arn="test_arn", + event_pattern="test_pattern", + pipeline_name="test_pipeline", + state="Enabled", + ) + + assert trigger.trigger_name == "test_trigger" + assert trigger.trigger_arn == "test_arn" + assert trigger.event_pattern == "test_pattern" + assert trigger.pipeline_name == "test_pipeline" + assert trigger.state == "Enabled" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py new file mode 100644 index 0000000000..52e65749be --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) + +from sagemaker.core.lineage.context import Context + +from test_constants import ( + PIPELINE_NAME, + PIPELINE_ARN, + LAST_UPDATE_TIME, + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, +) + + +def test_create_pipeline_version_context(): + with patch.object(Context, "create", return_value=CONTEXT_MOCK_01) as create_method: + result = PipelineVersionLineageEntityHandler.create_pipeline_version_context( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + create_method.assert_called_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + context_type=f"FeatureEngineeringPipelineVersion-{PIPELINE_NAME}", + source_uri=PIPELINE_ARN, + source_type=LAST_UPDATE_TIME, + properties={ + "PipelineName": PIPELINE_NAME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_load_pipeline_version_context(): + with patch.object(Context, "load", return_value=CONTEXT_MOCK_01) as load_method: + result = PipelineVersionLineageEntityHandler.load_pipeline_version_context( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + load_method.assert_called_once_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py new file mode 100644 index 0000000000..8b34806f1d --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py @@ -0,0 +1,434 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy + +from mock import patch +from test_constants import ( + ARTIFACT_RESULT, + ARTIFACT_SUMMARY, + PIPELINE_SCHEDULE, + PIPELINE_SCHEDULE_2, + PIPELINE_TRIGGER, + PIPELINE_TRIGGER_2, + PIPELINE_TRIGGER_ARTIFACT, + PIPELINE_TRIGGER_ARTIFACT_SUMMARY, + SCHEDULE_ARTIFACT_RESULT, + TRANSFORMATION_CODE_ARTIFACT_1, + TRANSFORMATION_CODE_INPUT_1, + LAST_UPDATE_TIME, + MockDataSource, +) +from test_pipeline_lineage_entity_handler import SAGEMAKER_SESSION_MOCK + +from sagemaker.mlops.feature_store.feature_processor import CSVDataSource +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.core.lineage.artifact import Artifact + +raw_data = CSVDataSource( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" +) + + +def test_retrieve_raw_data_artifact_when_artifact_already_exist(): + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=raw_data, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=raw_data.s3_uri, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_create_method.assert_not_called() + + +def test_retrieve_raw_data_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=raw_data, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=raw_data.s3_uri, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=raw_data.s3_uri, + artifact_type="DataSet", + artifact_name="sm-fs-fe-raw-data", + properties=None, + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_user_defined_raw_data_artifact_when_artifact_already_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_create_method.assert_not_called() + + +def test_retrieve_user_defined_raw_data_artifact_when_artifact_does_not_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, + artifact_type="DataSet", + artifact_name=data_source.data_source_name, + properties=None, + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact(): + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == TRANSFORMATION_CODE_ARTIFACT_1 + + artifact_create_method.assert_called_once_with( + source_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=LAST_UPDATE_TIME)], + artifact_type="TransformationCode", + artifact_name=f"sm-fs-fe-transformation-code-{LAST_UPDATE_TIME}", + properties=dict( + name=TRANSFORMATION_CODE_INPUT_1.name, + author=TRANSFORMATION_CODE_INPUT_1.author, + state="Active", + inclusive_start_date=LAST_UPDATE_TIME, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact_when_no_author_or_name(): + transformation_code_input = TransformationCode(s3_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri) + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=transformation_code_input, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == TRANSFORMATION_CODE_ARTIFACT_1 + + artifact_create_method.assert_called_once_with( + source_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=LAST_UPDATE_TIME)], + artifact_type="TransformationCode", + artifact_name=f"sm-fs-fe-transformation-code-{LAST_UPDATE_TIME}", + properties=dict( + state="Active", + inclusive_start_date=LAST_UPDATE_TIME, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact_when_no_code_provided(): + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=None, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result is None + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + artifact_type="PipelineSchedule", + artifact_name=f"sm-fs-fe-{PIPELINE_SCHEDULE.schedule_name}", + properties=dict( + pipeline_name=PIPELINE_SCHEDULE.pipeline_name, + schedule_expression=PIPELINE_SCHEDULE.schedule_expression, + state=PIPELINE_SCHEDULE.state, + start_date=PIPELINE_SCHEDULE.start_date, + ), + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_exists(): + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=SCHEDULE_ARTIFACT_RESULT + ) as artifact_load_method: + with patch.object(SCHEDULE_ARTIFACT_RESULT, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=SCHEDULE_ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == SCHEDULE_ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_updated(): + schedule_artifact_result = copy.deepcopy(SCHEDULE_ARTIFACT_RESULT) + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=schedule_artifact_result + ) as artifact_load_method: + with patch.object(schedule_artifact_result, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=schedule_artifact_result + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == schedule_artifact_result + assert schedule_artifact_result != SCHEDULE_ARTIFACT_RESULT + assert result.properties["pipeline_name"] == PIPELINE_SCHEDULE_2.pipeline_name + assert result.properties["schedule_expression"] == PIPELINE_SCHEDULE_2.schedule_expression + assert result.properties["state"] == PIPELINE_SCHEDULE_2.state + assert result.properties["start_date"] == PIPELINE_SCHEDULE_2.start_date + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == PIPELINE_TRIGGER_ARTIFACT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + artifact_type="PipelineTrigger", + artifact_name=f"sm-fs-fe-trigger-{PIPELINE_TRIGGER.trigger_name}", + properties=dict( + pipeline_name=PIPELINE_TRIGGER.pipeline_name, + event_pattern=PIPELINE_TRIGGER.event_pattern, + state=PIPELINE_TRIGGER.state, + ), + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_exists(): + with patch.object( + Artifact, "list", return_value=[PIPELINE_TRIGGER_ARTIFACT_SUMMARY] + ) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_load_method: + with patch.object(PIPELINE_TRIGGER_ARTIFACT, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == PIPELINE_TRIGGER_ARTIFACT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=PIPELINE_TRIGGER_ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_updated(): + trigger_artifact_result = copy.deepcopy(PIPELINE_TRIGGER_ARTIFACT) + with patch.object( + Artifact, "list", return_value=[PIPELINE_TRIGGER_ARTIFACT_SUMMARY] + ) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=trigger_artifact_result + ) as artifact_load_method: + with patch.object(trigger_artifact_result, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=trigger_artifact_result + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == trigger_artifact_result + assert trigger_artifact_result != PIPELINE_TRIGGER_ARTIFACT + assert result.properties["pipeline_name"] == PIPELINE_TRIGGER_2.pipeline_name + assert result.properties["event_pattern"] == PIPELINE_TRIGGER_2.event_pattern + assert result.properties["state"] == PIPELINE_TRIGGER_2.state + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=PIPELINE_TRIGGER_ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py new file mode 100644 index 0000000000..25ded72266 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py @@ -0,0 +1,317 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from mock import Mock, patch +from sagemaker.mlops.feature_store.feature_processor._config_uploader import ( + ConfigUploader, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + SPARK_JAR_FILES_PATH, + SPARK_FILES_PATH, + SPARK_PY_FILES_PATH, + SAGEMAKER_WHL_FILE_S3_PATH, +) +from sagemaker.core.remote_function.job import ( + _JobSettings, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, +) +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.helper.session_helper import Session + + +@pytest.fixture +def sagemaker_session(): + return Mock(Session) + + +@pytest.fixture +def wrapped_func(): + return Mock() + + +@pytest.fixture +def runtime_env_manager(): + mocked_runtime_env_manager = Mock() + mocked_runtime_env_manager.snapshot.return_value = "some_dependency_path" + return mocked_runtime_env_manager + + +def custom_file_filter(): + pass + + +@pytest.fixture +def remote_decorator_config(sagemaker_session): + return Mock( + _JobSettings, + sagemaker_session=sagemaker_session, + s3_root_uri="some_s3_uri", + s3_kms_key="some_kms", + spark_config=SparkConfig(), + dependencies=None, + include_local_workdir=True, + workdir_config=None, + pre_execution_commands="some_commands", + pre_execution_script="some_path", + python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, + custom_file_filter=None, + ) + + +@pytest.fixture +def config_uploader(remote_decorator_config, runtime_env_manager): + return ConfigUploader(remote_decorator_config, runtime_env_manager) + + +@pytest.fixture +def remote_decorator_config_with_filter(sagemaker_session): + return Mock( + _JobSettings, + sagemaker_session=sagemaker_session, + s3_root_uri="some_s3_uri", + s3_kms_key="some_kms", + spark_config=SparkConfig(), + dependencies=None, + include_local_workdir=True, + pre_execution_commands="some_commands", + pre_execution_script="some_path", + python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, + custom_file_filter=custom_file_filter, + ) + + +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.StoredFunction") +def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrapped_func): + mock_stored_function.save(wrapped_func).return_value = None + config_uploader._prepare_and_upload_callable(wrapped_func, "s3_base_uri", sagemaker_session) + assert mock_stored_function.called_once_with( + s3_base_uri="s3_base_uri", + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +def test_prepare_and_upload_workspace(mock_upload, config_uploader): + remote_decorator_config = config_uploader.remote_decorator_config + s3_path = config_uploader._prepare_and_upload_workspace( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_path == mock_upload.return_value + mock_upload.assert_called_once_with( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=None, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +def test_prepare_and_upload_workspace_with_filter( + mock_job_upload, remote_decorator_config_with_filter, runtime_env_manager +): + config_uploader_with_filter = ConfigUploader( + remote_decorator_config=remote_decorator_config_with_filter, + runtime_env_manager=runtime_env_manager, + ) + remote_decorator_config = config_uploader_with_filter.remote_decorator_config + config_uploader_with_filter._prepare_and_upload_workspace( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=remote_decorator_config_with_filter.custom_file_filter, + ) + + mock_job_upload.assert_called_once_with( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=custom_file_filter, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +def test_prepare_and_upload_runtime_scripts(mock_upload, config_uploader): + s3_path = config_uploader._prepare_and_upload_runtime_scripts( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_path == mock_upload.return_value + mock_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +def test_prepare_and_upload_spark_dependent_files(mock_upload, config_uploader): + s3_paths = config_uploader._prepare_and_upload_spark_dependent_files( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_paths == mock_upload.return_value + mock_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.Channel") +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.DataSource") +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.S3DataSource") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.StoredFunction") +def test_prepare_step_input_channel( + mock_upload_callable, + mock_script_upload, + mock_dependency_upload, + mock_spark_dependency_upload, + mock_s3_data_source, + mock_data_source, + mock_channel, + config_uploader, + wrapped_func, +): + ( + input_data_config, + spark_dependency_paths, + ) = config_uploader.prepare_step_input_channel_for_spark_mode( + wrapped_func, + config_uploader.remote_decorator_config.s3_root_uri, + sagemaker_session, + ) + remote_decorator_config = config_uploader.remote_decorator_config + + assert mock_upload_callable.called_once_with(wrapped_func) + + mock_script_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path="some_dependency_path", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + custom_file_filter=None, + ) + + mock_spark_dependency_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + ) + + # Verify input_data_config is a list of Channel objects + assert isinstance(input_data_config, list) + # 3 channels: runtime scripts, workspace, spark conf + assert len(input_data_config) == 3 + + # Verify each channel was constructed with the correct data + # Channel 1: runtime scripts + mock_s3_data_source.assert_any_call( + s3_uri="some_s3_uri", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + # Channel 2: workspace + mock_s3_data_source.assert_any_call( + s3_uri=f"{config_uploader.remote_decorator_config.s3_root_uri}/sm_rf_user_ws", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + # Channel 3: spark conf + mock_s3_data_source.assert_any_call( + s3_uri="path_d", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + + assert mock_s3_data_source.call_count == 3 + assert mock_data_source.call_count == 3 + assert mock_channel.call_count == 3 + + # Verify channel names and input_mode + channel_call_kwargs = [call.kwargs for call in mock_channel.call_args_list] + channel_names = [kw["channel_name"] for kw in channel_call_kwargs] + assert RUNTIME_SCRIPTS_CHANNEL_NAME in channel_names + assert REMOTE_FUNCTION_WORKSPACE in channel_names + assert SPARK_CONF_CHANNEL_NAME in channel_names + for kw in channel_call_kwargs: + assert kw["input_mode"] == "File" + + assert spark_dependency_paths == { + SPARK_JAR_FILES_PATH: "path_a", + SPARK_PY_FILES_PATH: "path_b", + SPARK_FILES_PATH: "path_c", + } diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py new file mode 100644 index 0000000000..e69499c5b1 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import json + +from dateutil.tz import tzlocal +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + +INPUT_S3_URI = "s3://bucket/prefix/" +INPUT_FEATURE_GROUP_NAME = "input-fg" +INPUT_FEATURE_GROUP_ARN = "arn:aws:sagemaker:us-west-2:12345789012:feature-group/input-fg" +INPUT_FEATURE_GROUP_S3_URI = "s3://bucket/input-fg/" +INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI = ( + "s3://bucket/input-fg/feature-store/12345789012/" + "sagemaker/us-west-2/offline-store/input-fg-12345/data" +) + +FEATURE_GROUP_DATA_SOURCE = FeatureGroupDataSource(name=INPUT_FEATURE_GROUP_ARN) +S3_DATA_SOURCE = CSVDataSource(s3_uri=INPUT_S3_URI) +FEATURE_PROCESSOR_INPUTS = [FEATURE_GROUP_DATA_SOURCE, S3_DATA_SOURCE] +OUTPUT_FEATURE_GROUP_ARN = "arn:aws:sagemaker:us-west-2:12345789012:feature-group/output-fg" + +FEATURE_GROUP_SYSTEM_PARAMS = { + "feature_group_name": "input-fg", + "online_store_enabled": True, + "offline_store_enabled": False, + "offline_store_resolved_s3_uri": None, +} +SYSTEM_PARAMS = {"system": {"scheduled_time": "2023-03-25T02:01:26Z"}} +USER_INPUT_PARAMS = { + "some-key": "some-value", + "some-other-key": {"some-key": "some-value"}, +} + +DATA_SOURCE_UNIQUE_ID_TOO_LONG = """ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +""" + +DESCRIBE_FEATURE_GROUP_RESPONSE = { + "FeatureGroupArn": INPUT_FEATURE_GROUP_ARN, + "FeatureGroupName": INPUT_FEATURE_GROUP_NAME, + "RecordIdentifierFeatureName": "id", + "EventTimeFeatureName": "ingest_time", + "FeatureDefinitions": [ + {"FeatureName": "id", "FeatureType": "String"}, + {"FeatureName": "model", "FeatureType": "String"}, + {"FeatureName": "model_year", "FeatureType": "String"}, + {"FeatureName": "status", "FeatureType": "String"}, + {"FeatureName": "mileage", "FeatureType": "String"}, + {"FeatureName": "price", "FeatureType": "String"}, + {"FeatureName": "msrp", "FeatureType": "String"}, + {"FeatureName": "ingest_time", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime(2023, 3, 29, 19, 15, 47, 20000, tzinfo=tzlocal()), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": INPUT_FEATURE_GROUP_S3_URI, + "ResolvedOutputS3Uri": INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI, + }, + "DisableGlueTableCreation": False, + "DataCatalogConfig": { + "TableName": "input_fg_1680142547", + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + }, + }, + "RoleArn": "arn:aws:iam::12345789012:role/role-name", + "FeatureGroupStatus": "Created", + "OnlineStoreTotalSizeBytes": 12345, + "ResponseMetadata": { + "RequestId": "d36d3647-1632-4f4e-9f7c-2a4e38e4c6f8", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "d36d3647-1632-4f4e-9f7c-2a4e38e4c6f8", + "content-type": "application/x-amz-json-1.1", + "content-length": "1311", + "date": "Fri, 31 Mar 2023 01:05:49 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE = { + "PipelineArn": "some_pipeline_arn", + "RoleArn": "some_execution_role_arn", + "CreationTime": datetime.datetime(2023, 3, 29, 19, 15, 47, 20000, tzinfo=tzlocal()), + "PipelineDefinition": json.dumps( + { + "Steps": [ + { + "RetryPolicies": [ + { + "BackoffRate": 2.0, + "IntervalSeconds": 1, + "MaxAttempts": 5, + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + }, + { + "BackoffRate": 2.0, + "IntervalSeconds": 1, + "MaxAttempts": 5, + "ExceptionType": [ + "SageMaker.JOB_INTERNAL_ERROR", + "SageMaker.CAPACITY_ERROR", + "SageMaker.RESOURCE_LIMIT", + ], + }, + ] + } + ] + } + ), +} + + +def create_fp_config( + inputs=None, + output=OUTPUT_FEATURE_GROUP_ARN, + mode=FeatureProcessorMode.PYSPARK, + target_stores=None, + enable_ingestion=True, + parameters=None, + spark_config=None, +): + """Helper method to create a FeatureProcessorConfig with fewer arguments.""" + + return FeatureProcessorConfig.create( + inputs=inputs or FEATURE_PROCESSOR_INPUTS, + output=output, + mode=mode, + target_stores=target_stores, + enable_ingestion=enable_ingestion, + parameters=parameters, + spark_config=spark_config, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py new file mode 100644 index 0000000000..51637fa979 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor._data_source import PySparkDataSource + + +def test_pyspark_data_source(): + class TestDataSource(PySparkDataSource): + + data_source_unique_id = "test_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestDataSource() + + assert test_data_source.data_source_name == "test_source_name" + assert test_data_source.data_source_unique_id == "test_unique_id" + assert test_data_source.read_data(spark=None, params=None) is None diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py new file mode 100644 index 0000000000..4aff330087 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +from mock import mock_open, patch +import pytest +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +SINGLE_NODE_RESOURCE_CONFIG = { + "current_host": "algo-1", + "current_instance_type": "ml.m5.xlarge", + "current_group_name": "homogeneousCluster", + "hosts": ["algo-1"], + "instance_groups": [ + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-1"], + } + ], + "network_interface_name": "eth0", +} +MULTI_NODE_COUNT = 3 +MULTI_NODE_RESOURCE_CONFIG = { + "current_host": "algo-1", + "current_instance_type": "ml.m5.xlarge", + "current_group_name": "homogeneousCluster", + "hosts": ["algo-1", "algo-2", "algo-3"], + "instance_groups": [ + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-1"], + }, + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-2"], + }, + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-3"], + }, + ], + "network_interface_name": "eth0", +} + + +@patch("builtins.open") +def test_is_training_job(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().is_training_job() is True + + mocked_open.assert_called_once_with("/opt/ml/input/config/resourceconfig.json", "r") + + +@patch("builtins.open") +def test_is_not_training_job(mocked_open): + mocked_open.side_effect = FileNotFoundError() + + assert EnvironmentHelper().is_training_job() is False + + +@patch("builtins.open") +def test_get_instance_count_single_node(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().get_instance_count() == 1 + + +@patch("builtins.open") +def test_get_instance_count_multi_node(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(MULTI_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().get_instance_count() == MULTI_NODE_COUNT + + +@patch("builtins.open") +def test_load_training_resource_config(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().load_training_resource_config() == SINGLE_NODE_RESOURCE_CONFIG + + +@patch("builtins.open") +def test_load_training_resource_config_none(mocked_open): + mocked_open.side_effect = FileNotFoundError() + + assert EnvironmentHelper().load_training_resource_config() is None + + +@pytest.mark.parametrize( + "is_training_result", + [(True), (False)], +) +@patch("datetime.now.strftime", return_value="test_current_time") +@patch("sagemaker.mlops.feature_store.feature_processor._env.EnvironmentHelper.is_training_job") +@patch("os.environ", return_value={"scheduled_time": "test_time"}) +def get_job_scheduled_time(mock_env, mock_is_training, mock_datetime, is_training_result): + + mock_is_training.return_value = is_training_result + output_time = EnvironmentHelper().get_job_scheduled_time + + if is_training_result: + assert output_time == "test_scheduled_time" + else: + assert output_time == "test_current_time" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py new file mode 100644 index 0000000000..f44472d519 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py @@ -0,0 +1,301 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import Mock +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from botocore.exceptions import ClientError +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) +import pytest + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("events").return_value = Mock() + return Mock(Session, boto_session=boto_session, sagemaker_client=Mock()) + + +@pytest.fixture +def event_bridge_rule_helper(sagemaker_session): + return EventBridgeRuleHelper(sagemaker_session, sagemaker_session.boto_session.client("events")) + + +def test_put_rule_without_event_pattern(event_bridge_rule_helper): + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_rule = Mock( + return_value=dict(RuleArn="rule_arn") + ) + event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline="target_pipeline", + event_pattern=None, + state="Disabled", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_rule.assert_called_with( + Name="target_pipeline", + EventPattern=( + '{"detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], ' + '"source": ["aws.sagemaker"], "detail": {"currentPipelineExecutionStatus": ' + '["Succeeded"], "pipelineArn": ["pipeline_arn"]}}' + ), + State="Disabled", + ) + + +def test_put_rule_with_event_pattern(event_bridge_rule_helper): + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_rule = Mock( + return_value=dict(RuleArn="rule_arn") + ) + event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline="target_pipeline", + event_pattern="event_pattern", + state="Disabled", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_rule.assert_called_with( + Name="target_pipeline", + EventPattern="event_pattern", + State="Disabled", + ) + + +def test_put_targets_success(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_targets = Mock( + return_value=dict(FailedEntryCount=0) + ) + event_bridge_rule_helper.put_target( + rule_name="rule_name", + target_pipeline="target_pipeline", + target_pipeline_parameters={"param": "value"}, + role_arn="role_arn", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_targets.assert_called_with( + Rule="rule_name", + Targets=[ + { + "Id": "pipeline_name", + "Arn": "pipeline_arn", + "RoleArn": "role_arn", + "SageMakerPipelineParameters": {"PipelineParameterList": {"param": "value"}}, + } + ], + ) + + +def test_put_targets_failure(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_targets = Mock( + return_value=dict( + FailedEntryCount=1, + FailedEntries=[dict(ErrorMessage="test_error_message")], + ) + ) + with pytest.raises( + Exception, match="Failed to add target pipeline to rule. Failure reason: test_error_message" + ): + event_bridge_rule_helper.put_target( + rule_name="rule_name", + target_pipeline="target_pipeline", + target_pipeline_parameters={"param": "value"}, + role_arn="role_arn", + ) + + +def test_delete_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.delete_rule = Mock() + event_bridge_rule_helper.delete_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.delete_rule.assert_called_with( + Name="rule_name" + ) + + +def test_describe_rule_success(event_bridge_rule_helper): + mock_describe_response = dict(State="ENABLED", RuleName="rule_name") + event_bridge_rule_helper.event_bridge_rule_client.describe_rule = Mock( + return_value=mock_describe_response + ) + assert event_bridge_rule_helper.describe_rule("rule_name") == mock_describe_response + + +def test_describe_rule_non_existent(event_bridge_rule_helper): + mock_describe_response = dict(State="ENABLED", RuleName="rule_name") + event_bridge_rule_helper.event_bridge_rule_client.describe_rule = Mock( + return_value=mock_describe_response, + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="describe_rule", + ), + ) + assert event_bridge_rule_helper.describe_rule("rule_name") is None + + +def test_remove_targets(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.remove_targets = Mock() + event_bridge_rule_helper.remove_targets(rule_name="rule_name", ids=["target_pipeline"]) + event_bridge_rule_helper.event_bridge_rule_client.remove_targets.assert_called_with( + Rule="rule_name", + Ids=["target_pipeline"], + ) + + +def test_enable_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.enable_rule = Mock() + event_bridge_rule_helper.enable_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.enable_rule.assert_called_with( + Name="rule_name" + ) + + +def test_disable_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.disable_rule = Mock() + event_bridge_rule_helper.disable_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.disable_rule.assert_called_with( + Name="rule_name" + ) + + +def test_add_tags(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.tag_resource = Mock() + event_bridge_rule_helper.add_tags("rule_arn", [{"key": "value"}]) + + event_bridge_rule_helper.event_bridge_rule_client.tag_resource.assert_called_with( + ResourceARN="rule_arn", Tags=[{"key": "value"}] + ) + + +def test_generate_event_pattern_from_feature_processor_pipeline_events(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_pattern = ( + event_bridge_rule_helper._generate_event_pattern_from_feature_processor_pipeline_events( + [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_2", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + ) + ) + + assert ( + event_pattern + == '{"detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], ' + '"$or": [{"source": ["aws.sagemaker"], "detail": {"currentPipelineExecutionStatus": ' + '["Failed"], "pipelineArn": ["pipeline_arn"]}}, {"source": ["aws.sagemaker"], "detail": ' + '{"currentPipelineExecutionStatus": ["Failed"], "pipelineArn": ["pipeline_arn"]}}]}' + ) + + +def test_validate_feature_processor_pipeline_events(event_bridge_rule_helper): + events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + + with pytest.raises(ValueError, match="Pipeline names in pipeline_events must be unique."): + event_bridge_rule_helper._validate_feature_processor_pipeline_events(events) + + +def test_aggregate_pipeline_events_with_same_desired_status(event_bridge_rule_helper): + events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_2", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + + assert event_bridge_rule_helper._aggregate_pipeline_events_with_same_desired_status(events) == { + (FeatureProcessorPipelineExecutionStatus.FAILED,): [ + "test_pipeline_1", + "test_pipeline_2", + ] + } + + +@pytest.mark.parametrize( + "pipeline_uri,expected_result", + [ + ( + "arn:aws:sagemaker:us-west-2:123456789012:pipeline/test-pipeline", + dict( + pipeline_arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test-pipeline", + pipeline_name="test-pipeline", + ), + ), + ( + "test-pipeline", + dict( + pipeline_arn="test-pipeline-arn", + pipeline_name="test-pipeline", + ), + ), + ], +) +def test_generate_pipeline_arn_and_name(event_bridge_rule_helper, pipeline_uri, expected_result): + event_bridge_rule_helper.sagemaker_session.sagemaker_client.describe_pipeline = Mock( + return_value=dict(PipelineArn="test-pipeline-arn") + ) + assert event_bridge_rule_helper._generate_pipeline_arn_and_name(pipeline_uri) == expected_result diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py new file mode 100644 index 0000000000..baca05e84d --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py @@ -0,0 +1,96 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from datetime import datetime +import pytest +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from mock import Mock + +from sagemaker.core.helper.session_helper import Session + +SCHEDULE_NAME = "test_schedule" +SCHEDULE_ARN = "test_schedule_arn" +NEW_SCHEDULE_ARN = "test_new_schedule_arn" +TARGET_ARN = "test_arn" +CRON_SCHEDULE = "test_cron" +STATE = "ENABLED" +ROLE = "test_role" +START_DATE = datetime.now() + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("scheduler").return_value = Mock() + return Mock(Session, boto_session=boto_session) + + +@pytest.fixture +def event_bridge_scheduler_helper(sagemaker_session): + return EventBridgeSchedulerHelper( + sagemaker_session, sagemaker_session.boto_session.client("scheduler") + ) + + +def test_upsert_schedule_already_exists(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.event_bridge_scheduler_client.update_schedule.return_value = ( + SCHEDULE_ARN + ) + schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=SCHEDULE_NAME, + pipeline_arn=TARGET_ARN, + schedule_expression=CRON_SCHEDULE, + state=STATE, + start_date=START_DATE, + role=ROLE, + ) + assert schedule_arn == SCHEDULE_ARN + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.assert_not_called() + + +def test_upsert_schedule_not_exists(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.event_bridge_scheduler_client.update_schedule.side_effect = ( + ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="update_schedule", + ) + ) + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.return_value = ( + NEW_SCHEDULE_ARN + ) + + schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=SCHEDULE_NAME, + pipeline_arn=TARGET_ARN, + schedule_expression=CRON_SCHEDULE, + state=STATE, + start_date=START_DATE, + role=ROLE, + ) + assert schedule_arn == NEW_SCHEDULE_ARN + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.assert_called_once() + + +def test_delete_schedule(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.sagemaker_session.boto_session = Mock() + event_bridge_scheduler_helper.sagemaker_session.sagemaker_client = Mock() + event_bridge_scheduler_helper.delete_schedule(schedule_name=TARGET_ARN) + event_bridge_scheduler_helper.event_bridge_scheduler_client.delete_schedule.assert_called_with( + Name=TARGET_ARN + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py new file mode 100644 index 0000000000..27a3b96231 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ( + FeatureProcessorArgValidator, + InputValidator, + SparkUDFSignatureValidator, + InputOffsetValidator, + BaseDataSourceValidator, +) +from sagemaker.core.helper.session_helper import Session + + +def test_get_validation_chain(): + fp_config = tdh.create_fp_config(mode=FeatureProcessorMode.PYSPARK) + result = ValidatorFactory.get_validation_chain(fp_config) + + assert result.validators is not None + assert { + InputValidator, + FeatureProcessorArgValidator, + InputOffsetValidator, + BaseDataSourceValidator, + SparkUDFSignatureValidator, + } == {type(instance) for instance in result.validators} + + +def test_get_udf_wrapper(): + fp_config = tdh.create_fp_config(mode=FeatureProcessorMode.PYSPARK) + udf_wrapper = Mock(UDFWrapper) + + with patch.object( + UDFWrapperFactory, "_get_spark_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper_method: + result = UDFWrapperFactory.get_udf_wrapper(fp_config) + + assert result == udf_wrapper + get_udf_wrapper_method.assert_called_with(fp_config) + + +def test_get_udf_wrapper_invalid_mode(): + fp_config = Mock(FeatureProcessorConfig) + fp_config.mode = FeatureProcessorMode.PYTHON + fp_config.sagemaker_session = Mock(Session) + + with pytest.raises( + ValueError, + match=r"FeatureProcessorMode FeatureProcessorMode.PYTHON is not supported\.", + ): + UDFWrapperFactory.get_udf_wrapper(fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py new file mode 100644 index 0000000000..6f6c8471aa --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ValidatorChain +from sagemaker.mlops.feature_store.feature_processor.feature_processor import ( + feature_processor, +) + + +@pytest.fixture +def udf(): + return Mock(Callable) + + +@pytest.fixture +def wrapped_udf(): + return Mock() + + +@pytest.fixture +def udf_wrapper(wrapped_udf): + mock = Mock(UDFWrapper) + mock.wrap.return_value = wrapped_udf + return mock + + +@pytest.fixture +def validator_chain(): + return Mock(ValidatorChain) + + +@pytest.fixture +def fp_config(): + mock = Mock(FeatureProcessorConfig) + mock.mode = FeatureProcessorMode.PYSPARK + return mock + + +def test_feature_processor(udf, udf_wrapper, validator_chain, fp_config, wrapped_udf): + with patch.object( + FeatureProcessorConfig, "create", return_value=fp_config + ) as fp_config_create_method: + with patch.object( + UDFWrapperFactory, "get_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper: + with patch.object( + ValidatorFactory, + "get_validation_chain", + return_value=validator_chain, + ) as get_validation_chain: + decorated_udf = feature_processor( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE], + output="", + )(udf) + + assert decorated_udf == wrapped_udf + + fp_config_create_method.assert_called() + get_udf_wrapper.assert_called_with(fp_config) + get_validation_chain.assert_called() + + validator_chain.validate.assert_called_with(fp_config=fp_config, udf=udf) + udf_wrapper.wrap.assert_called_with(fp_config=fp_config, udf=udf) + + assert decorated_udf.feature_processor_config == fp_config + + +def test_feature_processor_validation_fails(udf, udf_wrapper, validator_chain, fp_config): + with patch.object( + FeatureProcessorConfig, "create", return_value=fp_config + ) as fp_config_create_method: + with patch.object( + UDFWrapperFactory, "get_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper: + with patch.object( + ValidatorFactory, + "get_validation_chain", + return_value=validator_chain, + ) as get_validation_chain: + validator_chain.validate.side_effect = ValueError() + + # Verify validation error is raised to user. + with pytest.raises(ValueError): + feature_processor( + inputs=tdh.FEATURE_PROCESSOR_INPUTS, + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + )(udf) + + # Verify validation failure causes execution to terminate early. + # Verify FeatureProcessorConfig interactions. + fp_config_create_method.assert_called() + get_udf_wrapper.assert_called_once() + get_validation_chain.assert_called_once() + validator_chain.validate.assert_called_with(fp_config=fp_config, udf=udf) + udf_wrapper.wrap.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py new file mode 100644 index 0000000000..2914bbb63a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import attr +import pytest +import test_data_helpers as tdh + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +def test_feature_processor_config_is_immutable(): + fp_config = FeatureProcessorConfig.create( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + mode=FeatureProcessorMode.PYSPARK, + target_stores=None, + enable_ingestion=True, + parameters=None, + spark_config=None, + ) + + with pytest.raises(attr.exceptions.FrozenInstanceError): + # Only attempting one field, as FrozenInstanceError indicates all fields are frozen + # (as opposed to FrozenAttributeError). + fp_config.inputs = [] + + with pytest.raises( + TypeError, + match="'FeatureProcessorConfig' object does not support item assignment", + ): + fp_config["inputs"] = [] diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py new file mode 100644 index 0000000000..78c506313a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) + + +def test_feature_processor_pipeline_events(): + fe_pipeline_events = FeatureProcessorPipelineEvents( + pipeline_name="pipeline_name", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.EXECUTING], + ) + assert fe_pipeline_events.pipeline_name == "pipeline_name" + assert fe_pipeline_events.pipeline_execution_status == [ + FeatureProcessorPipelineExecutionStatus.EXECUTING + ] diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py new file mode 100644 index 0000000000..1cd7e381e0 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py @@ -0,0 +1,1057 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from datetime import datetime +from typing import Callable + +import pytest +import json +from botocore.exceptions import ClientError +from mock import Mock, patch, call + +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + FeatureProcessorLineageHandler, +) +from sagemaker.mlops.feature_store.feature_processor import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.remote_function.spark_config import SparkConfig + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._constants import ( + FEATURE_PROCESSOR_TAG_KEY, + FEATURE_PROCESSOR_TAG_VALUE, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, + PIPELINE_NAME_MAXIMUM_LENGTH, +) +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + schedule, + to_pipeline, + execute, + delete_schedule, + describe, + list_pipelines, + put_trigger, + enable_trigger, + disable_trigger, + delete_trigger, + _validate_fg_lineage_resources, + _validate_pipeline_lineage_resources, +) +from sagemaker.core.remote_function.job import ( + _JobSettings, + SPARK_APP_SCRIPT_PATH, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + ENTRYPOINT_SCRIPT_NAME, + SPARK_CONF_CHANNEL_NAME, +) +from sagemaker.core.workflow.parameters import Parameter, ParameterTypeEnum +from sagemaker.mlops.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) +import test_data_helpers as tdh + +REGION = "us-west-2" +IMAGE = "image_uri" +BUCKET = "my-s3-bucket" +DEFAULT_BUCKET_PREFIX = "default_bucket_prefix" +S3_URI = f"s3://{BUCKET}/keyprefix" +DEFAULT_IMAGE = ( + "153931337802.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark-processing:3.2-cpu-py39-v1.1" +) +PIPELINE_ARN = "pipeline_arn" +SCHEDULE_ARN = "schedule_arn" +SCHEDULE_ROLE_ARN = "my_schedule_role_arn" +EXECUTION_ROLE_ARN = "my_execution_role_arn" +EVENT_BRIDGE_RULE_ARN = "arn:aws:events:us-west-2:123456789012:rule/test-rule" +VALID_SCHEDULE_STATE = "ENABLED" +INVALID_SCHEDULE_STATE = "invalid" +TEST_REGION = "us-west-2" +PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name" +PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name" +NOW = datetime.now() +SAGEMAKER_SESSION_MOCK = Mock(Session) +CONTEXT_MOCK_01 = Mock(Context) +CONTEXT_MOCK_02 = Mock(Context) +CONTEXT_MOCK_03 = Mock(Context) +FEATURE_GROUP = tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() +PIPELINE = tdh.PIPELINE.copy() +TAGS = [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + + +def mock_session(): + session = Mock() + session.default_bucket.return_value = BUCKET + session.default_bucket_prefix = DEFAULT_BUCKET_PREFIX + session.expand_role.return_value = EXECUTION_ROLE_ARN + session.boto_region_name = TEST_REGION + session.sagemaker_config = None + session._append_sagemaker_config_tags.return_value = [] + session.default_bucket_prefix = None + session.sagemaker_client = Mock() + return session + + +def mock_pipeline(): + pipeline = Mock() + pipeline.describe.return_value = {"PipelineArn": PIPELINE_ARN} + pipeline.upsert.return_value = None + return pipeline + + +def mock_event_bridge_scheduler_helper(): + helper = Mock() + helper.upsert_schedule.return_value = dict(ScheduleArn=SCHEDULE_ARN) + helper.delete_schedule.return_value = None + helper.describe_schedule.return_value = { + "Arn": "some_schedule_arn", + "ScheduleExpression": "some_schedule_expression", + "StartDate": NOW, + "State": VALID_SCHEDULE_STATE, + "Target": {"Arn": "some_pipeline_arn", "RoleArn": "some_schedule_role_arn"}, + } + return helper + + +def mock_event_bridge_rule_helper(): + helper = Mock() + helper.describe_rule.return_value = { + "Arn": "some_rule_arn", + "EventPattern": "some_event_pattern", + "State": "ENABLED", + } + return helper + + +def mock_feature_processor_lineage(): + return Mock(FeatureProcessorLineageHandler) + + +@pytest.fixture +def job_function(): + return Mock(Callable) + + +@pytest.fixture +def config_uploader(): + uploader = Mock() + uploader.return_value = "some_s3_uri" + uploader.prepare_and_upload_runtime_scripts.return_value = "some_s3_uri" + return uploader + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_fg_lineage_resources", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.Pipeline", + return_value=mock_pipeline(), +) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.TrainingInput") +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.TrainingStep") +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.ModelTrainer") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader" + "._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.RuntimeEnvironmentManager") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_callable" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.create_lineage" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.get_pipeline_lineage_names", + return_value=dict( + pipeline_context_name="pipeline-context-name", + pipeline_version_context_name="pipeline-version-context-name", + ), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.PipelineSession" +) +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.expand_role", side_effect=lambda session, role: role) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline( + get_execution_role, + expand_role, + session, + mock_pipeline_session, + mock_get_pipeline_lineage_names, + mock_create_lineage, + mock_upload_callable, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + mock_spark_dependency_upload, + mock_model_trainer, + mock_training_step, + mock_training_input, + mock_spark_image, + pipeline, + lineage_validator, +): + session.sagemaker_config = None + session.boto_region_name = TEST_REGION + session.expand_role.return_value = EXECUTION_ROLE_ARN + + # Configure RuntimeEnvironmentManager mock to return proper string values + mock_runtime_manager_instance = mock_runtime_manager.return_value + mock_runtime_manager_instance._current_python_version.return_value = "3.10" + mock_runtime_manager_instance.snapshot.return_value = "/tmp/snapshot" + + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + sagemaker_session=session, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", session.boto_region_name]) + + mock_feature_processor_config = Mock( + mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg" + ) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + pipeline_arn = to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + tags=[("tag_key_1", "tag_value_1"), ("tag_key_2", "tag_value_2")], + sagemaker_session=session, + ) + assert pipeline_arn == PIPELINE_ARN + + assert mock_upload_callable.called_once_with(job_function) + + mock_script_upload.assert_called_once_with( + spark_config, + f"{S3_URI}/pipeline_name", + None, + session, + ) + + mock_dependency_upload.assert_called_once_with( + "/tmp/snapshot", + True, + None, + None, + f"{S3_URI}/pipeline_name", + None, + session, + None, + ) + + mock_spark_dependency_upload.assert_called_once_with( + spark_config, + f"{S3_URI}/pipeline_name", + None, + session, + ) + + mock_model_trainer.assert_called_once() + # Verify ModelTrainer was configured correctly + model_trainer_call_kwargs = mock_model_trainer.call_args[1] + assert model_trainer_call_kwargs["training_image"] == "some_image_uri" + assert model_trainer_call_kwargs["role"] == EXECUTION_ROLE_ARN + assert model_trainer_call_kwargs["training_input_mode"] == "File" + + # Verify PipelineSession was passed to ModelTrainer + assert model_trainer_call_kwargs["sagemaker_session"] == mock_pipeline_session.return_value + mock_pipeline_session.assert_called_once_with( + boto_session=session.boto_session, + default_bucket=session.default_bucket(), + default_bucket_prefix=session.default_bucket_prefix, + ) + + # Verify Compute config + compute_arg = model_trainer_call_kwargs["compute"] + assert compute_arg.instance_type == "ml.m5.large" + assert compute_arg.instance_count == 1 + assert compute_arg.volume_size_in_gb == 30 + + # No VPC config was provided, so networking should be None + assert model_trainer_call_kwargs["networking"] is None + + # max_runtime_in_seconds defaults to 86400 in _JobSettings + stopping_condition_arg = model_trainer_call_kwargs["stopping_condition"] + assert stopping_condition_arg.max_runtime_in_seconds == 86400 + + # Verify OutputDataConfig + output_data_config_arg = model_trainer_call_kwargs["output_data_config"] + assert output_data_config_arg.s3_output_path == f"{S3_URI}/pipeline_name" + + # Verify SourceCode has a command string + source_code_arg = model_trainer_call_kwargs["source_code"] + assert source_code_arg.command is not None + assert len(source_code_arg.command) > 0 + assert "--client_python_version 3.10" in source_code_arg.command + + # No tags on _JobSettings, so tags should be None + assert model_trainer_call_kwargs["tags"] is None + + # Verify train() was called with input_data_config + mock_model_trainer.return_value.train.assert_called_once() + train_call_kwargs = mock_model_trainer.return_value.train.call_args[1] + assert "input_data_config" in train_call_kwargs + + mock_training_step.assert_called_once_with( + name="-".join(["pipeline_name", "feature-processor"]), + step_args=mock_model_trainer.return_value.train.return_value, + retry_policies=[ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + max_attempts=1, + ), + SageMakerJobStepRetryPolicy( + exception_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT, + ], + max_attempts=1, + ), + ], + ) + + pipeline.assert_called_once_with( + name="pipeline_name", + steps=[mock_training_step()], + sagemaker_session=session, + parameters=[Parameter(name="scheduled_time", parameter_type=ParameterTypeEnum.STRING)], + ) + + pipeline().upsert.assert_has_calls( + [ + call( + role_arn=EXECUTION_ROLE_ARN, + tags=[ + dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE), + dict(Key="tag_key_1", Value="tag_value_1"), + dict(Key="tag_key_2", Value="tag_value_2"), + ], + ), + call( + role_arn=EXECUTION_ROLE_ARN, + tags=[ + { + "Key": PIPELINE_CONTEXT_NAME_TAG_KEY, + "Value": "pipeline-context-name", + }, + { + "Key": PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + "Value": "pipeline-version-context-name", + }, + ], + ), + ] + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_not_wrapped_by_feature_processor(get_execution_role, session): + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + wrapped_func = Mock( + Callable, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Please wrap step parameter with feature_processor decorator in order to use to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_not_wrapped_by_remote(get_execution_role, session): + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYTHON) + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=None, + wrapped_func=job_function, + ) + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Please wrap step parameter with remote decorator in order to use to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_wrong_mode(get_execution_role, mock_spark_image, session): + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", TEST_REGION]) + + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYTHON) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYTHON + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Mode FeatureProcessorMode.PYTHON is not supported by to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_pipeline_name_length_limit_exceeds( + get_execution_role, mock_spark_image, session +): + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", TEST_REGION]) + + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYSPARK) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Pipeline name used by feature processor should be less than 80 " + "characters. Please choose another pipeline name.", + ): + to_pipeline( + pipeline_name="".join(["a" for _ in range(PIPELINE_NAME_MAXIMUM_LENGTH + 1)]), + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_used_reserved_tags(get_execution_role, mock_spark_image, session): + session.sagemaker_config = None + session.boto_region_name = TEST_REGION + session.expand_role.return_value = EXECUTION_ROLE_ARN + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + sagemaker_session=session, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", session.boto_region_name]) + + mock_feature_processor_config = Mock( + mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg" + ) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="sm-fs-fe:created-from is a reserved tag key for to_pipeline API. Please choose another tag.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + tags=[("sm-fs-fe:created-from", "random")], + sagemaker_session=session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler" + "._get_tags_from_pipeline_to_propagate_to_lineage_resources", + return_value=TAGS, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.FeatureProcessorLineageHandler", + return_value=mock_feature_processor_lineage(), +) +def test_schedule(lineage, helper, validation, get_tags): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline = Mock( + return_value={"PipelineArn": "my:arn", "CreationTime": NOW} + ) + + schedule_arn = schedule( + schedule_expression="some_schedule", + state=VALID_SCHEDULE_STATE, + start_date=NOW, + pipeline_name=PIPELINE_ARN, + role_arn=SCHEDULE_ROLE_ARN, + sagemaker_session=session, + ) + + assert schedule_arn == SCHEDULE_ARN + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeRuleHelper", + return_value=mock_event_bridge_rule_helper(), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_describe_both_exist(mock_scheduler_helper, mock_rule_helper): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline.return_value = PIPELINE + describe_schedule_response = describe( + pipeline_name="some_pipeline_arn", sagemaker_session=session + ) + assert describe_schedule_response == dict( + pipeline_arn="some_pipeline_arn", + pipeline_execution_role_arn="some_execution_role_arn", + max_retries=5, + schedule_arn="some_schedule_arn", + schedule_expression="some_schedule_expression", + schedule_state=VALID_SCHEDULE_STATE, + schedule_start_date=NOW.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + schedule_role="some_schedule_role_arn", + trigger="some_rule_arn", + event_pattern="some_event_pattern", + trigger_state="ENABLED", + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeRuleHelper.describe_rule", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper.describe_schedule", + return_value=None, +) +def test_describe_only_pipeline_exist(helper, mock_describe_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline.return_value = { + "PipelineArn": "some_pipeline_arn", + "RoleArn": "some_execution_role_arn", + "PipelineDefinition": json.dumps({"Steps": [{"Arguments": {}}]}), + } + helper.describe_schedule().return_value = None + describe_schedule_response = describe( + pipeline_name="some_pipeline_arn", sagemaker_session=session + ) + assert describe_schedule_response == dict( + pipeline_arn="some_pipeline_arn", + pipeline_execution_role_arn="some_execution_role_arn", + ) + + +def test_list_pipelines(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.list_contexts.return_value = { + "ContextSummaries": [ + { + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:12345789012:pipeline/some_pipeline_name" + } + } + ] + } + list_response = list_pipelines(session) + assert list_response == [dict(pipeline_name="some_pipeline_name")] + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_delete_schedule_both_exist(helper): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.delete_pipeline = Mock() + delete_schedule(pipeline_name=PIPELINE_ARN, sagemaker_session=session) + helper().delete_schedule.assert_called_once_with(PIPELINE_ARN) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_delete_schedule_not_exist(helper): + helper.delete_schedule.side_effect = ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="update_schedule", + ) + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.delete_pipeline = Mock() + delete_schedule(pipeline_name=PIPELINE_ARN, sagemaker_session=session) + helper().delete_schedule.assert_called_once_with(PIPELINE_ARN) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources", + return_value=None, +) +def test_execute(validation): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline = Mock( + return_value={"PipelineArn": "my:arn", "CreationTime": NOW} + ) + session.sagemaker_client.start_pipeline_execution = Mock( + return_value={"PipelineExecutionArn": "my:arn"} + ) + execution_arn = execute( + pipeline_name="some_pipeline", execution_time=NOW, sagemaker_session=session + ) + assert execution_arn == "my:arn" + + +def test_validate_fg_lineage_resources_happy_case(): + with patch.object( + SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP + ) as fg_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02, CONTEXT_MOCK_03] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn" + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep" + type(CONTEXT_MOCK_03).context_arn = "context-arn-fep-ver" + _validate_fg_lineage_resources( + feature_group_name="some_fg", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + fg_describe_method.assert_called_once_with(feature_group_name="some_fg") + context_load.assert_has_calls( + [ + call( + context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline-version", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == context_load.call_count + + +def test_validete_fg_lineage_resources_rnf(): + with patch.object(SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP): + with patch.object( + Context, + "load", + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFound"}}, + operation_name="describe_context", + ), + ): + feature_group_name = "some_fg" + feature_group_creation_time = FEATURE_GROUP["CreationTime"].strftime("%s") + context_name = f"{feature_group_name}-{feature_group_creation_time}" + with pytest.raises( + ValueError, + match=f"Lineage resource {context_name} has not yet been created for feature group" + f" {feature_group_name} or has already been deleted. Please try again later.", + ): + _validate_fg_lineage_resources( + feature_group_name="some_fg", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_validate_pipeline_lineage_resources_happy_case(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.return_value = Mock() + pipeline_name = "some_pipeline" + with patch.object( + session.sagemaker_client, "describe_pipeline", return_value=PIPELINE + ) as pipeline_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn" + type(CONTEXT_MOCK_01).properties = {"LastUpdateTime": NOW.strftime("%s")} + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep" + _validate_pipeline_lineage_resources( + pipeline_name=pipeline_name, + sagemaker_session=session, + ) + pipeline_describe_method.assert_called_once_with(PipelineName=pipeline_name) + pipeline_creation_time = PIPELINE["CreationTime"].strftime("%s") + last_updated_time = NOW.strftime("%s") + context_load.assert_has_calls( + [ + call( + context_name=f"sm-fs-fe-{pipeline_name}-{pipeline_creation_time}-fep", + sagemaker_session=session, + ), + call( + context_name=f"sm-fs-fe-{pipeline_name}-{last_updated_time}-fep-ver", + sagemaker_session=session, + ), + ] + ) + assert 2 == context_load.call_count + + +def test_validate_pipeline_lineage_resources_rnf(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.return_value = Mock() + pipeline_name = "some_pipeline" + with patch.object(session.sagemaker_client, "describe_pipeline", return_value=PIPELINE): + with patch.object( + Context, + "load", + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFound"}}, + operation_name="describe_context", + ), + ): + with pytest.raises( + ValueError, + match="Pipeline lineage resources have not been created yet or have" + " already been deleted. Please try again later.", + ): + _validate_pipeline_lineage_resources( + pipeline_name=pipeline_name, + sagemaker_session=session, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_remote_decorator_fields_consistency(get_execution_role, session): + expected_remote_decorator_attributes = { + "sagemaker_session", + "environment_variables", + "image_uri", + "dependencies", + "pre_execution_commands", + "pre_execution_script", + "include_local_workdir", + "instance_type", + "instance_count", + "volume_size", + "max_runtime_in_seconds", + "max_retry_attempts", + "keep_alive_period_in_seconds", + "spark_config", + "job_conda_env", + "job_name_prefix", + "encrypt_inter_container_traffic", + "enable_network_isolation", + "role", + "s3_root_uri", + "s3_kms_key", + "volume_kms_key", + "vpc_config", + "tags", + "use_spot_instances", + "max_wait_time_in_seconds", + "custom_file_filter", + "disable_output_compression", + "use_torchrun", + "use_mpirun", + "nproc_per_node", + } + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + actual_attributes = {attribute for attribute, _ in job_settings.__dict__.items()} + + assert expected_remote_decorator_attributes == actual_attributes + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.create_trigger_lineage" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.describe_rule", + return_value={"EventPattern": "test-pattern"}, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.add_tags" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler." + "_get_tags_from_pipeline_to_propagate_to_lineage_resources", + return_value=TAGS, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.put_target" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.put_rule", + return_value="arn:aws:events:us-west-2:123456789012:rule/test-rule", +) +def test_put_trigger( + mock_put_rule, + mock_put_target, + mock_get_tags, + mock_add_tags, + mock_describe_rule, + mock_create_trigger_lineage, +): + session = Mock( + Session, + sagemaker_client=Mock( + describe_pipeline=Mock(return_value={"PipelineArn": "test-pipeline-arn"}) + ), + boto_session=Mock(), + ) + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test-pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + put_trigger( + source_pipeline_events=source_pipeline_events, + target_pipeline="test-target-pipeline", + state="Enabled", + event_pattern="test-pattern", + role_arn=SCHEDULE_ROLE_ARN, + sagemaker_session=session, + ) + + mock_put_rule.assert_called_once_with( + source_pipeline_events=source_pipeline_events, + target_pipeline="test-target-pipeline", + state="Enabled", + event_pattern="test-pattern", + ) + mock_put_target.assert_called_once_with( + rule_name="test-rule", + target_pipeline="test-target-pipeline", + target_pipeline_parameters=None, + role_arn=SCHEDULE_ROLE_ARN, + ) + mock_add_tags.assert_called_once_with(rule_arn=EVENT_BRIDGE_RULE_ARN, tags=TAGS) + mock_create_trigger_lineage.assert_called_once_with( + pipeline_name="test-target-pipeline", + trigger_arn=EVENT_BRIDGE_RULE_ARN, + state="Enabled", + tags=TAGS, + event_pattern="test-pattern", + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.enable_rule" +) +def test_enable_trigger(mock_enable_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + enable_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_enable_rule.assert_called_once_with(rule_name="test-pipeline") + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.disable_rule" +) +def test_disable_trigger(mock_disable_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + disable_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_disable_rule.assert_called_once_with(rule_name="test-pipeline") + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.list_targets_by_rule", + return_value=[{"Targets": [{"Id": "target_pipeline"}]}], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.remove_targets" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.delete_rule" +) +def test_delete_trigger(mock_delete_rule, mock_remove_targets, mock_list_targets_by_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + delete_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_delete_rule.assert_called_once_with("test-pipeline") + mock_list_targets_by_rule.assert_called_once_with("test-pipeline") + mock_remove_targets.assert_called_once_with(rule_name="test-pipeline", ids=["target_pipeline"]) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py new file mode 100644 index 0000000000..24d20f96c5 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py @@ -0,0 +1,320 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch, call +from pyspark.sql import SparkSession, DataFrame +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + IcebergTableDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.core.helper.session_helper import Session + + +@pytest.fixture +def describe_fg_response(): + return tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() + + +@pytest.fixture +def sagemaker_session(describe_fg_response): + return Mock(Session, describe_feature_group=Mock(return_value=describe_fg_response)) + + +@pytest.fixture +def spark_session(mock_data_frame): + return Mock( + SparkSession, + read=Mock( + csv=Mock(return_value=Mock()), + parquet=Mock(return_value=mock_data_frame), + conf=Mock(set=Mock()), + ), + table=Mock(return_value=mock_data_frame), + ) + + +@pytest.fixture +def environment_helper(): + return Mock( + EnvironmentHelper, + get_job_scheduled_time=Mock(return_value="2023-05-05T15:22:57Z"), + ) + + +@pytest.fixture +def mock_data_frame(): + return Mock(DataFrame, filter=Mock()) + + +@pytest.fixture +def spark_session_factory(spark_session): + factory = Mock(SparkSessionFactory) + factory.spark_session = spark_session + factory.get_spark_session_with_iceberg_config = Mock(return_value=spark_session) + return factory + + +@pytest.fixture +def fp_config(): + return tdh.create_fp_config() + + +@pytest.fixture +def input_loader(spark_session_factory, sagemaker_session, environment_helper): + return SparkDataFrameInputLoader( + spark_session_factory, + environment_helper, + sagemaker_session, + ) + + +def test_load_from_s3_with_csv_object(input_loader: SparkDataFrameInputLoader, spark_session): + s3_data_source = CSVDataSource( + s3_uri="s3://bucket/prefix/file.csv", + csv_header=True, + csv_infer_schema=True, + ) + + input_loader.load_from_s3(s3_data_source) + + spark_session.read.csv.assert_called_with( + "s3a://bucket/prefix/file.csv", header=True, inferSchema=True + ) + + +def test_load_from_s3_with_parquet_object(input_loader, spark_session): + s3_data_source = ParquetDataSource(s3_uri="s3://bucket/prefix/file.parquet") + + input_loader.load_from_s3(s3_data_source) + + spark_session.read.parquet.assert_called_with("s3a://bucket/prefix/file.parquet") + + +@pytest.mark.parametrize( + "condition", + [(None), ("condition")], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_loader." + "SparkDataFrameInputLoader._get_iceberg_offset_filter_condition" +) +def test_load_from_iceberg_table( + mock_get_filter_condition, + condition, + input_loader, + spark_session, + spark_session_factory, + mock_data_frame, +): + iceberg_table_data_source = IcebergTableDataSource( + warehouse_s3_uri="s3://bucket/prefix/", + catalog="Catalog", + database="Database", + table="Table", + ) + mock_get_filter_condition.return_value = condition + + input_loader.load_from_iceberg_table(iceberg_table_data_source, "event_time", "start", "end") + spark_session_factory.get_spark_session_with_iceberg_config.assert_called_with( + "s3://bucket/prefix/", "catalog" + ) + spark_session.table.assert_called_with("catalog.database.table") + mock_get_filter_condition.assert_called_with("event_time", "start", "end") + + if condition: + mock_data_frame.filter.assert_called_with(condition) + else: + mock_data_frame.filter.assert_not_called() + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_loader.SparkDataFrameInputLoader.load_from_date_partitioned_s3" +) +def test_load_from_feature_group_with_arn( + mock_load_from_date_partitioned_s3, sagemaker_session, input_loader +): + fg_arn = tdh.INPUT_FEATURE_GROUP_ARN + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource( + name=fg_arn, input_start_offset="start", input_end_offset="end" + ) + + input_loader.load_from_feature_group(fg_data_source) + + sagemaker_session.describe_feature_group.assert_called_with(fg_name) + mock_load_from_date_partitioned_s3.assert_called_with( + ParquetDataSource(tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI), + "start", + "end", + ) + + +def test_load_from_feature_group_offline_store_not_enabled(input_loader, describe_fg_response): + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + with pytest.raises( + ValueError, + match=( + f"Input Feature Groups must have an enabled Offline Store." + f" Feature Group: {fg_name} does not have an Offline Store enabled." + ), + ): + del describe_fg_response["OfflineStoreConfig"] + input_loader.load_from_feature_group(fg_data_source) + + +def test_load_from_feature_group_with_default_table_format( + sagemaker_session, input_loader, spark_session +): + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + input_loader.load_from_feature_group(fg_data_source) + + sagemaker_session.describe_feature_group.assert_called_with(fg_name) + spark_session.read.parquet.assert_called_with( + tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI.replace("s3:", "s3a:") + ) + + +def test_load_from_feature_group_with_iceberg_table_format( + describe_fg_response, spark_session_factory, spark_session, environment_helper +): + describe_iceberg_fg_response = describe_fg_response.copy() + describe_iceberg_fg_response["OfflineStoreConfig"]["TableFormat"] = "Iceberg" + mocked_session = Mock( + Session, describe_feature_group=Mock(return_value=describe_iceberg_fg_response) + ) + mock_input_loader = SparkDataFrameInputLoader( + spark_session_factory, environment_helper, mocked_session + ) + + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + mock_input_loader.load_from_feature_group(fg_data_source) + + mocked_session.describe_feature_group.assert_called_with(fg_name) + spark_session.table.assert_called_with( + "awsdatacatalog.sagemaker_featurestore.input_fg_1680142547" + ) + + +@pytest.mark.parametrize( + "param", + [ + (None, None, None), + ("start", None, "event_time >= 'start_time'"), + (None, "end", "event_time < 'end_time'"), + ("start", "end", "event_time >= 'start_time' AND event_time < 'end_time'"), + ], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_offset_parser.InputOffsetParser.get_iso_format_offset_date", + side_effect=[ + "start_time", + "end_time", + ], +) +def test_get_iceberg_offset_filter_condition(mock_get_iso_date, param, input_loader): + start_offset, end_offset, expected_condition_str = param + + condition = input_loader._get_iceberg_offset_filter_condition( + "event_time", start_offset, end_offset + ) + + if start_offset or end_offset: + mock_get_iso_date.assert_has_calls([call(start_offset), call(end_offset)]) + else: + mock_get_iso_date.assert_not_called() + + assert condition == expected_condition_str + + +@pytest.mark.parametrize( + "param", + [ + (None, None, None), + ( + "start", + None, + "(year >= 'year_start') AND NOT ((year = 'year_start' AND month < 'month_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day < 'day_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day = 'day_start' AND hour < 'hour_start'))", + ), + ( + None, + "end", + "(year <= 'year_end') AND NOT ((year = 'year_end' AND month > 'month_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day > 'day_end') OR (year = 'year_end' " + "AND month = 'month_end' AND day = 'day_end' AND hour >= 'hour_end'))", + ), + ( + "start", + "end", + "(year >= 'year_start' AND year <= 'year_end') AND NOT ((year = 'year_start' AND " + "month < 'month_start') OR (year = 'year_start' AND month = 'month_start' AND day < 'day_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day = 'day_start' AND hour < 'hour_start') OR " + "(year = 'year_end' AND month > 'month_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day > 'day_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day = 'day_end' AND hour >= 'hour_end'))", + ), + ], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_offset_parser.InputOffsetParser." + "get_offset_date_year_month_day_hour", + side_effect=[ + ("year_start", "month_start", "day_start", "hour_start"), + ("year_end", "month_end", "day_end", "hour_end"), + ], +) +def test_get_s3_partitions_offset_filter_condition(mock_get_ymdh, param, input_loader): + start_offset, end_offset, expected_condition_str = param + + condition = input_loader._get_s3_partitions_offset_filter_condition(start_offset, end_offset) + + if start_offset or end_offset: + mock_get_ymdh.assert_has_calls([call(start_offset), call(end_offset)]) + else: + mock_get_ymdh.assert_not_called() + + assert condition == expected_condition_str + + +@pytest.mark.parametrize( + "condition", + [(None), ("condition")], +) +def test_load_from_date_partitioned_s3(input_loader, spark_session, mock_data_frame, condition): + input_loader._get_s3_partitions_offset_filter_condition = Mock(return_value=condition) + + input_loader.load_from_date_partitioned_s3( + ParquetDataSource("s3://path/to/file"), "start", "end" + ) + df = spark_session.read.parquet + df.assert_called_with("s3a://path/to/file") + + if condition: + mock_data_frame.filter.assert_called_with(condition) + else: + mock_data_frame.filter.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py new file mode 100644 index 0000000000..b244a4c2da --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) +from datetime import datetime +from dateutil.relativedelta import relativedelta +import pytest + + +@pytest.fixture +def input_offset_parser(): + time_spec = dict(year=2023, month=5, day=10, hour=17, minute=30, second=20) + return InputOffsetParser(now=datetime(**time_spec)) + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ("1 hour", "2023-05-10T16:30:20Z"), + ("1 day", "2023-05-09T17:30:20Z"), + ("1 month", "2023-04-10T17:30:20Z"), + ("1 year", "2022-05-10T17:30:20Z"), + ], +) +def test_get_iso_format_offset_date(param, input_offset_parser): + input_offset, expected_offset_date = param + output_offset_date = input_offset_parser.get_iso_format_offset_date(input_offset) + + assert output_offset_date == expected_offset_date + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ( + "1 hour", + datetime.strptime("2023-05-10T16:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 day", + datetime.strptime("2023-05-09T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 month", + datetime.strptime("2023-04-10T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 year", + datetime.strptime("2022-05-10T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ], +) +def test_get_offset_datetime(param, input_offset_parser): + input_offset, expected_offset_datetime = param + output_offet_datetime = input_offset_parser.get_offset_datetime(input_offset) + + assert output_offet_datetime == expected_offset_datetime + + +@pytest.mark.parametrize( + "param", + [ + (None, (None, None, None, None)), + ("1 hour", ("2023", "05", "10", "16")), + ("1 day", ("2023", "05", "09", "17")), + ("1 month", ("2023", "04", "10", "17")), + ("1 year", ("2022", "05", "10", "17")), + ], +) +def test_get_offset_date_year_month_day_hour(param, input_offset_parser): + input_offset, expected_date_tuple = param + output_date_tuple = input_offset_parser.get_offset_date_year_month_day_hour(input_offset) + + assert output_date_tuple == expected_date_tuple + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ("1 hour", relativedelta(hours=-1)), + ("20 hours", relativedelta(hours=-20)), + ("1 day", relativedelta(days=-1)), + ("20 days", relativedelta(days=-20)), + ("1 month", relativedelta(months=-1)), + ("20 months", relativedelta(months=-20)), + ("1 year", relativedelta(years=-1)), + ("20 years", relativedelta(years=-20)), + ], +) +def test_parse_offset_to_timedelta(param, input_offset_parser): + input_offset, expected_deltatime = param + output_deltatime = input_offset_parser.parse_offset_to_timedelta(input_offset) + + assert output_deltatime == expected_deltatime + + +@pytest.mark.parametrize( + "param", + [ + ( + "random invalid string", + "[random invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "1 invalid string", + "[1 invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "2 days invalid string", + "[2 days invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "1 second", + "[second] is not a valid offset unit. Supported units: ['hour', 'day', 'week', 'month', 'year']", + ), + ], +) +def test_parse_offset_to_timedelta_negative(param, input_offset_parser): + input_offset, expected_error_message = param + + with pytest.raises(ValueError) as e: + input_offset_parser.parse_offset_to_timedelta(input_offset) + + assert str(e.value) == expected_error_message diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py new file mode 100644 index 0000000000..1ac2042a55 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +import pytest +import test_data_helpers as tdh +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._params_loader import ( + ParamsLoader, + SystemParamsLoader, +) + + +@pytest.fixture +def system_params_loader_mock(): + system_params_loader = Mock(SystemParamsLoader) + system_params_loader.get_system_args.return_value = tdh.SYSTEM_PARAMS + return system_params_loader + + +@pytest.fixture +def environment_checker(): + environment_checker = Mock(EnvironmentHelper) + environment_checker.is_training_job.return_value = False + environment_checker.get_job_scheduled_time = Mock(return_value="2023-05-05T15:22:57Z") + return environment_checker + + +@pytest.fixture +def params_loader(system_params_loader_mock): + return ParamsLoader(system_params_loader_mock) + + +@pytest.fixture +def system_params_loader(environment_checker): + return SystemParamsLoader(environment_checker) + + +def test_get_parameter_args(params_loader, system_params_loader_mock): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + parameters=tdh.USER_INPUT_PARAMS, + ) + + params = params_loader.get_parameter_args(fp_config) + + system_params_loader_mock.get_system_args.assert_called_once() + assert params == {"params": {**tdh.USER_INPUT_PARAMS, **tdh.SYSTEM_PARAMS}} + + +def test_get_parameter_args_with_no_user_params(params_loader, system_params_loader_mock): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + parameters=None, + ) + + params = params_loader.get_parameter_args(fp_config) + + system_params_loader_mock.get_system_args.assert_called_once() + assert params == {"params": {**tdh.SYSTEM_PARAMS}} + + +def test_get_system_arg_from_pipeline_execution(system_params_loader): + system_params = system_params_loader.get_system_args() + + assert system_params == { + "system": { + "scheduled_time": "2023-05-05T15:22:57Z", + } + } diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py new file mode 100644 index 0000000000..36ba54edd1 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py @@ -0,0 +1,174 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import feature_store_pyspark +import pytest +from mock import Mock, patch, call + +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, + SparkSessionFactory, +) + + +@pytest.fixture +def env_helper(): + return Mock( + is_training_job=Mock(return_value=False), + load_training_resource_config=Mock(return_value=None), + ) + + +def test_spark_session_factory_configuration(): + env_helper = Mock() + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) + spark_configs = dict(spark_session_factory._get_spark_configs(is_training_job=False)) + jsc_hadoop_configs = dict(spark_session_factory._get_jsc_hadoop_configs()) + + # General optimizations + assert spark_configs.get("spark.hadoop.fs.s3a.aws.credentials.provider") == ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ) + + assert spark_configs.get("spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version") == "2" + assert ( + spark_configs.get("spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored") + == "true" + ) + assert spark_configs.get("spark.hadoop.parquet.enable.summary-metadata") == "false" + + assert spark_configs.get("spark.sql.parquet.mergeSchema") == "false" + assert spark_configs.get("spark.sql.parquet.filterPushdown") == "true" + assert spark_configs.get("spark.sql.hive.metastorePartitionPruning") == "true" + + assert spark_configs.get("spark.hadoop.fs.s3a.threads.max") == "500" + assert spark_configs.get("spark.hadoop.fs.s3a.connection.maximum") == "500" + assert spark_configs.get("spark.hadoop.fs.s3a.experimental.input.fadvise") == "normal" + assert spark_configs.get("spark.hadoop.fs.s3a.block.size") == "128M" + assert spark_configs.get("spark.hadoop.fs.s3a.fast.upload.buffer") == "disk" + assert spark_configs.get("spark.hadoop.fs.trash.interval") == "0" + assert spark_configs.get("spark.port.maxRetries") == "50" + + assert spark_configs.get("spark.test.key") == "spark.test.value" + + assert jsc_hadoop_configs.get("mapreduce.fileoutputcommitter.marksuccessfuljobs") == "false" + + # Verify configurations when not running on a training job + assert ",".join(feature_store_pyspark.classpath_jars()) in spark_configs.get("spark.jars") + assert ",".join( + [ + "org.apache.hadoop:hadoop-aws:3.3.1", + "org.apache.hadoop:hadoop-common:3.3.1", + ] + ) in spark_configs.get("spark.jars.packages") + + +def test_spark_session_factory_configuration_on_training_job(): + env_helper = Mock() + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) + + spark_config = spark_session_factory._get_spark_configs(is_training_job=True) + assert dict(spark_config).get("spark.test.key") == "spark.test.value" + + assert all(tup[0] != "spark.jars" for tup in spark_config) + assert all(tup[0] != "spark.jars.packages" for tup in spark_config) + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory(mock_spark_context): + env_helper = Mock() + env_helper.get_instance_count.return_value = 1 + spark_session_factory = SparkSessionFactory(env_helper) + + spark_session_factory.spark_session + + _, _, kw_args = mock_spark_context.mock_calls[0] + spark_conf = kw_args["conf"] + + mock_spark_context.assert_called_once() + assert spark_conf.get("spark.master") == "local[*]" + for cfg in spark_session_factory._get_spark_configs(True): + assert spark_conf.get(cfg[0]) == cfg[1] + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory_with_iceberg_config(mock_spark_context): + mock_env_helper = Mock() + mock_spark_context.side_effect = [Mock(), Mock()] + + spark_session_factory = SparkSessionFactory(mock_env_helper) + + spark_session = spark_session_factory.spark_session + + spark_session_with_iceberg_config = spark_session_factory.get_spark_session_with_iceberg_config( + "warehouse", "catalog" + ) + + assert spark_session is spark_session_with_iceberg_config + mock_spark_conf = spark_session._jvm.SparkSession().conf() + expected_calls = [ + call.set(cfg[0], cfg[1]) + for cfg in spark_session_factory._get_iceberg_configs("warehouse", "catalog") + ] + + mock_spark_conf.assert_has_calls(expected_calls, any_order=False) + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory_same_instance(mock_spark_context): + mock_env_helper = Mock() + mock_spark_context.side_effect = [Mock(), Mock()] + + spark_session_factory = SparkSessionFactory(mock_env_helper) + + a_reference = spark_session_factory.spark_session + another_reference = spark_session_factory.spark_session + + assert a_reference is another_reference + + +@patch("feature_store_pyspark.FeatureStoreManager.FeatureStoreManager") +def test_feature_store_manager_same_instance(mock_feature_store_manager): + mock_feature_store_manager.side_effect = [Mock(), Mock()] + + factory = FeatureStoreManagerFactory() + + assert factory.feature_store_manager is factory.feature_store_manager + + +def test_spark_session_factory_get_spark_session_with_iceberg_config(env_helper): + spark_session_factory = SparkSessionFactory(env_helper) + iceberg_configs = dict(spark_session_factory._get_iceberg_configs("s3://test/path", "Catalog")) + + assert ( + iceberg_configs.get("spark.sql.catalog.catalog") + == "smfs.shaded.org.apache.iceberg.spark.SparkCatalog" + ) + assert iceberg_configs.get("spark.sql.catalog.catalog.warehouse") == "s3://test/path" + assert ( + iceberg_configs.get("spark.sql.catalog.catalog.catalog-impl") + == "smfs.shaded.org.apache.iceberg.aws.glue.GlueCatalog" + ) + assert ( + iceberg_configs.get("spark.sql.catalog.catalog.io-impl") + == "smfs.shaded.org.apache.iceberg.aws.s3.S3FileIO" + ) + assert iceberg_configs.get("spark.sql.catalog.catalog.glue.skip-name-validation") == "true" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py new file mode 100644 index 0000000000..561cc09e76 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py @@ -0,0 +1,280 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch +from pyspark.sql import DataFrame, SparkSession + +from sagemaker.mlops.feature_store.feature_processor._input_loader import InputLoader +from sagemaker.mlops.feature_store.feature_processor._params_loader import ParamsLoader +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import SparkArgProvider +from sagemaker.mlops.feature_store.feature_processor._data_source import PySparkDataSource + + +@pytest.fixture +def params_loader(): + params_loader = Mock(ParamsLoader) + params_loader.get_parameter_args = Mock(return_value={"params": {"key": "value"}}) + return params_loader + + +@pytest.fixture +def feature_group_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def s3_uri_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def base_data_source_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def input_loader(feature_group_as_spark_df, s3_uri_as_spark_df): + input_loader = Mock(InputLoader) + input_loader.load_from_s3.return_value = s3_uri_as_spark_df + input_loader.load_from_feature_group.return_value = feature_group_as_spark_df + + return input_loader + + +@pytest.fixture +def spark_session(): + return Mock(SparkSession) + + +@pytest.fixture +def spark_session_factory(spark_session): + return Mock(SparkSessionFactory, spark_session=spark_session) + + +@pytest.fixture +def spark_arg_provider(params_loader, input_loader, spark_session_factory): + return SparkArgProvider(params_loader, input_loader, spark_session_factory) + + +class MockDataSource(PySparkDataSource): + + data_source_unique_id = "test_id" + data_source_name = "test_source" + + def read_data(self, spark, params) -> DataFrame: + return Mock(DataFrame) + + +def test_provide_additional_kw_args(spark_arg_provider, spark_session): + def udf(fg_input, s3_input, params, spark): + return None + + additional_kw_args = spark_arg_provider.provide_additional_kwargs(udf) + + assert additional_kw_args.keys() == {"spark"} + assert additional_kw_args["spark"] == spark_session + + +def test_not_provide_additional_kw_args(spark_arg_provider): + def udf(input, params): + return None + + additional_kw_args = spark_arg_provider.provide_additional_kwargs(udf) + + assert additional_kw_args == {} + + +def test_provide_params(spark_arg_provider, params_loader): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(fg_input, s3_input, params, spark): + return None + + params = spark_arg_provider.provide_params_arg(udf, fp_config) + + params_loader.get_parameter_args.assert_called_with(fp_config) + assert params == params_loader.get_parameter_args.return_value + + +def test_not_provide_params(spark_arg_provider, params_loader): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(fg_input, s3_input, spark): + return None + + params = spark_arg_provider.provide_params_arg(udf, fp_config) + + assert params == {} + + +def test_provide_input_args_with_no_input(spark_arg_provider): + fp_config = tdh.create_fp_config(inputs=[], output=tdh.OUTPUT_FEATURE_GROUP_ARN) + + def udf() -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, match="Expected at least one input to the user defined function." + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_extra_udf_parameters(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.INPUT_S3_URI], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, + match=r"The signature of the user defined function does not match the list of inputs requested." + r" Expected 1 parameter\(s\).", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_extra_fp_config_inputs(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None) -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, + match=r"The signature of the user defined function does not match the list of inputs requested." + r" Expected 2 parameter\(s\).", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args( + spark_arg_provider, + feature_group_as_spark_df, + s3_uri_as_spark_df, +): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == feature_group_as_spark_df + assert inputs["input_s3_uri"] == s3_uri_as_spark_df + + +def test_provide_input_args_with_reversed_inputs( + spark_arg_provider, + feature_group_as_spark_df, + s3_uri_as_spark_df, +): + fp_config = tdh.create_fp_config( + inputs=[tdh.S3_DATA_SOURCE, tdh.FEATURE_GROUP_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == s3_uri_as_spark_df + assert inputs["input_s3_uri"] == feature_group_as_spark_df + + +def test_provide_input_args_with_optional_args_out_of_order(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf_spark_params(spark=None, params=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_params_spark(params=None, spark=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_spark(spark=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_params(params=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + for udf in [udf_spark_params, udf_params_spark, udf_spark, udf_params]: + with pytest.raises( + ValueError, + match="Expected at least one input to the user defined function.", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_optional_args( + spark_arg_provider, feature_group_as_spark_df, s3_uri_as_spark_df +): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf_all_optional(input_fg=None, input_s3_uri=None, params=None, spark=None) -> DataFrame: + return Mock(DataFrame) + + def udf_no_optional(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_only_params(input_fg=None, input_s3_uri=None, params=None) -> DataFrame: + return Mock(DataFrame) + + def udf_only_spark(input_fg=None, input_s3_uri=None, spark=None) -> DataFrame: + return Mock(DataFrame) + + for udf in [udf_all_optional, udf_no_optional, udf_only_params, udf_only_spark]: + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == feature_group_as_spark_df + assert inputs["input_s3_uri"] == s3_uri_as_spark_df + + +def test_provide_input_arg_for_base_data_source(spark_arg_provider, params_loader, spark_session): + fp_config = tdh.create_fp_config(inputs=[MockDataSource()], output=tdh.OUTPUT_FEATURE_GROUP_ARN) + + def udf(input_df) -> DataFrame: + return input_df + + with patch.object(MockDataSource, "read_data", return_value=Mock(DataFrame)) as mock_read: + spark_arg_provider.provide_input_args(udf, fp_config) + mock_read.assert_called_with(spark=spark_session, params={"key": "value"}) + params_loader.get_parameter_args.assert_called_with(fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py new file mode 100644 index 0000000000..34770a011a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from feature_store_pyspark.FeatureStoreManager import FeatureStoreManager +from mock import Mock +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor import IngestionError +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, +) +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + SparkOutputReceiver, +) + + +@pytest.fixture +def df() -> Mock: + return Mock(DataFrame) + + +@pytest.fixture +def feature_store_manager(): + return Mock(FeatureStoreManager) + + +@pytest.fixture +def feature_store_manager_factory(feature_store_manager): + return Mock(FeatureStoreManagerFactory, feature_store_manager=feature_store_manager) + + +@pytest.fixture +def spark_output_receiver(feature_store_manager_factory): + return SparkOutputReceiver(feature_store_manager_factory) + + +def test_ingest_udf_output_enable_ingestion_false(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config(enable_ingestion=False) + spark_output_receiver.ingest_udf_output(df, fp_config) + + feature_store_manager.ingest_data.assert_not_called() + + +def test_ingest_udf_output(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + spark_output_receiver.ingest_udf_output(df, fp_config) + + feature_store_manager.ingest_data.assert_called_with( + input_data_frame=df, + feature_group_arn=fp_config.output, + target_stores=fp_config.target_stores, + ) + + +def test_ingest_udf_output_failed_records(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + + # Simulate streaming ingestion failure. + mock_failed_records_df = Mock() + mock_java_exception = Mock(_target_id="") + mock_java_exception.getClass = Mock( + return_value=Mock(getSimpleName=Mock(return_value="StreamIngestionFailureException")) + ) + + feature_store_manager.ingest_data.side_effect = Py4JJavaError( + msg="", java_exception=mock_java_exception + ) + feature_store_manager.get_failed_stream_ingestion_data_frame.return_value = ( + mock_failed_records_df + ) + + with pytest.raises(IngestionError): + spark_output_receiver.ingest_udf_output(df, fp_config) + + mock_failed_records_df.show.assert_called_with(n=20, truncate=False) + + +def test_ingest_udf_output_all_py4j_error_raised(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + + # Simulate ingestion failure. + mock_java_exception = Mock(_target_id="") + mock_java_exception.getClass = Mock( + return_value=Mock(getSimpleName=Mock(return_value="ValidationError")) + ) + feature_store_manager.ingest_data.side_effect = Py4JJavaError( + msg="", java_exception=mock_java_exception + ) + + with pytest.raises(Py4JJavaError): + spark_output_receiver.ingest_udf_output(df, fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py new file mode 100644 index 0000000000..51747d78f9 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable + +import pytest +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import UDFArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + UDFOutputReceiver, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper + + +@pytest.fixture +def udf_arg_provider(): + udf_arg_provider = Mock(UDFArgProvider) + udf_arg_provider.provide_input_args.return_value = {"input": Mock()} + udf_arg_provider.provide_params_arg.return_value = {"params": Mock()} + udf_arg_provider.provide_additional_kwargs.return_value = {"kwarg": Mock()} + + return udf_arg_provider + + +@pytest.fixture +def udf_output_receiver(): + udf_output_receiver = Mock(UDFOutputReceiver) + udf_output_receiver.ingest_udf_output.return_value = Mock() + return udf_output_receiver + + +@pytest.fixture +def udf_output(): + udf_output = Mock(Callable) + return udf_output + + +@pytest.fixture +def udf(udf_output): + udf = Mock(Callable) + udf.return_value = udf_output + return udf + + +@pytest.fixture +def fp_config(): + fp_config = Mock(FeatureProcessorConfig) + return fp_config + + +def test_wrap(fp_config, udf_output, udf_arg_provider, udf_output_receiver): + def test_udf(input, params, kwarg): + # Verify wrapped function is called with auto-loaded arguments. + assert input is udf_arg_provider.provide_input_args.return_value["input"] + assert params is udf_arg_provider.provide_params_arg.return_value["params"] + assert kwarg is udf_arg_provider.provide_additional_kwargs.return_value["kwarg"] + return udf_output + + udf_wrapper = UDFWrapper(udf_arg_provider, udf_output_receiver) + + # Execute decorator function and the decorated function. + wrapped_udf = udf_wrapper.wrap(test_udf, fp_config) + wrapped_udf() + + # Verify interactions with dependencies. + udf_arg_provider.provide_input_args.assert_called_with(test_udf, fp_config) + udf_arg_provider.provide_params_arg.assert_called_with(test_udf, fp_config) + udf_arg_provider.provide_additional_kwargs.assert_called_with(test_udf) + udf_output_receiver.ingest_udf_output.assert_called_with(udf_output, fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py new file mode 100644 index 0000000000..16a8784586 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable +from pyspark.sql import DataFrame + +import pytest + +import test_data_helpers as tdh +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._validation import ( + SparkUDFSignatureValidator, + Validator, + ValidatorChain, + BaseDataSourceValidator, +) +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + BaseDataSource, +) + + +def test_validator_chain(): + fp_config = tdh.create_fp_config() + udf = Mock(Callable) + + first_validator = Mock(Validator) + second_validator = Mock(Validator) + validator_chain = ValidatorChain([first_validator, second_validator]) + + validator_chain.validate(udf, fp_config) + + first_validator.validate.assert_called_with(udf, fp_config) + second_validator.validate.assert_called_with(udf, fp_config) + + +def test_validator_chain_validation_fails(): + fp_config = tdh.create_fp_config() + udf = Mock(Callable) + + first_validator = Mock(validate=Mock(side_effect=ValueError())) + second_validator = Mock(validate=Mock()) + validator_chain = ValidatorChain([first_validator, second_validator]) + + with pytest.raises(ValueError): + validator_chain.validate(udf, fp_config) + + +def test_spark_udf_signature_validator_valid(): + # One Input + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE]) + + def one_data_source(fg_data_source, params, spark): + return None + + SparkUDFSignatureValidator().validate(one_data_source, fp_config) + + # Two Inputs + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def two_data_sources(fg_data_source, s3_data_source, params, spark): + return None + + SparkUDFSignatureValidator().validate(two_data_sources, fp_config) + + # No Optional Args (params and spark) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_args(fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(no_optional_args, fp_config) + + # Optional Args (no params) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_params_arg(fg_data_source, s3_data_source, spark): + return None + + SparkUDFSignatureValidator().validate(no_optional_params_arg, fp_config) + + # No Optional Args (no spark) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_spark_arg(fg_data_source, s3_data_source, params): + return None + + SparkUDFSignatureValidator().validate(no_optional_spark_arg, fp_config) + + +def test_spark_udf_signature_validator_udf_input_mismatch(): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def one_input(one, params, spark): + return None + + def three_inputs(one, two, three, params, spark): + return None + + exception_string = ( + r"feature_processor expected a function with \(2\) parameter\(s\) before any" + r" optional 'params' or 'spark' parameters for the \(2\) requested data source\(s\)\." + ) + + with pytest.raises(ValueError, match=exception_string): + SparkUDFSignatureValidator().validate(one_input, fp_config) + + with pytest.raises(ValueError, match=exception_string): + SparkUDFSignatureValidator().validate(three_inputs, fp_config) + + +def test_spark_udf_signature_validator_zero_input_params(): + def zero_inputs(params, spark): + return None + + with pytest.raises(ValueError, match="feature_processor expects at least 1 input parameter."): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + SparkUDFSignatureValidator().validate(zero_inputs, fp_config) + + +def test_spark_udf_signature_validator_udf_invalid_non_input_position(): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + with pytest.raises( + ValueError, + match="feature_processor expected the 'params' parameter to be the last or second last" + " parameter after input parameters.", + ): + + def invalid_params_position(params, fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(invalid_params_position, fp_config) + + with pytest.raises( + ValueError, + match="feature_processor expected the 'spark' parameter to be the last or second last" + " parameter after input parameters.", + ): + + def invalid_spark_position(spark, fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(invalid_spark_position, fp_config) + + +@pytest.mark.parametrize( + "data_source_name, data_source_unique_id, error_pattern", + [ + ("$_invalid_source", "unique_id", "data_source_name of input does not match pattern '.*'."), + ("", "unique_id", "data_source_name of input does not match pattern '.*'."), + ( + "source", + tdh.DATA_SOURCE_UNIQUE_ID_TOO_LONG, + "data_source_unique_id of input does not match pattern '.*'.", + ), + ("source", "", "data_source_unique_id of input does not match pattern '.*'."), + ], +) +def test_spark_udf_signature_validator_udf_invalid_base_data_source( + data_source_name, data_source_unique_id, error_pattern +): + class TestInValidCustomDataSource(BaseDataSource): + + data_source_name = None + data_source_unique_id = None + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestInValidCustomDataSource() + test_data_source.data_source_name = data_source_name + test_data_source.data_source_unique_id = data_source_unique_id + + fp_config = tdh.create_fp_config(inputs=[test_data_source]) + + def udf(input_data_source, params, spark): + return None + + with pytest.raises(ValueError, match=error_pattern): + BaseDataSourceValidator().validate(udf, fp_config) From e91d8d2fd69d28b52feb0b5f484a1d05cb210b76 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Thu, 19 Feb 2026 16:56:30 -0800 Subject: [PATCH 03/21] integ tests --- .../feature_processor/car-data.csv | 27 + sagemaker-mlops/tests/integ/__init__.py | 5 + .../test_feature_processor_integ.py | 1352 +++++++++++++++++ 3 files changed, 1384 insertions(+) create mode 100644 sagemaker-mlops/tests/data/feature_store/feature_processor/car-data.csv create mode 100644 sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py diff --git a/sagemaker-mlops/tests/data/feature_store/feature_processor/car-data.csv b/sagemaker-mlops/tests/data/feature_store/feature_processor/car-data.csv new file mode 100644 index 0000000000..55e0cbcd17 --- /dev/null +++ b/sagemaker-mlops/tests/data/feature_store/feature_processor/car-data.csv @@ -0,0 +1,27 @@ +Id,Model,Year,Status,Mileage,Price,MSRP +0,2022 Acura TLX A-Spec,2022,New,Not available,"$49,445.0","MSRP $49,445.0" +1,2023 Acura RDX A-Spec,2023,New,Not available,"$50,895.0",Not specified +2,2023 Acura TLX Type S,2023,New,Not available,"$57,745.0",Not specified +3,2023 Acura TLX Type S,2023,New,Not available,"$57,545.0",Not specified +4,2019 Acura MDX Sport Hybrid 3.0L w/Technology Package,2019,Used,"32,675.0","$40,990.0",$600.0 +5,2023 Acura TLX A-Spec,2023,New,Not available,"$50,195.0","MSRP $50,195.0" +6,2023 Acura TLX A-Spec,2023,New,Not available,"$50,195.0","MSRP $50,195.0" +7,2023 Acura TLX Type S,2023,New,Not available,"$57,745.0",Not specified +8,2023 Acura TLX A-Spec,2023,New,Not available,"$47,995.0",Not specified +9,2022 Acura TLX A-Spec,2022,New,Not available,"$49,545.0",Not specified +10,2023 Acura Integra w/A-Spec Tech Package,2023,New,Not available,"$36,895.0","MSRP $36,895.0" +11,2023 Acura TLX A-Spec,2023,New,Not available,"$48,395.0","MSRP $48,395.0" +12,2023 Acura MDX Type S w/Advance Package,2023,New,Not available,"$75,590.0",Not specified +13,2023 Acura RDX A-Spec Advance,2023,New,Not available,"$55,345.0",Not specified +14,2023 Acura TLX A-Spec,2023,New,Not available,"$50,195.0","MSRP $50,195" +15,2023 Acura RDX A-Spec Advance,2023,New,Not available,"$55,045.0",Not specified +16,2023 Acura TLX Type S,2023,New,Not available,"$56,445.0",Not specified +17,2023 Acura TLX A-Spec,2023,New,Not available,"$47,495.0","MSRP $47,495.0" +18,2023 Acura TLX Advance,2023,New,Not available,"$52,245.0","MSRP $52,245.0" +19,2023 Acura TLX A-Spec,2023,New,Not available,"$50,595.0","MSRP $50,595.0" +20,2023 Acura RDX Base,2023,New,Not available,"$43,045.0",Not specified +21,2023 Acura RDX A-Spec,2023,New,Not available,"$51,195.0",Not specified +22,2023 Acura RDX A-Spec,2023,New,Not available,"$50,895.0",Not specified +23,2023 Acura TLX A-Spec,2023,New,Not available,"$50,195.0","MSRP $50,195.0" +24,2023 Acura RDX A-Spec,2023,New,Not available,"$50,895.0",Not specified +25,2022 Acura MDX Type S,2022,New,Not available,"$68,245.0",Not specified \ No newline at end of file diff --git a/sagemaker-mlops/tests/integ/__init__.py b/sagemaker-mlops/tests/integ/__init__.py index 8573814647..ca83f0a2c5 100644 --- a/sagemaker-mlops/tests/integ/__init__.py +++ b/sagemaker-mlops/tests/integ/__init__.py @@ -1 +1,6 @@ """Integration tests for SageMaker V3 pipeline examples.""" +from __future__ import absolute_import + +import os + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") \ No newline at end of file diff --git a/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py b/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py new file mode 100644 index 0000000000..a49910cf07 --- /dev/null +++ b/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py @@ -0,0 +1,1352 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import logging +import os +import subprocess +import sys +import time +from typing import Dict +from datetime import datetime +from pyspark.sql import DataFrame +import pytz + +import pytest +import pandas as pd +import numpy as np +import json +import attr +from boto3 import client + +from tests.integ import DATA_DIR +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core.s3 import S3Uploader +from urllib.parse import urlparse +from sagemaker.core.remote_function import remote +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.mlops.feature_store import ( + FeatureGroup, + FeatureDefinition, + FeatureTypeEnum, + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, +) +from sagemaker.mlops.feature_store.feature_utils import create_athena_query +from sagemaker.mlops.feature_store.feature_processor import ( + feature_processor, + CSVDataSource, + PySparkDataSource, + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + to_pipeline, + describe, + execute, + schedule, + delete_schedule, + put_trigger, + enable_trigger, + disable_trigger, + delete_trigger, +) +from sagemaker.mlops.workflow.pipeline import Pipeline + +CAR_SALES_FG_FEATURE_DEFINITIONS = [ + FeatureDefinition(feature_name="id", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="model", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="model_year", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="status", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="mileage", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="price", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="msrp", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="ingest_time", feature_type=FeatureTypeEnum.FRACTIONAL.value), +] +CAR_SALES_FG_RECORD_IDENTIFIER_NAME = "id" +CAR_SALES_FG_EVENT_TIME_FEATURE_NAME = "ingest_time" + +AGG_CAR_SALES_FG_FEATURE_DEFINITIONS = [ + FeatureDefinition(feature_name="model_year_status", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="avg_mileage", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="max_mileage", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="avg_price", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="max_price", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="avg_msrp", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="max_msrp", feature_type=FeatureTypeEnum.STRING.value), + FeatureDefinition(feature_name="ingest_time", feature_type=FeatureTypeEnum.FRACTIONAL.value), +] +AGG_CAR_SALES_FG_RECORD_IDENTIFIER_NAME = "model_year_status" +AGG_CAR_SALES_FG_EVENT_TIME_FEATURE_NAME = "ingest_time" + +BUCKET_POLICY = { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "FeatureStoreOfflineStoreS3BucketPolicy", + "Effect": "Allow", + "Principal": {"Service": "sagemaker.amazonaws.com"}, + "Action": ["s3:PutObject", "s3:PutObjectAcl"], + "Resource": "arn:aws:s3:::{bucket_name}-{region_name}/*", + "Condition": {"StringEquals": {"s3:x-amz-acl": "bucket-owner-full-control"}}, + }, + { + "Sid": "FeatureStoreOfflineStoreS3BucketPolicy", + "Effect": "Allow", + "Principal": {"Service": "sagemaker.amazonaws.com"}, + "Action": "s3:GetBucketAcl", + "Resource": "arn:aws:s3:::{bucket_name}-{region_name}", + }, + ], +} + +_FEATURE_PROCESSOR_DIR = os.path.join(DATA_DIR, "feature_store/feature_processor") + +SCHEDULE_EXPRESSION_TIMESTAMP_FORMAT = "%Y-%m-%dT%H:%M:%S" # 2023-01-01T07:00:00 + + +@pytest.fixture(scope="module") +def sagemaker_session(): + return Session() + + +@pytest.mark.slow_test +def test_feature_processor_transform_online_only_store_ingestion( + sagemaker_session, +): + car_data_feature_group_name = get_car_data_feature_group_name() + car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() + try: + feature_groups = create_feature_groups( + sagemaker_session=sagemaker_session, + car_data_feature_group_name=car_data_feature_group_name, + car_data_aggregated_feature_group_name=car_data_aggregated_feature_group_name, + offline_store_s3_uri=get_offline_store_s3_uri(sagemaker_session=sagemaker_session), + ) + + raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) + + print("About to apply @feature_processor decorator...") + + @feature_processor( + inputs=[CSVDataSource(raw_data_uri)], + output=feature_groups["car_data_arn"], + target_stores=["OnlineStore"], + ) + def transform(raw_s3_data_as_df): + """Load data from S3, perform basic feature engineering, store it in a Feature Group""" + from pyspark.sql.functions import regexp_replace + from pyspark.sql.functions import lit + + transformed_df = ( + raw_s3_data_as_df + # Rename Columns + .withColumnRenamed("Id", "id") + .withColumnRenamed("Model", "model") + .withColumnRenamed("Year", "model_year") + .withColumnRenamed("Status", "status") + .withColumnRenamed("Mileage", "mileage") + .withColumnRenamed("Price", "price") + .withColumnRenamed("MSRP", "msrp") + # Add Event Time + .withColumn("ingest_time", lit(int(time.time()))) + # Remove punctuation and fluff; replace with NA + .withColumn("Price", regexp_replace("Price", "\$", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "(,)|(mi\.)", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "Not available", "NA")) + .withColumn("price", regexp_replace("price", ",", "")) + .withColumn("msrp", regexp_replace("msrp", "(^MSRP\s\\$)|(,)", "")) # noqa: W605 + .withColumn("msrp", regexp_replace("msrp", "Not specified", "NA")) + .withColumn("msrp", regexp_replace("msrp", "\\$\d+[a-zA-Z\s]+", "NA")) # noqa: W605 + .withColumn("model", regexp_replace("model", "^\d\d\d\d\s", "")) # noqa: W605 + ) + + transformed_df.show() + return transformed_df + + print("Decorator applied. About to call transform()...") + transform() + print("transform() completed.") + + featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client + results = featurestore_client.batch_get_record( + Identifiers=[ + { + "FeatureGroupName": car_data_feature_group_name, + "RecordIdentifiersValueAsString": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + ], + }, + ] + ) + + assert len(results["Records"]) == 26 + + car_sales_query = create_athena_query(feature_group_name=car_data_feature_group_name, session=sagemaker_session) + query = f'SELECT * FROM "sagemaker_featurestore".{car_sales_query.table_name} LIMIT 1000;' + output_uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "feature-processor-test", + "csv-data-fg-result", + ) + car_sales_query.run(query_string=query, output_location=output_uri) + car_sales_query.wait() + dataset = car_sales_query.as_dataframe() + assert dataset.empty + finally: + cleanup_offline_store( + feature_group=feature_groups["car_data_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_offline_store( + feature_group=feature_groups["car_data_aggregated_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_feature_group( + feature_groups["car_data_feature_group"], sagemaker_session=sagemaker_session + ) + cleanup_feature_group( + feature_groups["car_data_aggregated_feature_group"], sagemaker_session=sagemaker_session + ) + + +@pytest.mark.slow_test +def test_feature_processor_transform_with_customized_data_source( + sagemaker_session, +): + car_data_feature_group_name = get_car_data_feature_group_name() + car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() + + try: + feature_groups = create_feature_groups( + sagemaker_session=sagemaker_session, + car_data_feature_group_name=car_data_feature_group_name, + car_data_aggregated_feature_group_name=car_data_aggregated_feature_group_name, + offline_store_s3_uri=get_offline_store_s3_uri(sagemaker_session=sagemaker_session), + ) + + raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) + + @attr.s + class TestCSVDataSource(PySparkDataSource): + + s3_uri = attr.ib() + data_source_name = "TestCSVDataSource" + data_source_unique_id = "s3_uri" + + def read_data(self, spark, params) -> DataFrame: + s3a_uri = self.s3_uri.replace("s3://", "s3a://") + return spark.read.csv(s3a_uri, header=True, inferSchema=False) + + @feature_processor( + inputs=[TestCSVDataSource(raw_data_uri)], + output=feature_groups["car_data_arn"], + target_stores=["OnlineStore"], + spark_config={ + "spark.hadoop.fs.s3a.aws.credentials.provider": ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ) + }, + ) + def transform(raw_s3_data_as_df): + """Load data from S3, perform basic feature engineering, store it in a Feature Group""" + from pyspark.sql.functions import regexp_replace + from pyspark.sql.functions import lit + + transformed_df = ( + raw_s3_data_as_df + # Rename Columns + .withColumnRenamed("Id", "id") + .withColumnRenamed("Model", "model") + .withColumnRenamed("Year", "model_year") + .withColumnRenamed("Status", "status") + .withColumnRenamed("Mileage", "mileage") + .withColumnRenamed("Price", "price") + .withColumnRenamed("MSRP", "msrp") + # Add Event Time + .withColumn("ingest_time", lit(int(time.time()))) + # Remove punctuation and fluff; replace with NA + .withColumn("Price", regexp_replace("Price", "\$", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "(,)|(mi\.)", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "Not available", "NA")) + .withColumn("price", regexp_replace("price", ",", "")) + .withColumn("msrp", regexp_replace("msrp", "(^MSRP\s\\$)|(,)", "")) # noqa: W605 + .withColumn("msrp", regexp_replace("msrp", "Not specified", "NA")) + .withColumn("msrp", regexp_replace("msrp", "\\$\d+[a-zA-Z\s]+", "NA")) # noqa: W605 + .withColumn("model", regexp_replace("model", "^\d\d\d\d\s", "")) # noqa: W605 + ) + + transformed_df.show() + return transformed_df + + transform() + + featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client + results = featurestore_client.batch_get_record( + Identifiers=[ + { + "FeatureGroupName": car_data_feature_group_name, + "RecordIdentifiersValueAsString": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + ], + }, + ] + ) + + assert len(results["Records"]) == 26 + + car_sales_query = create_athena_query(feature_group_name=car_data_feature_group_name, session=sagemaker_session) + query = f'SELECT * FROM "sagemaker_featurestore".{car_sales_query.table_name} LIMIT 1000;' + output_uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "feature-processor-test", + "csv-data-fg-result", + ) + car_sales_query.run(query_string=query, output_location=output_uri) + car_sales_query.wait() + dataset = car_sales_query.as_dataframe() + assert dataset.empty + finally: + cleanup_offline_store( + feature_group=feature_groups["car_data_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_offline_store( + feature_group=feature_groups["car_data_aggregated_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_feature_group( + feature_groups["car_data_feature_group"], sagemaker_session=sagemaker_session + ) + cleanup_feature_group( + feature_groups["car_data_aggregated_feature_group"], sagemaker_session=sagemaker_session + ) + + +@pytest.mark.slow_test +@pytest.mark.flaky(reruns=5, reruns_delay=2) +def test_feature_processor_transform_offline_only_store_ingestion( + sagemaker_session, +): + car_data_feature_group_name = get_car_data_feature_group_name() + car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() + try: + feature_groups = create_feature_groups( + sagemaker_session=sagemaker_session, + car_data_feature_group_name=car_data_feature_group_name, + car_data_aggregated_feature_group_name=car_data_aggregated_feature_group_name, + offline_store_s3_uri=get_offline_store_s3_uri(sagemaker_session=sagemaker_session), + ) + + raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) + + @feature_processor( + inputs=[CSVDataSource(raw_data_uri)], + output=feature_groups["car_data_arn"], + target_stores=["OfflineStore"], + ) + def transform(raw_s3_data_as_df): + """Load data from S3, perform basic feature engineering, store it in a Feature Group""" + from pyspark.sql.functions import regexp_replace + from pyspark.sql.functions import lit + + transformed_df = ( + raw_s3_data_as_df + # Rename Columns + .withColumnRenamed("Id", "id") + .withColumnRenamed("Model", "model") + .withColumnRenamed("Year", "model_year") + .withColumnRenamed("Status", "status") + .withColumnRenamed("Mileage", "mileage") + .withColumnRenamed("Price", "price") + .withColumnRenamed("MSRP", "msrp") + # Add Event Time + .withColumn("ingest_time", lit(int(time.time()))) + # Remove punctuation and fluff; replace with NA + .withColumn("Price", regexp_replace("Price", "\$", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "(,)|(mi\.)", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "Not available", "NA")) + .withColumn("price", regexp_replace("price", ",", "")) + .withColumn("msrp", regexp_replace("msrp", "(^MSRP\s\\$)|(,)", "")) # noqa: W605 + .withColumn("msrp", regexp_replace("msrp", "Not specified", "NA")) + .withColumn("msrp", regexp_replace("msrp", "\\$\d+[a-zA-Z\s]+", "NA")) # noqa: W605 + .withColumn("model", regexp_replace("model", "^\d\d\d\d\s", "")) # noqa: W605 + ) + + transformed_df.show() + return transformed_df + + transform() + + featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client + results = featurestore_client.batch_get_record( + Identifiers=[ + { + "FeatureGroupName": car_data_feature_group_name, + "RecordIdentifiersValueAsString": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + ], + }, + ] + ) + + assert len(results["Records"]) == 0 + + car_sales_query = create_athena_query(feature_group_name=car_data_feature_group_name, session=sagemaker_session) + query = f'SELECT * FROM "sagemaker_featurestore".{car_sales_query.table_name} LIMIT 1000;' + output_uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "feature-processor-test", + "csv-data-fg-result", + ) + car_sales_query.run(query_string=query, output_location=output_uri) + car_sales_query.wait() + dataset = car_sales_query.as_dataframe() + dataset = dataset.drop( + columns=["ingest_time", "write_time", "api_invocation_time", "is_deleted"] + ) + + assert dataset.equals(get_expected_dataframe()) + finally: + cleanup_offline_store( + feature_group=feature_groups["car_data_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_offline_store( + feature_group=feature_groups["car_data_aggregated_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_feature_group( + feature_groups["car_data_feature_group"], sagemaker_session=sagemaker_session + ) + cleanup_feature_group( + feature_groups["car_data_aggregated_feature_group"], sagemaker_session=sagemaker_session + ) + + +@pytest.mark.slow_test +@pytest.mark.skipif( + not sys.version.startswith("3.9"), + reason="Only allow this test to run with py39", +) +def test_feature_processor_transform_offline_only_store_ingestion_run_with_remote( + sagemaker_session, +): + car_data_feature_group_name = get_car_data_feature_group_name() + car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() + try: + feature_groups = create_feature_groups( + sagemaker_session=sagemaker_session, + car_data_feature_group_name=car_data_feature_group_name, + car_data_aggregated_feature_group_name=car_data_aggregated_feature_group_name, + offline_store_s3_uri=get_offline_store_s3_uri(sagemaker_session=sagemaker_session), + ) + + raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) + whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) + whl_file_name = os.path.basename(whl_file_uri) + + pre_execution_commands = [ + f"aws s3 cp {whl_file_uri} ./", + f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", + ] + + @remote( + pre_execution_commands=pre_execution_commands, + spark_config=SparkConfig(), + instance_type="ml.m5.xlarge", + ) + @feature_processor( + inputs=[CSVDataSource(raw_data_uri)], + output=feature_groups["car_data_arn"], + target_stores=["OfflineStore"], + ) + def transform(raw_s3_data_as_df): + """Load data from S3, perform basic feature engineering, store it in a Feature Group""" + from pyspark.sql.functions import regexp_replace + from pyspark.sql.functions import lit + + transformed_df = ( + raw_s3_data_as_df + # Rename Columns + .withColumnRenamed("Id", "id") + .withColumnRenamed("Model", "model") + .withColumnRenamed("Year", "model_year") + .withColumnRenamed("Status", "status") + .withColumnRenamed("Mileage", "mileage") + .withColumnRenamed("Price", "price") + .withColumnRenamed("MSRP", "msrp") + # Add Event Time + .withColumn("ingest_time", lit(int(time.time()))) + # Remove punctuation and fluff; replace with NA + .withColumn("Price", regexp_replace("Price", "\$", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "(,)|(mi\.)", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "Not available", "NA")) + .withColumn("price", regexp_replace("price", ",", "")) + .withColumn("msrp", regexp_replace("msrp", "(^MSRP\s\\$)|(,)", "")) # noqa: W605 + .withColumn("msrp", regexp_replace("msrp", "Not specified", "NA")) + .withColumn("msrp", regexp_replace("msrp", "\\$\d+[a-zA-Z\s]+", "NA")) # noqa: W605 + .withColumn("model", regexp_replace("model", "^\d\d\d\d\s", "")) # noqa: W605 + ) + + transformed_df.show() + return transformed_df + + transform() + + featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client + results = featurestore_client.batch_get_record( + Identifiers=[ + { + "FeatureGroupName": car_data_feature_group_name, + "RecordIdentifiersValueAsString": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + ], + }, + ] + ) + + assert len(results["Records"]) == 0 + + car_sales_query = create_athena_query(feature_group_name=car_data_feature_group_name, session=sagemaker_session) + query = f'SELECT * FROM "sagemaker_featurestore".{car_sales_query.table_name} LIMIT 1000;' + output_uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "feature-processor-test", + "csv-data-fg-result", + ) + car_sales_query.run(query_string=query, output_location=output_uri) + car_sales_query.wait() + dataset = car_sales_query.as_dataframe() + dataset = dataset.drop( + columns=["ingest_time", "write_time", "api_invocation_time", "is_deleted"] + ) + + assert dataset.equals(get_expected_dataframe()) + finally: + cleanup_offline_store( + feature_group=feature_groups["car_data_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_offline_store( + feature_group=feature_groups["car_data_aggregated_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_feature_group( + feature_groups["car_data_feature_group"], sagemaker_session=sagemaker_session + ) + cleanup_feature_group( + feature_groups["car_data_aggregated_feature_group"], sagemaker_session=sagemaker_session + ) + + +@pytest.mark.slow_test +@pytest.mark.skipif( + not sys.version.startswith("3.9"), + reason="Only allow this test to run with py39", +) +def test_to_pipeline_and_execute( + sagemaker_session, +): + pipeline_name = "pipeline-name-01" + car_data_feature_group_name = get_car_data_feature_group_name() + car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() + try: + feature_groups = create_feature_groups( + sagemaker_session=sagemaker_session, + car_data_feature_group_name=car_data_feature_group_name, + car_data_aggregated_feature_group_name=car_data_aggregated_feature_group_name, + offline_store_s3_uri=get_offline_store_s3_uri(sagemaker_session=sagemaker_session), + ) + + raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) + whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) + whl_file_name = os.path.basename(whl_file_uri) + + pre_execution_commands = [ + f"aws s3 cp {whl_file_uri} ./", + f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", + ] + + @remote( + pre_execution_commands=pre_execution_commands, + spark_config=SparkConfig(), + instance_type="ml.m5.xlarge", + ) + @feature_processor( + inputs=[CSVDataSource(raw_data_uri)], + output=feature_groups["car_data_arn"], + target_stores=["OfflineStore"], + ) + def transform(raw_s3_data_as_df): + """Load data from S3, perform basic feature engineering, store it in a Feature Group""" + from pyspark.sql.functions import regexp_replace + from pyspark.sql.functions import lit + + transformed_df = ( + raw_s3_data_as_df + # Rename Columns + .withColumnRenamed("Id", "id") + .withColumnRenamed("Model", "model") + .withColumnRenamed("Year", "model_year") + .withColumnRenamed("Status", "status") + .withColumnRenamed("Mileage", "mileage") + .withColumnRenamed("Price", "price") + .withColumnRenamed("MSRP", "msrp") + # Add Event Time + .withColumn("ingest_time", lit(int(time.time()))) + # Remove punctuation and fluff; replace with NA + .withColumn("Price", regexp_replace("Price", "\$", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "(,)|(mi\.)", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "Not available", "NA")) + .withColumn("price", regexp_replace("price", ",", "")) + .withColumn("msrp", regexp_replace("msrp", "(^MSRP\s\\$)|(,)", "")) # noqa: W605 + .withColumn("msrp", regexp_replace("msrp", "Not specified", "NA")) + .withColumn("msrp", regexp_replace("msrp", "\\$\d+[a-zA-Z\s]+", "NA")) # noqa: W605 + .withColumn("model", regexp_replace("model", "^\d\d\d\d\s", "")) # noqa: W605 + ) + + transformed_df.show() + return transformed_df + + _wait_for_feature_group_lineage_contexts( + car_data_feature_group_name, sagemaker_session + ) + + pipeline_arn = to_pipeline( + pipeline_name=pipeline_name, + step=transform, + role=get_execution_role(sagemaker_session), + max_retries=2, + tags=[("integ_test_tag_key_1", "integ_test_tag_key_2")], + sagemaker_session=sagemaker_session, + ) + _sagemaker_client = get_sagemaker_client(sagemaker_session=sagemaker_session) + + assert pipeline_arn is not None + + tags = _sagemaker_client.list_tags(ResourceArn=pipeline_arn)["Tags"] + + tag_keys = [tag["Key"] for tag in tags] + assert "integ_test_tag_key_1" in tag_keys + + pipeline_description = Pipeline(name=pipeline_name).describe() + assert pipeline_arn == pipeline_description["PipelineArn"] + assert get_execution_role(sagemaker_session) == pipeline_description["RoleArn"] + + pipeline_definition = json.loads(pipeline_description["PipelineDefinition"]) + assert len(pipeline_definition["Steps"]) == 1 + for retry_policy in pipeline_definition["Steps"][0]["RetryPolicies"]: + assert retry_policy["MaxAttempts"] == 2 + + pipeline_execution_arn = execute( + pipeline_name=pipeline_name, sagemaker_session=sagemaker_session + ) + + status = _wait_for_pipeline_execution_to_reach_terminal_state( + pipeline_execution_arn=pipeline_execution_arn, + sagemaker_client=_sagemaker_client, + ) + assert status == "Succeeded" + + finally: + cleanup_offline_store( + feature_group=feature_groups["car_data_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_offline_store( + feature_group=feature_groups["car_data_aggregated_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_feature_group( + feature_groups["car_data_feature_group"], sagemaker_session=sagemaker_session + ) + cleanup_feature_group( + feature_groups["car_data_aggregated_feature_group"], sagemaker_session=sagemaker_session + ) + # cleanup_pipeline(pipeline_name="pipeline-name-01", sagemaker_session=sagemaker_session) + + +@pytest.mark.slow_test +@pytest.mark.skipif( + not sys.version.startswith("3.9"), + reason="Only allow this test to run with py39", +) +def test_schedule_and_event_trigger( + sagemaker_session, +): + pipeline_name = "pipeline-name-01" + car_data_feature_group_name = get_car_data_feature_group_name() + car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() + try: + feature_groups = create_feature_groups( + sagemaker_session=sagemaker_session, + car_data_feature_group_name=car_data_feature_group_name, + car_data_aggregated_feature_group_name=car_data_aggregated_feature_group_name, + offline_store_s3_uri=get_offline_store_s3_uri(sagemaker_session=sagemaker_session), + ) + + raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) + whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) + whl_file_name = os.path.basename(whl_file_uri) + + pre_execution_commands = [ + f"aws s3 cp {whl_file_uri} ./", + f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", + ] + + @remote( + pre_execution_commands=pre_execution_commands, + spark_config=SparkConfig(), + instance_type="ml.m5.xlarge", + ) + @feature_processor( + inputs=[CSVDataSource(raw_data_uri)], + output=feature_groups["car_data_arn"], + target_stores=["OfflineStore"], + ) + def transform(raw_s3_data_as_df): + """Load data from S3, perform basic feature engineering, store it in a Feature Group""" + from pyspark.sql.functions import regexp_replace + from pyspark.sql.functions import lit + + transformed_df = ( + raw_s3_data_as_df + # Rename Columns + .withColumnRenamed("Id", "id") + .withColumnRenamed("Model", "model") + .withColumnRenamed("Year", "model_year") + .withColumnRenamed("Status", "status") + .withColumnRenamed("Mileage", "mileage") + .withColumnRenamed("Price", "price") + .withColumnRenamed("MSRP", "msrp") + # Add Event Time + .withColumn("ingest_time", lit(int(time.time()))) + # Remove punctuation and fluff; replace with NA + .withColumn("Price", regexp_replace("Price", "\$", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "(,)|(mi\.)", "")) # noqa: W605 + .withColumn("mileage", regexp_replace("mileage", "Not available", "NA")) + .withColumn("price", regexp_replace("price", ",", "")) + .withColumn("msrp", regexp_replace("msrp", "(^MSRP\s\\$)|(,)", "")) # noqa: W605 + .withColumn("msrp", regexp_replace("msrp", "Not specified", "NA")) + .withColumn("msrp", regexp_replace("msrp", "\\$\d+[a-zA-Z\s]+", "NA")) # noqa: W605 + .withColumn("model", regexp_replace("model", "^\d\d\d\d\s", "")) # noqa: W605 + ) + + transformed_df.show() + return transformed_df + + _wait_for_feature_group_lineage_contexts( + car_data_feature_group_name, sagemaker_session + ) + + pipeline_arn = to_pipeline( + pipeline_name=pipeline_name, + step=transform, + role=get_execution_role(sagemaker_session), + max_retries=2, + sagemaker_session=sagemaker_session, + ) + + assert pipeline_arn is not None + + pipeline_description = Pipeline(name=pipeline_name).describe() + assert pipeline_arn == pipeline_description["PipelineArn"] + assert get_execution_role(sagemaker_session) == pipeline_description["RoleArn"] + + pipeline_definition = json.loads(pipeline_description["PipelineDefinition"]) + assert len(pipeline_definition["Steps"]) == 1 + for retry_policy in pipeline_definition["Steps"][0]["RetryPolicies"]: + assert retry_policy["MaxAttempts"] == 2 + now = datetime.now(tz=pytz.utc) + schedule_expression = f"at({now.strftime(SCHEDULE_EXPRESSION_TIMESTAMP_FORMAT)})" + schedule( + pipeline_name=pipeline_name, + schedule_expression=schedule_expression, + start_date=now, + sagemaker_session=sagemaker_session, + ) + time.sleep(60) + executions = sagemaker_session.sagemaker_client.list_pipeline_executions( + PipelineName=pipeline_name + ) + pipeline_execution_arn = executions["PipelineExecutionSummaries"][0]["PipelineExecutionArn"] + + status = _wait_for_pipeline_execution_to_reach_terminal_state( + pipeline_execution_arn=pipeline_execution_arn, + sagemaker_client=get_sagemaker_client(sagemaker_session=sagemaker_session), + ) + assert status == "Succeeded" + + featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client + results = featurestore_client.batch_get_record( + Identifiers=[ + { + "FeatureGroupName": car_data_feature_group_name, + "RecordIdentifiersValueAsString": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + ], + }, + ] + ) + + assert len(results["Records"]) == 0 + + car_sales_query = create_athena_query(feature_group_name=car_data_feature_group_name, session=sagemaker_session) + query = f'SELECT * FROM "sagemaker_featurestore".{car_sales_query.table_name} LIMIT 1000;' + output_uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "feature-processor-test", + "csv-data-fg-result", + ) + car_sales_query.run(query_string=query, output_location=output_uri) + car_sales_query.wait() + dataset = car_sales_query.as_dataframe() + dataset = dataset.drop( + columns=["ingest_time", "write_time", "api_invocation_time", "is_deleted"] + ) + + # assert dataset.equals(get_expected_dataframe()) + + put_trigger( + source_pipeline_events=[ + FeatureProcessorPipelineEvents( + pipeline_name=pipeline_name, + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ) + ], + target_pipeline=pipeline_name, + ) + + assert "trigger" in describe( + pipeline_name=pipeline_name, sagemaker_session=sagemaker_session + ) + assert describe(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session)[ + "event_pattern" + ] == json.dumps( + { + "detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], + "source": ["aws.sagemaker"], + "detail": { + "currentPipelineExecutionStatus": ["Failed"], + "pipelineArn": [pipeline_arn], + }, + } + ) + enable_trigger(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session) + assert ( + describe(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session)[ + "trigger_state" + ] + == "ENABLED" + ) + disable_trigger(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session) + assert ( + describe(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session)[ + "trigger_state" + ] + == "DISABLED" + ) + + delete_schedule(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session) + assert "schedule_arn" not in describe( + pipeline_name=pipeline_name, sagemaker_session=sagemaker_session + ) + delete_trigger(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session) + assert "trigger" not in describe( + pipeline_name=pipeline_name, sagemaker_session=sagemaker_session + ) + + finally: + cleanup_offline_store( + feature_group=feature_groups["car_data_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_offline_store( + feature_group=feature_groups["car_data_aggregated_feature_group"], + sagemaker_session=sagemaker_session, + ) + cleanup_feature_group( + feature_groups["car_data_feature_group"], sagemaker_session=sagemaker_session + ) + cleanup_feature_group( + feature_groups["car_data_aggregated_feature_group"], sagemaker_session=sagemaker_session + ) + + +def get_car_data_feature_group_name(): + return f"car-data-{int(time.time() * 10 ** 7)}" + + +def get_car_data_aggregated_feature_group_name(): + return f"car-data-aggregated-{int(time.time() * 10 ** 7)}" + + +def get_offline_store_s3_uri(sagemaker_session): + region_name = sagemaker_session.boto_region_name + bucket = f"sagemaker-test-featurestore-{region_name}-{sagemaker_session.account_id()}" + sagemaker_session._create_s3_bucket_if_it_does_not_exist(bucket, region_name) + s3 = sagemaker_session.boto_session.client("s3", region_name=region_name) + BUCKET_POLICY["Statement"][0]["Resource"] = f"arn:aws:s3:::{bucket}/*" + BUCKET_POLICY["Statement"][1]["Resource"] = f"arn:aws:s3:::{bucket}" + s3.put_bucket_policy( + Bucket=f"{bucket}", + Policy=json.dumps(BUCKET_POLICY), + ) + return f"s3://{bucket}" + + +def get_raw_car_data_s3_uri(sagemaker_session) -> str: + uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "feature-processor-test", + "csv-data", + ) + print("About to upload raw car data to S3...") + raw_car_data_s3_uri = S3Uploader.upload( + os.path.join(_FEATURE_PROCESSOR_DIR, "car-data.csv"), + uri, + sagemaker_session=sagemaker_session, + ) + print(f"Upload complete: {raw_car_data_s3_uri}") + return raw_car_data_s3_uri + + +def get_wheel_file_s3_uri(sagemaker_session) -> str: + uri = "s3://{}/{}/wheel-file".format( + sagemaker_session.default_bucket(), "feature-processor-test" + ) + source = _generate_and_move_sagemaker_sdk_tar() + print(source) + raw_car_data_s3_uri = S3Uploader.upload( + source, + uri, + sagemaker_session=sagemaker_session, + ) + return raw_car_data_s3_uri + + +def create_feature_groups( + sagemaker_session, + car_data_feature_group_name, + car_data_aggregated_feature_group_name, + offline_store_s3_uri, +) -> Dict: + # Create Feature Group - Car sale records. + car_sales_fg = None + agg_car_sales_fg = None + try: + car_sales_fg = FeatureGroup.create( + feature_group_name=car_data_feature_group_name, + record_identifier_feature_name=CAR_SALES_FG_RECORD_IDENTIFIER_NAME, + event_time_feature_name=CAR_SALES_FG_EVENT_TIME_FEATURE_NAME, + feature_definitions=CAR_SALES_FG_FEATURE_DEFINITIONS, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"{offline_store_s3_uri}/car-data"), + ), + online_store_config=OnlineStoreConfig(enable_online_store=True), + role_arn=get_execution_role(sagemaker_session), + session=sagemaker_session.boto_session, + ) + print(f"Created feature group {car_sales_fg.feature_group_name}") + except Exception as e: + if "ResourceInUse" in str(e): + print("Feature Group already exists") + car_sales_fg = FeatureGroup.get( + feature_group_name=car_data_feature_group_name, + session=sagemaker_session.boto_session, + ) + else: + raise e + + # Create Feature Group - Aggregated car sales records. + try: + agg_car_sales_fg = FeatureGroup.create( + feature_group_name=car_data_aggregated_feature_group_name, + record_identifier_feature_name=AGG_CAR_SALES_FG_RECORD_IDENTIFIER_NAME, + event_time_feature_name=AGG_CAR_SALES_FG_EVENT_TIME_FEATURE_NAME, + feature_definitions=AGG_CAR_SALES_FG_FEATURE_DEFINITIONS, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri=f"{offline_store_s3_uri}/car-data-aggregated" + ), + ), + online_store_config=OnlineStoreConfig(enable_online_store=True), + role_arn=get_execution_role(sagemaker_session), + session=sagemaker_session.boto_session, + ) + print(f"Created feature group {agg_car_sales_fg.feature_group_name}") + print("Sleeping for a bit, to let Feature Groups get ready.") + except Exception as e: + if "ResourceInUse" in str(e): + print("Feature Group already exists") + agg_car_sales_fg = FeatureGroup.get( + feature_group_name=car_data_aggregated_feature_group_name, + session=sagemaker_session.boto_session, + ) + else: + raise e + + _wait_for_feature_group_create(car_sales_fg) + _wait_for_feature_group_create(agg_car_sales_fg) + + return dict( + car_data_arn=car_sales_fg.feature_group_arn, + car_data_feature_group=car_sales_fg, + car_data_aggregated_arn=agg_car_sales_fg.feature_group_arn, + car_data_aggregated_feature_group=agg_car_sales_fg, + ) + + +def get_expected_dataframe(): + expected_dataframe = pd.read_csv(os.path.join(_FEATURE_PROCESSOR_DIR, "car-data.csv")) + expected_dataframe["Model"].replace("^\d\d\d\d\s", "", regex=True, inplace=True) # noqa: W605 + expected_dataframe["Mileage"].replace("(,)|(mi\.)", "", regex=True, inplace=True) # noqa: W605 + expected_dataframe["Mileage"].replace("Not available", np.nan, inplace=True) + expected_dataframe["Price"].replace("\$", "", regex=True, inplace=True) # noqa: W605 + expected_dataframe["Price"].replace(",", "", regex=True, inplace=True) + expected_dataframe["MSRP"].replace( + "(^MSRP\s\\$)|(,)", "", regex=True, inplace=True # noqa: W605 + ) + expected_dataframe["MSRP"].replace("Not specified", np.nan, inplace=True) + expected_dataframe["MSRP"].replace( + "\\$\d+[a-zA-Z\s]+", np.nan, regex=True, inplace=True # noqa: W605 + ) + expected_dataframe["Mileage"] = expected_dataframe["Mileage"].astype(float) + expected_dataframe["Price"] = expected_dataframe["Price"].astype(float) + expected_dataframe.rename( + columns={ + "Id": "id", + "Model": "model", + "Year": "model_year", + "Status": "status", + "Mileage": "mileage", + "Price": "price", + "MSRP": "msrp", + }, + inplace=True, + ) + return expected_dataframe + + +def _wait_for_feature_group_create(feature_group: FeatureGroup): + status = feature_group.feature_group_status + while status == "Creating": + print("Waiting for Feature Group Creation") + time.sleep(5) + feature_group.refresh() + status = feature_group.feature_group_status + if status != "Created": + print(f"FeatureGroup {feature_group.feature_group_name} status: {status}") + raise RuntimeError(f"Failed to create feature group {feature_group.feature_group_name}") + print(f"FeatureGroup {feature_group.feature_group_name} successfully created.") + + +def _wait_for_pipeline_execution_to_stop(pipeline_execution_arn: str, sagemaker_client: client): + status = sagemaker_client.describe_pipeline_execution( + PipelineExecutionArn=pipeline_execution_arn + )["PipelineExecutionStatus"] + if status != "Stopping" and status != "Stopped": + raise RuntimeError( + f"Pipeline execution Arn: {pipeline_execution_arn} " + f"status is not in Stopping or Stopped mode, instead is in {status} mode." + ) + while status == "Stopping": + print("Waiting for Pipeline Execution to Stop") + time.sleep(5) + status = sagemaker_client.describe_pipeline_execution( + PipelineExecutionArn=pipeline_execution_arn + )["PipelineExecutionStatus"] + if status != "Stopped": + raise RuntimeError(f"Failed to Stop pipeline execution {pipeline_execution_arn}") + logging.info(f"pipeline execution {pipeline_execution_arn} successfully Stopped.") + + +def _wait_for_pipeline_execution_to_reach_terminal_state( + pipeline_execution_arn: str, sagemaker_client: client +) -> str: + status = sagemaker_client.describe_pipeline_execution( + PipelineExecutionArn=pipeline_execution_arn + )["PipelineExecutionStatus"] + while status == "Stopping" or status == "Executing": + print("Waiting for Pipeline Execution to reach terminal state") + time.sleep(60) + status = sagemaker_client.describe_pipeline_execution( + PipelineExecutionArn=pipeline_execution_arn + )["PipelineExecutionStatus"] + logging.info( + f"pipeline execution {pipeline_execution_arn} successfully reached a terminal state {status}." + ) + return status + + +def cleanup_feature_group(feature_group: FeatureGroup, sagemaker_session: Session): + try: + feature_group.delete() + print(f"{feature_group.feature_group_name} is deleted.") + except sagemaker_session.sagemaker_client.exceptions.ResourceNotFound: + print(f"{feature_group.feature_group_name} not found.") + pass + except Exception as e: + raise RuntimeError( + f"Failed to delete feature group with name {feature_group.feature_group_name}", e + ) + + +def cleanup_pipeline(pipeline_name: str, sagemaker_session: Session): + try: + pipeline = Pipeline(name=pipeline_name, sagemaker_session=sagemaker_session) + + sagemaker_client = get_sagemaker_client(sagemaker_session=sagemaker_session) + executions = sagemaker_client.list_pipeline_executions(PipelineName=pipeline_name) + for execution in executions["PipelineExecutionSummaries"]: + if execution["PipelineExecutionStatus"] == "Executing": + logging.info(f'Stopping pipeline execution: {execution["PipelineExecutionArn"]}') + sagemaker_client.stop_pipeline_execution( + PipelineExecutionArn=execution["PipelineExecutionArn"] + ) + _wait_for_pipeline_execution_to_stop( + pipeline_execution_arn=execution["PipelineExecutionArn"], + sagemaker_client=sagemaker_client, + ) + if execution["PipelineExecutionStatus"] == "Stopping": + _wait_for_pipeline_execution_to_stop( + pipeline_execution_arn=execution["PipelineExecutionArn"], + sagemaker_client=sagemaker_client, + ) + pipeline.delete() + logging.info(f"{pipeline_name} is deleted.") + except sagemaker_session.sagemaker_client.exceptions.ResourceNotFound: + print(f"{pipeline_name} not found.") + pass + except sagemaker_session.sagemaker_client.exceptions.ClientError as ce: + if ce.response["Error"]["Code"] == "ValidationException": + if ( + "Pipelines with running executions cannot be deleted." + in ce.response["Error"]["Message"] + ): + cleanup_pipeline(pipeline_name=pipeline_name, sagemaker_session=sagemaker_session) + pass + raise RuntimeError(f"Failed to delete Pipeline with name {pipeline_name}", ce) + except Exception as e: + raise RuntimeError(f"Failed to delete Pipeline with name {pipeline_name}", e) + + +def cleanup_offline_store(feature_group: FeatureGroup, sagemaker_session: Session): + feature_group_name = feature_group.feature_group_name + try: + feature_group.refresh() + s3_uri = feature_group.offline_store_config.s3_storage_config.resolved_output_s3_uri + parsed_uri = urlparse(s3_uri) + bucket_name, prefix = parsed_uri.netloc, parsed_uri.path + prefix = prefix.strip("/") + prefix = prefix[:-5] if prefix.endswith("/data") else prefix + region_name = sagemaker_session.boto_region_name + s3_client = sagemaker_session.boto_session.client( + service_name="s3", region_name=region_name + ) + response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + files_in_folder = response["Contents"] + files_to_delete = [] + for f in files_in_folder: + files_to_delete.append({"Key": f["Key"]}) + s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": files_to_delete}) + except sagemaker_session.sagemaker_client.exceptions.ResourceNotFound: + print(f"{feature_group.feature_group_name} not found.") + pass + except Exception as e: + raise RuntimeError(f"Failed to delete data for feature_group {feature_group_name}", e) + + +def get_sagemaker_client(sagemaker_session=Session) -> client: + region_name = sagemaker_session.boto_session.region_name + return sagemaker_session.boto_session.client(service_name="sagemaker", region_name=region_name) + + +def _generate_and_move_sagemaker_sdk_tar(): + """ + Run setup.py sdist to generate the PySDK whl file + """ + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..")) + subprocess.run("python -m build --wheel", shell=True, cwd=repo_root, check=True) + dist_dir = os.path.join(repo_root, "dist") + source_archive = os.listdir(dist_dir)[0] + source_path = os.path.join(dist_dir, source_archive) + + return source_path + + +def _wait_for_feature_group_lineage_contexts( + feature_group_name, sagemaker_session, max_attempts=12, delay=10 +): + """Wait for the lineage contexts to be created for a feature group. + + A third-party service asynchronously creates lineage contexts after a feature group + is created. This helper polls until they exist or the timeout is reached. + """ + from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + _validate_fg_lineage_resources, + ) + + for attempt in range(max_attempts): + try: + _validate_fg_lineage_resources(feature_group_name, sagemaker_session) + logging.info( + "Lineage contexts ready for %s after %d seconds.", + feature_group_name, + attempt * delay, + ) + return + except (ValueError, Exception): + logging.info( + "Waiting for lineage contexts for %s (attempt %d/%d)...", + feature_group_name, + attempt + 1, + max_attempts, + ) + time.sleep(delay) + + raise TimeoutError( + f"Lineage contexts for feature group {feature_group_name} were not created " + f"after {max_attempts * delay} seconds." + ) From 4079a5d82c164e5cfe93f111e77b3ba1c59451ae Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Wed, 25 Feb 2026 12:09:55 -0800 Subject: [PATCH 04/21] fix --- pyproject.toml | 1 + .../mlops/feature_store/feature_processor/_udf_arg_provider.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8680debc53..ba670e80f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ train = ["sagemaker-train"] serve = ["sagemaker-serve"] mlops = ["sagemaker-mlops"] +feature-processor = ["sagemaker-mlops", "pyspark==3.3.2", "sagemaker-feature-store-pyspark-3.3"] all = ["sagemaker-train", "sagemaker-serve", "sagemaker-mlops"] [project.urls] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py index bd21e804eb..00810455b0 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py @@ -222,7 +222,8 @@ def _load_data_frame( return data_source.read_data(spark=spark_session, params=params) if isinstance(data_source, BaseDataSource): - return data_source.read_data(params=params) + spark_session = self.spark_session_factory.spark_session + return data_source.read_data(spark=spark_session, params=params) raise ValueError(f"Unknown data source type: {type(data_source)}") From e6721ce648c5c8cdc46faf31c429bcc9be480b71 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Wed, 25 Feb 2026 13:56:22 -0800 Subject: [PATCH 05/21] chore(docs): Add API docs --- docs/api/sagemaker_mlops.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/api/sagemaker_mlops.rst b/docs/api/sagemaker_mlops.rst index f67879111d..b9aa069c7d 100644 --- a/docs/api/sagemaker_mlops.rst +++ b/docs/api/sagemaker_mlops.rst @@ -28,3 +28,12 @@ Local Development :members: :undoc-members: :show-inheritance: + + +Feature Store +------------- + +.. automodule:: sagemaker.mlops.feature_store + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file From 611f5775892a32b66652a97fb24ac0e69aa0663e Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Thu, 26 Feb 2026 10:26:01 -0800 Subject: [PATCH 06/21] fix: Fix flaky integ tests --- .../feature_processor/requirements.txt | 1 + .../test_feature_processor_integ.py | 150 +++++++++++------- 2 files changed, 94 insertions(+), 57 deletions(-) create mode 100644 sagemaker-mlops/tests/data/feature_store/feature_processor/requirements.txt diff --git a/sagemaker-mlops/tests/data/feature_store/feature_processor/requirements.txt b/sagemaker-mlops/tests/data/feature_store/feature_processor/requirements.txt new file mode 100644 index 0000000000..2940fa2876 --- /dev/null +++ b/sagemaker-mlops/tests/data/feature_store/feature_processor/requirements.txt @@ -0,0 +1 @@ +# unrelased sagemaker is installed via pre_execution_commands \ No newline at end of file diff --git a/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py b/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py index a49910cf07..a9f0d0c1df 100644 --- a/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py +++ b/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py @@ -12,10 +12,12 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import glob import logging import os import subprocess import sys +import tempfile import time from typing import Dict from datetime import datetime @@ -121,6 +123,16 @@ def sagemaker_session(): return Session() +@pytest.fixture(scope="module") +def pre_execution_commands(sagemaker_session): + return get_pre_execution_commands(sagemaker_session=sagemaker_session) + + +@pytest.fixture(scope="module") +def dependencies_path(): + return os.path.join(_FEATURE_PROCESSOR_DIR, "requirements.txt") + + @pytest.mark.slow_test def test_feature_processor_transform_online_only_store_ingestion( sagemaker_session, @@ -137,8 +149,6 @@ def test_feature_processor_transform_online_only_store_ingestion( raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) - print("About to apply @feature_processor decorator...") - @feature_processor( inputs=[CSVDataSource(raw_data_uri)], output=feature_groups["car_data_arn"], @@ -175,9 +185,8 @@ def transform(raw_s3_data_as_df): transformed_df.show() return transformed_df - print("Decorator applied. About to call transform()...") - transform() - print("transform() completed.") + # this calls spark 3.3 which requires java 11 + transform() featurestore_client = sagemaker_session.sagemaker_featurestore_runtime_client results = featurestore_client.batch_get_record( @@ -496,7 +505,10 @@ def transform(raw_s3_data_as_df): columns=["ingest_time", "write_time", "api_invocation_time", "is_deleted"] ) - assert dataset.equals(get_expected_dataframe()) + expected = get_expected_dataframe() + dataset_sorted = dataset.sort_values(by="id").reset_index(drop=True) + expected_sorted = expected.sort_values(by="id").reset_index(drop=True) + assert dataset_sorted.equals(expected_sorted) finally: cleanup_offline_store( feature_group=feature_groups["car_data_feature_group"], @@ -521,6 +533,8 @@ def transform(raw_s3_data_as_df): ) def test_feature_processor_transform_offline_only_store_ingestion_run_with_remote( sagemaker_session, + pre_execution_commands, + dependencies_path, ): car_data_feature_group_name = get_car_data_feature_group_name() car_data_aggregated_feature_group_name = get_car_data_aggregated_feature_group_name() @@ -533,16 +547,10 @@ def test_feature_processor_transform_offline_only_store_ingestion_run_with_remot ) raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) - whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) - whl_file_name = os.path.basename(whl_file_uri) - - pre_execution_commands = [ - f"aws s3 cp {whl_file_uri} ./", - f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", - ] @remote( pre_execution_commands=pre_execution_commands, + dependencies=dependencies_path, spark_config=SparkConfig(), instance_type="ml.m5.xlarge", ) @@ -637,7 +645,10 @@ def transform(raw_s3_data_as_df): columns=["ingest_time", "write_time", "api_invocation_time", "is_deleted"] ) - assert dataset.equals(get_expected_dataframe()) + expected = get_expected_dataframe() + dataset_sorted = dataset.sort_values(by="id").reset_index(drop=True) + expected_sorted = expected.sort_values(by="id").reset_index(drop=True) + assert dataset_sorted.equals(expected_sorted) finally: cleanup_offline_store( feature_group=feature_groups["car_data_feature_group"], @@ -662,6 +673,8 @@ def transform(raw_s3_data_as_df): ) def test_to_pipeline_and_execute( sagemaker_session, + pre_execution_commands, + dependencies_path, ): pipeline_name = "pipeline-name-01" car_data_feature_group_name = get_car_data_feature_group_name() @@ -675,16 +688,10 @@ def test_to_pipeline_and_execute( ) raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) - whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) - whl_file_name = os.path.basename(whl_file_uri) - - pre_execution_commands = [ - f"aws s3 cp {whl_file_uri} ./", - f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", - ] @remote( pre_execution_commands=pre_execution_commands, + dependencies=dependencies_path, spark_config=SparkConfig(), instance_type="ml.m5.xlarge", ) @@ -789,6 +796,8 @@ def transform(raw_s3_data_as_df): ) def test_schedule_and_event_trigger( sagemaker_session, + pre_execution_commands, + dependencies_path, ): pipeline_name = "pipeline-name-01" car_data_feature_group_name = get_car_data_feature_group_name() @@ -802,16 +811,10 @@ def test_schedule_and_event_trigger( ) raw_data_uri = get_raw_car_data_s3_uri(sagemaker_session=sagemaker_session) - whl_file_uri = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) - whl_file_name = os.path.basename(whl_file_uri) - - pre_execution_commands = [ - f"aws s3 cp {whl_file_uri} ./", - f"/usr/local/bin/python3.9 -m pip install ./{whl_file_name} --force-reinstall", - ] @remote( pre_execution_commands=pre_execution_commands, + dependencies=dependencies_path, spark_config=SparkConfig(), instance_type="ml.m5.xlarge", ) @@ -1042,7 +1045,6 @@ def get_raw_car_data_s3_uri(sagemaker_session) -> str: "feature-processor-test", "csv-data", ) - print("About to upload raw car data to S3...") raw_car_data_s3_uri = S3Uploader.upload( os.path.join(_FEATURE_PROCESSOR_DIR, "car-data.csv"), uri, @@ -1052,18 +1054,36 @@ def get_raw_car_data_s3_uri(sagemaker_session) -> str: return raw_car_data_s3_uri -def get_wheel_file_s3_uri(sagemaker_session) -> str: - uri = "s3://{}/{}/wheel-file".format( +def get_wheel_file_s3_uri(sagemaker_session): + """Upload all SDK wheels to S3 and return (s3_prefix, wheel_basenames). + + Returns: + tuple: (s3_prefix, [sagemaker_whl, core_whl, mlops_whl]) where each + element is the basename of the corresponding wheel file. + """ + s3_prefix = "s3://{}/{}/wheel-file".format( sagemaker_session.default_bucket(), "feature-processor-test" ) - source = _generate_and_move_sagemaker_sdk_tar() - print(source) - raw_car_data_s3_uri = S3Uploader.upload( - source, - uri, - sagemaker_session=sagemaker_session, - ) - return raw_car_data_s3_uri + sources = _generate_and_move_sagemaker_sdk_tar() + for source in sources: + print(source) + S3Uploader.upload(source, s3_prefix, sagemaker_session=sagemaker_session) + wheel_names = [os.path.basename(s) for s in sources] + return s3_prefix, wheel_names + + +def get_pre_execution_commands(sagemaker_session): + """Build SDK wheels, upload to S3, and return pre-execution install commands.""" + s3_prefix, wheel_names = get_wheel_file_s3_uri(sagemaker_session=sagemaker_session) + sagemaker_whl, core_whl, mlops_whl = wheel_names + print(f'{sagemaker_whl=}, {core_whl=}, {mlops_whl}') + return [ + f"aws s3 cp {s3_prefix}/ /tmp/packages/ --recursive", + "pip3 install 'setuptools<75'", + f"pip3 install --no-build-isolation '/tmp/packages/{sagemaker_whl}[feature-processor]' 'numpy<2.0.0' 'ml_dtypes<=0.4.1' 'setuptools<75' || true", + f"pip3 install --no-deps --force-reinstall /tmp/packages/{sagemaker_whl}", + f"pip3 install --no-deps --force-reinstall /tmp/packages/{core_whl} /tmp/packages/{mlops_whl}", + ] def create_feature_groups( @@ -1170,16 +1190,7 @@ def get_expected_dataframe(): def _wait_for_feature_group_create(feature_group: FeatureGroup): - status = feature_group.feature_group_status - while status == "Creating": - print("Waiting for Feature Group Creation") - time.sleep(5) - feature_group.refresh() - status = feature_group.feature_group_status - if status != "Created": - print(f"FeatureGroup {feature_group.feature_group_name} status: {status}") - raise RuntimeError(f"Failed to create feature group {feature_group.feature_group_name}") - print(f"FeatureGroup {feature_group.feature_group_name} successfully created.") + feature_group.wait_for_status(target_status="Created", poll=5) def _wait_for_pipeline_execution_to_stop(pipeline_execution_arn: str, sagemaker_client: client): @@ -1304,16 +1315,41 @@ def get_sagemaker_client(sagemaker_session=Session) -> client: def _generate_and_move_sagemaker_sdk_tar(): - """ - Run setup.py sdist to generate the PySDK whl file - """ - repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..")) - subprocess.run("python -m build --wheel", shell=True, cwd=repo_root, check=True) + """Build all three SDK wheel files and return their paths.""" + repo_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..") + ) dist_dir = os.path.join(repo_root, "dist") - source_archive = os.listdir(dist_dir)[0] - source_path = os.path.join(dist_dir, source_archive) - return source_path + # Build wheels for all three sub-packages into the shared dist/ directory + build_dirs = [ + repo_root, + os.path.join(repo_root, "sagemaker-core"), + os.path.join(repo_root, "sagemaker-mlops"), + ] + for build_dir in build_dirs: + subprocess.run( + f"python -m build --wheel --outdir {dist_dir}", + shell=True, + cwd=build_dir, + check=True, + ) + + # Locate the three expected wheels by prefix pattern + wheel_patterns = [ + "sagemaker-[0-9]*.whl", + "sagemaker_core-*.whl", + "sagemaker_mlops-*.whl", + ] + paths = [] + for pattern in wheel_patterns: + matches = glob.glob(os.path.join(dist_dir, pattern)) + if not matches: + raise FileNotFoundError( + f"No wheel found matching {pattern} in {dist_dir}" + ) + paths.append(matches[0]) + return paths def _wait_for_feature_group_lineage_contexts( From 6f4f490471f340446e5afe97af1cb2c1d0020897 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Fri, 27 Feb 2026 11:29:42 -0800 Subject: [PATCH 07/21] fix diff --- .../src/sagemaker/mlops/feature_store/feature_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py index 5ee04780be..3e3e7813df 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -731,4 +731,4 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: """ for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") - return data_frame + return data_frame \ No newline at end of file From 09abb01f43f8ad93ec44435c719920025847e9ba Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Fri, 27 Feb 2026 13:03:18 -0800 Subject: [PATCH 08/21] chore: rename parameter + cleanup comments --- .../feature_processor/feature_scheduler.py | 11 ++++------- .../test_feature_processor_integ.py | 4 ++-- .../feature_processor/test_feature_scheduler.py | 12 ++++++------ 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py index c9039d982c..4f67a9dc5d 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py @@ -116,7 +116,7 @@ def to_pipeline( pipeline_name: str, step: Callable, - role: Optional[str] = None, + role_arn: Optional[str] = None, transformation_code: Optional[TransformationCode] = None, max_retries: Optional[int] = None, tags: Optional[List[Tuple[str, str]]] = None, @@ -133,7 +133,7 @@ def to_pipeline( pipeline_name (str): The name of the pipeline. step (Callable): A user provided function wrapped by feature_processor and optionally wrapped by remote_decorator. - role (Optional[str]): The Amazon Resource Name (ARN) of the role used by the pipeline to + role_arn (Optional[str]): The Amazon Resource Name (ARN) of the role used by the pipeline to access and create resources. If not specified, it will default to the credentials provided by the AWS configuration chain. transformation_code (Optional[str]): The data source for a reference to the transformation @@ -162,7 +162,7 @@ def to_pipeline( remote_decorator_config = _get_remote_decorator_config_from_input( wrapped_func=step, sagemaker_session=_sagemaker_session ) - _role = role or get_execution_role(_sagemaker_session) + _role = role_arn or get_execution_role(_sagemaker_session) runtime_env_manager = RuntimeEnvironmentManager() client_python_version = runtime_env_manager._current_python_version() @@ -863,7 +863,7 @@ def _prepare_model_trainer_from_remote_decorator_config( """ logger.info("Mapping remote decorator config to ModelTrainer params") - # Build environment dict from remote_decorator_config (strings only for Pydantic validation) + # Build environment dict from remote_decorator_config environment = dict(remote_decorator_config.environment_variables or {}) # Build command from container entry point and arguments @@ -946,9 +946,6 @@ def _prepare_model_trainer_from_remote_decorator_config( tags=tags, ) - # Inject SCHEDULED_TIME_PIPELINE_PARAMETER after construction to bypass Pydantic - # validation (Parameter is not a string). The @runnable_by_pipeline decorator resolves - # Parameter objects to strings during pipeline definition serialization. model_trainer.environment[EXECUTION_TIME_PIPELINE_PARAMETER] = SCHEDULED_TIME_PIPELINE_PARAMETER logger.info( diff --git a/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py b/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py index a9f0d0c1df..6994971f22 100644 --- a/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py +++ b/sagemaker-mlops/tests/integ/feature_store/feature_processor/test_feature_processor_integ.py @@ -738,7 +738,7 @@ def transform(raw_s3_data_as_df): pipeline_arn = to_pipeline( pipeline_name=pipeline_name, step=transform, - role=get_execution_role(sagemaker_session), + role_arn=get_execution_role(sagemaker_session), max_retries=2, tags=[("integ_test_tag_key_1", "integ_test_tag_key_2")], sagemaker_session=sagemaker_session, @@ -861,7 +861,7 @@ def transform(raw_s3_data_as_df): pipeline_arn = to_pipeline( pipeline_name=pipeline_name, step=transform, - role=get_execution_role(sagemaker_session), + role_arn=get_execution_role(sagemaker_session), max_retries=2, sagemaker_session=sagemaker_session, ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py index 1cd7e381e0..5f3fb85934 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py @@ -275,7 +275,7 @@ def test_to_pipeline( pipeline_arn = to_pipeline( pipeline_name="pipeline_name", step=wrapped_func, - role=EXECUTION_ROLE_ARN, + role_arn=EXECUTION_ROLE_ARN, max_retries=1, tags=[("tag_key_1", "tag_value_1"), ("tag_key_2", "tag_value_2")], sagemaker_session=session, @@ -437,7 +437,7 @@ def test_to_pipeline_not_wrapped_by_feature_processor(get_execution_role, sessio to_pipeline( pipeline_name="pipeline_name", step=wrapped_func, - role=EXECUTION_ROLE_ARN, + role_arn=EXECUTION_ROLE_ARN, max_retries=1, ) @@ -461,7 +461,7 @@ def test_to_pipeline_not_wrapped_by_remote(get_execution_role, session): to_pipeline( pipeline_name="pipeline_name", step=wrapped_func, - role=EXECUTION_ROLE_ARN, + role_arn=EXECUTION_ROLE_ARN, max_retries=1, ) @@ -513,7 +513,7 @@ def test_to_pipeline_wrong_mode(get_execution_role, mock_spark_image, session): to_pipeline( pipeline_name="pipeline_name", step=wrapped_func, - role=EXECUTION_ROLE_ARN, + role_arn=EXECUTION_ROLE_ARN, max_retries=1, ) @@ -568,7 +568,7 @@ def test_to_pipeline_pipeline_name_length_limit_exceeds( to_pipeline( pipeline_name="".join(["a" for _ in range(PIPELINE_NAME_MAXIMUM_LENGTH + 1)]), step=wrapped_func, - role=EXECUTION_ROLE_ARN, + role_arn=EXECUTION_ROLE_ARN, max_retries=1, ) @@ -626,7 +626,7 @@ def test_to_pipeline_used_reserved_tags(get_execution_role, mock_spark_image, se to_pipeline( pipeline_name="pipeline_name", step=wrapped_func, - role=EXECUTION_ROLE_ARN, + role_arn=EXECUTION_ROLE_ARN, max_retries=1, tags=[("sm-fs-fe:created-from", "random")], sagemaker_session=session, From aa654e4d6bcc741ef0f04f55123efd0e29b56b11 Mon Sep 17 00:00:00 2001 From: Aditi Sharma <165942273+Aditi2424@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:58:41 -0800 Subject: [PATCH 09/21] Feature store v3 (#5490) * feat: Add Feature Store Support to V3 * Add feature store tests --------- Co-authored-by: adishaa --- .../src/sagemaker/mlops/feature_store/feature_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py index 3e3e7813df..5ee04780be 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -731,4 +731,4 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: """ for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") - return data_frame \ No newline at end of file + return data_frame From 7cdd07743d30962ae5ed4352a27dac1121ec27ea Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Fri, 27 Feb 2026 15:50:17 -0800 Subject: [PATCH 10/21] add pyspark to test deps --- sagemaker-mlops/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/sagemaker-mlops/pyproject.toml b/sagemaker-mlops/pyproject.toml index f61cc5467b..9004523d0f 100644 --- a/sagemaker-mlops/pyproject.toml +++ b/sagemaker-mlops/pyproject.toml @@ -35,6 +35,7 @@ test = [ "pytest", "pytest-cov", "mock", + "pyspark==3.3.2" ] dev = [ "pytest", From 101824fa517cb6f1a105323cf085c65314f95f9a Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Fri, 27 Feb 2026 17:48:12 -0800 Subject: [PATCH 11/21] add test deps --- sagemaker-mlops/pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sagemaker-mlops/pyproject.toml b/sagemaker-mlops/pyproject.toml index 9004523d0f..519cf526a5 100644 --- a/sagemaker-mlops/pyproject.toml +++ b/sagemaker-mlops/pyproject.toml @@ -35,7 +35,10 @@ test = [ "pytest", "pytest-cov", "mock", - "pyspark==3.3.2" + "pyspark==3.3.2", + "sagemaker-feature-store-pyspark-3.3", + "pandas", + "numpy", ] dev = [ "pytest", From 1e2b5fea14af01f9df015f9bf427f40faf3fa0a2 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Mon, 2 Mar 2026 11:10:55 -0800 Subject: [PATCH 12/21] fix unit test deps --- sagemaker-mlops/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/sagemaker-mlops/pyproject.toml b/sagemaker-mlops/pyproject.toml index 519cf526a5..5e7291abb6 100644 --- a/sagemaker-mlops/pyproject.toml +++ b/sagemaker-mlops/pyproject.toml @@ -35,6 +35,7 @@ test = [ "pytest", "pytest-cov", "mock", + "setuptools", "pyspark==3.3.2", "sagemaker-feature-store-pyspark-3.3", "pandas", From bbf1a3fa46ee375bba7de739df20105131a3d0ab Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Mon, 2 Mar 2026 12:30:45 -0800 Subject: [PATCH 13/21] pin setuptools<82 for feature-processor and unit tests --- pyproject.toml | 2 +- sagemaker-mlops/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ba670e80f7..4568453b27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ train = ["sagemaker-train"] serve = ["sagemaker-serve"] mlops = ["sagemaker-mlops"] -feature-processor = ["sagemaker-mlops", "pyspark==3.3.2", "sagemaker-feature-store-pyspark-3.3"] +feature-processor = ["sagemaker-mlops", "pyspark==3.3.2", "sagemaker-feature-store-pyspark-3.3", "setuptools<82"] all = ["sagemaker-train", "sagemaker-serve", "sagemaker-mlops"] [project.urls] diff --git a/sagemaker-mlops/pyproject.toml b/sagemaker-mlops/pyproject.toml index 5e7291abb6..55b5872387 100644 --- a/sagemaker-mlops/pyproject.toml +++ b/sagemaker-mlops/pyproject.toml @@ -35,7 +35,7 @@ test = [ "pytest", "pytest-cov", "mock", - "setuptools", + "setuptools<82", "pyspark==3.3.2", "sagemaker-feature-store-pyspark-3.3", "pandas", From 5ca86dc16094ac9bc4474c3bc6e69b0b9c9037f7 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Mon, 2 Mar 2026 13:06:00 -0800 Subject: [PATCH 14/21] Set JAVA_HOME for integ tests which requires java --- sagemaker-mlops/tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/sagemaker-mlops/tox.ini b/sagemaker-mlops/tox.ini index 4c28a1769f..a8f00c5bdc 100644 --- a/sagemaker-mlops/tox.ini +++ b/sagemaker-mlops/tox.ini @@ -67,6 +67,7 @@ markers = [testenv] setenv = PYTHONHASHSEED=42 + JAVA_HOME={env:JAVA_HOME:/usr/lib/jvm/default-java} #PYTHONPATH = {toxinidir}/../sagemaker_utils/src:{toxinidir}/src pip_version = pip==24.3 passenv = From beb078d2c3cc90a6f6fd99b87f169dc83b6394eb Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Mon, 2 Mar 2026 14:28:11 -0800 Subject: [PATCH 15/21] fix spark session bug --- .../feature_processor/_spark_factory.py | 2 +- .../test_spark_session_factory.py | 21 ++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py index 0d6d41506e..d304185e85 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py @@ -180,7 +180,7 @@ def get_spark_session_with_iceberg_config(self, warehouse_s3_uri, catalog) -> Sp SparkSession: A SparkSession ready to support reading and writing data from an Iceberg Table. """ - conf = self.spark_session._jvm.SparkSession().conf() + conf = self.spark_session.conf for cfg in self._get_iceberg_configs(warehouse_s3_uri, catalog): conf.set(cfg[0], cfg[1]) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py index 36ba54edd1..09fade48ed 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py @@ -117,19 +117,20 @@ def test_spark_session_factory_with_iceberg_config(mock_spark_context): spark_session_factory = SparkSessionFactory(mock_env_helper) spark_session = spark_session_factory.spark_session + mock_conf = Mock() - spark_session_with_iceberg_config = spark_session_factory.get_spark_session_with_iceberg_config( - "warehouse", "catalog" - ) + with patch.object(type(spark_session), "conf", new_callable=lambda: property(lambda self: mock_conf)): + spark_session_with_iceberg_config = spark_session_factory.get_spark_session_with_iceberg_config( + "warehouse", "catalog" + ) - assert spark_session is spark_session_with_iceberg_config - mock_spark_conf = spark_session._jvm.SparkSession().conf() - expected_calls = [ - call.set(cfg[0], cfg[1]) - for cfg in spark_session_factory._get_iceberg_configs("warehouse", "catalog") - ] + assert spark_session is spark_session_with_iceberg_config + expected_calls = [ + call.set(cfg[0], cfg[1]) + for cfg in spark_session_factory._get_iceberg_configs("warehouse", "catalog") + ] - mock_spark_conf.assert_has_calls(expected_calls, any_order=False) + mock_conf.assert_has_calls(expected_calls, any_order=False) @patch("pyspark.context.SparkContext.getOrCreate") From 7a659af9ad13e39cdcadb0a3f1b6b7e99385e386 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Mon, 2 Mar 2026 15:27:14 -0800 Subject: [PATCH 16/21] fix(feature-processor): Fix Spark session config and Ivy cache race condition Isolate Ivy cache per Spark session via spark.jars.ivy to prevent concurrent pytest-xdist workers from corrupting shared /root/.ivy2/cache during Maven dependency resolution in CI. --- .../mlops/feature_store/feature_processor/_spark_factory.py | 6 ++++++ .../feature_processor/test_spark_session_factory.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py index d304185e85..1738b14f37 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py @@ -13,6 +13,8 @@ """Contains factory classes for instantiating Spark objects.""" from __future__ import absolute_import +import os +import tempfile from functools import lru_cache from typing import List, Tuple, Dict @@ -134,6 +136,10 @@ def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]: "spark.jars.packages", ",".join(fp_spark_packages), ), + ( + "spark.jars.ivy", + os.path.join(tempfile.mkdtemp(), ".ivy2"), + ), ) ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py index 09fade48ed..88631d1ea9 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py @@ -79,6 +79,9 @@ def test_spark_session_factory_configuration(): ] ) in spark_configs.get("spark.jars.packages") + assert spark_configs.get("spark.jars.ivy") is not None + assert ".ivy2" in spark_configs.get("spark.jars.ivy") + def test_spark_session_factory_configuration_on_training_job(): env_helper = Mock() @@ -90,6 +93,7 @@ def test_spark_session_factory_configuration_on_training_job(): assert all(tup[0] != "spark.jars" for tup in spark_config) assert all(tup[0] != "spark.jars.packages" for tup in spark_config) + assert all(tup[0] != "spark.jars.ivy" for tup in spark_config) @patch("pyspark.context.SparkContext.getOrCreate") From 59b68e8f79f26e653c57214df4bda5bc0e9cf571 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Mon, 2 Mar 2026 16:13:08 -0800 Subject: [PATCH 17/21] revert previous change + create different ivy cache per test to fix concurrent writes in CI --- .../feature_processor/_spark_factory.py | 6 ---- .../feature_processor/conftest.py | 35 +++++++++++++++++++ .../test_spark_session_factory.py | 4 --- 3 files changed, 35 insertions(+), 10 deletions(-) create mode 100644 sagemaker-mlops/tests/integ/feature_store/feature_processor/conftest.py diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py index 1738b14f37..d304185e85 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py @@ -13,8 +13,6 @@ """Contains factory classes for instantiating Spark objects.""" from __future__ import absolute_import -import os -import tempfile from functools import lru_cache from typing import List, Tuple, Dict @@ -136,10 +134,6 @@ def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]: "spark.jars.packages", ",".join(fp_spark_packages), ), - ( - "spark.jars.ivy", - os.path.join(tempfile.mkdtemp(), ".ivy2"), - ), ) ) diff --git a/sagemaker-mlops/tests/integ/feature_store/feature_processor/conftest.py b/sagemaker-mlops/tests/integ/feature_store/feature_processor/conftest.py new file mode 100644 index 0000000000..022431c9af --- /dev/null +++ b/sagemaker-mlops/tests/integ/feature_store/feature_processor/conftest.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Conftest for feature processor integration tests.""" +import os +import tempfile + +import pytest + +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory + + +@pytest.fixture(autouse=True, scope="session") +def isolate_ivy_cache(): + """Give each pytest-xdist worker its own Ivy cache to prevent concurrent cache corruption.""" + ivy_dir = os.path.join(tempfile.mkdtemp(), ".ivy2") + original = SparkSessionFactory._get_spark_configs + + def _patched_get_spark_configs(self, is_training_job): + configs = original(self, is_training_job) + configs.append(("spark.jars.ivy", ivy_dir)) + return configs + + SparkSessionFactory._get_spark_configs = _patched_get_spark_configs + yield + SparkSessionFactory._get_spark_configs = original diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py index 88631d1ea9..09fade48ed 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py @@ -79,9 +79,6 @@ def test_spark_session_factory_configuration(): ] ) in spark_configs.get("spark.jars.packages") - assert spark_configs.get("spark.jars.ivy") is not None - assert ".ivy2" in spark_configs.get("spark.jars.ivy") - def test_spark_session_factory_configuration_on_training_job(): env_helper = Mock() @@ -93,7 +90,6 @@ def test_spark_session_factory_configuration_on_training_job(): assert all(tup[0] != "spark.jars" for tup in spark_config) assert all(tup[0] != "spark.jars.packages" for tup in spark_config) - assert all(tup[0] != "spark.jars.ivy" for tup in spark_config) @patch("pyspark.context.SparkContext.getOrCreate") From a1b1bc3ecd67b34cc92333ca1e40857bbc1a243a Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Tue, 3 Mar 2026 12:03:30 -0800 Subject: [PATCH 18/21] revert changes to sagemaker-core --- .../sagemaker/core/helper/session_helper.py | 220 ------- .../tests/unit/session/test_session_helper.py | 545 ------------------ 2 files changed, 765 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index b4c327ec09..41957e30a2 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -85,10 +85,6 @@ TAGS, SESSION_DEFAULT_S3_BUCKET_PATH, SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, - FEATURE_GROUP, - FEATURE_GROUP_ROLE_ARN_PATH, - FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, - FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, ) # Setting LOGGER for backward compatibility, in case users import it... @@ -1611,222 +1607,6 @@ def delete_endpoint_config(self, endpoint_config_name): logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name) self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) - def delete_feature_group(self, feature_group_name): - """Delete an Amazon SageMaker Feature Group. - - Args: - feature_group_name (str): Name of the Amazon SageMaker Feature Group to delete. - """ - logger.info("Deleting feature group with name: %s", feature_group_name) - self.sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name) - - def create_feature_group( - self, - feature_group_name, - record_identifier_name, - event_time_feature_name, - feature_definitions, - role_arn=None, - online_store_config=None, - offline_store_config=None, - throughput_config=None, - description=None, - tags=None, - ): - """Create an Amazon SageMaker Feature Group. - - Args: - feature_group_name (str): Name of the Feature Group. - record_identifier_name (str): Name of the record identifier feature. - event_time_feature_name (str): Name of the event time feature. - feature_definitions (list): List of feature definitions. - role_arn (str): ARN of the role used to execute the API (default: None). - Resolved from SageMaker Config if not provided. - online_store_config (dict): Online store configuration (default: None). - offline_store_config (dict): Offline store configuration (default: None). - throughput_config (dict): Throughput configuration (default: None). - description (str): Description of the Feature Group (default: None). - tags (Optional[Tags]): Tags for labeling the Feature Group (default: None). - - Returns: - dict: Response from the CreateFeatureGroup API. - """ - tags = format_tags(tags) - tags = _append_project_tags(tags) - tags = self._append_sagemaker_config_tags( - tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) - ) - - role_arn = resolve_value_from_config( - role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self - ) - - inferred_online_store_config = update_nested_dictionary_with_values_from_config( - online_store_config, - FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, - sagemaker_session=self, - ) - if inferred_online_store_config is not None: - # OnlineStore should be handled differently because if you set KmsKeyId, then you - # need to set EnableOnlineStore key as well - inferred_online_store_config["EnableOnlineStore"] = True - - inferred_offline_store_config = update_nested_dictionary_with_values_from_config( - offline_store_config, - FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, - sagemaker_session=self, - ) - - kwargs = dict( - FeatureGroupName=feature_group_name, - RecordIdentifierFeatureName=record_identifier_name, - EventTimeFeatureName=event_time_feature_name, - FeatureDefinitions=feature_definitions, - RoleArn=role_arn, - ) - update_args( - kwargs, - OnlineStoreConfig=inferred_online_store_config, - OfflineStoreConfig=inferred_offline_store_config, - ThroughputConfig=throughput_config, - Description=description, - Tags=tags, - ) - - logger.info("Creating feature group with name: %s", feature_group_name) - return self.sagemaker_client.create_feature_group(**kwargs) - - def describe_feature_group(self, feature_group_name, next_token=None): - """Describe an Amazon SageMaker Feature Group. - - Args: - feature_group_name (str): Name of the Amazon SageMaker Feature Group to describe. - next_token (str): A token for paginated results (default: None). - - Returns: - dict: Response from the DescribeFeatureGroup API. - """ - args = {"FeatureGroupName": feature_group_name} - update_args(args, NextToken=next_token) - return self.sagemaker_client.describe_feature_group(**args) - - def update_feature_group( - self, - feature_group_name, - feature_additions=None, - online_store_config=None, - throughput_config=None, - ): - """Update an Amazon SageMaker Feature Group. - - Args: - feature_group_name (str): Name of the Amazon SageMaker Feature Group to update. - feature_additions (list): List of feature definitions to add (default: None). - online_store_config (dict): Online store configuration updates (default: None). - throughput_config (dict): Throughput configuration updates (default: None). - - Returns: - dict: Response from the UpdateFeatureGroup API. - """ - args = {"FeatureGroupName": feature_group_name} - update_args( - args, - FeatureAdditions=feature_additions, - OnlineStoreConfig=online_store_config, - ThroughputConfig=throughput_config, - ) - return self.sagemaker_client.update_feature_group(**args) - - def list_feature_groups( - self, - name_contains=None, - feature_group_status_equals=None, - offline_store_status_equals=None, - creation_time_after=None, - creation_time_before=None, - sort_order=None, - sort_by=None, - max_results=None, - next_token=None, - ): - """List Amazon SageMaker Feature Groups. - - Args: - name_contains (str): Filter by name substring (default: None). - feature_group_status_equals (str): Filter by status (default: None). - offline_store_status_equals (str): Filter by offline store status (default: None). - creation_time_after (datetime): Filter by creation time lower bound (default: None). - creation_time_before (datetime): Filter by creation time upper bound (default: None). - sort_order (str): Sort order, 'Ascending' or 'Descending' (default: None). - sort_by (str): Sort by field (default: None). - max_results (int): Maximum number of results (default: None). - next_token (str): Pagination token (default: None). - - Returns: - dict: Response from the ListFeatureGroups API. - """ - args = {} - update_args( - args, - NameContains=name_contains, - FeatureGroupStatusEquals=feature_group_status_equals, - OfflineStoreStatusEquals=offline_store_status_equals, - CreationTimeAfter=creation_time_after, - CreationTimeBefore=creation_time_before, - SortOrder=sort_order, - SortBy=sort_by, - MaxResults=max_results, - NextToken=next_token, - ) - return self.sagemaker_client.list_feature_groups(**args) - - def update_feature_metadata( - self, - feature_group_name, - feature_name, - description=None, - parameter_additions=None, - parameter_removals=None, - ): - """Update metadata for a feature in an Amazon SageMaker Feature Group. - - Args: - feature_group_name (str): Name of the Feature Group. - feature_name (str): Name of the feature to update metadata for. - description (str): Updated description for the feature (default: None). - parameter_additions (list): Parameters to add (default: None). - parameter_removals (list): Parameters to remove (default: None). - - Returns: - dict: Response from the UpdateFeatureMetadata API. - """ - args = { - "FeatureGroupName": feature_group_name, - "FeatureName": feature_name, - } - update_args( - args, - Description=description, - ParameterAdditions=parameter_additions, - ParameterRemovals=parameter_removals, - ) - return self.sagemaker_client.update_feature_metadata(**args) - - def describe_feature_metadata(self, feature_group_name, feature_name): - """Describe metadata for a feature in an Amazon SageMaker Feature Group. - - Args: - feature_group_name (str): Name of the Feature Group. - feature_name (str): Name of the feature to describe metadata for. - - Returns: - dict: Response from the DescribeFeatureMetadata API. - """ - return self.sagemaker_client.describe_feature_metadata( - FeatureGroupName=feature_group_name, - FeatureName=feature_name, - ) - def wait_for_optimization_job(self, job, poll=5): """Wait for an Amazon SageMaker Optimization job to complete. diff --git a/sagemaker-core/tests/unit/session/test_session_helper.py b/sagemaker-core/tests/unit/session/test_session_helper.py index 7e2004c1d0..ca4fd81aa8 100644 --- a/sagemaker-core/tests/unit/session/test_session_helper.py +++ b/sagemaker-core/tests/unit/session/test_session_helper.py @@ -29,11 +29,6 @@ update_args, NOTEBOOK_METADATA_FILE, ) -from sagemaker.core.config.config_schema import ( - FEATURE_GROUP_ROLE_ARN_PATH, - FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, - FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, -) class TestSession: @@ -456,543 +451,3 @@ def test_update_args_with_none_values(self): assert args["existing"] == "value" assert "new_key" not in args assert args["another_key"] == "another_value" - -class TestFeatureGroupSessionMethods: - """Test cases for Feature Group session methods""" - - @pytest.fixture - def session_with_mock_client(self): - """Create a Session with a mocked sagemaker_client.""" - mock_boto_session = Mock() - mock_boto_session.region_name = "us-west-2" - mock_boto_session.client.return_value = Mock() - mock_boto_session.resource.return_value = Mock() - session = Session(boto_session=mock_boto_session) - session.sagemaker_client = Mock() - return session - - # --- delete_feature_group --- - - def test_delete_feature_group(self, session_with_mock_client): - """Test delete_feature_group delegates to sagemaker_client.""" - session = session_with_mock_client - session.delete_feature_group("my-feature-group") - - session.sagemaker_client.delete_feature_group.assert_called_once_with( - FeatureGroupName="my-feature-group" - ) - - # --- describe_feature_group --- - - def test_describe_feature_group(self, session_with_mock_client): - """Test describe_feature_group delegates and returns response.""" - session = session_with_mock_client - expected = {"FeatureGroupName": "my-fg", "CreationTime": "2024-01-01"} - session.sagemaker_client.describe_feature_group.return_value = expected - - result = session.describe_feature_group("my-fg") - - session.sagemaker_client.describe_feature_group.assert_called_once_with( - FeatureGroupName="my-fg" - ) - assert result == expected - - def test_describe_feature_group_with_next_token(self, session_with_mock_client): - """Test describe_feature_group includes NextToken when provided.""" - session = session_with_mock_client - session.sagemaker_client.describe_feature_group.return_value = {} - - session.describe_feature_group("my-fg", next_token="abc123") - - session.sagemaker_client.describe_feature_group.assert_called_once_with( - FeatureGroupName="my-fg", NextToken="abc123" - ) - - def test_describe_feature_group_omits_none_next_token(self, session_with_mock_client): - """Test describe_feature_group omits NextToken when None.""" - session = session_with_mock_client - session.sagemaker_client.describe_feature_group.return_value = {} - - session.describe_feature_group("my-fg", next_token=None) - - call_kwargs = session.sagemaker_client.describe_feature_group.call_args[1] - assert "NextToken" not in call_kwargs - - # --- update_feature_group --- - - def test_update_feature_group_all_params(self, session_with_mock_client): - """Test update_feature_group with all optional params provided.""" - session = session_with_mock_client - expected = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123:feature-group/my-fg"} - session.sagemaker_client.update_feature_group.return_value = expected - - additions = [{"FeatureName": "new_feat", "FeatureType": "String"}] - online_cfg = {"EnableOnlineStore": True} - throughput_cfg = {"ThroughputMode": "OnDemand"} - - result = session.update_feature_group( - "my-fg", - feature_additions=additions, - online_store_config=online_cfg, - throughput_config=throughput_cfg, - ) - - session.sagemaker_client.update_feature_group.assert_called_once_with( - FeatureGroupName="my-fg", - FeatureAdditions=additions, - OnlineStoreConfig=online_cfg, - ThroughputConfig=throughput_cfg, - ) - assert result == expected - - def test_update_feature_group_omits_none_params(self, session_with_mock_client): - """Test update_feature_group omits None optional params.""" - session = session_with_mock_client - session.sagemaker_client.update_feature_group.return_value = {} - - session.update_feature_group("my-fg") - - call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] - assert call_kwargs == {"FeatureGroupName": "my-fg"} - - def test_update_feature_group_partial_params(self, session_with_mock_client): - """Test update_feature_group with only some optional params.""" - session = session_with_mock_client - session.sagemaker_client.update_feature_group.return_value = {} - - throughput_cfg = {"ThroughputMode": "Provisioned"} - session.update_feature_group("my-fg", throughput_config=throughput_cfg) - - call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] - assert call_kwargs == { - "FeatureGroupName": "my-fg", - "ThroughputConfig": throughput_cfg, - } - - # --- list_feature_groups --- - - def test_list_feature_groups_no_params(self, session_with_mock_client): - """Test list_feature_groups with no filters delegates with empty args.""" - session = session_with_mock_client - expected = {"FeatureGroupSummaries": []} - session.sagemaker_client.list_feature_groups.return_value = expected - - result = session.list_feature_groups() - - session.sagemaker_client.list_feature_groups.assert_called_once_with() - assert result == expected - - def test_list_feature_groups_all_params(self, session_with_mock_client): - """Test list_feature_groups with all params provided.""" - session = session_with_mock_client - session.sagemaker_client.list_feature_groups.return_value = {} - - session.list_feature_groups( - name_contains="test", - feature_group_status_equals="Created", - offline_store_status_equals="Active", - creation_time_after="2024-01-01", - creation_time_before="2024-12-31", - sort_order="Ascending", - sort_by="Name", - max_results=10, - next_token="token123", - ) - - session.sagemaker_client.list_feature_groups.assert_called_once_with( - NameContains="test", - FeatureGroupStatusEquals="Created", - OfflineStoreStatusEquals="Active", - CreationTimeAfter="2024-01-01", - CreationTimeBefore="2024-12-31", - SortOrder="Ascending", - SortBy="Name", - MaxResults=10, - NextToken="token123", - ) - - def test_list_feature_groups_omits_none_params(self, session_with_mock_client): - """Test list_feature_groups omits None params.""" - session = session_with_mock_client - session.sagemaker_client.list_feature_groups.return_value = {} - - session.list_feature_groups(name_contains="test", max_results=5) - - call_kwargs = session.sagemaker_client.list_feature_groups.call_args[1] - assert call_kwargs == {"NameContains": "test", "MaxResults": 5} - - # --- update_feature_metadata --- - - def test_update_feature_metadata_all_params(self, session_with_mock_client): - """Test update_feature_metadata with all optional params.""" - session = session_with_mock_client - session.sagemaker_client.update_feature_metadata.return_value = {} - - additions = [{"Key": "team", "Value": "ml"}] - removals = [{"Key": "deprecated"}] - - result = session.update_feature_metadata( - "my-fg", - "my-feature", - description="Updated desc", - parameter_additions=additions, - parameter_removals=removals, - ) - - session.sagemaker_client.update_feature_metadata.assert_called_once_with( - FeatureGroupName="my-fg", - FeatureName="my-feature", - Description="Updated desc", - ParameterAdditions=additions, - ParameterRemovals=removals, - ) - assert result == {} - - def test_update_feature_metadata_omits_none_params(self, session_with_mock_client): - """Test update_feature_metadata omits None optional params.""" - session = session_with_mock_client - session.sagemaker_client.update_feature_metadata.return_value = {} - - session.update_feature_metadata("my-fg", "my-feature") - - call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] - assert call_kwargs == { - "FeatureGroupName": "my-fg", - "FeatureName": "my-feature", - } - - def test_update_feature_metadata_partial_params(self, session_with_mock_client): - """Test update_feature_metadata with only description.""" - session = session_with_mock_client - session.sagemaker_client.update_feature_metadata.return_value = {} - - session.update_feature_metadata("my-fg", "my-feature", description="New desc") - - call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] - assert call_kwargs == { - "FeatureGroupName": "my-fg", - "FeatureName": "my-feature", - "Description": "New desc", - } - - # --- describe_feature_metadata --- - - def test_describe_feature_metadata(self, session_with_mock_client): - """Test describe_feature_metadata delegates and returns response.""" - session = session_with_mock_client - expected = {"FeatureGroupName": "my-fg", "FeatureName": "my-feature"} - session.sagemaker_client.describe_feature_metadata.return_value = expected - - result = session.describe_feature_metadata("my-fg", "my-feature") - - session.sagemaker_client.describe_feature_metadata.assert_called_once_with( - FeatureGroupName="my-fg", FeatureName="my-feature" - ) - assert result == expected - -MODULE = "sagemaker.core.helper.session_helper" - - -class TestCreateFeatureGroup: - """Test cases for create_feature_group session method.""" - - @pytest.fixture - def session(self): - """Create a Session with a mocked sagemaker_client.""" - mock_boto_session = Mock() - mock_boto_session.region_name = "us-west-2" - mock_boto_session.client.return_value = Mock() - mock_boto_session.resource.return_value = Mock() - session = Session(boto_session=mock_boto_session) - session.sagemaker_client = Mock() - return session - - @pytest.fixture - def base_args(self): - """Minimal required arguments for create_feature_group.""" - return dict( - feature_group_name="my-fg", - record_identifier_name="record_id", - event_time_feature_name="event_time", - feature_definitions=[{"FeatureName": "f1", "FeatureType": "String"}], - ) - - # --- Full parameter pass-through --- - - def test_create_feature_group_all_params(self, session, base_args): - """Test that all parameters are passed through to sagemaker_client.""" - role = "arn:aws:iam::123456789012:role/Role" - online_cfg = {"SecurityConfig": {"KmsKeyId": "key-123"}} - offline_cfg = {"S3StorageConfig": {"S3Uri": "s3://bucket"}} - throughput_cfg = {"ThroughputMode": "ON_DEMAND"} - description = "My feature group" - tags = [{"Key": "team", "Value": "ml"}] - - expected_response = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-fg"} - session.sagemaker_client.create_feature_group.return_value = expected_response - - with patch(f"{MODULE}.format_tags", return_value=tags) as mock_format, \ - patch(f"{MODULE}._append_project_tags", return_value=tags) as mock_proj, \ - patch.object(session, "_append_sagemaker_config_tags", return_value=tags), \ - patch(f"{MODULE}.resolve_value_from_config", return_value=role), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", side_effect=[online_cfg, offline_cfg]): - - result = session.create_feature_group( - **base_args, - role_arn=role, - online_store_config=online_cfg, - offline_store_config=offline_cfg, - throughput_config=throughput_cfg, - description=description, - tags=tags, - ) - - assert result == expected_response - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["FeatureGroupName"] == "my-fg" - assert call_kwargs["RecordIdentifierFeatureName"] == "record_id" - assert call_kwargs["EventTimeFeatureName"] == "event_time" - assert call_kwargs["FeatureDefinitions"] == base_args["feature_definitions"] - assert call_kwargs["RoleArn"] == role - # EnableOnlineStore is set to True when online config is inferred - assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True - assert call_kwargs["OfflineStoreConfig"] == offline_cfg - assert call_kwargs["ThroughputConfig"] == throughput_cfg - assert call_kwargs["Description"] == description - assert call_kwargs["Tags"] == tags - - # --- Tag processing pipeline --- - - def test_tag_processing_pipeline_order(self, session, base_args): - """Test that tags go through format_tags -> _append_project_tags -> _append_sagemaker_config_tags.""" - raw_tags = {"team": "ml"} - formatted = [{"Key": "team", "Value": "ml"}] - with_project = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}] - with_config = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}, {"Key": "cfg", "Value": "v"}] - - with patch(f"{MODULE}.format_tags", return_value=formatted) as mock_format, \ - patch(f"{MODULE}._append_project_tags", return_value=with_project) as mock_proj, \ - patch.object(session, "_append_sagemaker_config_tags", return_value=with_config) as mock_cfg, \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group(**base_args, tags=raw_tags) - - # format_tags is called with the raw input - mock_format.assert_called_once_with(raw_tags) - # _append_project_tags receives the formatted tags - mock_proj.assert_called_once_with(formatted) - # _append_sagemaker_config_tags receives the project-appended tags - mock_cfg.assert_called_once_with(with_project, "SageMaker.FeatureGroup.Tags") - - # Final tags in the API call should be the config-appended tags - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["Tags"] == with_config - - def test_tags_none_still_processed(self, session, base_args): - """Test that None tags still go through the pipeline (format_tags handles None).""" - with patch(f"{MODULE}.format_tags", return_value=None) as mock_format, \ - patch(f"{MODULE}._append_project_tags", return_value=None) as mock_proj, \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group(**base_args, tags=None) - - mock_format.assert_called_once_with(None) - mock_proj.assert_called_once_with(None) - # Tags=None should be omitted from the API call via update_args - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert "Tags" not in call_kwargs - - # --- role_arn resolution from config --- - - def test_role_arn_resolved_from_config_when_none(self, session, base_args): - """Test that role_arn is resolved from SageMaker Config when not provided.""" - config_role = "arn:aws:iam::123456789012:role/ConfigRole" - - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value=config_role) as mock_resolve, \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group(**base_args, role_arn=None) - - mock_resolve.assert_called_once_with( - None, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session - ) - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["RoleArn"] == config_role - - def test_role_arn_passed_through_when_provided(self, session, base_args): - """Test that an explicit role_arn is passed to resolve_value_from_config (which returns it).""" - explicit_role = "arn:aws:iam::123456789012:role/ExplicitRole" - - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value=explicit_role) as mock_resolve, \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group(**base_args, role_arn=explicit_role) - - mock_resolve.assert_called_once_with( - explicit_role, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session - ) - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["RoleArn"] == explicit_role - - # --- online_store_config merging and EnableOnlineStore --- - - def test_online_store_config_merged_and_enable_set(self, session, base_args): - """Test that online_store_config is merged from config and EnableOnlineStore=True is set.""" - inferred_online = {"SecurityConfig": {"KmsKeyId": "config-key"}} - - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", - side_effect=[inferred_online, None]) as mock_update: - - session.create_feature_group(**base_args, online_store_config=None) - - # First call is for online store config - mock_update.assert_any_call( - None, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, sagemaker_session=session - ) - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True - assert call_kwargs["OnlineStoreConfig"]["SecurityConfig"]["KmsKeyId"] == "config-key" - - def test_online_store_config_none_when_no_config(self, session, base_args): - """Test that OnlineStoreConfig is omitted when config returns None.""" - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group(**base_args) - - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert "OnlineStoreConfig" not in call_kwargs - - def test_online_store_config_explicit_gets_enable_set(self, session, base_args): - """Test that explicitly provided online_store_config also gets EnableOnlineStore=True.""" - explicit_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} - # update_nested_dictionary returns the merged result - merged_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} - - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", - side_effect=[merged_online, None]): - - session.create_feature_group(**base_args, online_store_config=explicit_online) - - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True - - # --- offline_store_config merging --- - - def test_offline_store_config_merged_from_config(self, session, base_args): - """Test that offline_store_config is merged from SageMaker Config.""" - inferred_offline = {"S3StorageConfig": {"S3Uri": "s3://config-bucket"}} - - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", - side_effect=[None, inferred_offline]) as mock_update: - - session.create_feature_group(**base_args, offline_store_config=None) - - # Second call is for offline store config - mock_update.assert_any_call( - None, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, sagemaker_session=session - ) - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["OfflineStoreConfig"] == inferred_offline - - def test_offline_store_config_none_when_no_config(self, session, base_args): - """Test that OfflineStoreConfig is omitted when config returns None.""" - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group(**base_args) - - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert "OfflineStoreConfig" not in call_kwargs - - # --- None optional parameters omitted --- - - def test_none_optional_params_omitted(self, session, base_args): - """Test that None optional params (throughput, description, tags) are omitted from API call.""" - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group(**base_args) - - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert "ThroughputConfig" not in call_kwargs - assert "Description" not in call_kwargs - assert "Tags" not in call_kwargs - assert "OnlineStoreConfig" not in call_kwargs - assert "OfflineStoreConfig" not in call_kwargs - # Required params should still be present - assert "FeatureGroupName" in call_kwargs - assert "RecordIdentifierFeatureName" in call_kwargs - assert "EventTimeFeatureName" in call_kwargs - assert "FeatureDefinitions" in call_kwargs - assert "RoleArn" in call_kwargs - - def test_partial_optional_params(self, session, base_args): - """Test that only provided optional params appear in the API call.""" - throughput = {"ThroughputMode": "ON_DEMAND"} - - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - session.create_feature_group( - **base_args, - throughput_config=throughput, - description="test desc", - ) - - call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] - assert call_kwargs["ThroughputConfig"] == throughput - assert call_kwargs["Description"] == "test desc" - assert "Tags" not in call_kwargs - assert "OnlineStoreConfig" not in call_kwargs - assert "OfflineStoreConfig" not in call_kwargs - - # --- Return value --- - - def test_returns_api_response(self, session, base_args): - """Test that the method returns the sagemaker_client response.""" - expected = {"FeatureGroupArn": "arn:fg"} - session.sagemaker_client.create_feature_group.return_value = expected - - with patch(f"{MODULE}.format_tags", return_value=None), \ - patch(f"{MODULE}._append_project_tags", return_value=None), \ - patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ - patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ - patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): - - result = session.create_feature_group(**base_args) - - assert result == expected From 3cd468657da96901079073111a3d2f5fe52327cb Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Tue, 3 Mar 2026 13:30:00 -0800 Subject: [PATCH 19/21] refactor(feature-processor): Migrate to FeatureGroup resource API - Replace sagemaker_session.describe_feature_group() calls with FeatureGroup.get() - Update _input_loader.py to use FeatureGroup resource attributes instead of dictionary access - Update feature_scheduler.py to use FeatureGroup.get() and access creation_time as attribute - Update _feature_group_lineage_entity_handler.py to return FeatureGroup resource instead of Dict - Remove unused imports (Dict, Any, FEATURE_GROUP, CREATION_TIME constants) - Replace dictionary key access with typed resource properties (offline_store_config, data_catalog_config, event_time_feature_name, etc.) - Update unit tests to reflect new FeatureGroup resource API usage - Improves type safety and reduces reliance on dictionary-based API responses --- .../feature_processor/_input_loader.py | 29 +++--- .../feature_processor/feature_scheduler.py | 5 +- .../_feature_group_lineage_entity_handler.py | 14 ++- .../lineage/test_constants.py | 1 + ...st_feature_group_lineage_entity_handler.py | 18 ++-- .../test_feature_scheduler.py | 19 ++-- .../feature_processor/test_input_loader.py | 93 ++++++++++++++----- 7 files changed, 119 insertions(+), 60 deletions(-) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py index 627de943c1..7f8ef855b7 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py @@ -34,6 +34,7 @@ InputOffsetParser, ) from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.core.resources import FeatureGroup T = TypeVar("T") @@ -96,8 +97,9 @@ def load_from_feature_group( sagemaker_session: Session = self.sagemaker_session or Session() feature_group_name = feature_group_data_source.name - feature_group = sagemaker_session.describe_feature_group( - self._parse_name_from_arn(feature_group_name) + feature_group = FeatureGroup.get( + feature_group_name=self._parse_name_from_arn(feature_group_name), + session=sagemaker_session.boto_session, ) logger.debug( "Called describe_feature_group with %s and received: %s", @@ -105,17 +107,20 @@ def load_from_feature_group( feature_group, ) - if "OfflineStoreConfig" not in feature_group: + if not feature_group.offline_store_config: raise ValueError( f"Input Feature Groups must have an enabled Offline Store." f" Feature Group: {feature_group_name} does not have an Offline Store enabled." ) - offline_store_uri = feature_group["OfflineStoreConfig"]["S3StorageConfig"][ - "ResolvedOutputS3Uri" - ] + offline_store_config = feature_group.offline_store_config + offline_store_uri = offline_store_config.s3_storage_config.resolved_output_s3_uri - table_format = feature_group["OfflineStoreConfig"].get("TableFormat", None) + table_format = ( + offline_store_config.table_format + if offline_store_config.table_format + else None + ) if table_format not in self._supported_table_format: raise ValueError( @@ -127,15 +132,15 @@ def load_from_feature_group( end_offset = feature_group_data_source.input_end_offset if table_format == "Iceberg": - data_catalog_config = feature_group["OfflineStoreConfig"]["DataCatalogConfig"] + data_catalog_config = offline_store_config.data_catalog_config return self.load_from_iceberg_table( IcebergTableDataSource( offline_store_uri, - data_catalog_config["Catalog"], - data_catalog_config["Database"], - data_catalog_config["TableName"], + data_catalog_config.catalog, + data_catalog_config.database, + data_catalog_config.table_name, ), - feature_group["EventTimeFeatureName"], + feature_group.event_time_feature_name, start_offset, end_offset, ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py index 4f67a9dc5d..c18232cb08 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py @@ -73,6 +73,7 @@ ) from sagemaker.core.s3 import s3_path_join +from sagemaker.core.resources import FeatureGroup from sagemaker.core.helper.session_helper import Session, get_execution_role from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( @@ -770,8 +771,8 @@ def _validate_fg_lineage_resources(feature_group_name: str, sagemaker_session: S groups. """ - feature_group = sagemaker_session.describe_feature_group(feature_group_name=feature_group_name) - feature_group_creation_time = feature_group["CreationTime"].strftime("%s") + feature_group = FeatureGroup.get(feature_group_name=feature_group_name, session=sagemaker_session.boto_session) + feature_group_creation_time = feature_group.creation_time.strftime("%s") feature_group_context = _get_feature_group_lineage_context_name( feature_group_name=feature_group_name, feature_group_creation_time=feature_group_creation_time, diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py index 55230d7c1c..f8785a9a0a 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py @@ -14,7 +14,6 @@ from __future__ import absolute_import import re -from typing import Dict, Any import logging from sagemaker.core.helper.session_helper import Session @@ -24,10 +23,9 @@ ) from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( SAGEMAKER, - FEATURE_GROUP, - CREATION_TIME, ) from sagemaker.core.lineage.context import Context +from sagemaker.core.resources import FeatureGroup # pylint: disable=C0301 from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( @@ -62,8 +60,8 @@ def retrieve_feature_group_context_arns( ), sagemaker_session=sagemaker_session, ) - feature_group_name = feature_group[FEATURE_GROUP] - feature_group_creation_time = feature_group[CREATION_TIME].strftime("%s") + feature_group_name = feature_group.feature_group_name + feature_group_creation_time = feature_group.creation_time.strftime("%s") feature_group_pipeline_context = ( FeatureGroupLineageEntityHandler._load_feature_group_pipeline_context( feature_group_name=feature_group_name, @@ -87,7 +85,7 @@ def retrieve_feature_group_context_arns( @staticmethod def _describe_feature_group( feature_group_name: str, sagemaker_session: Session - ) -> Dict[str, Any]: + ) -> FeatureGroup: """Retrieve the Feature Group. Arguments: @@ -97,9 +95,9 @@ def _describe_feature_group( function creates one using the default AWS configuration chain. Returns: - Dict[str, Any]: The Feature Group details. + FeatureGroup: The Feature Group resource. """ - feature_group = sagemaker_session.describe_feature_group(feature_group_name) + feature_group = FeatureGroup.get(feature_group_name=feature_group_name, session=sagemaker_session.boto_session) logger.debug( "Called describe_feature_group with %s and received: %s", feature_group_name, diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py index 12a323f871..9103e54b0f 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py @@ -46,6 +46,7 @@ CREATION_TIME = "123123123" LAST_UPDATE_TIME = "234234234" SAGEMAKER_SESSION_MOCK = Mock(Session) +SAGEMAKER_SESSION_MOCK.boto_session = Mock() CONTEXT_MOCK_01 = Mock(Context) CONTEXT_MOCK_02 = Mock(Context) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py index bc725570b9..f82c8e0e41 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py @@ -11,12 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from mock import patch, call +from mock import patch, call, Mock from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( FeatureGroupLineageEntityHandler, ) from sagemaker.core.lineage.context import Context +from sagemaker.core.resources import FeatureGroup from test_constants import ( SAGEMAKER_SESSION_MOCK, @@ -26,11 +27,15 @@ FEATURE_GROUP_NAME, ) +FEATURE_GROUP_MOCK = Mock() +FEATURE_GROUP_MOCK.feature_group_name = FEATURE_GROUP["FeatureGroupName"] +FEATURE_GROUP_MOCK.creation_time = FEATURE_GROUP["CreationTime"] + def test_retrieve_feature_group_context_arns(): with patch.object( - SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP - ) as fg_describe_method: + FeatureGroup, "get", return_value=FEATURE_GROUP_MOCK + ) as fg_get_method: with patch.object( Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02] ) as context_load: @@ -44,16 +49,17 @@ def test_retrieve_feature_group_context_arns(): assert result.name == FEATURE_GROUP_NAME assert result.pipeline_context_arn == "context-arn-fep" assert result.pipeline_version_context_arn == "context-arn-fep-ver" - fg_describe_method.assert_called_once_with(FEATURE_GROUP_NAME) + fg_get_method.assert_called_once_with(feature_group_name=FEATURE_GROUP_NAME, session=SAGEMAKER_SESSION_MOCK.boto_session) + creation_time_str = FEATURE_GROUP_MOCK.creation_time.strftime("%s") context_load.assert_has_calls( [ call( - context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + context_name=f"{FEATURE_GROUP_NAME}-{creation_time_str}" f"-feature-group-pipeline", sagemaker_session=SAGEMAKER_SESSION_MOCK, ), call( - context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + context_name=f"{FEATURE_GROUP_NAME}-{creation_time_str}" f"-feature-group-pipeline-version", sagemaker_session=SAGEMAKER_SESSION_MOCK, ), diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py index 5f3fb85934..9cd5655de5 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py @@ -29,6 +29,7 @@ ) from sagemaker.core.lineage.context import Context from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.resources import FeatureGroup from sagemaker.core.helper.session_helper import Session from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode @@ -89,10 +90,13 @@ PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name" NOW = datetime.now() SAGEMAKER_SESSION_MOCK = Mock(Session) +SAGEMAKER_SESSION_MOCK.boto_session = Mock() CONTEXT_MOCK_01 = Mock(Context) CONTEXT_MOCK_02 = Mock(Context) CONTEXT_MOCK_03 = Mock(Context) FEATURE_GROUP = tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() +FEATURE_GROUP_MOCK = Mock() +FEATURE_GROUP_MOCK.creation_time = FEATURE_GROUP["CreationTime"] PIPELINE = tdh.PIPELINE.copy() TAGS = [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] @@ -783,8 +787,8 @@ def test_execute(validation): def test_validate_fg_lineage_resources_happy_case(): with patch.object( - SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP - ) as fg_describe_method: + FeatureGroup, "get", return_value=FEATURE_GROUP_MOCK + ) as fg_get_method: with patch.object( Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02, CONTEXT_MOCK_03] ) as context_load: @@ -795,16 +799,17 @@ def test_validate_fg_lineage_resources_happy_case(): feature_group_name="some_fg", sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - fg_describe_method.assert_called_once_with(feature_group_name="some_fg") + fg_get_method.assert_called_once_with(feature_group_name="some_fg", session=SAGEMAKER_SESSION_MOCK.boto_session) + creation_time_str = FEATURE_GROUP_MOCK.creation_time.strftime("%s") context_load.assert_has_calls( [ call( - context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + context_name=f'{"some_fg"}-{creation_time_str}' f"-feature-group-pipeline", sagemaker_session=SAGEMAKER_SESSION_MOCK, ), call( - context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + context_name=f'{"some_fg"}-{creation_time_str}' f"-feature-group-pipeline-version", sagemaker_session=SAGEMAKER_SESSION_MOCK, ), @@ -814,7 +819,7 @@ def test_validate_fg_lineage_resources_happy_case(): def test_validete_fg_lineage_resources_rnf(): - with patch.object(SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP): + with patch.object(FeatureGroup, "get", return_value=FEATURE_GROUP_MOCK): with patch.object( Context, "load", @@ -824,7 +829,7 @@ def test_validete_fg_lineage_resources_rnf(): ), ): feature_group_name = "some_fg" - feature_group_creation_time = FEATURE_GROUP["CreationTime"].strftime("%s") + feature_group_creation_time = FEATURE_GROUP_MOCK.creation_time.strftime("%s") context_name = f"{feature_group_name}-{feature_group_creation_time}" with pytest.raises( ValueError, diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py index 24d20f96c5..1e054b57b1 100644 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py @@ -29,16 +29,53 @@ from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper from sagemaker.core.helper.session_helper import Session +from sagemaker.core.resources import FeatureGroup + + +def _build_fg_mock(response=None): + """Build a mock FeatureGroup object from a describe response dict.""" + if response is None: + response = tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() + fg_mock = Mock() + fg_mock.feature_group_name = response["FeatureGroupName"] + fg_mock.event_time_feature_name = response["EventTimeFeatureName"] + fg_mock.creation_time = response["CreationTime"] + + if "OfflineStoreConfig" in response: + osc = response["OfflineStoreConfig"] + fg_mock.offline_store_config = Mock() + fg_mock.offline_store_config.s3_storage_config = Mock() + fg_mock.offline_store_config.s3_storage_config.resolved_output_s3_uri = ( + osc["S3StorageConfig"]["ResolvedOutputS3Uri"] + ) + fg_mock.offline_store_config.table_format = osc.get("TableFormat", None) + if "DataCatalogConfig" in osc: + fg_mock.offline_store_config.data_catalog_config = Mock() + fg_mock.offline_store_config.data_catalog_config.catalog = osc["DataCatalogConfig"]["Catalog"] + fg_mock.offline_store_config.data_catalog_config.database = osc["DataCatalogConfig"]["Database"] + fg_mock.offline_store_config.data_catalog_config.table_name = osc["DataCatalogConfig"]["TableName"] + else: + fg_mock.offline_store_config = None + + return fg_mock @pytest.fixture -def describe_fg_response(): - return tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() +def fg_mock(): + return _build_fg_mock() @pytest.fixture -def sagemaker_session(describe_fg_response): - return Mock(Session, describe_feature_group=Mock(return_value=describe_fg_response)) +def sagemaker_session(): + session = Mock(Session) + session.boto_session = Mock() + return session + + +@pytest.fixture +def mock_fg_get(fg_mock): + with patch.object(FeatureGroup, "get", return_value=fg_mock) as mock_get: + yield mock_get @pytest.fixture @@ -152,7 +189,7 @@ def test_load_from_iceberg_table( "sagemaker.mlops.feature_store.feature_processor._input_loader.SparkDataFrameInputLoader.load_from_date_partitioned_s3" ) def test_load_from_feature_group_with_arn( - mock_load_from_date_partitioned_s3, sagemaker_session, input_loader + mock_load_from_date_partitioned_s3, sagemaker_session, input_loader, mock_fg_get ): fg_arn = tdh.INPUT_FEATURE_GROUP_ARN fg_name = tdh.INPUT_FEATURE_GROUP_NAME @@ -162,7 +199,7 @@ def test_load_from_feature_group_with_arn( input_loader.load_from_feature_group(fg_data_source) - sagemaker_session.describe_feature_group.assert_called_with(fg_name) + mock_fg_get.assert_called_with(feature_group_name=fg_name, session=sagemaker_session.boto_session) mock_load_from_date_partitioned_s3.assert_called_with( ParquetDataSource(tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI), "start", @@ -170,50 +207,56 @@ def test_load_from_feature_group_with_arn( ) -def test_load_from_feature_group_offline_store_not_enabled(input_loader, describe_fg_response): +def test_load_from_feature_group_offline_store_not_enabled(input_loader): fg_name = tdh.INPUT_FEATURE_GROUP_NAME fg_data_source = FeatureGroupDataSource(name=fg_name) - with pytest.raises( - ValueError, - match=( - f"Input Feature Groups must have an enabled Offline Store." - f" Feature Group: {fg_name} does not have an Offline Store enabled." - ), - ): - del describe_fg_response["OfflineStoreConfig"] - input_loader.load_from_feature_group(fg_data_source) + no_offline_response = tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() + del no_offline_response["OfflineStoreConfig"] + no_offline_mock = _build_fg_mock(no_offline_response) + with patch.object(FeatureGroup, "get", return_value=no_offline_mock): + with pytest.raises( + ValueError, + match=( + f"Input Feature Groups must have an enabled Offline Store." + f" Feature Group: {fg_name} does not have an Offline Store enabled." + ), + ): + input_loader.load_from_feature_group(fg_data_source) def test_load_from_feature_group_with_default_table_format( - sagemaker_session, input_loader, spark_session + sagemaker_session, input_loader, spark_session, mock_fg_get ): fg_name = tdh.INPUT_FEATURE_GROUP_NAME fg_data_source = FeatureGroupDataSource(name=fg_name) input_loader.load_from_feature_group(fg_data_source) - sagemaker_session.describe_feature_group.assert_called_with(fg_name) + mock_fg_get.assert_called_with(feature_group_name=fg_name, session=sagemaker_session.boto_session) spark_session.read.parquet.assert_called_with( tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI.replace("s3:", "s3a:") ) def test_load_from_feature_group_with_iceberg_table_format( - describe_fg_response, spark_session_factory, spark_session, environment_helper + spark_session_factory, spark_session, environment_helper ): - describe_iceberg_fg_response = describe_fg_response.copy() + describe_iceberg_fg_response = tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() describe_iceberg_fg_response["OfflineStoreConfig"]["TableFormat"] = "Iceberg" - mocked_session = Mock( - Session, describe_feature_group=Mock(return_value=describe_iceberg_fg_response) - ) + iceberg_fg_mock = _build_fg_mock(describe_iceberg_fg_response) + mocked_session = Mock(Session) + mocked_session.boto_session = Mock() mock_input_loader = SparkDataFrameInputLoader( spark_session_factory, environment_helper, mocked_session ) fg_name = tdh.INPUT_FEATURE_GROUP_NAME fg_data_source = FeatureGroupDataSource(name=fg_name) - mock_input_loader.load_from_feature_group(fg_data_source) + with patch.object( + FeatureGroup, "get", return_value=iceberg_fg_mock + ) as mock_get: + mock_input_loader.load_from_feature_group(fg_data_source) - mocked_session.describe_feature_group.assert_called_with(fg_name) + mock_get.assert_called_with(feature_group_name=fg_name, session=mocked_session.boto_session) spark_session.table.assert_called_with( "awsdatacatalog.sagemaker_featurestore.input_fg_1680142547" ) From 39f1be2d0b0b10bb8dbad4424137e8c1a9ad0208 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Tue, 3 Mar 2026 14:24:19 -0800 Subject: [PATCH 20/21] add `build` to test_requirements --- requirements/extras/test_requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index cacbd2a165..898dd31608 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -11,4 +11,5 @@ scipy omegaconf graphene typing_extensions>=4.9.0 -tensorflow>=2.16.2,<=2.19.0 \ No newline at end of file +tensorflow>=2.16.2,<=2.19.0 +build \ No newline at end of file From 4797633a1dce6a7d4a655ec811ac6d87f7105baf Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Tue, 3 Mar 2026 14:30:58 -0800 Subject: [PATCH 21/21] add upper bounds for test dependencies --- sagemaker-mlops/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sagemaker-mlops/pyproject.toml b/sagemaker-mlops/pyproject.toml index 2383074ac7..b588d7a1cf 100644 --- a/sagemaker-mlops/pyproject.toml +++ b/sagemaker-mlops/pyproject.toml @@ -38,8 +38,8 @@ test = [ "setuptools<82", "pyspark==3.3.2", "sagemaker-feature-store-pyspark-3.3", - "pandas", - "numpy", + "pandas<3.0", + "numpy<3.0", ] dev = [ "pytest",