diff --git a/changes/6360.feature.md b/changes/6360.feature.md new file mode 100644 index 00000000000..8ad6abff03e --- /dev/null +++ b/changes/6360.feature.md @@ -0,0 +1 @@ +Apply RBAC validators to basic VFolder actions diff --git a/src/ai/backend/manager/actions/validator/args.py b/src/ai/backend/manager/actions/validator/args.py new file mode 100644 index 00000000000..83a7bd6c352 --- /dev/null +++ b/src/ai/backend/manager/actions/validator/args.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass + +from .batch import BatchActionValidator +from .scope import ScopeActionValidator +from .single_entity import SingleEntityActionValidator + + +@dataclass +class ValidatorArgs: + batch: list[BatchActionValidator] + scope: list[ScopeActionValidator] + single_entity: list[SingleEntityActionValidator] diff --git a/src/ai/backend/manager/actions/validators/rbac/scope.py b/src/ai/backend/manager/actions/validators/rbac/scope.py index fbc39b8d0ef..7fedd9a7f60 100644 --- a/src/ai/backend/manager/actions/validators/rbac/scope.py +++ b/src/ai/backend/manager/actions/validators/rbac/scope.py @@ -6,6 +6,7 @@ from ai.backend.manager.data.permission.id import ScopeId from ai.backend.manager.data.permission.role import ScopePermissionCheckInput from ai.backend.manager.data.permission.types import EntityType, ScopeType +from ai.backend.manager.errors.rbac import RBACForbidden from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.permission_controller.repository import ( PermissionControllerRepository, @@ -30,7 +31,7 @@ async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) - if user is None: raise UserNotFound("User not found in context") - await self._repository.check_permission_in_scope( + is_valid = await self._repository.check_permission_in_scope( ScopePermissionCheckInput( user_id=user.user_id, operation=action.permission_operation_type(), @@ -41,3 +42,8 @@ async def validate(self, action: BaseScopeAction, meta: BaseActionTriggerMeta) - ), ) ) + if not is_valid: + raise RBACForbidden( + "User does not have permission to perform this action in the specified scope " + f"({scope_type.value}:{scope_id})" + ) diff --git a/src/ai/backend/manager/actions/validators/rbac/single_entity.py b/src/ai/backend/manager/actions/validators/rbac/single_entity.py index 88461882f65..10ab902a1bb 100644 --- a/src/ai/backend/manager/actions/validators/rbac/single_entity.py +++ b/src/ai/backend/manager/actions/validators/rbac/single_entity.py @@ -6,6 +6,7 @@ from ai.backend.manager.data.permission.id import ObjectId from ai.backend.manager.data.permission.role import SingleEntityPermissionCheckInput from ai.backend.manager.data.permission.types import EntityType +from ai.backend.manager.errors.rbac import RBACForbidden from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.permission_controller.repository import ( PermissionControllerRepository, @@ -29,7 +30,7 @@ async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTrigger if user is None: raise UserNotFound("User not found in context") - await self._repository.check_permission_of_entity( + is_valid = await self._repository.check_permission_of_entity( SingleEntityPermissionCheckInput( user_id=user.user_id, operation=action.permission_operation_type(), @@ -39,3 +40,8 @@ async def validate(self, action: BaseSingleEntityAction, meta: BaseActionTrigger ), ) ) + if not is_valid: + raise RBACForbidden( + "User does not have permission to perform this action on the specified entity " + f"({entity_type.value}:{entity_id})" + ) diff --git a/src/ai/backend/manager/errors/rbac.py b/src/ai/backend/manager/errors/rbac.py new file mode 100644 index 00000000000..faafd933867 --- /dev/null +++ b/src/ai/backend/manager/errors/rbac.py @@ -0,0 +1,22 @@ +from aiohttp import web + +from ai.backend.common.exception import ( + BackendAIError, + ErrorCode, + ErrorDetail, + ErrorDomain, + ErrorOperation, +) + + +class RBACForbidden(BackendAIError, web.HTTPForbidden): + error_type = "https://api.backend.ai/probs/rbac-forbidden" + error_title = "The operation is forbidden due to insufficient RBAC permissions." + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.PERMISSION, + operation=ErrorOperation.ACCESS, + error_detail=ErrorDetail.FORBIDDEN, + ) diff --git a/src/ai/backend/manager/repositories/permission_controller/db_source.py b/src/ai/backend/manager/repositories/permission_controller/db_source.py index c4ff94cd596..00be8ac37c3 100644 --- a/src/ai/backend/manager/repositories/permission_controller/db_source.py +++ b/src/ai/backend/manager/repositories/permission_controller/db_source.py @@ -134,27 +134,16 @@ async def get_user_roles(self, user_id: uuid.UUID) -> list[RoleRow]: result = await db_session.scalars(stmt) return result.all() - async def get_entity_mapped_scopes( - self, target_object_id: ObjectId - ) -> list[AssociationScopesEntitiesRow]: - async with self._db.begin_readonly_session() as db_session: - stmt = sa.select(AssociationScopesEntitiesRow.scope_id).where( - AssociationScopesEntitiesRow.entity_id == target_object_id.entity_id, - AssociationScopesEntitiesRow.entity_type == target_object_id.entity_type.value, - ) - result = await db_session.scalars(stmt) - return result.all() - async def check_scope_permission_exist( self, user_id: uuid.UUID, scope_id: ScopeId, operation: OperationType, ) -> bool: - role_query = ( - sa.select(sa.func.exist()) + exist_query = sa.exists( + sa.select(1) .select_from( - sa.join(UserRoleRow, RoleRow.id == UserRoleRow.role_id) + sa.join(UserRoleRow, RoleRow, RoleRow.id == UserRoleRow.role_id) .join(PermissionGroupRow, RoleRow.id == PermissionGroupRow.role_id) .join(PermissionRow, PermissionGroupRow.id == PermissionRow.permission_group_id) ) @@ -169,12 +158,8 @@ async def check_scope_permission_exist( PermissionRow.operation == operation, ) ) - .options( - contains_eager(RoleRow.permission_group_rows).options( - selectinload(PermissionGroupRow.permission_rows) - ) - ) ) + role_query = sa.select(exist_query) async with self._db.begin_readonly_session() as db_session: result = await db_session.scalar(role_query) return result diff --git a/src/ai/backend/manager/repositories/permission_controller/repository.py b/src/ai/backend/manager/repositories/permission_controller/repository.py index 87754ad6cd9..c72aeb7bea4 100644 --- a/src/ai/backend/manager/repositories/permission_controller/repository.py +++ b/src/ai/backend/manager/repositories/permission_controller/repository.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Mapping -from typing import Optional +from typing import Optional, Self from ai.backend.common.exception import BackendAIError from ai.backend.common.metrics.metric import DomainType, LayerType @@ -20,6 +20,7 @@ UserRoleAssignmentInput, ) from ...models.utils import ExtendedAsyncSAEngine +from ..types import RepositoryArgs from .db_source import PermissionDBSource permission_controller_repository_resilience = Resilience( @@ -47,6 +48,12 @@ class PermissionControllerRepository: def __init__(self, db: ExtendedAsyncSAEngine) -> None: self._db_source = PermissionDBSource(db) + @classmethod + def create(cls, args: RepositoryArgs) -> Self: + return cls( + db=args.db, + ) + @permission_controller_repository_resilience.apply() async def create_role(self, data: RoleCreateInput) -> RoleData: """ @@ -79,27 +86,20 @@ async def get_role(self, role_id: uuid.UUID) -> Optional[RoleData]: @permission_controller_repository_resilience.apply() async def check_permission_of_entity(self, data: SingleEntityPermissionCheckInput) -> bool: - target_object_id = data.target_object_id - roles = await self._db_source.get_user_roles(data.user_id) - associated_scopes = await self._db_source.get_entity_mapped_scopes(target_object_id) - associated_scopes_set = set([row.parsed_scope_id() for row in associated_scopes]) - for role in roles: - for object_perm in role.object_permission_rows: - if object_perm.operation != data.operation: - continue - if object_perm.object_id() == target_object_id: - return True - - for permission_group in role.permission_group_rows: - if permission_group.parsed_scope_id() not in associated_scopes_set: - continue - for permission in permission_group.permission_rows: - if permission.operation == data.operation: - return True - return False + """ + Check if the user has the requested operation permission on the given entity. + Returns True if the permission exists, False otherwise. + """ + return await self._db_source.check_object_permission_exist( + data.user_id, data.target_object_id, data.operation + ) @permission_controller_repository_resilience.apply() async def check_permission_in_scope(self, data: ScopePermissionCheckInput) -> bool: + """ + Check if the user has the requested operation permission in the given scope. + Returns True if the permission exists, False otherwise. + """ return await self._db_source.check_scope_permission_exist( data.user_id, data.target_scope_id, data.operation ) diff --git a/src/ai/backend/manager/repositories/repositories.py b/src/ai/backend/manager/repositories/repositories.py index 6af4bc6313e..0a417ba54f3 100644 --- a/src/ai/backend/manager/repositories/repositories.py +++ b/src/ai/backend/manager/repositories/repositories.py @@ -23,6 +23,9 @@ from ai.backend.manager.repositories.metric.repositories import MetricRepositories from ai.backend.manager.repositories.model_serving.repositories import ModelServingRepositories from ai.backend.manager.repositories.object_storage.repositories import ObjectStorageRepositories +from ai.backend.manager.repositories.permission_controller.repository import ( + PermissionControllerRepository, +) from ai.backend.manager.repositories.project_resource_policy.repositories import ( ProjectResourcePolicyRepositories, ) @@ -72,6 +75,7 @@ class Repositories: artifact: ArtifactRepositories artifact_registry: ArtifactRegistryRepositories storage_namespace: StorageNamespaceRepositories + permission_controller: PermissionControllerRepository @classmethod def create(cls, args: RepositoryArgs) -> Self: @@ -100,6 +104,7 @@ def create(cls, args: RepositoryArgs) -> Self: huggingface_registry_repositories = HuggingFaceRegistryRepositories.create(args) artifact_registries = ArtifactRegistryRepositories.create(args) storage_namespace_repositories = StorageNamespaceRepositories.create(args) + permission_controller_repository = PermissionControllerRepository.create(args) return cls( agent=agent_repositories, @@ -127,4 +132,5 @@ def create(cls, args: RepositoryArgs) -> Self: artifact=artifact_repositories, artifact_registry=artifact_registries, storage_namespace=storage_namespace_repositories, + permission_controller=permission_controller_repository, ) diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index 654bbd9d88a..da4438ab99a 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -653,6 +653,10 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]: from .actions.monitors.audit_log import AuditLogMonitor from .actions.monitors.prometheus import PrometheusMonitor from .actions.monitors.reporter import ReporterMonitor + from .actions.validator.args import ValidatorArgs + from .actions.validators.rbac.batch import BatchActionRBACValidator + from .actions.validators.rbac.scope import ScopeActionRBACValidator + from .actions.validators.rbac.single_entity import SingleEntityActionRBACValidator from .reporters.hub import ReporterHub, ReporterHubArgs from .services.processors import ProcessorArgs, Processors, ServiceArgs @@ -666,6 +670,15 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]: reporter_monitor = ReporterMonitor(reporter_hub) prometheus_monitor = PrometheusMonitor() audit_log_monitor = AuditLogMonitor(root_ctx.db) + batch_action_rbac_validator = BatchActionRBACValidator( + root_ctx.repositories.permission_controller + ) + single_entity_rbac_validator = SingleEntityActionRBACValidator( + root_ctx.repositories.permission_controller + ) + scope_action_rbac_validator = ScopeActionRBACValidator( + root_ctx.repositories.permission_controller + ) root_ctx.processors = Processors.create( ProcessorArgs( service_args=ServiceArgs( @@ -688,9 +701,14 @@ async def processors_ctx(root_ctx: RootContext) -> AsyncIterator[None]: deployment_controller=root_ctx.deployment_controller, event_producer=root_ctx.event_producer, agent_cache=root_ctx.agent_cache, - ) + ), + action_monitors=[reporter_monitor, prometheus_monitor, audit_log_monitor], + action_validator_args=ValidatorArgs( + batch=[batch_action_rbac_validator], + single_entity=[single_entity_rbac_validator], + scope=[scope_action_rbac_validator], + ), ), - [reporter_monitor, prometheus_monitor, audit_log_monitor], ) yield diff --git a/src/ai/backend/manager/services/processors.py b/src/ai/backend/manager/services/processors.py index 0174e12268f..9b759c39aaa 100644 --- a/src/ai/backend/manager/services/processors.py +++ b/src/ai/backend/manager/services/processors.py @@ -12,6 +12,7 @@ from ai.backend.common.plugin.monitor import ErrorPluginContext from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.actions.validator.args import ValidatorArgs from ai.backend.manager.agent_cache import AgentRPCCache from ai.backend.manager.config.provider import ManagerConfigProvider from ai.backend.manager.idle import IdleCheckerHost @@ -318,6 +319,8 @@ def create(cls, args: ServiceArgs) -> Self: @dataclass class ProcessorArgs: service_args: ServiceArgs + action_monitors: list[ActionMonitor] + action_validator_args: ValidatorArgs @dataclass @@ -349,61 +352,65 @@ class Processors(AbstractProcessorPackage): storage_namespace: StorageNamespaceProcessors @classmethod - def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Self: + def create(cls, args: ProcessorArgs) -> Self: services = Services.create(args.service_args) - agent_processors = AgentProcessors(services.agent, action_monitors) - domain_processors = DomainProcessors(services.domain, action_monitors) - group_processors = GroupProcessors(services.group, action_monitors) - user_processors = UserProcessors(services.user, action_monitors) - image_processors = ImageProcessors(services.image, action_monitors) + agent_processors = AgentProcessors(services.agent, args.action_monitors) + domain_processors = DomainProcessors(services.domain, args.action_monitors) + group_processors = GroupProcessors(services.group, args.action_monitors) + user_processors = UserProcessors(services.user, args.action_monitors) + image_processors = ImageProcessors(services.image, args.action_monitors) container_registry_processors = ContainerRegistryProcessors( - services.container_registry, action_monitors + services.container_registry, args.action_monitors ) - vfolder_processors = VFolderProcessors(services.vfolder, action_monitors) - vfolder_file_processors = VFolderFileProcessors(services.vfolder_file, action_monitors) + vfolder_processors = VFolderProcessors( + services.vfolder, args.action_monitors, args.action_validator_args + ) + vfolder_file_processors = VFolderFileProcessors(services.vfolder_file, args.action_monitors) vfolder_invite_processors = VFolderInviteProcessors( - services.vfolder_invite, action_monitors + services.vfolder_invite, args.action_monitors ) - session_processors = SessionProcessors(services.session, action_monitors) + session_processors = SessionProcessors(services.session, args.action_monitors) keypair_resource_policy_processors = KeypairResourcePolicyProcessors( - services.keypair_resource_policy, action_monitors + services.keypair_resource_policy, args.action_monitors ) user_resource_policy_processors = UserResourcePolicyProcessors( - services.user_resource_policy, action_monitors + services.user_resource_policy, args.action_monitors ) project_resource_policy_processors = ProjectResourcePolicyProcessors( - services.project_resource_policy, action_monitors + services.project_resource_policy, args.action_monitors ) resource_preset_processors = ResourcePresetProcessors( - services.resource_preset, action_monitors + services.resource_preset, args.action_monitors + ) + model_serving_processors = ModelServingProcessors( + services.model_serving, args.action_monitors ) - model_serving_processors = ModelServingProcessors(services.model_serving, action_monitors) model_serving_auto_scaling_processors = ModelServingAutoScalingProcessors( - services.model_serving_auto_scaling, action_monitors + services.model_serving_auto_scaling, args.action_monitors ) utilization_metric_processors = UtilizationMetricProcessors( - services.utilization_metric, action_monitors + services.utilization_metric, args.action_monitors ) - auth = AuthProcessors(services.auth, action_monitors) + auth = AuthProcessors(services.auth, args.action_monitors) object_storage_processors = ObjectStorageProcessors( - services.object_storage, action_monitors + services.object_storage, args.action_monitors ) - vfs_storage_processors = VFSStorageProcessors(services.vfs_storage, action_monitors) - artifact_processors = ArtifactProcessors(services.artifact, action_monitors) + vfs_storage_processors = VFSStorageProcessors(services.vfs_storage, args.action_monitors) + artifact_processors = ArtifactProcessors(services.artifact, args.action_monitors) artifact_registry_processors = ArtifactRegistryProcessors( - services.artifact_registry, action_monitors + services.artifact_registry, args.action_monitors ) artifact_revision_processors = ArtifactRevisionProcessors( - services.artifact_revision, action_monitors + services.artifact_revision, args.action_monitors ) # Initialize deployment processors if service is available deployment_processors = None if services.deployment is not None: - deployment_processors = DeploymentProcessors(services.deployment, action_monitors) + deployment_processors = DeploymentProcessors(services.deployment, args.action_monitors) storage_namespace_processors = StorageNamespaceProcessors( - services.storage_namespace, action_monitors + services.storage_namespace, args.action_monitors ) return cls( diff --git a/src/ai/backend/manager/services/vfolder/processors/vfolder.py b/src/ai/backend/manager/services/vfolder/processors/vfolder.py index e4d3014b901..b4252b57a35 100644 --- a/src/ai/backend/manager/services/vfolder/processors/vfolder.py +++ b/src/ai/backend/manager/services/vfolder/processors/vfolder.py @@ -4,6 +4,7 @@ from ai.backend.manager.actions.processor.scope import ScopeActionProcessor from ai.backend.manager.actions.processor.single_entity import SingleEntityActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.actions.validator.args import ValidatorArgs from ..actions.base import ( CloneVFolderAction, @@ -52,27 +53,42 @@ class VFolderProcessors(AbstractProcessorPackage): clone_vfolder: SingleEntityActionProcessor[CloneVFolderAction, CloneVFolderActionResult] get_task_logs: SingleEntityActionProcessor[GetTaskLogsAction, GetTaskLogsActionResult] - def __init__(self, service: VFolderService, action_monitors: list[ActionMonitor]): - self.create_vfolder = ScopeActionProcessor(service.create, action_monitors) - self.get_vfolder = SingleEntityActionProcessor(service.get, action_monitors) - self.list_vfolder = ScopeActionProcessor(service.list, action_monitors) + def __init__( + self, + service: VFolderService, + action_monitors: list[ActionMonitor], + action_validators: ValidatorArgs, + ) -> None: + self.create_vfolder = ScopeActionProcessor( + service.create, action_monitors, action_validators.scope + ) + self.get_vfolder = SingleEntityActionProcessor( + service.get, action_monitors, action_validators.single_entity + ) + self.list_vfolder = ScopeActionProcessor( + service.list, action_monitors, action_validators.scope + ) self.update_vfolder_attribute = SingleEntityActionProcessor( - service.update_attribute, action_monitors + service.update_attribute, action_monitors, action_validators.single_entity ) self.move_to_trash_vfolder = SingleEntityActionProcessor( - service.move_to_trash, action_monitors + service.move_to_trash, action_monitors, action_validators.single_entity ) self.restore_vfolder_from_trash = SingleEntityActionProcessor( - service.restore, action_monitors + service.restore, action_monitors, action_validators.single_entity ) self.delete_forever_vfolder = SingleEntityActionProcessor( - service.delete_forever, action_monitors + service.delete_forever, action_monitors, action_validators.single_entity ) self.force_delete_vfolder = SingleEntityActionProcessor( - service.force_delete, action_monitors + service.force_delete, action_monitors, action_validators.single_entity + ) + self.clone_vfolder = SingleEntityActionProcessor( + service.clone, action_monitors, action_validators.single_entity + ) + self.get_task_logs = SingleEntityActionProcessor( + service.get_task_logs, action_monitors, action_validators.single_entity ) - self.clone_vfolder = SingleEntityActionProcessor(service.clone, action_monitors) - self.get_task_logs = SingleEntityActionProcessor(service.get_task_logs, action_monitors) @override def supported_actions(self) -> list[ActionSpec]: