Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/6258.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement `artifact_verifier` type plugin in storage-proxy
70 changes: 70 additions & 0 deletions src/ai/backend/common/artifact_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from ai.backend.common.data.storage.registries.types import ModelTarget
from ai.backend.common.data.storage.types import ArtifactStorageImportStep
from ai.backend.common.types import StreamReader


class AbstractStorage(ABC):
@abstractmethod
async def stream_upload(
self,
filepath: str,
data_stream: StreamReader,
) -> None:
raise NotImplementedError

@abstractmethod
async def stream_download(self, filepath: str) -> StreamReader:
raise NotImplementedError

@abstractmethod
async def delete_file(self, filepath: str) -> None:
raise NotImplementedError

@abstractmethod
# TODO: Remove Any and define a proper return type
async def get_file_info(self, filepath: str) -> Any:
raise NotImplementedError


class AbstractStoragePool(ABC):
"""Abstract base class for storage pool interface"""

@abstractmethod
def get_storage(self, name: str) -> AbstractStorage:
"""Get storage by name"""
raise NotImplementedError

@abstractmethod
def add_storage(self, name: str, storage: AbstractStorage) -> None:
"""Add a storage to the pool"""
raise NotImplementedError

@abstractmethod
def remove_storage(self, name: str) -> None:
"""Remove a storage from the pool"""
raise NotImplementedError

@abstractmethod
def list_storages(self) -> list[str]:
"""List all storage names in the pool"""
raise NotImplementedError

@abstractmethod
def has_storage(self, name: str) -> bool:
"""Check if storage exists in the pool"""
raise NotImplementedError


@dataclass
class ImportStepContext:
"""Context shared across import steps"""

model: ModelTarget
registry_name: str
storage_pool: AbstractStoragePool
storage_step_mappings: dict[ArtifactStorageImportStep, str]
step_metadata: dict[str, Any] # For passing data between steps
37 changes: 37 additions & 0 deletions src/ai/backend/common/events/event_types/artifact/anycast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,43 @@ class ModelMetadataInfo:
size: int


@dataclass
class ModelVerifyingEvent(BaseArtifactEvent):
"""
Mark the model revision's status to verifying.
"""

model_id: str
revision: str
registry_type: ArtifactRegistryType
registry_name: str

@classmethod
@override
def event_name(cls) -> str:
return "model_verifying"

def serialize(self) -> tuple:
return (self.model_id, self.revision, self.registry_type, self.registry_name)

@classmethod
def deserialize(cls, value: tuple):
return cls(
model_id=value[0],
revision=value[1],
registry_type=value[2],
registry_name=value[3],
)

@override
def domain_id(self) -> Optional[str]:
return None

@override
def user_event(self) -> Optional[UserEvent]:
return None


@dataclass
class ModelImportDoneEvent(BaseArtifactEvent):
model_id: str
Expand Down
6 changes: 6 additions & 0 deletions src/ai/backend/manager/event_dispatcher/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ai.backend.common.events.event_types.artifact.anycast import (
ModelImportDoneEvent,
ModelMetadataFetchDoneEvent,
ModelVerifyingEvent,
)
from ai.backend.common.events.event_types.artifact_registry.anycast import (
DoPullReservoirRegistryEvent,
Expand Down Expand Up @@ -585,6 +586,11 @@ def _dispatch_artifact_events(self, event_dispatcher: EventDispatcher) -> None:
None,
self._artifact_event_handler.handle_model_metadata_fetch_done,
)
evd.consume(
ModelVerifyingEvent,
None,
self._artifact_event_handler.handle_model_verifying,
)
evd.consume(
ModelImportDoneEvent,
None,
Expand Down
42 changes: 42 additions & 0 deletions src/ai/backend/manager/event_dispatcher/handlers/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ai.backend.common.events.event_types.artifact.anycast import (
ModelImportDoneEvent,
ModelMetadataFetchDoneEvent,
ModelVerifyingEvent,
)
from ai.backend.common.types import (
AgentId,
Expand Down Expand Up @@ -41,6 +42,47 @@ def __init__(
self._reservoir_repository = reservoir_repository
self._config_provider = config_provider

async def handle_model_verifying(
self,
context: None,
source: AgentId,
event: ModelVerifyingEvent,
) -> None:
try:
registry_type = ArtifactRegistryType(event.registry_type)
except Exception:
raise InvalidArtifactRegistryTypeError(
f"Unsupported artifact registry type: {event.registry_type}"
)
registry_id: UUID
match registry_type:
case ArtifactRegistryType.HUGGINGFACE:
huggingface_registry_data = (
await self._huggingface_repository.get_registry_data_by_name(
event.registry_name
)
)
registry_id = huggingface_registry_data.id
case ArtifactRegistryType.RESERVOIR:
registry_data = await self._reservoir_repository.get_registry_data_by_name(
event.registry_name
)
registry_id = registry_data.id

artifact = await self._artifact_repository.get_model_artifact(
event.model_id, registry_id=registry_id
)

# Get the specific revision
revision = await self._artifact_repository.get_artifact_revision(
artifact.id, revision=event.revision
)

if revision.status == ArtifactStatus.PULLING:
await self._artifact_repository.update_artifact_revision_status(
revision.id, ArtifactStatus.VERIFYING
)

async def handle_model_import_done(
self,
context: None,
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/storage/api/v1/registries/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ async def import_models(
registry_configs=self._huggingface_service._registry_configs,
transfer_manager=self._huggingface_service._transfer_manager,
storage_step_mappings=body.parsed.storage_step_mappings,
artifact_verifier_ctx=self._huggingface_service._artifact_verifier_ctx,
event_producer=self._huggingface_service._event_producer,
)

task_id = await self._huggingface_service.import_models_batch(
Expand Down Expand Up @@ -217,6 +219,7 @@ def create_app(ctx: RootContext) -> web.Application:
storage_pool=ctx.storage_pool,
registry_configs=huggingface_registry_configs,
event_producer=ctx.event_producer,
artifact_verifier_ctx=ctx.artifact_verifier_ctx,
)
)
huggingface_api_handler = HuggingFaceRegistryAPIHandler(huggingface_service=huggingface_service)
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/storage/api/v1/registries/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ async def import_models(
registry_configs=self._reservoir_service._reservoir_registry_configs,
storage_step_mappings=body.parsed.storage_step_mappings,
transfer_manager=self._reservoir_service._transfer_manager,
artifact_verifier_ctx=self._reservoir_service._artifact_verifier_ctx,
event_producer=self._reservoir_service._event_producer,
)

task_id = await self._reservoir_service.import_models_batch(
Expand Down Expand Up @@ -100,6 +102,7 @@ def create_app(ctx: RootContext) -> web.Application:
event_producer=ctx.event_producer,
storage_pool=ctx.storage_pool,
reservoir_registry_configs=reservoir_registry_configs,
artifact_verifier_ctx=ctx.artifact_verifier_ctx,
)
)
reservoir_api_handler = ReservoirRegistryAPIHandler(
Expand Down
14 changes: 14 additions & 0 deletions src/ai/backend/storage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from ai.backend.logging import BraceStyleAdapter

from .config.unified import StorageProxyUnifiedConfig
from .context_types import ArtifactVerifierContext
from .exception import InvalidVolumeError
from .plugin import (
StorageArtifactVerifierPluginContext,
)
from .services.service import VolumeService
from .storages.storage_pool import StoragePool
from .types import VolumeInfo
Expand Down Expand Up @@ -96,6 +100,16 @@ class RootContext:
# volume backend states
backends: MutableMapping[str, type[AbstractVolume]]
volumes: MutableMapping[str, AbstractVolume]
artifact_verifier_ctx: ArtifactVerifierContext

async def init_storage_artifact_verifier_plugin(self) -> None:
plugin_ctx = StorageArtifactVerifierPluginContext(self.etcd, self.local_config.model_dump())
await plugin_ctx.init()
plugins = {}
for plugin_name, plugin_instance in plugin_ctx.plugins.items():
log.info("Loading artifact verifier storage plugin: {0}", plugin_name)
plugins[plugin_name] = plugin_instance
self.artifact_verifier_ctx.load_verifiers(plugins)

def list_volumes(self) -> Mapping[str, VolumeInfo]:
return {name: info.to_dataclass() for name, info in self.local_config.volume.items()}
Expand Down
11 changes: 11 additions & 0 deletions src/ai/backend/storage/context_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ai.backend.storage.plugin import AbstractArtifactVerifierPlugin


class ArtifactVerifierContext:
_verifiers: dict[str, AbstractArtifactVerifierPlugin]

def __init__(self) -> None:
self._verifiers = {}

def load_verifiers(self, verifiers: dict[str, AbstractArtifactVerifierPlugin]) -> None:
self._verifiers.update(verifiers)
26 changes: 26 additions & 0 deletions src/ai/backend/storage/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,29 @@ def error_code(cls) -> ErrorCode:
operation=ErrorOperation.GENERIC,
error_detail=ErrorDetail.BAD_REQUEST,
)


class ArtifactVerifyStorageTypeInvalid(BackendAIError, web.HTTPBadRequest):
error_type = "https://api.backend.ai/probs/storage/artifact/verify/storage-type/invalid"
error_title = "Artifact Verify Storage Type Invalid"

@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.STORAGE_PROXY,
operation=ErrorOperation.GENERIC,
error_detail=ErrorDetail.BAD_REQUEST,
)


class ArtifactVerificationFailedError(BackendAIError, web.HTTPBadRequest):
error_type = "https://api.backend.ai/probs/storage/artifact/verification/failed"
error_title = "Artifact Verification Failed"

@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.STORAGE_PROXY,
operation=ErrorOperation.GENERIC,
error_detail=ErrorDetail.BAD_REQUEST,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's correct to set the status code for this error to 400.

)
2 changes: 2 additions & 0 deletions src/ai/backend/storage/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .config.loaders import load_local_config, make_etcd
from .config.unified import StorageProxyUnifiedConfig
from .context import DEFAULT_BACKENDS, EVENT_DISPATCHER_CONSUMER_GROUP, RootContext
from .context_types import ArtifactVerifierContext
from .types import VFolderID
from .volumes.abc import CAP_FAST_SIZE, AbstractVolume

Expand Down Expand Up @@ -268,6 +269,7 @@ async def check_and_upgrade(
volume_pool=None, # type: ignore[arg-type]
storage_pool=None, # type: ignore[arg-type]
background_task_manager=None, # type: ignore[arg-type]
artifact_verifier_ctx=ArtifactVerifierContext(), # type: ignore[arg-type]
metric_registry=CommonMetricRegistry(),
cors_options={},
backends={**DEFAULT_BACKENDS},
Expand Down
29 changes: 29 additions & 0 deletions src/ai/backend/storage/plugin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Iterator, Optional

from aiohttp import web

from ai.backend.common.plugin import AbstractPlugin, BasePluginContext
from ai.backend.storage.services.artifacts.types import ImportStepContext

if TYPE_CHECKING:
from .api.types import CORSOptions, WebMiddleware
Expand All @@ -20,6 +23,18 @@ def get_volume_class(
raise NotImplementedError


@dataclass
class VerificationResult:
scanned_count: int
infected_count: int


class AbstractArtifactVerifierPlugin(AbstractPlugin, metaclass=ABCMeta):
@abstractmethod
async def verify(self, artifact_path: Path, context: ImportStepContext) -> VerificationResult:
raise NotImplementedError


class StoragePluginContext(BasePluginContext[AbstractStoragePlugin]):
plugin_group = "backendai_storage_v10"

Expand All @@ -34,6 +49,20 @@ def discover_plugins(
yield from scanned_plugins


class StorageArtifactVerifierPluginContext(BasePluginContext[AbstractArtifactVerifierPlugin]):
plugin_group = "backendai_storage_artifact_verifier_v1"

@classmethod
def discover_plugins(
cls,
plugin_group: str,
allowlist: Optional[set[str]] = None,
blocklist: Optional[set[str]] = None,
) -> Iterator[tuple[str, type[AbstractArtifactVerifierPlugin]]]:
scanned_plugins = [*super().discover_plugins(plugin_group, allowlist, blocklist)]
yield from scanned_plugins


class StorageManagerWebappPlugin(AbstractPlugin, metaclass=ABCMeta):
@abstractmethod
async def create_app(
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/storage/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from ai.backend.common.utils import env_info
from ai.backend.logging import BraceStyleAdapter, Logger, LogLevel
from ai.backend.logging.otel import OpenTelemetrySpec
from ai.backend.storage.context_types import ArtifactVerifierContext

from . import __version__ as VERSION
from .config.loaders import load_local_config, make_etcd
Expand Down Expand Up @@ -614,6 +615,7 @@ async def server_main(
volumes={
NOOP_STORAGE_VOLUME_NAME: init_noop_volume(etcd, event_dispatcher, event_producer)
},
artifact_verifier_ctx=ArtifactVerifierContext(),
)
if pidx == 0:
await check_latest(root_ctx)
Expand Down
Loading
Loading