Skip to content

Commit 4f91c5d

Browse files
committed
feat(BA-2753): Add static method for constructing AgentRuntime
1 parent 95deceb commit 4f91c5d

File tree

4 files changed

+129
-82
lines changed

4 files changed

+129
-82
lines changed

src/ai/backend/agent/runtime.py

Lines changed: 100 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import importlib
35
import signal
@@ -20,59 +22,120 @@
2022

2123
class AgentRuntime:
2224
_local_config: AgentUnifiedConfig
23-
_agents: dict[AgentId, AbstractAgent]
25+
_etcd_views: Mapping[AgentId, AgentEtcdClientView]
26+
_agents: Mapping[AgentId, AbstractAgent]
2427
_default_agent: AbstractAgent
2528
_kernel_registry: KernelRegistry
26-
_etcd: AsyncEtcd
27-
_etcd_views: Mapping[AgentId, AgentEtcdClientView]
29+
_metadata_server: Optional[MetadataServer]
2830

2931
_stop_signal: signal.Signals
3032

31-
def __init__(
32-
self,
33+
@classmethod
34+
async def create_runtime(
35+
cls,
3336
local_config: AgentUnifiedConfig,
3437
etcd: AsyncEtcd,
35-
) -> None:
36-
self._local_config = local_config
37-
self._agents = {}
38-
self._kernel_registry = KernelRegistry()
39-
self._etcd = etcd
40-
self._etcd_views = {
41-
AgentId(agent_config.agent.id): AgentEtcdClientView(self._etcd, agent_config)
42-
for agent_config in self._local_config.get_agent_configs()
43-
}
44-
self._metadata_server: MetadataServer | None = None
45-
46-
self._stop_signal = signal.SIGTERM
47-
48-
async def create_agents(
49-
self,
5038
stats_monitor: AgentStatsPluginContext,
5139
error_monitor: AgentErrorPluginContext,
5240
agent_public_key: Optional[PublicKey],
53-
) -> None:
54-
if self._local_config.agent_common.backend == AgentBackend.DOCKER:
55-
await self._initialize_metadata_server()
41+
) -> AgentRuntime:
42+
kernel_registry = KernelRegistry()
43+
44+
if local_config.agent_common.backend == AgentBackend.DOCKER:
45+
await cls._create_metadata_server(local_config, etcd, kernel_registry)
5646

57-
tasks: list[asyncio.Task] = []
47+
agent_configs = local_config.get_agent_configs()
48+
etcd_views: dict[AgentId, AgentEtcdClientView] = {}
49+
create_agent_tasks: list[asyncio.Task] = []
5850
async with asyncio.TaskGroup() as tg:
59-
for agent_config in self._local_config.get_agent_configs():
51+
for agent_config in agent_configs:
6052
agent_id = AgentId(agent_config.agent.id)
61-
tasks.append(
62-
tg.create_task(
63-
self._create_agent(
64-
self.get_etcd(agent_id),
65-
agent_config,
66-
stats_monitor,
67-
error_monitor,
68-
agent_public_key,
69-
)
53+
54+
etcd_view = AgentEtcdClientView(etcd, agent_config)
55+
etcd_views[agent_id] = etcd_view
56+
57+
create_agent_task = tg.create_task(
58+
cls._create_agent(
59+
local_config,
60+
etcd_view,
61+
kernel_registry,
62+
agent_config,
63+
stats_monitor,
64+
error_monitor,
65+
agent_public_key,
7066
)
7167
)
68+
create_agent_tasks.append(create_agent_task)
69+
agents_list = [task.result() for task in create_agent_tasks]
70+
default_agent = agents_list[0]
71+
agents = {agent.id: agent for agent in agents_list}
72+
73+
return AgentRuntime(
74+
local_config=local_config,
75+
etcd_views=etcd_views,
76+
agents=agents,
77+
default_agent=default_agent,
78+
kernel_registry=kernel_registry,
79+
)
80+
81+
@classmethod
82+
async def _create_metadata_server(
83+
cls,
84+
local_config: AgentUnifiedConfig,
85+
etcd: AsyncEtcd,
86+
kernel_registry: KernelRegistry,
87+
) -> MetadataServer:
88+
from .docker.metadata.server import MetadataServer
7289

73-
agents = [task.result() for task in tasks]
74-
self._default_agent = agents[0]
75-
self._agents = {agent.id: agent for agent in agents}
90+
metadata_server = await MetadataServer.new(
91+
local_config,
92+
etcd,
93+
kernel_registry=kernel_registry.global_view(),
94+
)
95+
await metadata_server.start_server()
96+
return metadata_server
97+
98+
@classmethod
99+
async def _create_agent(
100+
cls,
101+
local_config: AgentUnifiedConfig,
102+
etcd_view: AgentEtcdClientView,
103+
kernel_registry: KernelRegistry,
104+
agent_config: AgentUnifiedConfig,
105+
stats_monitor: AgentStatsPluginContext,
106+
error_monitor: AgentErrorPluginContext,
107+
agent_public_key: Optional[PublicKey],
108+
) -> AbstractAgent:
109+
agent_kwargs = {
110+
"kernel_registry": kernel_registry,
111+
"stats_monitor": stats_monitor,
112+
"error_monitor": error_monitor,
113+
"agent_public_key": agent_public_key,
114+
}
115+
116+
backend = local_config.agent_common.backend
117+
agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}")
118+
agent_cls: Type[AbstractAgent] = agent_mod.get_agent_cls()
119+
120+
return await agent_cls.new(etcd_view, agent_config, **agent_kwargs)
121+
122+
def __init__(
123+
self,
124+
local_config: AgentUnifiedConfig,
125+
etcd_views: Mapping[AgentId, AgentEtcdClientView],
126+
agents: dict[AgentId, AbstractAgent],
127+
default_agent: AbstractAgent,
128+
kernel_registry: KernelRegistry,
129+
metadata_server: Optional[MetadataServer] = None,
130+
) -> None:
131+
self._local_config = local_config
132+
self._etcd_views = etcd_views
133+
self._agents = agents
134+
self._default_agent = default_agent
135+
self._kernel_registry = kernel_registry
136+
self._metadata_server = metadata_server
137+
138+
self._stop_signal = signal.SIGTERM
76139

77140
async def __aexit__(self, *exc_info) -> None:
78141
for agent in self._agents.values():
@@ -107,34 +170,3 @@ def mark_stop_signal(self, stop_signal: signal.Signals) -> None:
107170
async def update_status(self, status, agent_id: AgentId) -> None:
108171
etcd = self.get_etcd(agent_id)
109172
await etcd.put("", status, scope=ConfigScopes.NODE)
110-
111-
async def _create_agent(
112-
self,
113-
etcd_view: AgentEtcdClientView,
114-
agent_config: AgentUnifiedConfig,
115-
stats_monitor: AgentStatsPluginContext,
116-
error_monitor: AgentErrorPluginContext,
117-
agent_public_key: Optional[PublicKey],
118-
) -> AbstractAgent:
119-
agent_kwargs = {
120-
"kernel_registry": self._kernel_registry,
121-
"stats_monitor": stats_monitor,
122-
"error_monitor": error_monitor,
123-
"agent_public_key": agent_public_key,
124-
}
125-
126-
backend = self._local_config.agent_common.backend
127-
agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}")
128-
agent_cls: Type[AbstractAgent] = agent_mod.get_agent_cls()
129-
130-
return await agent_cls.new(etcd_view, agent_config, **agent_kwargs)
131-
132-
async def _initialize_metadata_server(self) -> None:
133-
from .docker.metadata.server import MetadataServer
134-
135-
self._metadata_server = await MetadataServer.new(
136-
self._local_config,
137-
self._etcd,
138-
kernel_registry=self._kernel_registry.global_view(),
139-
)
140-
await self._metadata_server.start_server()

src/ai/backend/agent/server.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,9 @@ def __init__(
285285
self.loop = current_loop()
286286
self.etcd = etcd
287287
self.local_config = local_config
288-
self.runtime = AgentRuntime(self.local_config, self.etcd)
289288
self.skip_detect_manager = skip_detect_manager
290289

291290
async def __ainit__(self) -> None:
292-
# Start serving requests.
293-
async with asyncio.TaskGroup() as tg:
294-
for agent_id in self.local_config.agent_ids:
295-
tg.create_task(self.update_status("starting", agent_id))
296-
297291
if not self.skip_detect_manager:
298292
await self.detect_manager()
299293

@@ -338,12 +332,19 @@ async def __ainit__(self) -> None:
338332
self.rpc_auth_agent_secret_key = None
339333
auth_handler = None
340334

341-
await self.runtime.create_agents(
335+
self.runtime = await AgentRuntime.create_runtime(
336+
self.local_config,
337+
self.etcd,
342338
self.stats_monitor,
343339
self.error_monitor,
344340
self.rpc_auth_agent_public_key,
345341
)
346342

343+
# Start serving requests.
344+
async with asyncio.TaskGroup() as tg:
345+
for agent_id in self.local_config.agent_ids:
346+
tg.create_task(self.update_status("starting", agent_id))
347+
347348
rpc_addr = self.local_config.agent_common.rpc_listen_addr
348349
self.rpc_server = Peer(
349350
bind=ZeroMQAddress(f"tcp://{rpc_addr.address}"),

tests/agent/conftest.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,13 @@ async def agent_runtime(
295295
mock_stats_monitor = Mock()
296296
mock_error_monitor = Mock()
297297

298-
runtime = AgentRuntime(local_config, etcd)
299-
300-
await runtime.create_agents(mock_stats_monitor, mock_error_monitor, None)
298+
runtime = await AgentRuntime.create_runtime(
299+
local_config,
300+
etcd,
301+
mock_stats_monitor,
302+
mock_error_monitor,
303+
None,
304+
)
301305

302306
try:
303307
yield runtime

tests/agent/test_agent_runtime.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,13 @@ async def test_runtime_creates_agents_from_config(
102102
mock_stats_monitor = Mock()
103103
mock_error_monitor = Mock()
104104

105-
runtime = AgentRuntime(local_config, etcd)
106-
await runtime.create_agents(mock_stats_monitor, mock_error_monitor, None)
105+
runtime = await AgentRuntime.create_runtime(
106+
local_config,
107+
etcd,
108+
mock_stats_monitor,
109+
mock_error_monitor,
110+
None,
111+
)
107112

108113
try:
109114
# Verify agents were created
@@ -135,8 +140,13 @@ async def test_runtime_shutdown_cleans_up_agents(
135140
mock_stats_monitor = Mock()
136141
mock_error_monitor = Mock()
137142

138-
runtime = AgentRuntime(local_config, etcd)
139-
await runtime.create_agents(mock_stats_monitor, mock_error_monitor, None)
143+
runtime = await AgentRuntime.create_runtime(
144+
local_config,
145+
etcd,
146+
mock_stats_monitor,
147+
mock_error_monitor,
148+
None,
149+
)
140150

141151
# Verify agents exist before shutdown
142152
agents = runtime.get_agents()

0 commit comments

Comments
 (0)