Skip to content

Commit 7101c0b

Browse files
author
Pravali Uppugunduri
committed
fix: Add HMAC integrity verification for Triton inference handler
- Add HMAC integrity check before pickle deserialization in TritonPythonModel.initialize() - Replace hardcoded secret key with generate_secret_key() in _prepare_for_triton() ONNX path - Add _hmac_signing() after ONNX export for both PyTorch and TensorFlow frameworks - Add secret key validation in _start_triton_server() to reject None/empty keys Fixes RCE vulnerabilities in Triton handler by aligning with HMAC verification patterns used by TorchServe, MMS, TF Serving, and SMD handlers.
1 parent 6a174f4 commit 7101c0b

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3075,8 +3075,8 @@ def _prepare_for_triton(self):
30753075
export_path.mkdir(parents=True)
30763076

30773077
if self.model:
3078-
self.secret_key = "dummy secret key for onnx backend"
3079-
3078+
# ONNX path: export model to ONNX format for Triton's native ONNX backend.
3079+
# No pickle is created or loaded at runtime, so no HMAC signing is needed.
30803080
if self.framework == Framework.PYTORCH:
30813081
self._export_pytorch_to_onnx(
30823082
export_path=export_path, model=self.model, schema_builder=self.schema_builder

sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@ def auto_complete_config(auto_complete_model_config):
2626
def initialize(self, args: dict) -> None:
2727
"""Placeholder docstring"""
2828
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
29+
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")
2930
with open(str(serve_path), mode="rb") as f:
30-
inference_spec, schema_builder = cloudpickle.load(f)
31+
buffer = f.read()
32+
perform_integrity_check(buffer=buffer, metadata_path=str(metadata_path))
3133

32-
# TODO: HMAC signing for integrity check
34+
with open(str(serve_path), mode="rb") as f:
35+
inference_spec, schema_builder = cloudpickle.load(f)
3336

3437
self.inference_spec = inference_spec
3538
self.schema_builder = schema_builder

sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,16 @@ def _start_triton_server(
4141
env_vars.update(
4242
{
4343
"TRITON_MODEL_DIR": "/models/model",
44-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
4544
"LOCAL_PYTHON": platform.python_version(),
4645
}
4746
)
4847

48+
# Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where
49+
# pickle integrity verification is needed. The ONNX path does not
50+
# use pickles, so no secret key is required.
51+
if secret_key and isinstance(secret_key, str) and secret_key.strip():
52+
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key
53+
4954
if "cpu" not in image_uri:
5055
self.container = docker_client.containers.run(
5156
image=image_uri,
@@ -133,7 +138,12 @@ def _upload_triton_artifacts(
133138
env_vars = {
134139
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
135140
"TRITON_MODEL_DIR": "/opt/ml/model/model",
136-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
137141
"LOCAL_PYTHON": platform.python_version(),
138142
}
143+
144+
# Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where
145+
# pickle integrity verification is needed.
146+
if secret_key and isinstance(secret_key, str) and secret_key.strip():
147+
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key
148+
139149
return s3_upload_path, env_vars

sagemaker-serve/tests/unit/test_model_builder_utils_triton.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,14 @@ class TestPrepareForTriton(unittest.TestCase):
8181
"""Test _prepare_for_triton method."""
8282

8383
@patch('shutil.copy2')
84+
@patch.object(_ModelBuilderUtils, '_hmac_signing')
8485
@patch.object(_ModelBuilderUtils, '_export_pytorch_to_onnx')
85-
def test_prepare_for_triton_pytorch(self, mock_export, mock_copy):
86-
"""Test preparing PyTorch model for Triton."""
86+
def test_prepare_for_triton_pytorch(self, mock_export, mock_hmac, mock_copy):
87+
"""Test preparing PyTorch model for Triton.
88+
89+
ONNX path: no pickle is created or loaded at runtime,
90+
so no HMAC signing is needed.
91+
"""
8792
utils = _ModelBuilderUtils()
8893
utils.framework = Framework.PYTORCH
8994
utils.model = Mock()
@@ -94,11 +99,17 @@ def test_prepare_for_triton_pytorch(self, mock_export, mock_copy):
9499
utils._prepare_for_triton()
95100

96101
mock_export.assert_called_once()
102+
mock_hmac.assert_not_called()
97103

98104
@patch('shutil.copy2')
105+
@patch.object(_ModelBuilderUtils, '_hmac_signing')
99106
@patch.object(_ModelBuilderUtils, '_export_tf_to_onnx')
100-
def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy):
101-
"""Test preparing TensorFlow model for Triton."""
107+
def test_prepare_for_triton_tensorflow(self, mock_export, mock_hmac, mock_copy):
108+
"""Test preparing TensorFlow model for Triton.
109+
110+
ONNX path: no pickle is created or loaded at runtime,
111+
so no HMAC signing is needed.
112+
"""
102113
utils = _ModelBuilderUtils()
103114
utils.framework = Framework.TENSORFLOW
104115
utils.model = Mock()
@@ -109,6 +120,7 @@ def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy):
109120
utils._prepare_for_triton()
110121

111122
mock_export.assert_called_once()
123+
mock_hmac.assert_not_called()
112124

113125
@patch('shutil.copy2')
114126
@patch.object(_ModelBuilderUtils, '_generate_config_pbtxt')

0 commit comments

Comments
 (0)