From 1f06017c665ac380cc30c25da1a92cbe0e1f81f3 Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Sat, 1 Mar 2025 08:10:15 +0000 Subject: [PATCH 1/3] Update celery forwarder to use greenlets instead of processes --- .../inference/forwarding/celery_forwarder.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 27007969a..a46fe9b2f 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional, TypedDict, Union from celery import Celery, Task, states +from gevent import monkey from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.tasks import EndpointPredictV1Request @@ -23,7 +24,9 @@ from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) -from requests import ConnectionError +from request import ConnectionError + +monkey.patch_all() logger = make_logger(logger_name()) @@ -144,7 +147,7 @@ def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs): # Don't fail the celery task even if there's a status code # (otherwise we can't really control what gets put in the result attribute) # in the task (https://docs.celeryq.dev/en/stable/reference/celery.result.html#celery.result.AsyncResult.status) - result = forwarder(payload) + result = forwarder.forward(payload) request_duration = datetime.now() - arrival_timestamp if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS): monitoring_metrics_gateway.emit_async_task_stuck_metric(queue_name) @@ -177,12 +180,7 @@ def start_celery_service( concurrency=concurrency, loglevel="INFO", optimization="fair", - # Don't use pool="solo" so we can send multiple concurrent requests over - # Historically, pool="solo" argument fixes the known issues of celery and some of the libraries. - # Particularly asyncio and torchvision transformers. This isn't relevant since celery-forwarder - # is quite lightweight - # TODO: we should probably use eventlet or gevent for the pool, since - # the forwarder is nearly the most extreme example of IO bound. + pool="gevent", ) worker.start() From f491aae3258bfebca30d82ac9f41c7da569e86dc Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Tue, 4 Mar 2025 22:13:10 +0000 Subject: [PATCH 2/3] Add imports --- .../inference/forwarding/celery_forwarder.py | 2 +- model-engine/requirements.in | 1 + model-engine/requirements.txt | 14 ++++++++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index a46fe9b2f..23cf50de6 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -24,7 +24,7 @@ from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) -from request import ConnectionError +from requests import ConnectionError monkey.patch_all() diff --git a/model-engine/requirements.in b/model-engine/requirements.in index d503f7b83..949413bf5 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -25,6 +25,7 @@ ddtrace==1.8.3 deprecation~=2.1 docker~=5.0 fastapi~=0.110.0 +gevent~=24.11.1 gitdb2~=2.0 gunicorn~=20.0 httptools==0.5.0 diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt index 6e784ecc9..54af7eac0 100644 --- a/model-engine/requirements.txt +++ b/model-engine/requirements.txt @@ -167,6 +167,8 @@ frozenlist==1.3.3 # aiosignal fsspec==2023.10.0 # via huggingface-hub +gevent==24.11.1 + # via -r model-engine/requirements.in gitdb==4.0.10 # via gitpython gitdb2==2.0.6 @@ -175,8 +177,10 @@ gitpython==3.1.41 # via -r model-engine/requirements.in google-auth==2.21.0 # via kubernetes -greenlet==2.0.2 - # via sqlalchemy +greenlet==3.1.1 + # via + # gevent + # sqlalchemy gunicorn==20.1.0 # via -r model-engine/requirements.in h11==0.14.0 @@ -569,6 +573,10 @@ yarl==1.9.2 # aiohttp zipp==3.16.0 # via importlib-metadata +zope-event==5.0 + # via gevent +zope-interface==7.2 + # via gevent # The following packages are considered to be unsafe in a requirements file: setuptools==69.0.3 @@ -576,3 +584,5 @@ setuptools==69.0.3 # gunicorn # kubernetes # kubernetes-asyncio + # zope-event + # zope-interface From 2fb2264c5c0a15d072e38882b8b0a91a1e045e2d Mon Sep 17 00:00:00 2001 From: Michael Choi Date: Wed, 5 Mar 2025 04:01:16 +0000 Subject: [PATCH 3/3] fix gevent import order --- .ruff.toml | 3 +++ .../inference/forwarding/celery_forwarder.py | 9 +++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index 69f832536..f5e81157b 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -3,3 +3,6 @@ line-length = 100 ignore = ["E501"] exclude = ["gen", "alembic"] + +[lint.per-file-ignores] +"model-engine/model_engine_server/inference/forwarding/celery_forwarder.py" = ["E402"] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py index 23cf50de6..4b387faad 100644 --- a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -1,10 +1,13 @@ +from gevent import monkey + +monkey.patch_all() + import argparse import json from datetime import datetime, timedelta from typing import Any, Dict, Optional, TypedDict, Union from celery import Celery, Task, states -from gevent import monkey from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME from model_engine_server.common.dtos.model_endpoints import BrokerType from model_engine_server.common.dtos.tasks import EndpointPredictV1Request @@ -26,8 +29,6 @@ ) from requests import ConnectionError -monkey.patch_all() - logger = make_logger(logger_name()) @@ -147,7 +148,7 @@ def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs): # Don't fail the celery task even if there's a status code # (otherwise we can't really control what gets put in the result attribute) # in the task (https://docs.celeryq.dev/en/stable/reference/celery.result.html#celery.result.AsyncResult.status) - result = forwarder.forward(payload) + result = forwarder(payload) request_duration = datetime.now() - arrival_timestamp if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS): monitoring_metrics_gateway.emit_async_task_stuck_metric(queue_name)