Skip to content

Commit ab323b9

Browse files
committed
Unwind some more future/couritine issues
1 parent f7c80c8 commit ab323b9

File tree

5 files changed

+228
-19
lines changed

5 files changed

+228
-19
lines changed

salt/minion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,8 +3114,9 @@ def cleanup_subprocesses(self):
31143114
# Add an extra fallback in case a forked process leaks through
31153115
multiprocessing.active_children()
31163116
self.subprocess_list.cleanup()
3117-
if self.schedule:
3118-
self.schedule.cleanup_subprocesses()
3117+
schedule = getattr(self, "schedule", None)
3118+
if schedule:
3119+
schedule.cleanup_subprocesses()
31193120

31203121
def _setup_core(self):
31213122
"""
@@ -3405,7 +3406,11 @@ def destroy(self):
34053406
self.req_channel.close()
34063407
if hasattr(self, "periodic_callbacks"):
34073408
for cb in self.periodic_callbacks.values():
3408-
cb.stop()
3409+
if hasattr(cb, "stop"):
3410+
cb.stop()
3411+
elif asyncio.isfuture(cb) or isinstance(cb, asyncio.Task):
3412+
cb.cancel()
3413+
self.periodic_callbacks.clear()
34093414

34103415
# pylint: disable=W1701
34113416
def __del__(self):

salt/transport/tcp.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,7 +1513,12 @@ async def publish(
15131513
"""
15141514
if not self.pub_sock:
15151515
self.connect()
1516-
self.pub_sock.send(payload)
1516+
log.debug("TCP PublishServer publishing payload (%d bytes)", len(payload))
1517+
result = self.pub_sock.send(payload)
1518+
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
1519+
await result
1520+
else:
1521+
log.debug("TCP PublishServer publish returned %r", result)
15171522

15181523
def close(self):
15191524
self._closing = True
@@ -1607,15 +1612,23 @@ def __init__(self, host, port, path, io_loop=None):
16071612
self.stream = None
16081613
self.unpacker = salt.utils.msgpack.Unpacker(raw=False)
16091614
self._connecting_future = None
1615+
self._asyncio_loop = None
1616+
if hasattr(self.io_loop, "asyncio_loop"):
1617+
self._asyncio_loop = self.io_loop.asyncio_loop
16101618

16111619
def connected(self):
16121620
return self.stream is not None and not self.stream.closed()
16131621

1614-
def connect(self, callback=None, timeout=None):
1622+
def _legacy_connect(self, callback=None, timeout=None):
16151623
"""
16161624
Connect to the IPC socket
16171625
"""
16181626
if self._connecting_future is not None and not self._connecting_future.done():
1627+
log.debug(
1628+
"%s connect reuse pending future (closing=%s)",
1629+
self.__class__.__name__,
1630+
self._closing,
1631+
)
16191632
future = self._connecting_future
16201633
else:
16211634
if self._connecting_future is not None:
@@ -1624,6 +1637,12 @@ def connect(self, callback=None, timeout=None):
16241637
future = tornado.concurrent.Future()
16251638
self._connecting_future = future
16261639
# self._connect(timeout)
1640+
log.debug(
1641+
"%s connect spawning _connect (closing=%s, timeout=%s)",
1642+
self.__class__.__name__,
1643+
self._closing,
1644+
timeout,
1645+
)
16271646
self.io_loop.spawn_callback(self._connect, timeout)
16281647

16291648
if callback is not None:
@@ -1636,6 +1655,34 @@ def handle_future(future):
16361655

16371656
return future
16381657

1658+
async def connect(self, callback=None, timeout=None):
1659+
if self._asyncio_loop is None:
1660+
# Fall back to the legacy tornado.Future based implementation
1661+
future = self._legacy_connect(callback=callback, timeout=timeout)
1662+
loop = asyncio.get_running_loop()
1663+
await salt.utils.asynchronous._ensure_task(loop, future)
1664+
return True
1665+
1666+
if self._connecting_future is None or self._connecting_future.done():
1667+
self._connecting_future = self._asyncio_loop.create_future()
1668+
1669+
async def runner():
1670+
try:
1671+
await self._connect(timeout=timeout)
1672+
except Exception as exc: # pylint: disable=broad-except
1673+
if not self._connecting_future.done():
1674+
self._connecting_future.set_exception(exc)
1675+
else:
1676+
if not self._connecting_future.done():
1677+
self._connecting_future.set_result(True)
1678+
1679+
asyncio.ensure_future(runner(), loop=self._asyncio_loop)
1680+
1681+
result = await self._connecting_future
1682+
if callback is not None:
1683+
self.io_loop.spawn_callback(callback, result)
1684+
return result
1685+
16391686
async def _connect(self, timeout=None):
16401687
"""
16411688
Connect to a running IPCServer
@@ -1652,9 +1699,19 @@ async def _connect(self, timeout=None):
16521699
self.stream = None
16531700
if timeout is not None:
16541701
timeout_at = time.monotonic() + timeout
1702+
log.debug(
1703+
"%s will timeout connection at %.3f (in %.3fs)",
1704+
self.__class__.__name__,
1705+
timeout_at,
1706+
timeout,
1707+
)
16551708

16561709
while True:
16571710
if self._closing:
1711+
log.debug(
1712+
"%s connect aborting due to closing flag",
1713+
self.__class__.__name__,
1714+
)
16581715
break
16591716

16601717
if self.stream is None:
@@ -1664,6 +1721,12 @@ async def _connect(self, timeout=None):
16641721
)
16651722
try:
16661723
await self.stream.connect(sock_addr)
1724+
log.debug(
1725+
"%s connected to %s (closing=%s)",
1726+
self.__class__.__name__,
1727+
sock_addr,
1728+
self._closing,
1729+
)
16671730
self._connecting_future.set_result(True)
16681731
break
16691732
except Exception as e: # pylint: disable=broad-except
@@ -1674,6 +1737,12 @@ async def _connect(self, timeout=None):
16741737
if self.stream is not None:
16751738
self.stream.close()
16761739
self.stream = None
1740+
log.debug(
1741+
"%s connection to %s failed: %s",
1742+
self.__class__.__name__,
1743+
sock_addr,
1744+
e,
1745+
)
16771746
self._connecting_future.set_exception(e)
16781747
break
16791748

@@ -1728,7 +1797,14 @@ async def send(self, msg, timeout=None, tries=None):
17281797
if not self.connected():
17291798
await self.connect()
17301799
pack = salt.transport.frame.frame_msg_ipc(msg, raw_body=True)
1800+
log.debug(
1801+
"%s sending %d bytes (closing=%s)",
1802+
self.__class__.__name__,
1803+
len(pack),
1804+
self._closing,
1805+
)
17311806
await self.stream.write(pack)
1807+
log.debug("%s send complete", self.__class__.__name__)
17321808

17331809

17341810
class RequestClient(salt.transport.base.RequestClient):

salt/transport/zeromq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,8 @@ async def request_handler(self):
575575
continue
576576
except Exception as exc: # pylint: disable=broad-except
577577
log.error(
578-
"Exception in request handler",
578+
"Exception in request handler: %r",
579+
exc,
579580
exc_info_on_loglevel=logging.DEBUG,
580581
)
581582
continue

salt/utils/asynchronous.py

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@
1919
from typing import Any
2020

2121
try: # Optional dependency during the transition away from tornado.
22+
import tornado # type: ignore
23+
import tornado.concurrent # type: ignore
2224
import tornado.ioloop # type: ignore
25+
from tornado.platform.asyncio import to_asyncio_future # type: ignore
2326
except ImportError: # pragma: no cover - tornado optional
2427
tornado = None # type: ignore
28+
to_asyncio_future = None # type: ignore
29+
_TORNADO_FUTURE_TYPES: tuple[type[Any], ...] = ()
30+
else:
31+
_TORNADO_FUTURE_TYPES = (tornado.concurrent.Future,) # type: ignore[attr-defined]
2532

2633
log = logging.getLogger(__name__)
2734

@@ -32,8 +39,94 @@ def _ensure_task(loop: asyncio.AbstractEventLoop, result: Any) -> Any:
3239
return it unchanged.
3340
"""
3441

35-
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
36-
return loop.create_task(result)
42+
if asyncio.iscoroutine(result):
43+
return asyncio.ensure_future(result, loop=loop)
44+
45+
if isinstance(result, asyncio.Future):
46+
try:
47+
future_loop = result.get_loop()
48+
except AttributeError:
49+
future_loop = loop
50+
if future_loop is loop:
51+
return asyncio.ensure_future(result, loop=loop)
52+
53+
proxy = loop.create_future()
54+
55+
def _relay(src_future: asyncio.Future):
56+
if proxy.done():
57+
return
58+
if src_future.cancelled():
59+
loop.call_soon_threadsafe(proxy.cancel)
60+
return
61+
exc = src_future.exception()
62+
if exc is not None:
63+
loop.call_soon_threadsafe(proxy.set_exception, exc)
64+
return
65+
loop.call_soon_threadsafe(proxy.set_result, src_future.result())
66+
67+
result.add_done_callback(_relay)
68+
return proxy
69+
70+
if _TORNADO_FUTURE_TYPES and isinstance(
71+
result, _TORNADO_FUTURE_TYPES # type: ignore[arg-type]
72+
):
73+
converted = None
74+
if to_asyncio_future is not None:
75+
try:
76+
converted = to_asyncio_future(result)
77+
except TypeError:
78+
converted = None
79+
except Exception: # pylint: disable=broad-except
80+
log.exception(
81+
"Failed to convert future %r to asyncio future using tornado helper",
82+
result,
83+
)
84+
if converted is not None:
85+
try:
86+
return asyncio.ensure_future(converted, loop=loop)
87+
except TypeError:
88+
converted = None
89+
if converted is None:
90+
converted = loop.create_future()
91+
92+
def _relay(src_future):
93+
if converted.done():
94+
return
95+
if src_future.cancelled():
96+
loop.call_soon_threadsafe(converted.cancel)
97+
return
98+
exc = src_future.exception()
99+
if exc is not None:
100+
loop.call_soon_threadsafe(converted.set_exception, exc)
101+
return
102+
loop.call_soon_threadsafe(converted.set_result, src_future.result())
103+
104+
result.add_done_callback(_relay)
105+
return asyncio.ensure_future(converted, loop=loop)
106+
107+
if hasattr(result, "add_done_callback"):
108+
try:
109+
wrapped = asyncio.wrap_future(result, loop=loop)
110+
return asyncio.ensure_future(wrapped, loop=loop)
111+
except TypeError:
112+
bridge = loop.create_future()
113+
114+
def _relay(src_future):
115+
if bridge.done():
116+
return
117+
if getattr(src_future, "cancelled", lambda: False)():
118+
loop.call_soon_threadsafe(bridge.cancel)
119+
return
120+
exc = getattr(src_future, "exception", lambda: None)()
121+
if exc is not None:
122+
loop.call_soon_threadsafe(bridge.set_exception, exc)
123+
return
124+
result_value = getattr(src_future, "result", lambda: None)()
125+
loop.call_soon_threadsafe(bridge.set_result, result_value)
126+
127+
result.add_done_callback(_relay)
128+
return bridge
129+
37130
return result
38131

39132

@@ -103,9 +196,23 @@ def remove_timeout(self, handle: asyncio.TimerHandle):
103196
handle.cancel()
104197

105198
def add_future(self, future: Awaitable, callback: Callable[[asyncio.Future], Any]):
106-
fut = asyncio.ensure_future(future, loop=self._loop)
107-
fut.add_done_callback(lambda done: self._loop.call_soon(callback, done))
108-
return fut
199+
scheduled = _ensure_task(self._loop, future)
200+
if isinstance(scheduled, asyncio.Future):
201+
scheduled.add_done_callback(
202+
lambda done: self._loop.call_soon(callback, done)
203+
)
204+
return scheduled
205+
if hasattr(scheduled, "add_done_callback"):
206+
207+
def _relay(done):
208+
try:
209+
loop = self._loop
210+
loop.call_soon_threadsafe(callback, done)
211+
except RuntimeError:
212+
loop.call_soon(callback, done)
213+
214+
scheduled.add_done_callback(_relay)
215+
return scheduled
109216

110217
def create_task(self, coro: Awaitable):
111218
return self._loop.create_task(coro)
@@ -120,16 +227,17 @@ def run_sync(
120227
**kwargs,
121228
):
122229
result = func(*args, **kwargs)
123-
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
230+
scheduled = _ensure_task(self._loop, result)
231+
if asyncio.iscoroutine(scheduled) or isinstance(scheduled, asyncio.Future):
124232
if self._loop.is_running():
125233
raise RuntimeError("Cannot run_sync on a running event loop")
126234
policy = asyncio.get_event_loop_policy()
127235
try:
128236
policy.set_event_loop(self._loop)
129-
return self._loop.run_until_complete(result)
237+
return self._loop.run_until_complete(scheduled)
130238
finally:
131239
policy.set_event_loop(None)
132-
return result
240+
return scheduled
133241

134242
def start(self):
135243
policy = asyncio.get_event_loop_policy()

0 commit comments

Comments
 (0)