Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/pythonpublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
cache-from: type=gha
cache-to: type=gha,mode=max

build-and-push-external-plugin-service-images:
build-and-push-flyteagent-images:
runs-on: ubuntu-latest
needs: deploy
steps:
Expand All @@ -161,12 +161,12 @@ jobs:
registry: ghcr.io
username: "${{ secrets.FLYTE_BOT_USERNAME }}"
password: "${{ secrets.FLYTE_BOT_PAT }}"
- name: Prepare External Plugin Service Image Names
id: external-plugin-service-names
- name: Prepare Flyte Agent Image Names
id: flyteagent-names
uses: docker/metadata-action@v3
with:
images: |
ghcr.io/${{ github.repository_owner }}/external-plugin-service
ghcr.io/${{ github.repository_owner }}/flyteagent
tags: |
latest
${{ github.sha }}
Expand All @@ -177,10 +177,10 @@ jobs:
context: "."
platforms: linux/arm64, linux/amd64
push: ${{ github.event_name == 'release' }}
tags: ${{ steps.external-plugin-service-names.outputs.tags }}
tags: ${{ steps.flyteagent-names.outputs.tags }}
build-args: |
VERSION=${{ needs.deploy.outputs.version }}
file: ./Dockerfile.external-plugin-service
file: ./Dockerfile.agent
cache-from: type=gha
cache-to: type=gha,mode=max

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ flask==2.2.3
# via mlflow
flatbuffers==23.1.21
# via tensorflow
flyteidl==1.5.6
flyteidl==1.5.10
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
14 changes: 7 additions & 7 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import click
import grpc
from flyteidl.service.external_plugin_service_pb2_grpc import add_ExternalPluginServiceServicer_to_server
from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server

from flytekit.extend.backend.external_plugin_service import BackendPluginServer
from flytekit.extend.backend.agent_service import AgentService

_serve_help = """Start a grpc server for the external plugin service."""
_serve_help = """Start a grpc server for the agent service."""


@click.command("serve", help=_serve_help)
Expand All @@ -15,7 +15,7 @@
default="8000",
is_flag=False,
type=int,
help="Grpc port for the external plugin service",
help="Grpc port for the agent service",
)
@click.option(
"--worker",
Expand All @@ -35,11 +35,11 @@
@click.pass_context
def serve(_: click.Context, port, worker, timeout):
"""
Start a grpc server for the external plugin service.
Start a grpc server for the agent service.
"""
click.secho("Starting the external plugin service...", fg="blue")
click.secho("Starting the agent service...", fg="blue")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker))
add_ExternalPluginServiceServicer_to_server(BackendPluginServer(), server)
add_AsyncAgentServiceServicer_to_server(AgentService(), server)

server.add_insecure_port(f"[::]:{port}")
server.start()
Expand Down
54 changes: 54 additions & 0 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import grpc
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
CreateTaskRequest,
CreateTaskResponse,
DeleteTaskRequest,
DeleteTaskResponse,
GetTaskRequest,
GetTaskResponse,
Resource,
)
from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer

from flytekit import logger
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


class AgentService(AsyncAgentServiceServicer):
def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
try:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(context, tmp.type)
if agent is None:
return CreateTaskResponse()
return agent.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp)
except Exception as e:
logger.error(f"failed to create task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")

def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE))
return agent.get(context=context, resource_meta=request.resource_meta)
except Exception as e:
logger.error(f"failed to get task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")

def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
try:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return DeleteTaskResponse()
return agent.delete(context=context, resource_meta=request.resource_meta)
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to delete task with error {e}")
107 changes: 107 additions & 0 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import typing
from abc import ABC, abstractmethod

import grpc
from flyteidl.admin.agent_pb2 import (
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
State,
)
from flyteidl.core.tasks_pb2 import TaskTemplate

from flytekit import logger
from flytekit.models.literals import LiteralMap


class AgentBase(ABC):
"""
This is the base class for all agents. It defines the interface that all agents must implement.
The agent service will be run either locally or in a pod, and will be responsible for
invoking agents. The propeller will communicate with the agent service
to create tasks, get the status of tasks, and delete tasks.

All the agents should be registered in the AgentRegistry. Agent Service
will look up the agent based on the task type. Every task type can only have one agent.
"""

def __init__(self, task_type: str):
self._task_type = task_type

@property
def task_type(self) -> str:
"""
task_type is the name of the task type that this agent supports.
"""
return self._task_type

@abstractmethod
def create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
) -> CreateTaskResponse:
"""
Return a Unique ID for the task that was created. It should return error code if the task creation failed.
"""
pass

@abstractmethod
def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
"""
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
and the propeller will write the structured dataset to the blob store.
"""
pass

@abstractmethod
def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
"""
Delete the task. This call should be idempotent.
"""
pass


class AgentRegistry(object):
"""
This is the registry for all agents. The agent service will look up the agent
based on the task type.
"""

_REGISTRY: typing.Dict[str, AgentBase] = {}

@staticmethod
def register(agent: AgentBase):
if agent.task_type in AgentRegistry._REGISTRY:
raise ValueError(f"Duplicate agent for task type {agent.task_type}")
AgentRegistry._REGISTRY[agent.task_type] = agent
logger.info(f"Registering an agent for task type {agent.task_type}")

@staticmethod
def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
logger.error(f"Cannot find agent for task type [{task_type}]")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Cannot find the agent for task type [{task_type}]")
return None
return AgentRegistry._REGISTRY[task_type]


def convert_to_flyte_state(state: str) -> State:
"""
Convert the state from the agent to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
raise ValueError(f"Unrecognized state: {state}")
107 changes: 0 additions & 107 deletions flytekit/extend/backend/base_plugin.py

This file was deleted.

Loading