Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6360.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Apply RBAC validators to basic VFolder actions
12 changes: 12 additions & 0 deletions src/ai/backend/manager/actions/validator/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass

from .batch import BatchActionValidator
from .scope import ScopeActionValidator
from .single_entity import SingleEntityActionValidator


@dataclass
class ValidatorArgs:
batch: list[BatchActionValidator]
scope: list[ScopeActionValidator]
single_entity: list[SingleEntityActionValidator]
8 changes: 7 additions & 1 deletion src/ai/backend/manager/actions/validators/rbac/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ai.backend.manager.data.permission.id import ScopeId
from ai.backend.manager.data.permission.role import ScopePermissionCheckInput
from ai.backend.manager.data.permission.types import EntityType, ScopeType
from ai.backend.manager.errors.rbac import RBACForbidden
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.repositories.permission_controller.repository import (
PermissionControllerRepository,
Expand All @@ -30,7 +31,7 @@ async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) -
if user is None:
raise UserNotFound("User not found in context")

await self._repository.check_permission_in_scope(
is_valid = await self._repository.check_permission_in_scope(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about juwt valid rather than is_valid.

ScopePermissionCheckInput(
user_id=user.user_id,
operation=action.permission_operation_type(),
Expand All @@ -41,3 +42,8 @@ async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) -
),
)
)
if not is_valid:
raise RBACForbidden(
"User does not have permission to perform this action in the specified scope "
f"({scope_type.value}:{scope_id})"
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ai.backend.manager.data.permission.id import ObjectId
from ai.backend.manager.data.permission.role import SingleEntityPermissionCheckInput
from ai.backend.manager.data.permission.types import EntityType
from ai.backend.manager.errors.rbac import RBACForbidden
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.repositories.permission_controller.repository import (
PermissionControllerRepository,
Expand All @@ -29,7 +30,7 @@ async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTrigger
if user is None:
raise UserNotFound("User not found in context")

await self._repository.check_permission_of_entity(
is_valid = await self._repository.check_permission_of_entity(
SingleEntityPermissionCheckInput(
user_id=user.user_id,
operation=action.permission_operation_type(),
Expand All @@ -39,3 +40,8 @@ async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTrigger
),
)
)
if not is_valid:
raise RBACForbidden(
"User does not have permission to perform this action on the specified entity "
f"({entity_type.value}:{entity_id})"
)
22 changes: 22 additions & 0 deletions src/ai/backend/manager/errors/rbac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from aiohttp import web

from ai.backend.common.exception import (
BackendAIError,
ErrorCode,
ErrorDetail,
ErrorDomain,
ErrorOperation,
)


class RBACForbidden(BackendAIError, web.HTTPForbidden):
error_type = "https://api.backend.ai/probs/rbac-forbidden"
error_title = "The operation is forbidden due to insufficient RBAC permissions."

@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.PERMISSION,
operation=ErrorOperation.ACCESS,
error_detail=ErrorDetail.FORBIDDEN,
)
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,16 @@ async def get_user_roles(self, user_id: uuid.UUID) -> list[RoleRow]:
result = await db_session.scalars(stmt)
return result.all()

async def get_entity_mapped_scopes(
self, target_object_id: ObjectId
) -> list[AssociationScopesEntitiesRow]:
async with self._db.begin_readonly_session() as db_session:
stmt = sa.select(AssociationScopesEntitiesRow.scope_id).where(
AssociationScopesEntitiesRow.entity_id == target_object_id.entity_id,
AssociationScopesEntitiesRow.entity_type == target_object_id.entity_type.value,
)
result = await db_session.scalars(stmt)
return result.all()

async def check_scope_permission_exist(
self,
user_id: uuid.UUID,
scope_id: ScopeId,
operation: OperationType,
) -> bool:
role_query = (
sa.select(sa.func.exist())
exist_query = sa.exists(
sa.select(1)
.select_from(
sa.join(UserRoleRow, RoleRow.id == UserRoleRow.role_id)
sa.join(UserRoleRow, RoleRow, RoleRow.id == UserRoleRow.role_id)
.join(PermissionGroupRow, RoleRow.id == PermissionGroupRow.role_id)
.join(PermissionRow, PermissionGroupRow.id == PermissionRow.permission_group_id)
)
Expand All @@ -169,12 +158,8 @@ async def check_scope_permission_exist(
PermissionRow.operation == operation,
)
)
.options(
contains_eager(RoleRow.permission_group_rows).options(
selectinload(PermissionGroupRow.permission_rows)
)
)
)
role_query = sa.select(exist_query)
async with self._db.begin_readonly_session() as db_session:
result = await db_session.scalar(role_query)
return result
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from collections.abc import Mapping
from typing import Optional
from typing import Optional, Self

from ai.backend.common.exception import BackendAIError
from ai.backend.common.metrics.metric import DomainType, LayerType
Expand All @@ -20,6 +20,7 @@
UserRoleAssignmentInput,
)
from ...models.utils import ExtendedAsyncSAEngine
from ..types import RepositoryArgs
from .db_source import PermissionDBSource

permission_controller_repository_resilience = Resilience(
Expand Down Expand Up @@ -47,6 +48,12 @@ class PermissionControllerRepository:
def __init__(self, db: ExtendedAsyncSAEngine) -> None:
self._db_source = PermissionDBSource(db)

@classmethod
def create(cls, args: RepositoryArgs) -> Self:
return cls(
db=args.db,
)

@permission_controller_repository_resilience.apply()
async def create_role(self, data: RoleCreateInput) -> RoleData:
"""
Expand Down Expand Up @@ -79,27 +86,20 @@ async def get_role(self, role_id: uuid.UUID) -> Optional[RoleData]:

@permission_controller_repository_resilience.apply()
async def check_permission_of_entity(self, data: SingleEntityPermissionCheckInput) -> bool:
target_object_id = data.target_object_id
roles = await self._db_source.get_user_roles(data.user_id)
associated_scopes = await self._db_source.get_entity_mapped_scopes(target_object_id)
associated_scopes_set = set([row.parsed_scope_id() for row in associated_scopes])
for role in roles:
for object_perm in role.object_permission_rows:
if object_perm.operation != data.operation:
continue
if object_perm.object_id() == target_object_id:
return True

for permission_group in role.permission_group_rows:
if permission_group.parsed_scope_id() not in associated_scopes_set:
continue
for permission in permission_group.permission_rows:
if permission.operation == data.operation:
return True
return False
"""
Check if the user has the requested operation permission on the given entity.
Returns True if the permission exists, False otherwise.
"""
return await self._db_source.check_object_permission_exist(
data.user_id, data.target_object_id, data.operation
)
Comment on lines +93 to +95
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactored check_permission_of_entity method now directly delegates to _db_source.check_object_permission_exist, removing the inline permission checking logic. However, the original inline implementation iterated through roles, object permissions, and permission groups to check permissions comprehensively. Verify that check_object_permission_exist in db_source.py fully replicates this logic, including checking both direct object permissions and scope-based permission groups, to avoid breaking existing permission checks.

Suggested change
return await self._db_source.check_object_permission_exist(
data.user_id, data.target_object_id, data.operation
)
# Comprehensive permission check: direct object permissions and scope-based permission groups
# 1. Get all roles assigned to the user
roles = await self._db_source.get_roles_of_user(data.user_id)
for role in roles:
# 2. Check direct object permissions for the role
if await self._db_source.check_object_permission_exist(
role.id, data.target_object_id, data.operation
):
return True
# 3. Check permission groups (scope-based)
permission_groups = await self._db_source.get_permission_groups_of_role(role.id)
for group in permission_groups:
if await self._db_source.check_scope_permission_exist(
group.scope_id, data.target_object_id, data.operation
):
return True
return False

Copilot uses AI. Check for mistakes.

@permission_controller_repository_resilience.apply()
async def check_permission_in_scope(self, data: ScopePermissionCheckInput) -> bool:
"""
Check if the user has the requested operation permission in the given scope.
Returns True if the permission exists, False otherwise.
"""
return await self._db_source.check_scope_permission_exist(
data.user_id, data.target_scope_id, data.operation
)
Expand Down
6 changes: 6 additions & 0 deletions src/ai/backend/manager/repositories/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from ai.backend.manager.repositories.metric.repositories import MetricRepositories
from ai.backend.manager.repositories.model_serving.repositories import ModelServingRepositories
from ai.backend.manager.repositories.object_storage.repositories import ObjectStorageRepositories
from ai.backend.manager.repositories.permission_controller.repository import (
PermissionControllerRepository,
)
from ai.backend.manager.repositories.project_resource_policy.repositories import (
ProjectResourcePolicyRepositories,
)
Expand Down Expand Up @@ -72,6 +75,7 @@ class Repositories:
artifact: ArtifactRepositories
artifact_registry: ArtifactRegistryRepositories
storage_namespace: StorageNamespaceRepositories
permission_controller: PermissionControllerRepository

@classmethod
def create(cls, args: RepositoryArgs) -> Self:
Expand Down Expand Up @@ -100,6 +104,7 @@ def create(cls, args: RepositoryArgs) -> Self:
huggingface_registry_repositories = HuggingFaceRegistryRepositories.create(args)
artifact_registries = ArtifactRegistryRepositories.create(args)
storage_namespace_repositories = StorageNamespaceRepositories.create(args)
permission_controller_repository = PermissionControllerRepository.create(args)

return cls(
agent=agent_repositories,
Expand Down Expand Up @@ -127,4 +132,5 @@ def create(cls, args: RepositoryArgs) -> Self:
artifact=artifact_repositories,
artifact_registry=artifact_registries,
storage_namespace=storage_namespace_repositories,
permission_controller=permission_controller_repository,
)
22 changes: 20 additions & 2 deletions src/ai/backend/manager/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,10 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
from .actions.monitors.audit_log import AuditLogMonitor
from .actions.monitors.prometheus import PrometheusMonitor
from .actions.monitors.reporter import ReporterMonitor
from .actions.validator.args import ValidatorArgs
from .actions.validators.rbac.batch import BatchActionRBACValidator
from .actions.validators.rbac.scope import ScopeActionRBACValidator
from .actions.validators.rbac.single_entity import SingleEntityActionRBACValidator
from .reporters.hub import ReporterHub, ReporterHubArgs
from .services.processors import ProcessorArgs, Processors, ServiceArgs

Expand All @@ -666,6 +670,15 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
reporter_monitor = ReporterMonitor(reporter_hub)
prometheus_monitor = PrometheusMonitor()
audit_log_monitor = AuditLogMonitor(root_ctx.db)
batch_action_rbac_validator = BatchActionRBACValidator(
root_ctx.repositories.permission_controller
)
single_entity_rbac_validator = SingleEntityActionRBACValidator(
root_ctx.repositories.permission_controller
)
scope_action_rbac_validator = ScopeActionRBACValidator(
root_ctx.repositories.permission_controller
)
root_ctx.processors = Processors.create(
ProcessorArgs(
service_args=ServiceArgs(
Expand All @@ -688,9 +701,14 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
deployment_controller=root_ctx.deployment_controller,
event_producer=root_ctx.event_producer,
agent_cache=root_ctx.agent_cache,
)
),
action_monitors=[reporter_monitor, prometheus_monitor, audit_log_monitor],
action_validator_args=ValidatorArgs(
batch=[batch_action_rbac_validator],
single_entity=[single_entity_rbac_validator],
scope=[scope_action_rbac_validator],
),
),
[reporter_monitor, prometheus_monitor, audit_log_monitor],
)
yield

Expand Down
59 changes: 33 additions & 26 deletions src/ai/backend/manager/services/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ai.backend.common.plugin.monitor import ErrorPluginContext
from ai.backend.manager.actions.monitors.monitor import ActionMonitor
from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec
from ai.backend.manager.actions.validator.args import ValidatorArgs
from ai.backend.manager.agent_cache import AgentRPCCache
from ai.backend.manager.config.provider import ManagerConfigProvider
from ai.backend.manager.idle import IdleCheckerHost
Expand Down Expand Up @@ -318,6 +319,8 @@ def create(cls, args: ServiceArgs) -> Self:
@dataclass
class ProcessorArgs:
service_args: ServiceArgs
action_monitors: list[ActionMonitor]
action_validator_args: ValidatorArgs


@dataclass
Expand Down Expand Up @@ -349,61 +352,65 @@ class Processors(AbstractProcessorPackage):
storage_namespace: StorageNamespaceProcessors

@classmethod
def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Self:
def create(cls, args: ProcessorArgs) -> Self:
services = Services.create(args.service_args)
agent_processors = AgentProcessors(services.agent, action_monitors)
domain_processors = DomainProcessors(services.domain, action_monitors)
group_processors = GroupProcessors(services.group, action_monitors)
user_processors = UserProcessors(services.user, action_monitors)
image_processors = ImageProcessors(services.image, action_monitors)
agent_processors = AgentProcessors(services.agent, args.action_monitors)
domain_processors = DomainProcessors(services.domain, args.action_monitors)
group_processors = GroupProcessors(services.group, args.action_monitors)
user_processors = UserProcessors(services.user, args.action_monitors)
image_processors = ImageProcessors(services.image, args.action_monitors)
container_registry_processors = ContainerRegistryProcessors(
services.container_registry, action_monitors
services.container_registry, args.action_monitors
)
vfolder_processors = VFolderProcessors(services.vfolder, action_monitors)
vfolder_file_processors = VFolderFileProcessors(services.vfolder_file, action_monitors)
vfolder_processors = VFolderProcessors(
services.vfolder, args.action_monitors, args.action_validator_args
)
vfolder_file_processors = VFolderFileProcessors(services.vfolder_file, args.action_monitors)
vfolder_invite_processors = VFolderInviteProcessors(
services.vfolder_invite, action_monitors
services.vfolder_invite, args.action_monitors
)
session_processors = SessionProcessors(services.session, action_monitors)
session_processors = SessionProcessors(services.session, args.action_monitors)
keypair_resource_policy_processors = KeypairResourcePolicyProcessors(
services.keypair_resource_policy, action_monitors
services.keypair_resource_policy, args.action_monitors
)
user_resource_policy_processors = UserResourcePolicyProcessors(
services.user_resource_policy, action_monitors
services.user_resource_policy, args.action_monitors
)
project_resource_policy_processors = ProjectResourcePolicyProcessors(
services.project_resource_policy, action_monitors
services.project_resource_policy, args.action_monitors
)
resource_preset_processors = ResourcePresetProcessors(
services.resource_preset, action_monitors
services.resource_preset, args.action_monitors
)
model_serving_processors = ModelServingProcessors(
services.model_serving, args.action_monitors
)
model_serving_processors = ModelServingProcessors(services.model_serving, action_monitors)
model_serving_auto_scaling_processors = ModelServingAutoScalingProcessors(
services.model_serving_auto_scaling, action_monitors
services.model_serving_auto_scaling, args.action_monitors
)
utilization_metric_processors = UtilizationMetricProcessors(
services.utilization_metric, action_monitors
services.utilization_metric, args.action_monitors
)
auth = AuthProcessors(services.auth, action_monitors)
auth = AuthProcessors(services.auth, args.action_monitors)
object_storage_processors = ObjectStorageProcessors(
services.object_storage, action_monitors
services.object_storage, args.action_monitors
)
vfs_storage_processors = VFSStorageProcessors(services.vfs_storage, action_monitors)
artifact_processors = ArtifactProcessors(services.artifact, action_monitors)
vfs_storage_processors = VFSStorageProcessors(services.vfs_storage, args.action_monitors)
artifact_processors = ArtifactProcessors(services.artifact, args.action_monitors)
artifact_registry_processors = ArtifactRegistryProcessors(
services.artifact_registry, action_monitors
services.artifact_registry, args.action_monitors
)
artifact_revision_processors = ArtifactRevisionProcessors(
services.artifact_revision, action_monitors
services.artifact_revision, args.action_monitors
)

# Initialize deployment processors if service is available
deployment_processors = None
if services.deployment is not None:
deployment_processors = DeploymentProcessors(services.deployment, action_monitors)
deployment_processors = DeploymentProcessors(services.deployment, args.action_monitors)

storage_namespace_processors = StorageNamespaceProcessors(
services.storage_namespace, action_monitors
services.storage_namespace, args.action_monitors
)

return cls(
Expand Down
Loading
Loading