Skip to content

Commit e6c1f4b

Browse files
committed
feat(BA-2851): Add resource isolation options for multi-agent setup
This change adds configuration for partitioning resources rather than every agent always seeing the full resource pool. This prevents unintended over-allocation that could crash kernels. SHARED mode allows all agents to see full resources (useful for stress testing). This is the same behavior as before. AUTO_SPLIT automatically divides resources equally among agents. MANUAL mode lets users specify exact per-agent allocations for all resources. Single-agent deployments remain unaffected and retain access to all available hardware resources.
1 parent 9f12687 commit e6c1f4b

File tree

9 files changed

+1386
-44
lines changed

9 files changed

+1386
-44
lines changed

src/ai/backend/agent/agent.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from dataclasses import dataclass
3131
from decimal import Decimal
3232
from io import SEEK_END, BytesIO
33+
from itertools import chain
3334
from pathlib import Path
3435
from types import TracebackType
3536
from typing import (
@@ -174,7 +175,6 @@
174175
from ai.backend.common.types import (
175176
MODEL_SERVICE_RUNTIME_PROFILES,
176177
AbuseReportValue,
177-
AcceleratorMetadata,
178178
AgentId,
179179
AutoPullBehavior,
180180
BinarySize,
@@ -233,11 +233,11 @@
233233
from .observer.heartbeat import HeartbeatObserver
234234
from .observer.host_port import HostPortObserver
235235
from .resources import (
236-
AbstractComputeDevice,
237236
AbstractComputePlugin,
238237
ComputerContext,
239238
KernelResourceSpec,
240239
Mount,
240+
ResourcePartitioner,
241241
align_memory,
242242
allocate,
243243
known_slot_types,
@@ -765,7 +765,10 @@ class AbstractAgent(
765765
etcd: AsyncEtcd
766766
local_instance_id: str
767767
kernel_registry: MutableMapping[KernelId, AbstractKernel]
768+
resource_partitioner: ResourcePartitioner
768769
computers: MutableMapping[DeviceName, ComputerContext]
770+
total_slots: Mapping[SlotName, Decimal]
771+
reserved_slots: Mapping[SlotName, Decimal]
769772
images: Mapping[ImageCanonical, ScannedImage]
770773
port_pool: set[int]
771774

@@ -836,6 +839,7 @@ def __init__(
836839
error_monitor: ErrorPluginContext,
837840
skip_initial_scan: bool = False,
838841
agent_public_key: Optional[PublicKey],
842+
resource_partitioner: ResourcePartitioner,
839843
) -> None:
840844
self._skip_initial_scan = skip_initial_scan
841845
self.loop = current_loop()
@@ -845,7 +849,10 @@ def __init__(
845849
self.local_instance_id = generate_local_instance_id(__file__)
846850
self.agent_public_key = agent_public_key
847851
self.kernel_registry = {}
852+
self.resource_partitioner = resource_partitioner
848853
self.computers = {}
854+
self.total_slots = {}
855+
self.reserved_slots = {}
849856
self.images = {}
850857
self.restarting_kernels = {}
851858
self.stat_ctx = StatContext(
@@ -934,14 +941,17 @@ async def __ainit__(self) -> None:
934941
alloc_map_mod.log_alloc_map = self.local_config.debug.log_alloc_map
935942
computers = await self.load_resources()
936943

937-
all_devices: list[AbstractComputeDevice] = []
938-
metadatas: list[AcceleratorMetadata] = []
939944
for name, computer in computers.items():
940945
devices = await computer.list_devices()
941-
all_devices.extend(devices)
942946
alloc_map = await computer.create_alloc_map()
943947
self.computers[name] = ComputerContext(computer, devices, alloc_map)
944-
metadatas.append(computer.get_metadata())
948+
949+
self.total_slots = self.resource_partitioner.calculate_total_slots(
950+
self.computers, self.local_config.resource
951+
)
952+
self.reserved_slots = self.resource_partitioner.restrict_computer_resources(
953+
self.computers, self.total_slots
954+
)
945955

946956
self.slots = await self.update_slots()
947957
log.info("Resource slots: {!r}", self.slots)
@@ -950,12 +960,16 @@ async def __ainit__(self) -> None:
950960

951961
# Use ValkeyStatClient batch operations for better performance
952962
field_value_map = {}
953-
for metadata in metadatas:
963+
for computer_ctx in self.computers.values():
964+
metadata = computer_ctx.instance.get_metadata()
954965
field_value_map[metadata["slot_name"]] = dump_json_str(metadata).encode()
955966

956967
if field_value_map:
957968
await self.valkey_stat_client.store_computer_metadata(field_value_map)
958969

970+
all_devices = list(
971+
chain.from_iterable(computer.devices for computer in self.computers.values())
972+
)
959973
self.affinity_map = AffinityMap.build(all_devices)
960974

961975
if not self._skip_initial_scan:
@@ -1949,6 +1963,7 @@ async def load_resources(
19491963
"""
19501964
Detect available resources attached on the system and load corresponding device plugin.
19511965
"""
1966+
raise NotImplementedError
19521967

19531968
@abstractmethod
19541969
async def scan_available_resources(
@@ -1957,6 +1972,7 @@ async def scan_available_resources(
19571972
"""
19581973
Scan and define the amount of available resource slots in this node.
19591974
"""
1975+
raise NotImplementedError
19601976

19611977
async def update_slots(
19621978
self,
@@ -1967,14 +1983,9 @@ async def update_slots(
19671983
"""
19681984
scanned_slots = await self.scan_available_resources()
19691985
usable_slots: dict[SlotName, Decimal] = {}
1970-
reserved_slots = {
1971-
SlotName("cpu"): Decimal(self.local_config.resource.reserved_cpu),
1972-
SlotName("mem"): Decimal(self.local_config.resource.reserved_mem),
1973-
SlotName("disk"): Decimal(self.local_config.resource.reserved_disk),
1974-
}
19751986
for slot_name, slot_capacity in scanned_slots.items():
19761987
if slot_name == SlotName("mem"):
1977-
mem_reserved = int(reserved_slots.get(slot_name, 0))
1988+
mem_reserved = int(self.reserved_slots.get(slot_name, 0))
19781989
mem_align = int(self.local_config.resource.memory_align_size)
19791990
mem_usable, mem_reserved = align_memory(
19801991
int(slot_capacity), mem_reserved, align=mem_align
@@ -1988,7 +1999,7 @@ async def update_slots(
19881999
)
19892000
else:
19902001
usable_capacity = max(
1991-
Decimal(0), slot_capacity - reserved_slots.get(slot_name, Decimal(0))
2002+
Decimal(0), slot_capacity - self.reserved_slots.get(slot_name, Decimal(0))
19922003
)
19932004
usable_slots[slot_name] = usable_capacity
19942005
return usable_slots
@@ -2100,6 +2111,7 @@ async def scan_images(self) -> ScanImagesResult:
21002111
This is called periodically to keep the image list up-to-date and allow
21012112
manual image addition and deletions by admins.
21022113
"""
2114+
raise NotImplementedError
21032115

21042116
async def _scan_images_wrapper(self, interval: float) -> None:
21052117
result = await self.scan_images()
@@ -2120,6 +2132,7 @@ async def push_image(
21202132
"""
21212133
Push the given image to the given registry.
21222134
"""
2135+
raise NotImplementedError
21232136

21242137
@abstractmethod
21252138
async def pull_image(
@@ -2132,12 +2145,14 @@ async def pull_image(
21322145
"""
21332146
Pull the given image from the given registry.
21342147
"""
2148+
raise NotImplementedError
21352149

21362150
@abstractmethod
21372151
async def purge_images(self, request: PurgeImagesReq) -> PurgeImagesResp:
21382152
"""
21392153
Purge the given images from the agent.
21402154
"""
2155+
raise NotImplementedError
21412156

21422157
async def check_and_pull(
21432158
self,
@@ -2269,7 +2284,7 @@ async def check_image(
22692284
Check the availability of the image and return a boolean flag that indicates whether
22702285
the agent should try pulling the image from a registry.
22712286
"""
2272-
return False
2287+
raise NotImplementedError
22732288

22742289
async def scan_running_kernels(self) -> None:
22752290
"""
@@ -3491,6 +3506,7 @@ async def destroy_kernel(
34913506
* Send SIGTERM to the kernel's main process.
34923507
* Send SIGKILL if it's not terminated within a few seconds.
34933508
"""
3509+
raise NotImplementedError
34943510

34953511
@abstractmethod
34963512
async def clean_kernel(
@@ -3514,6 +3530,7 @@ async def clean_kernel(
35143530
The ``container_id`` may be ``None`` if the container has already gone away.
35153531
In such cases, skip container-specific cleanups.
35163532
"""
3533+
raise NotImplementedError
35173534

35183535
@abstractmethod
35193536
async def create_local_network(self, network_name: str) -> None:
@@ -3525,6 +3542,7 @@ async def create_local_network(self, network_name: str) -> None:
35253542
It may raise :exc:`NotImplementedError` and then the manager
35263543
will cancel creation of the session.
35273544
"""
3545+
raise NotImplementedError
35283546

35293547
@abstractmethod
35303548
async def destroy_local_network(self, network_name: str) -> None:
@@ -3533,6 +3551,7 @@ async def destroy_local_network(self, network_name: str) -> None:
35333551
35343552
This is called by the manager after kernel destruction.
35353553
"""
3554+
raise NotImplementedError
35363555

35373556
@abstractmethod
35383557
async def restart_kernel__load_config(
@@ -3543,7 +3562,7 @@ async def restart_kernel__load_config(
35433562
"""
35443563
Restore the cluster config from a previous launch of the kernel.
35453564
"""
3546-
pass
3565+
raise NotImplementedError
35473566

35483567
@abstractmethod
35493568
async def restart_kernel__store_config(
@@ -3556,7 +3575,7 @@ async def restart_kernel__store_config(
35563575
Store the cluster config to a kernel-related storage (e.g., scratch space),
35573576
so that restarts of this kernel can reuse the configuration.
35583577
"""
3559-
pass
3578+
raise NotImplementedError
35603579

35613580
async def restart_kernel(
35623581
self,

0 commit comments

Comments
 (0)