221 lines
7.1 KiB
Python
221 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import contextvars
|
|
import inspect
|
|
import sys
|
|
import types
|
|
from collections.abc import Awaitable, Coroutine, Generator
|
|
from typing import TypeVar, cast
|
|
|
|
T = TypeVar("T")
|
|
AnyFuture = asyncio.Future | concurrent.futures.Future
|
|
|
|
CONTEXT_NOT_SUPPORTED = sys.version_info < (3, 11)
|
|
EAGER_NOT_SUPPORTED = sys.version_info < (3, 12)
|
|
|
|
|
|
def _get_loop(fut: asyncio.Future) -> asyncio.AbstractEventLoop:
|
|
# Tries to call Future.get_loop() if it's available.
|
|
# Otherwise fallbacks to using the old '_loop' property.
|
|
try:
|
|
get_loop = fut.get_loop
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
return get_loop()
|
|
return fut._loop
|
|
|
|
|
|
def _convert_future_exc(exc: BaseException) -> BaseException:
|
|
exc_class = type(exc)
|
|
if exc_class is concurrent.futures.CancelledError:
|
|
return asyncio.CancelledError(*exc.args)
|
|
elif exc_class is concurrent.futures.TimeoutError:
|
|
return asyncio.TimeoutError(*exc.args)
|
|
elif exc_class is concurrent.futures.InvalidStateError:
|
|
return asyncio.InvalidStateError(*exc.args)
|
|
else:
|
|
return exc
|
|
|
|
|
|
def _set_concurrent_future_state(
|
|
concurrent: concurrent.futures.Future,
|
|
source: AnyFuture,
|
|
) -> None:
|
|
"""Copy state from a future to a concurrent.futures.Future."""
|
|
assert source.done()
|
|
if source.cancelled():
|
|
concurrent.cancel()
|
|
if not concurrent.set_running_or_notify_cancel():
|
|
return
|
|
exception = source.exception()
|
|
if exception is not None:
|
|
concurrent.set_exception(_convert_future_exc(exception))
|
|
else:
|
|
result = source.result()
|
|
concurrent.set_result(result)
|
|
|
|
|
|
def _copy_future_state(source: AnyFuture, dest: asyncio.Future) -> None:
|
|
"""Internal helper to copy state from another Future.
|
|
|
|
The other Future may be a concurrent.futures.Future.
|
|
"""
|
|
if dest.done():
|
|
return
|
|
assert source.done()
|
|
if dest.cancelled():
|
|
return
|
|
if source.cancelled():
|
|
dest.cancel()
|
|
else:
|
|
exception = source.exception()
|
|
if exception is not None:
|
|
dest.set_exception(_convert_future_exc(exception))
|
|
else:
|
|
result = source.result()
|
|
dest.set_result(result)
|
|
|
|
|
|
def _chain_future(source: AnyFuture, destination: AnyFuture) -> None:
|
|
"""Chain two futures so that when one completes, so does the other.
|
|
|
|
The result (or exception) of source will be copied to destination.
|
|
If destination is cancelled, source gets cancelled too.
|
|
Compatible with both asyncio.Future and concurrent.futures.Future.
|
|
"""
|
|
if not asyncio.isfuture(source) and not isinstance(
|
|
source, concurrent.futures.Future
|
|
):
|
|
raise TypeError("A future is required for source argument")
|
|
if not asyncio.isfuture(destination) and not isinstance(
|
|
destination, concurrent.futures.Future
|
|
):
|
|
raise TypeError("A future is required for destination argument")
|
|
source_loop = _get_loop(source) if asyncio.isfuture(source) else None
|
|
dest_loop = _get_loop(destination) if asyncio.isfuture(destination) else None
|
|
|
|
def _set_state(future: AnyFuture, other: AnyFuture) -> None:
|
|
if asyncio.isfuture(future):
|
|
_copy_future_state(other, future)
|
|
else:
|
|
_set_concurrent_future_state(future, other)
|
|
|
|
def _call_check_cancel(destination: AnyFuture) -> None:
|
|
if destination.cancelled():
|
|
if source_loop is None or source_loop is dest_loop:
|
|
source.cancel()
|
|
else:
|
|
source_loop.call_soon_threadsafe(source.cancel)
|
|
|
|
def _call_set_state(source: AnyFuture) -> None:
|
|
if destination.cancelled() and dest_loop is not None and dest_loop.is_closed():
|
|
return
|
|
if dest_loop is None or dest_loop is source_loop:
|
|
_set_state(destination, source)
|
|
else:
|
|
if dest_loop.is_closed():
|
|
return
|
|
dest_loop.call_soon_threadsafe(_set_state, destination, source)
|
|
|
|
destination.add_done_callback(_call_check_cancel)
|
|
source.add_done_callback(_call_set_state)
|
|
|
|
|
|
def chain_future(source: AnyFuture, destination: AnyFuture) -> AnyFuture:
|
|
# adapted from asyncio.run_coroutine_threadsafe
|
|
try:
|
|
_chain_future(source, destination)
|
|
return destination
|
|
except (SystemExit, KeyboardInterrupt):
|
|
raise
|
|
except BaseException as exc:
|
|
if isinstance(destination, concurrent.futures.Future):
|
|
if destination.set_running_or_notify_cancel():
|
|
destination.set_exception(exc)
|
|
else:
|
|
destination.set_exception(exc)
|
|
raise
|
|
|
|
|
|
def _ensure_future(
|
|
coro_or_future: Coroutine[None, None, T] | Awaitable[T],
|
|
*,
|
|
loop: asyncio.AbstractEventLoop,
|
|
name: str | None = None,
|
|
context: contextvars.Context | None = None,
|
|
lazy: bool = True,
|
|
) -> asyncio.Task[T]:
|
|
called_wrap_awaitable = False
|
|
if not asyncio.iscoroutine(coro_or_future):
|
|
if inspect.isawaitable(coro_or_future):
|
|
coro_or_future = cast(
|
|
Coroutine[None, None, T], _wrap_awaitable(coro_or_future)
|
|
)
|
|
called_wrap_awaitable = True
|
|
else:
|
|
raise TypeError(
|
|
"An asyncio.Future, a coroutine or an awaitable is required."
|
|
f" Got {type(coro_or_future).__name__} instead."
|
|
)
|
|
|
|
try:
|
|
if CONTEXT_NOT_SUPPORTED:
|
|
return loop.create_task(coro_or_future, name=name)
|
|
elif EAGER_NOT_SUPPORTED or lazy:
|
|
return loop.create_task(coro_or_future, name=name, context=context)
|
|
else:
|
|
return asyncio.eager_task_factory(
|
|
loop, coro_or_future, name=name, context=context
|
|
)
|
|
except RuntimeError:
|
|
if not called_wrap_awaitable:
|
|
coro_or_future.close()
|
|
raise
|
|
|
|
|
|
@types.coroutine
|
|
def _wrap_awaitable(awaitable: Awaitable[T]) -> Generator[None, None, T]:
|
|
"""Helper for asyncio.ensure_future().
|
|
|
|
Wraps awaitable (an object with __await__) into a coroutine
|
|
that will later be wrapped in a Task by ensure_future().
|
|
"""
|
|
return (yield from awaitable.__await__())
|
|
|
|
|
|
def run_coroutine_threadsafe(
|
|
coro: Coroutine[None, None, T],
|
|
loop: asyncio.AbstractEventLoop,
|
|
*,
|
|
lazy: bool,
|
|
name: str | None = None,
|
|
context: contextvars.Context | None = None,
|
|
) -> asyncio.Future[T]:
|
|
"""Submit a coroutine object to a given event loop.
|
|
|
|
Return an asyncio.Future to access the result.
|
|
"""
|
|
|
|
if asyncio._get_running_loop() is loop:
|
|
return _ensure_future(coro, loop=loop, name=name, context=context, lazy=lazy)
|
|
else:
|
|
future: asyncio.Future[T] = asyncio.Future(loop=loop)
|
|
|
|
def callback() -> None:
|
|
try:
|
|
chain_future(
|
|
_ensure_future(coro, loop=loop, name=name, context=context),
|
|
future,
|
|
)
|
|
except (SystemExit, KeyboardInterrupt):
|
|
raise
|
|
except BaseException as exc:
|
|
future.set_exception(exc)
|
|
raise
|
|
|
|
loop.call_soon_threadsafe(callback, context=context)
|
|
return future
|