from __future__ import annotations import asyncio import concurrent.futures import contextvars import inspect import sys import types from collections.abc import Awaitable, Coroutine, Generator from typing import TypeVar, cast T = TypeVar("T") AnyFuture = asyncio.Future | concurrent.futures.Future CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11) EAGER_NOT_SUPPORTED = sys.version_info < (3, 12) def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop: # Tries to call Future.get_loop() if it's available. # Otherwise fallbacks to using the old '_loop' property. try: get_loop = fut.get_loop except AttributeError: pass else: return get_loop() return fut._loop def _convert_future_exc(exc: BaseException) -> BaseException: exc_class = type(exc) if exc_class is concurrent.futures.CancelledError: return asyncio.CancelledError(*exc.args) elif exc_class is concurrent.futures.TimeoutError: return asyncio.TimeoutError(*exc.args) elif exc_class is concurrent.futures.InvalidStateError: return asyncio.InvalidStateError(*exc.args) else: return exc def _set_concurrent_future_state( concurrent: concurrent.futures.Future, source: AnyFuture, ) -> None: """Copy state from a future to a concurrent.futures.Future.""" assert source.done() if source.cancelled(): concurrent.cancel() if not concurrent.set_running_or_notify_cancel(): return exception = source.exception() if exception is not None: concurrent.set_exception(_convert_future_exc(exception)) else: result = source.result() concurrent.set_result(result) def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None: """Internal helper to copy state from another Future. The other Future may be a concurrent.futures.Future. """ if dest.done(): return assert source.done() if dest.cancelled(): return if source.cancelled(): dest.cancel() else: exception = source.exception() if exception is not None: dest.set_exception(_convert_future_exc(exception)) else: result = source.result() dest.set_result(result) def _chain_future(source: AnyFuture, destination: AnyFuture) -> None: """Chain two futures so that when one completes, so does the other. The result (or exception) of source will be copied to destination. If destination is cancelled, source gets cancelled too. Compatible with both asyncio.Future and concurrent.futures.Future. """ if not asyncio.isfuture(source) and not isinstance( source, concurrent.futures.Future ): raise TypeError("A future is required for source argument") if not asyncio.isfuture(destination) and not isinstance( destination, concurrent.futures.Future ): raise TypeError("A future is required for destination argument") source_loop = _get_loop(source) if asyncio.isfuture(source) else None dest_loop = _get_loop(destination) if asyncio.isfuture(destination) else None def _set_state(future: AnyFuture, other: AnyFuture) -> None: if asyncio.isfuture(future): _copy_future_state(other, future) else: _set_concurrent_future_state(future, other) def _call_check_cancel(destination: AnyFuture) -> None: if destination.cancelled(): if source_loop is None or source_loop is dest_loop: source.cancel() else: source_loop.call_soon_threadsafe(source.cancel) def _call_set_state(source: AnyFuture) -> None: if destination.cancelled() and dest_loop is not None and dest_loop.is_closed(): return if dest_loop is None or dest_loop is source_loop: _set_state(destination, source) else: if dest_loop.is_closed(): return dest_loop.call_soon_threadsafe(_set_state, destination, source) destination.add_done_callback(_call_check_cancel) source.add_done_callback(_call_set_state) def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture: # adapted from asyncio.run_coroutine_threadsafe try: _chain_future(source, destination) return destination except (SystemExit, KeyboardInterrupt): raise except BaseException as exc: if isinstance(destination, concurrent.futures.Future): if destination.set_running_or_notify_cancel(): destination.set_exception(exc) else: destination.set_exception(exc) raise def _ensure_future( coro_or_future: Coroutine[None, None, T] | Awaitable[T], *, loop: asyncio.AbstractEventLoop, name: str | None = None, context: contextvars.Context | None = None, lazy: bool = True, ) -> asyncio.Task[T]: called_wrap_awaitable = False if not asyncio.iscoroutine(coro_or_future): if inspect.isawaitable(coro_or_future): coro_or_future = cast( Coroutine[None, None, T], _wrap_awaitable(coro_or_future) ) called_wrap_awaitable = True else: raise TypeError( "An asyncio.Future, a coroutine or an awaitable is required." f" Got {type(coro_or_future).__name__} instead." ) try: if CONTEXT_NOT_SUPPORTED: return loop.create_task(coro_or_future, name=name) elif EAGER_NOT_SUPPORTED or lazy: return loop.create_task(coro_or_future, name=name, context=context) else: return asyncio.eager_task_factory( loop, coro_or_future, name=name, context=context ) except RuntimeError: if not called_wrap_awaitable: coro_or_future.close() raise @types.coroutine def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]: """Helper for asyncio.ensure_future(). Wraps awaitable (an object with __await__) into a coroutine that will later be wrapped in a Task by ensure_future(). """ return (yield from awaitable.__await__()) def run_coroutine_threadsafe( coro: Coroutine[None, None, T], loop: asyncio.AbstractEventLoop, *, lazy: bool, name: str | None = None, context: contextvars.Context | None = None, ) -> asyncio.Future[T]: """Submit a coroutine object to a given event loop. Return an asyncio.Future to access the result. """ if asyncio._get_running_loop() is loop: return _ensure_future(coro, loop=loop, name=name, context=context, lazy=lazy) else: future: asyncio.Future[T] = asyncio.Future(loop=loop) def callback() -> None: try: chain_future( _ensure_future(coro, loop=loop, name=name, context=context), future, ) except (SystemExit, KeyboardInterrupt): raise except BaseException as exc: future.set_exception(exc) raise loop.call_soon_threadsafe(callback, context=context) return future