3030from dataclasses import dataclass
3131from decimal import Decimal
3232from io import SEEK_END , BytesIO
33+ from itertools import chain
3334from pathlib import Path
3435from types import TracebackType
3536from typing import (
174175from ai .backend .common .types import (
175176 MODEL_SERVICE_RUNTIME_PROFILES ,
176177 AbuseReportValue ,
177- AcceleratorMetadata ,
178178 AgentId ,
179179 AutoPullBehavior ,
180180 BinarySize ,
233233from .observer .heartbeat import HeartbeatObserver
234234from .observer .host_port import HostPortObserver
235235from .resources import (
236- AbstractComputeDevice ,
237236 AbstractComputePlugin ,
238237 ComputerContext ,
239238 KernelResourceSpec ,
240239 Mount ,
240+ ResourcePartitioner ,
241241 align_memory ,
242242 allocate ,
243243 known_slot_types ,
@@ -765,7 +765,10 @@ class AbstractAgent(
765765 etcd : AsyncEtcd
766766 local_instance_id : str
767767 kernel_registry : MutableMapping [KernelId , AbstractKernel ]
768+ resource_partitioner : ResourcePartitioner
768769 computers : MutableMapping [DeviceName , ComputerContext ]
770+ total_slots : Mapping [SlotName , Decimal ]
771+ reserved_slots : Mapping [SlotName , Decimal ]
769772 images : Mapping [ImageCanonical , ScannedImage ]
770773 port_pool : set [int ]
771774
@@ -836,6 +839,7 @@ def __init__(
836839 error_monitor : ErrorPluginContext ,
837840 skip_initial_scan : bool = False ,
838841 agent_public_key : Optional [PublicKey ],
842+ resource_partitioner : ResourcePartitioner ,
839843 ) -> None :
840844 self ._skip_initial_scan = skip_initial_scan
841845 self .loop = current_loop ()
@@ -845,7 +849,10 @@ def __init__(
845849 self .local_instance_id = generate_local_instance_id (__file__ )
846850 self .agent_public_key = agent_public_key
847851 self .kernel_registry = {}
852+ self .resource_partitioner = resource_partitioner
848853 self .computers = {}
854+ self .total_slots = {}
855+ self .reserved_slots = {}
849856 self .images = {}
850857 self .restarting_kernels = {}
851858 self .stat_ctx = StatContext (
@@ -934,14 +941,17 @@ async def __ainit__(self) -> None:
934941 alloc_map_mod .log_alloc_map = self .local_config .debug .log_alloc_map
935942 computers = await self .load_resources ()
936943
937- all_devices : list [AbstractComputeDevice ] = []
938- metadatas : list [AcceleratorMetadata ] = []
939944 for name , computer in computers .items ():
940945 devices = await computer .list_devices ()
941- all_devices .extend (devices )
942946 alloc_map = await computer .create_alloc_map ()
943947 self .computers [name ] = ComputerContext (computer , devices , alloc_map )
944- metadatas .append (computer .get_metadata ())
948+
949+ self .total_slots = self .resource_partitioner .calculate_total_slots (
950+ self .computers , self .local_config .resource
951+ )
952+ self .reserved_slots = self .resource_partitioner .restrict_computer_resources (
953+ self .computers , self .total_slots
954+ )
945955
946956 self .slots = await self .update_slots ()
947957 log .info ("Resource slots: {!r}" , self .slots )
@@ -950,12 +960,16 @@ async def __ainit__(self) -> None:
950960
951961 # Use ValkeyStatClient batch operations for better performance
952962 field_value_map = {}
953- for metadata in metadatas :
963+ for computer_ctx in self .computers .values ():
964+ metadata = computer_ctx .instance .get_metadata ()
954965 field_value_map [metadata ["slot_name" ]] = dump_json_str (metadata ).encode ()
955966
956967 if field_value_map :
957968 await self .valkey_stat_client .store_computer_metadata (field_value_map )
958969
970+ all_devices = list (
971+ chain .from_iterable (computer .devices for computer in self .computers .values ())
972+ )
959973 self .affinity_map = AffinityMap .build (all_devices )
960974
961975 if not self ._skip_initial_scan :
@@ -1949,6 +1963,7 @@ async def load_resources(
19491963 """
19501964 Detect available resources attached on the system and load corresponding device plugin.
19511965 """
1966+ raise NotImplementedError
19521967
19531968 @abstractmethod
19541969 async def scan_available_resources (
@@ -1957,6 +1972,7 @@ async def scan_available_resources(
19571972 """
19581973 Scan and define the amount of available resource slots in this node.
19591974 """
1975+ raise NotImplementedError
19601976
19611977 async def update_slots (
19621978 self ,
@@ -1967,14 +1983,9 @@ async def update_slots(
19671983 """
19681984 scanned_slots = await self .scan_available_resources ()
19691985 usable_slots : dict [SlotName , Decimal ] = {}
1970- reserved_slots = {
1971- SlotName ("cpu" ): Decimal (self .local_config .resource .reserved_cpu ),
1972- SlotName ("mem" ): Decimal (self .local_config .resource .reserved_mem ),
1973- SlotName ("disk" ): Decimal (self .local_config .resource .reserved_disk ),
1974- }
19751986 for slot_name , slot_capacity in scanned_slots .items ():
19761987 if slot_name == SlotName ("mem" ):
1977- mem_reserved = int (reserved_slots .get (slot_name , 0 ))
1988+ mem_reserved = int (self . reserved_slots .get (slot_name , 0 ))
19781989 mem_align = int (self .local_config .resource .memory_align_size )
19791990 mem_usable , mem_reserved = align_memory (
19801991 int (slot_capacity ), mem_reserved , align = mem_align
@@ -1988,7 +1999,7 @@ async def update_slots(
19881999 )
19892000 else :
19902001 usable_capacity = max (
1991- Decimal (0 ), slot_capacity - reserved_slots .get (slot_name , Decimal (0 ))
2002+ Decimal (0 ), slot_capacity - self . reserved_slots .get (slot_name , Decimal (0 ))
19922003 )
19932004 usable_slots [slot_name ] = usable_capacity
19942005 return usable_slots
@@ -2100,6 +2111,7 @@ async def scan_images(self) -> ScanImagesResult:
21002111 This is called periodically to keep the image list up-to-date and allow
21012112 manual image addition and deletions by admins.
21022113 """
2114+ raise NotImplementedError
21032115
21042116 async def _scan_images_wrapper (self , interval : float ) -> None :
21052117 result = await self .scan_images ()
@@ -2120,6 +2132,7 @@ async def push_image(
21202132 """
21212133 Push the given image to the given registry.
21222134 """
2135+ raise NotImplementedError
21232136
21242137 @abstractmethod
21252138 async def pull_image (
@@ -2132,12 +2145,14 @@ async def pull_image(
21322145 """
21332146 Pull the given image from the given registry.
21342147 """
2148+ raise NotImplementedError
21352149
21362150 @abstractmethod
21372151 async def purge_images (self , request : PurgeImagesReq ) -> PurgeImagesResp :
21382152 """
21392153 Purge the given images from the agent.
21402154 """
2155+ raise NotImplementedError
21412156
21422157 async def check_and_pull (
21432158 self ,
@@ -2269,7 +2284,7 @@ async def check_image(
22692284 Check the availability of the image and return a boolean flag that indicates whether
22702285 the agent should try pulling the image from a registry.
22712286 """
2272- return False
2287+ raise NotImplementedError
22732288
22742289 async def scan_running_kernels (self ) -> None :
22752290 """
@@ -3491,6 +3506,7 @@ async def destroy_kernel(
34913506 * Send SIGTERM to the kernel's main process.
34923507 * Send SIGKILL if it's not terminated within a few seconds.
34933508 """
3509+ raise NotImplementedError
34943510
34953511 @abstractmethod
34963512 async def clean_kernel (
@@ -3514,6 +3530,7 @@ async def clean_kernel(
35143530 The ``container_id`` may be ``None`` if the container has already gone away.
35153531 In such cases, skip container-specific cleanups.
35163532 """
3533+ raise NotImplementedError
35173534
35183535 @abstractmethod
35193536 async def create_local_network (self , network_name : str ) -> None :
@@ -3525,6 +3542,7 @@ async def create_local_network(self, network_name: str) -> None:
35253542 It may raise :exc:`NotImplementedError` and then the manager
35263543 will cancel creation of the session.
35273544 """
3545+ raise NotImplementedError
35283546
35293547 @abstractmethod
35303548 async def destroy_local_network (self , network_name : str ) -> None :
@@ -3533,6 +3551,7 @@ async def destroy_local_network(self, network_name: str) -> None:
35333551
35343552 This is called by the manager after kernel destruction.
35353553 """
3554+ raise NotImplementedError
35363555
35373556 @abstractmethod
35383557 async def restart_kernel__load_config (
@@ -3543,7 +3562,7 @@ async def restart_kernel__load_config(
35433562 """
35443563 Restore the cluster config from a previous launch of the kernel.
35453564 """
3546- pass
3565+ raise NotImplementedError
35473566
35483567 @abstractmethod
35493568 async def restart_kernel__store_config (
@@ -3556,7 +3575,7 @@ async def restart_kernel__store_config(
35563575 Store the cluster config to a kernel-related storage (e.g., scratch space),
35573576 so that restarts of this kernel can reuse the configuration.
35583577 """
3559- pass
3578+ raise NotImplementedError
35603579
35613580 async def restart_kernel (
35623581 self ,
0 commit comments