|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import asyncio |
2 | 4 | import importlib |
3 | 5 | import signal |
|
20 | 22 |
|
21 | 23 | class AgentRuntime: |
22 | 24 | _local_config: AgentUnifiedConfig |
23 | | - _agents: dict[AgentId, AbstractAgent] |
| 25 | + _etcd_views: Mapping[AgentId, AgentEtcdClientView] |
| 26 | + _agents: Mapping[AgentId, AbstractAgent] |
24 | 27 | _default_agent: AbstractAgent |
25 | 28 | _kernel_registry: KernelRegistry |
26 | | - _etcd: AsyncEtcd |
27 | | - _etcd_views: Mapping[AgentId, AgentEtcdClientView] |
| 29 | + _metadata_server: Optional[MetadataServer] |
28 | 30 |
|
29 | 31 | _stop_signal: signal.Signals |
30 | 32 |
|
31 | | - def __init__( |
32 | | - self, |
| 33 | + @classmethod |
| 34 | + async def create_runtime( |
| 35 | + cls, |
33 | 36 | local_config: AgentUnifiedConfig, |
34 | 37 | etcd: AsyncEtcd, |
35 | | - ) -> None: |
36 | | - self._local_config = local_config |
37 | | - self._agents = {} |
38 | | - self._kernel_registry = KernelRegistry() |
39 | | - self._etcd = etcd |
40 | | - self._etcd_views = { |
41 | | - AgentId(agent_config.agent.id): AgentEtcdClientView(self._etcd, agent_config) |
42 | | - for agent_config in self._local_config.get_agent_configs() |
43 | | - } |
44 | | - self._metadata_server: MetadataServer | None = None |
45 | | - |
46 | | - self._stop_signal = signal.SIGTERM |
47 | | - |
48 | | - async def create_agents( |
49 | | - self, |
50 | 38 | stats_monitor: AgentStatsPluginContext, |
51 | 39 | error_monitor: AgentErrorPluginContext, |
52 | 40 | agent_public_key: Optional[PublicKey], |
53 | | - ) -> None: |
54 | | - if self._local_config.agent_common.backend == AgentBackend.DOCKER: |
55 | | - await self._initialize_metadata_server() |
| 41 | + ) -> AgentRuntime: |
| 42 | + kernel_registry = KernelRegistry() |
| 43 | + |
| 44 | + if local_config.agent_common.backend == AgentBackend.DOCKER: |
| 45 | + await cls._create_metadata_server(local_config, etcd, kernel_registry) |
56 | 46 |
|
57 | | - tasks: list[asyncio.Task] = [] |
| 47 | + agent_configs = local_config.get_agent_configs() |
| 48 | + etcd_views: dict[AgentId, AgentEtcdClientView] = {} |
| 49 | + create_agent_tasks: list[asyncio.Task] = [] |
58 | 50 | async with asyncio.TaskGroup() as tg: |
59 | | - for agent_config in self._local_config.get_agent_configs(): |
| 51 | + for agent_config in agent_configs: |
60 | 52 | agent_id = AgentId(agent_config.agent.id) |
61 | | - tasks.append( |
62 | | - tg.create_task( |
63 | | - self._create_agent( |
64 | | - self.get_etcd(agent_id), |
65 | | - agent_config, |
66 | | - stats_monitor, |
67 | | - error_monitor, |
68 | | - agent_public_key, |
69 | | - ) |
| 53 | + |
| 54 | + etcd_view = AgentEtcdClientView(etcd, agent_config) |
| 55 | + etcd_views[agent_id] = etcd_view |
| 56 | + |
| 57 | + create_agent_task = tg.create_task( |
| 58 | + cls._create_agent( |
| 59 | + local_config, |
| 60 | + etcd_view, |
| 61 | + kernel_registry, |
| 62 | + agent_config, |
| 63 | + stats_monitor, |
| 64 | + error_monitor, |
| 65 | + agent_public_key, |
70 | 66 | ) |
71 | 67 | ) |
| 68 | + create_agent_tasks.append(create_agent_task) |
| 69 | + agents_list = [task.result() for task in create_agent_tasks] |
| 70 | + default_agent = agents_list[0] |
| 71 | + agents = {agent.id: agent for agent in agents_list} |
| 72 | + |
| 73 | + return AgentRuntime( |
| 74 | + local_config=local_config, |
| 75 | + etcd_views=etcd_views, |
| 76 | + agents=agents, |
| 77 | + default_agent=default_agent, |
| 78 | + kernel_registry=kernel_registry, |
| 79 | + ) |
| 80 | + |
| 81 | + @classmethod |
| 82 | + async def _create_metadata_server( |
| 83 | + cls, |
| 84 | + local_config: AgentUnifiedConfig, |
| 85 | + etcd: AsyncEtcd, |
| 86 | + kernel_registry: KernelRegistry, |
| 87 | + ) -> MetadataServer: |
| 88 | + from .docker.metadata.server import MetadataServer |
72 | 89 |
|
73 | | - agents = [task.result() for task in tasks] |
74 | | - self._default_agent = agents[0] |
75 | | - self._agents = {agent.id: agent for agent in agents} |
| 90 | + metadata_server = await MetadataServer.new( |
| 91 | + local_config, |
| 92 | + etcd, |
| 93 | + kernel_registry=kernel_registry.global_view(), |
| 94 | + ) |
| 95 | + await metadata_server.start_server() |
| 96 | + return metadata_server |
| 97 | + |
| 98 | + @classmethod |
| 99 | + async def _create_agent( |
| 100 | + cls, |
| 101 | + local_config: AgentUnifiedConfig, |
| 102 | + etcd_view: AgentEtcdClientView, |
| 103 | + kernel_registry: KernelRegistry, |
| 104 | + agent_config: AgentUnifiedConfig, |
| 105 | + stats_monitor: AgentStatsPluginContext, |
| 106 | + error_monitor: AgentErrorPluginContext, |
| 107 | + agent_public_key: Optional[PublicKey], |
| 108 | + ) -> AbstractAgent: |
| 109 | + agent_kwargs = { |
| 110 | + "kernel_registry": kernel_registry, |
| 111 | + "stats_monitor": stats_monitor, |
| 112 | + "error_monitor": error_monitor, |
| 113 | + "agent_public_key": agent_public_key, |
| 114 | + } |
| 115 | + |
| 116 | + backend = local_config.agent_common.backend |
| 117 | + agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}") |
| 118 | + agent_cls: Type[AbstractAgent] = agent_mod.get_agent_cls() |
| 119 | + |
| 120 | + return await agent_cls.new(etcd_view, agent_config, **agent_kwargs) |
| 121 | + |
| 122 | + def __init__( |
| 123 | + self, |
| 124 | + local_config: AgentUnifiedConfig, |
| 125 | + etcd_views: Mapping[AgentId, AgentEtcdClientView], |
| 126 | + agents: dict[AgentId, AbstractAgent], |
| 127 | + default_agent: AbstractAgent, |
| 128 | + kernel_registry: KernelRegistry, |
| 129 | + metadata_server: Optional[MetadataServer] = None, |
| 130 | + ) -> None: |
| 131 | + self._local_config = local_config |
| 132 | + self._etcd_views = etcd_views |
| 133 | + self._agents = agents |
| 134 | + self._default_agent = default_agent |
| 135 | + self._kernel_registry = kernel_registry |
| 136 | + self._metadata_server = metadata_server |
| 137 | + |
| 138 | + self._stop_signal = signal.SIGTERM |
76 | 139 |
|
77 | 140 | async def __aexit__(self, *exc_info) -> None: |
78 | 141 | for agent in self._agents.values(): |
@@ -107,34 +170,3 @@ def mark_stop_signal(self, stop_signal: signal.Signals) -> None: |
107 | 170 | async def update_status(self, status, agent_id: AgentId) -> None: |
108 | 171 | etcd = self.get_etcd(agent_id) |
109 | 172 | await etcd.put("", status, scope=ConfigScopes.NODE) |
110 | | - |
111 | | - async def _create_agent( |
112 | | - self, |
113 | | - etcd_view: AgentEtcdClientView, |
114 | | - agent_config: AgentUnifiedConfig, |
115 | | - stats_monitor: AgentStatsPluginContext, |
116 | | - error_monitor: AgentErrorPluginContext, |
117 | | - agent_public_key: Optional[PublicKey], |
118 | | - ) -> AbstractAgent: |
119 | | - agent_kwargs = { |
120 | | - "kernel_registry": self._kernel_registry, |
121 | | - "stats_monitor": stats_monitor, |
122 | | - "error_monitor": error_monitor, |
123 | | - "agent_public_key": agent_public_key, |
124 | | - } |
125 | | - |
126 | | - backend = self._local_config.agent_common.backend |
127 | | - agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}") |
128 | | - agent_cls: Type[AbstractAgent] = agent_mod.get_agent_cls() |
129 | | - |
130 | | - return await agent_cls.new(etcd_view, agent_config, **agent_kwargs) |
131 | | - |
132 | | - async def _initialize_metadata_server(self) -> None: |
133 | | - from .docker.metadata.server import MetadataServer |
134 | | - |
135 | | - self._metadata_server = await MetadataServer.new( |
136 | | - self._local_config, |
137 | | - self._etcd, |
138 | | - kernel_registry=self._kernel_registry.global_view(), |
139 | | - ) |
140 | | - await self._metadata_server.start_server() |
|
0 commit comments