Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6907.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce source-based structure in user resource policy
10 changes: 5 additions & 5 deletions docs/manager/graphql-reference/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ type Query {
"""Added in 24.03.1"""
id: String
reference: String
architecture: String = "aarch64"
architecture: String = "x86_64"
): Image
images(
"""
Expand Down Expand Up @@ -2341,7 +2341,7 @@ type Mutation {
): RescanImages
preload_image(references: [String]!, target_agents: [String]!): PreloadImage
unload_image(references: [String]!, target_agents: [String]!): UnloadImage
modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage
modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage

"""Added in 25.6.0"""
clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload
Expand All @@ -2350,7 +2350,7 @@ type Mutation {
forget_image_by_id(image_id: String!): ForgetImageById

"""Deprecated since 25.4.0. Use `forget_image_by_id` instead."""
forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")
forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")

"""Added in 25.4.0"""
purge_image_by_id(
Expand All @@ -2362,7 +2362,7 @@ type Mutation {

"""Added in 24.03.1"""
untag_image_from_registry(image_id: String!): UntagImageFromRegistry
alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage
alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage
dealias_image(alias: String!): DealiasImage
clear_images(registry: String): ClearImages

Expand Down Expand Up @@ -2937,7 +2937,7 @@ type ClearImageCustomResourceLimitPayload {
"""Added in 25.6.0."""
input ClearImageCustomResourceLimitKey {
image_canonical: String!
architecture: String! = "aarch64"
architecture: String! = "x86_64"
}

"""Added in 24.03.0."""
Expand Down
10 changes: 5 additions & 5 deletions docs/manager/graphql-reference/supergraph.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ input ClearImageCustomResourceLimitKey
@join__type(graph: GRAPHENE)
{
image_canonical: String!
architecture: String! = "aarch64"
architecture: String! = "x86_64"
}

"""Added in 25.6.0."""
Expand Down Expand Up @@ -4461,7 +4461,7 @@ type Mutation
): RescanImages @join__field(graph: GRAPHENE)
preload_image(references: [String]!, target_agents: [String]!): PreloadImage @join__field(graph: GRAPHENE)
unload_image(references: [String]!, target_agents: [String]!): UnloadImage @join__field(graph: GRAPHENE)
modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE)
modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE)

"""Added in 25.6.0"""
clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload @join__field(graph: GRAPHENE)
Expand All @@ -4470,7 +4470,7 @@ type Mutation
forget_image_by_id(image_id: String!): ForgetImageById @join__field(graph: GRAPHENE)

"""Deprecated since 25.4.0. Use `forget_image_by_id` instead."""
forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")
forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.")

"""Added in 25.4.0"""
purge_image_by_id(
Expand All @@ -4482,7 +4482,7 @@ type Mutation

"""Added in 24.03.1"""
untag_image_from_registry(image_id: String!): UntagImageFromRegistry @join__field(graph: GRAPHENE)
alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage @join__field(graph: GRAPHENE)
alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage @join__field(graph: GRAPHENE)
dealias_image(alias: String!): DealiasImage @join__field(graph: GRAPHENE)
clear_images(registry: String): ClearImages @join__field(graph: GRAPHENE)

Expand Down Expand Up @@ -5512,7 +5512,7 @@ type Query
"""Added in 24.03.1"""
id: String
reference: String
architecture: String = "aarch64"
architecture: String = "x86_64"
): Image @join__field(graph: GRAPHENE)
images(
"""
Expand Down
13 changes: 13 additions & 0 deletions src/ai/backend/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class ErrorDomain(enum.StrEnum):
MESSAGE_QUEUE = "message-queue"
NOTIFICATION = "notification"
HEALTH_CHECK = "health-check"
USER_RESOURCE_POLICY = "user-resource-policy"


class ErrorOperation(enum.StrEnum):
Expand Down Expand Up @@ -835,3 +836,15 @@ def error_code(self) -> ErrorCode:
operation=ErrorOperation.READ,
error_detail=ErrorDetail.UNAVAILABLE,
)


class UserResourcePolicyNotFound(BackendAIError, web.HTTPNotFound):
error_type = "https://api.backend.ai/probs/user-resource-policy-not-found"
error_title = "User Resource Policy Not Found"

def error_code(self) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.USER_RESOURCE_POLICY,
operation=ErrorOperation.READ,
error_detail=ErrorDetail.NOT_FOUND,
)
1 change: 1 addition & 0 deletions src/ai/backend/common/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class LayerType(enum.StrEnum):
RESOURCE_PRESET_DB_SOURCE = "resource_preset_db_source"
SCHEDULE_DB_SOURCE = "schedule_db_source"
SCHEDULER_DB_SOURCE = "scheduler_db_source"
USER_RESOURCE_POLICY_DB_SOURCE = "user_resource_policy_db_source"

# Cache Source layers
AGENT_CACHE_SOURCE = "agent_cache_source"
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/manager/models/resource_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ def __init__(
self.max_session_count_per_model_session = max_session_count_per_model_session
self.max_customized_image_count = max_customized_image_count

@classmethod
def from_creator(cls, creator: UserResourcePolicyCreator) -> Self:
return cls(
name=creator.name,
max_vfolder_count=creator.max_vfolder_count,
max_quota_scope_size=creator.max_quota_scope_size,
max_session_count_per_model_session=creator.max_session_count_per_model_session,
max_customized_image_count=creator.max_customized_image_count,
)

@classmethod
def from_dataclass(cls, data: UserResourcePolicyData) -> Self:
return cls(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import sqlalchemy as sa

from ai.backend.common.exception import BackendAIError, UserResourcePolicyNotFound
from ai.backend.common.metrics.metric import DomainType, LayerType
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
from ai.backend.common.resilience.resilience import Resilience
from ai.backend.manager.data.resource.types import UserResourcePolicyData
from ai.backend.manager.models.resource_policy import UserResourcePolicyRow
from ai.backend.manager.services.user_resource_policy.actions.modify_user_resource_policy import (
UserResourcePolicyModifier,
)
from ai.backend.manager.services.user_resource_policy.types import UserResourcePolicyCreator

if TYPE_CHECKING:
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine

user_resource_policy_db_source_resilience = Resilience(
policies=[
MetricPolicy(
MetricArgs(domain=DomainType.DB_SOURCE, layer=LayerType.USER_RESOURCE_POLICY_DB_SOURCE)
),
RetryPolicy(
RetryArgs(
max_retries=5,
retry_delay=0.1,
backoff_strategy=BackoffStrategy.FIXED,
non_retryable_exceptions=(BackendAIError,),
)
),
]
)


class UserResourcePolicyDBSource:
"""
Database source for user resource policy operations.
Handles all database operations for user resource policies.
"""
Comment on lines +39 to +43
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is serializable applied here? It seems like applying Read Committed this time would be worthwhile, don't you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like applying it in 1.4 would be limited... Let's see after it goes over 2.0 for now.


_db: ExtendedAsyncSAEngine

def __init__(self, db: ExtendedAsyncSAEngine) -> None:
self._db = db

@user_resource_policy_db_source_resilience.apply()
async def create(self, creator: UserResourcePolicyCreator) -> UserResourcePolicyData:
"""Creates a new user resource policy."""
async with self._db.begin_session() as db_sess:
db_row = UserResourcePolicyRow.from_creator(creator)
db_sess.add(db_row)
await db_sess.flush()
return db_row.to_dataclass()

@user_resource_policy_db_source_resilience.apply()
async def get_by_name(self, name: str) -> UserResourcePolicyData:
"""Retrieves a user resource policy by name."""
async with self._db.begin_readonly_session() as db_sess:
query = sa.select(UserResourcePolicyRow).where(UserResourcePolicyRow.name == name)
row = await db_sess.scalar(query)
if row is None:
raise UserResourcePolicyNotFound(
f"User resource policy with name {name} not found."
)
return row.to_dataclass()

Comment on lines +59 to +70
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this could also be integrated into the querier.

@user_resource_policy_db_source_resilience.apply()
async def update(
self, name: str, modifier: UserResourcePolicyModifier
) -> UserResourcePolicyData:
"""Updates an existing user resource policy."""
async with self._db.begin_session() as db_sess:
# Check if the policy exists first
check_query = sa.select(UserResourcePolicyRow).where(UserResourcePolicyRow.name == name)
existing_row: Optional[UserResourcePolicyRow] = await db_sess.scalar(check_query)
if existing_row is None:
raise UserResourcePolicyNotFound(
f"User resource policy with name {name} not found."
)

fields = modifier.fields_to_update()
update_stmt = (
sa.update(UserResourcePolicyRow)
.where(UserResourcePolicyRow.name == name)
.values(**fields)
.returning(UserResourcePolicyRow)
)
query_stmt = (
sa.select(UserResourcePolicyRow)
.from_statement(update_stmt)
.execution_options(populate_existing=True)
)
updated_row: Optional[UserResourcePolicyRow] = await db_sess.scalar(query_stmt)
if updated_row is None:
raise UserResourcePolicyNotFound(
f"User resource policy with name {name} not found after update."
)
return updated_row.to_dataclass()

@user_resource_policy_db_source_resilience.apply()
async def delete(self, name: str) -> UserResourcePolicyData:
"""Deletes a user resource policy."""
async with self._db.begin_session() as db_sess:
delete_stmt = (
sa.delete(UserResourcePolicyRow)
.where(UserResourcePolicyRow.name == name)
.returning(UserResourcePolicyRow)
)
query_stms = (
sa.select(UserResourcePolicyRow)
.from_statement(delete_stmt)
.execution_options(populate_existing=True)
)
row: Optional[UserResourcePolicyRow] = await db_sess.scalar(query_stms)
if row is None:
raise UserResourcePolicyNotFound(
f"User resource policy with name {name} not found."
)
await db_sess.delete(row)
return row.to_dataclass()
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from typing import Any, Mapping
from __future__ import annotations

import sqlalchemy as sa
from typing import TYPE_CHECKING

from ai.backend.common.exception import BackendAIError
from ai.backend.common.metrics.metric import DomainType, LayerType
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
from ai.backend.common.resilience.resilience import Resilience
from ai.backend.manager.data.resource.types import UserResourcePolicyData
from ai.backend.manager.errors.common import ObjectNotFound
from ai.backend.manager.models.resource_policy import UserResourcePolicyRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.repositories.user_resource_policy.db_source.db_source import (
UserResourcePolicyDBSource,
)
from ai.backend.manager.services.user_resource_policy.actions.modify_user_resource_policy import (
UserResourcePolicyModifier,
)
from ai.backend.manager.services.user_resource_policy.types import UserResourcePolicyCreator

if TYPE_CHECKING:
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine

user_resource_policy_repository_resilience = Resilience(
policies=[
Expand All @@ -21,7 +28,7 @@
),
RetryPolicy(
RetryArgs(
max_retries=10,
max_retries=5,
retry_delay=0.1,
backoff_strategy=BackoffStrategy.FIXED,
non_retryable_exceptions=(BackendAIError,),
Expand All @@ -32,49 +39,31 @@


class UserResourcePolicyRepository:
_db: ExtendedAsyncSAEngine
"""Repository for user resource policy data access."""

_db_source: UserResourcePolicyDBSource

def __init__(self, db: ExtendedAsyncSAEngine) -> None:
self._db = db
self._db_source = UserResourcePolicyDBSource(db)

@user_resource_policy_repository_resilience.apply()
async def create(self, fields: Mapping[str, Any]) -> UserResourcePolicyData:
async with self._db.begin_session() as db_sess:
db_row = UserResourcePolicyRow(**fields)
db_sess.add(db_row)
await db_sess.flush()
return db_row.to_dataclass()
async def create(self, creator: UserResourcePolicyCreator) -> UserResourcePolicyData:
"""Creates a new user resource policy."""
return await self._db_source.create(creator)

@user_resource_policy_repository_resilience.apply()
async def get_by_name(self, name: str) -> UserResourcePolicyData:
async with self._db.begin_readonly_session() as db_sess:
query = sa.select(UserResourcePolicyRow).where(UserResourcePolicyRow.name == name)
result = await db_sess.execute(query)
row = result.scalar_one_or_none()
if row is None:
raise ObjectNotFound(f"User resource policy with name {name} not found.")
return row.to_dataclass()
"""Retrieves a user resource policy by name."""
return await self._db_source.get_by_name(name)

@user_resource_policy_repository_resilience.apply()
async def update(self, name: str, fields: Mapping[str, Any]) -> UserResourcePolicyData:
async with self._db.begin_session() as db_sess:
query = sa.select(UserResourcePolicyRow).where(UserResourcePolicyRow.name == name)
result = await db_sess.execute(query)
row = result.scalar_one_or_none()
if row is None:
raise ObjectNotFound(f"User resource policy with name {name} not found.")
for key, value in fields.items():
setattr(row, key, value)
await db_sess.flush()
return row.to_dataclass()
async def update(
self, name: str, modifier: UserResourcePolicyModifier
) -> UserResourcePolicyData:
"""Updates an existing user resource policy."""
return await self._db_source.update(name, modifier)

@user_resource_policy_repository_resilience.apply()
async def delete(self, name: str) -> UserResourcePolicyData:
async with self._db.begin_session() as db_sess:
query = sa.select(UserResourcePolicyRow).where(UserResourcePolicyRow.name == name)
result = await db_sess.execute(query)
row = result.scalar_one_or_none()
if row is None:
raise ObjectNotFound(f"User resource policy with name {name} not found.")
await db_sess.delete(row)
return row.to_dataclass()
"""Deletes a user resource policy."""
return await self._db_source.delete(name)
12 changes: 3 additions & 9 deletions src/ai/backend/manager/services/user_resource_policy/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,17 @@ def __init__(
async def create_user_resource_policy(
self, action: CreateUserResourcePolicyAction
) -> CreateUserResourcePolicyActionResult:
creator = action.creator
to_create = creator.fields_to_store()
result = await self._user_resource_policy_repository.create(to_create)
result = await self._user_resource_policy_repository.create(action.creator)
return CreateUserResourcePolicyActionResult(user_resource_policy=result)

async def modify_user_resource_policy(
self, action: ModifyUserResourcePolicyAction
) -> ModifyUserResourcePolicyActionResult:
name = action.name
modifier = action.modifier
to_update = modifier.fields_to_update()
result = await self._user_resource_policy_repository.update(name, to_update)
result = await self._user_resource_policy_repository.update(action.name, action.modifier)
return ModifyUserResourcePolicyActionResult(user_resource_policy=result)

async def delete_user_resource_policy(
self, action: DeleteUserResourcePolicyAction
) -> DeleteUserResourcePolicyActionResult:
name = action.name
result = await self._user_resource_policy_repository.delete(name)
result = await self._user_resource_policy_repository.delete(action.name)
return DeleteUserResourcePolicyActionResult(user_resource_policy=result)
Loading
Loading