|
1 | 1 | import asyncio |
2 | 2 | import logging |
| 3 | +import signal |
3 | 4 | import traceback |
4 | 5 | from collections.abc import Sequence |
5 | 6 | from typing import Any |
@@ -139,13 +140,64 @@ async def call_tool(name: str, arguments: Any) -> Sequence[TextContent]: |
139 | 140 | async def main() -> None: |
140 | 141 | """Main entry point for the MCP shell server""" |
141 | 142 | logger.info(f"Starting MCP shell server v{__version__}") |
| 143 | + |
| 144 | + # Setup signal handling |
| 145 | + loop = asyncio.get_running_loop() |
| 146 | + stop_event = asyncio.Event() |
| 147 | + |
| 148 | + def handle_signal(): |
| 149 | + if not stop_event.is_set(): # Prevent duplicate handling |
| 150 | + logger.info("Received shutdown signal, starting cleanup...") |
| 151 | + stop_event.set() |
| 152 | + |
| 153 | + # Register signal handlers |
| 154 | + for sig in (signal.SIGTERM, signal.SIGINT): |
| 155 | + loop.add_signal_handler(sig, handle_signal) |
| 156 | + |
142 | 157 | try: |
143 | 158 | from mcp.server.stdio import stdio_server |
144 | 159 |
|
145 | 160 | async with stdio_server() as (read_stream, write_stream): |
146 | | - await app.run( |
147 | | - read_stream, write_stream, app.create_initialization_options() |
| 161 | + # Run the server until stop_event is set |
| 162 | + server_task = asyncio.create_task( |
| 163 | + app.run(read_stream, write_stream, app.create_initialization_options()) |
148 | 164 | ) |
| 165 | + |
| 166 | + # Create task for stop event |
| 167 | + stop_task = asyncio.create_task(stop_event.wait()) |
| 168 | + |
| 169 | + # Wait for either server completion or stop signal |
| 170 | + done, pending = await asyncio.wait( |
| 171 | + [server_task, stop_task], return_when=asyncio.FIRST_COMPLETED |
| 172 | + ) |
| 173 | + |
| 174 | + # Check for exceptions in completed tasks |
| 175 | + for task in done: |
| 176 | + try: |
| 177 | + await task |
| 178 | + except Exception: |
| 179 | + raise # Re-raise the exception |
| 180 | + |
| 181 | + # Cancel any pending tasks |
| 182 | + for task in pending: |
| 183 | + task.cancel() |
| 184 | + try: |
| 185 | + await task |
| 186 | + except asyncio.CancelledError: |
| 187 | + pass |
| 188 | + |
149 | 189 | except Exception as e: |
150 | 190 | logger.error(f"Server error: {str(e)}") |
151 | 191 | raise |
| 192 | + finally: |
| 193 | + # Cleanup signal handlers |
| 194 | + for sig in (signal.SIGTERM, signal.SIGINT): |
| 195 | + loop.remove_signal_handler(sig) |
| 196 | + |
| 197 | + # Ensure all processes are terminated |
| 198 | + if hasattr(tool_handler, "executor") and hasattr( |
| 199 | + tool_handler.executor, "process_manager" |
| 200 | + ): |
| 201 | + await tool_handler.executor.process_manager.cleanup_processes() |
| 202 | + |
| 203 | + logger.info("Server shutdown complete") |
0 commit comments