770 lines
27 KiB
Python
770 lines
27 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import inspect
|
|
import threading
|
|
import time
|
|
import weakref
|
|
from collections.abc import (
|
|
AsyncIterator,
|
|
Awaitable,
|
|
Callable,
|
|
Iterable,
|
|
Iterator,
|
|
Sequence,
|
|
)
|
|
from functools import partial
|
|
from typing import (
|
|
Any,
|
|
Generic,
|
|
TypeVar,
|
|
cast,
|
|
)
|
|
|
|
from langchain_core.callbacks import Callbacks
|
|
|
|
from langgraph._internal._constants import (
|
|
CONF,
|
|
CONFIG_KEY_CALL,
|
|
CONFIG_KEY_SCRATCHPAD,
|
|
ERROR,
|
|
INTERRUPT,
|
|
NO_WRITES,
|
|
RESUME,
|
|
RETURN,
|
|
)
|
|
from langgraph._internal._future import chain_future, run_coroutine_threadsafe
|
|
from langgraph._internal._scratchpad import PregelScratchpad
|
|
from langgraph._internal._typing import MISSING
|
|
from langgraph.constants import TAG_HIDDEN
|
|
from langgraph.errors import GraphBubbleUp, GraphInterrupt
|
|
from langgraph.pregel._algo import Call
|
|
from langgraph.pregel._executor import Submit
|
|
from langgraph.pregel._retry import arun_with_retry, run_with_retry
|
|
from langgraph.types import (
|
|
CachePolicy,
|
|
PregelExecutableTask,
|
|
RetryPolicy,
|
|
)
|
|
|
|
F = TypeVar("F", concurrent.futures.Future, asyncio.Future)
|
|
E = TypeVar("E", threading.Event, asyncio.Event)
|
|
|
|
# List of filenames to exclude from exception traceback
|
|
# Note: Frames will be removed if they are the last frame in traceback, recursively
|
|
EXCLUDED_FRAME_FNAMES = (
|
|
"langgraph/pregel/retry.py",
|
|
"langgraph/pregel/runner.py",
|
|
"langgraph/pregel/executor.py",
|
|
"langgraph/utils/runnable.py",
|
|
"langchain_core/runnables/config.py",
|
|
"concurrent/futures/thread.py",
|
|
"concurrent/futures/_base.py",
|
|
)
|
|
|
|
SKIP_RERAISE_SET: weakref.WeakSet[concurrent.futures.Future | asyncio.Future] = (
|
|
weakref.WeakSet()
|
|
)
|
|
|
|
|
|
class FuturesDict(Generic[F, E], dict[F, PregelExecutableTask | None]):
|
|
event: E
|
|
callback: weakref.ref[Callable[[PregelExecutableTask, BaseException | None], None]]
|
|
counter: int
|
|
done: set[F]
|
|
lock: threading.Lock
|
|
|
|
def __init__(
|
|
self,
|
|
event: E,
|
|
callback: weakref.ref[
|
|
Callable[[PregelExecutableTask, BaseException | None], None]
|
|
],
|
|
future_type: type[F],
|
|
# used for generic typing, newer py supports FutureDict[...](...)
|
|
) -> None:
|
|
super().__init__()
|
|
self.lock = threading.Lock()
|
|
self.event = event
|
|
self.callback = callback
|
|
self.counter = 0
|
|
self.done: set[F] = set()
|
|
|
|
def __setitem__(
|
|
self,
|
|
key: F,
|
|
value: PregelExecutableTask | None,
|
|
) -> None:
|
|
super().__setitem__(key, value) # type: ignore[index]
|
|
if value is not None:
|
|
with self.lock:
|
|
self.event.clear()
|
|
self.counter += 1
|
|
key.add_done_callback(partial(self.on_done, value))
|
|
|
|
def on_done(
|
|
self,
|
|
task: PregelExecutableTask,
|
|
fut: F,
|
|
) -> None:
|
|
try:
|
|
if cb := self.callback():
|
|
cb(task, _exception(fut))
|
|
finally:
|
|
with self.lock:
|
|
self.done.add(fut)
|
|
self.counter -= 1
|
|
if self.counter == 0 or _should_stop_others(self.done):
|
|
self.event.set()
|
|
|
|
|
|
class PregelRunner:
|
|
"""Responsible for executing a set of Pregel tasks concurrently, committing
|
|
their writes, yielding control to caller when there is output to emit, and
|
|
interrupting other tasks if appropriate."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
submit: weakref.ref[Submit],
|
|
put_writes: weakref.ref[Callable[[str, Sequence[tuple[str, Any]]], None]],
|
|
use_astream: bool = False,
|
|
node_finished: Callable[[str], None] | None = None,
|
|
) -> None:
|
|
self.submit = submit
|
|
self.put_writes = put_writes
|
|
self.use_astream = use_astream
|
|
self.node_finished = node_finished
|
|
|
|
def tick(
|
|
self,
|
|
tasks: Iterable[PregelExecutableTask],
|
|
*,
|
|
reraise: bool = True,
|
|
timeout: float | None = None,
|
|
retry_policy: Sequence[RetryPolicy] | None = None,
|
|
get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None,
|
|
schedule_task: Callable[
|
|
[PregelExecutableTask, int, Call | None],
|
|
PregelExecutableTask | None,
|
|
],
|
|
) -> Iterator[None]:
|
|
tasks = tuple(tasks)
|
|
futures = FuturesDict(
|
|
callback=weakref.WeakMethod(self.commit),
|
|
event=threading.Event(),
|
|
future_type=concurrent.futures.Future,
|
|
)
|
|
# give control back to the caller
|
|
yield
|
|
# fast path if single task with no timeout and no waiter
|
|
if len(tasks) == 0:
|
|
return
|
|
elif len(tasks) == 1 and timeout is None and get_waiter is None:
|
|
t = tasks[0]
|
|
try:
|
|
run_with_retry(
|
|
t,
|
|
retry_policy,
|
|
configurable={
|
|
CONFIG_KEY_CALL: partial(
|
|
_call,
|
|
weakref.ref(t),
|
|
retry_policy=retry_policy,
|
|
futures=weakref.ref(futures),
|
|
schedule_task=schedule_task,
|
|
submit=self.submit,
|
|
),
|
|
},
|
|
)
|
|
self.commit(t, None)
|
|
except Exception as exc:
|
|
self.commit(t, exc)
|
|
if reraise and futures:
|
|
# will be re-raised after futures are done
|
|
fut: concurrent.futures.Future = concurrent.futures.Future()
|
|
fut.set_exception(exc)
|
|
futures.done.add(fut)
|
|
elif reraise:
|
|
if tb := exc.__traceback__:
|
|
while tb.tb_next is not None and any(
|
|
tb.tb_frame.f_code.co_filename.endswith(name)
|
|
for name in EXCLUDED_FRAME_FNAMES
|
|
):
|
|
tb = tb.tb_next
|
|
exc.__traceback__ = tb
|
|
raise
|
|
if not futures: # maybe `t` scheduled another task
|
|
return
|
|
else:
|
|
tasks = () # don't reschedule this task
|
|
# add waiter task if requested
|
|
if get_waiter is not None:
|
|
futures[get_waiter()] = None
|
|
# schedule tasks
|
|
for t in tasks:
|
|
fut = self.submit()( # type: ignore[misc]
|
|
run_with_retry,
|
|
t,
|
|
retry_policy,
|
|
configurable={
|
|
CONFIG_KEY_CALL: partial(
|
|
_call,
|
|
weakref.ref(t),
|
|
retry_policy=retry_policy,
|
|
futures=weakref.ref(futures),
|
|
schedule_task=schedule_task,
|
|
submit=self.submit,
|
|
),
|
|
},
|
|
__reraise_on_exit__=reraise,
|
|
)
|
|
futures[fut] = t
|
|
# execute tasks, and wait for one to fail or all to finish.
|
|
# each task is independent from all other concurrent tasks
|
|
# yield updates/debug output as each task finishes
|
|
end_time = timeout + time.monotonic() if timeout else None
|
|
while len(futures) > (1 if get_waiter is not None else 0):
|
|
done, inflight = concurrent.futures.wait(
|
|
futures,
|
|
return_when=concurrent.futures.FIRST_COMPLETED,
|
|
timeout=(max(0, end_time - time.monotonic()) if end_time else None),
|
|
)
|
|
if not done:
|
|
break # timed out
|
|
for fut in done:
|
|
task = futures.pop(fut)
|
|
if task is None:
|
|
# waiter task finished, schedule another
|
|
if inflight and get_waiter is not None:
|
|
futures[get_waiter()] = None
|
|
else:
|
|
# remove references to loop vars
|
|
del fut, task
|
|
# maybe stop other tasks
|
|
if _should_stop_others(done):
|
|
break
|
|
# give control back to the caller
|
|
yield
|
|
# wait for done callbacks
|
|
futures.event.wait(
|
|
timeout=(max(0, end_time - time.monotonic()) if end_time else None)
|
|
)
|
|
# give control back to the caller
|
|
yield
|
|
# panic on failure or timeout
|
|
try:
|
|
_panic_or_proceed(
|
|
futures.done.union(f for f, t in futures.items() if t is not None),
|
|
panic=reraise,
|
|
)
|
|
except Exception as exc:
|
|
if tb := exc.__traceback__:
|
|
while tb.tb_next is not None and any(
|
|
tb.tb_frame.f_code.co_filename.endswith(name)
|
|
for name in EXCLUDED_FRAME_FNAMES
|
|
):
|
|
tb = tb.tb_next
|
|
exc.__traceback__ = tb
|
|
raise
|
|
|
|
async def atick(
|
|
self,
|
|
tasks: Iterable[PregelExecutableTask],
|
|
*,
|
|
reraise: bool = True,
|
|
timeout: float | None = None,
|
|
retry_policy: Sequence[RetryPolicy] | None = None,
|
|
get_waiter: Callable[[], asyncio.Future[None]] | None = None,
|
|
schedule_task: Callable[
|
|
[PregelExecutableTask, int, Call | None],
|
|
Awaitable[PregelExecutableTask | None],
|
|
],
|
|
) -> AsyncIterator[None]:
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
except RuntimeError:
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
tasks = tuple(tasks)
|
|
futures = FuturesDict(
|
|
callback=weakref.WeakMethod(self.commit),
|
|
event=asyncio.Event(),
|
|
future_type=asyncio.Future,
|
|
)
|
|
# give control back to the caller
|
|
yield
|
|
# fast path if single task with no waiter and no timeout
|
|
if len(tasks) == 0:
|
|
return
|
|
elif len(tasks) == 1 and get_waiter is None and timeout is None:
|
|
t = tasks[0]
|
|
try:
|
|
await arun_with_retry(
|
|
t,
|
|
retry_policy,
|
|
stream=self.use_astream,
|
|
configurable={
|
|
CONFIG_KEY_CALL: partial(
|
|
_acall,
|
|
weakref.ref(t),
|
|
stream=self.use_astream,
|
|
retry_policy=retry_policy,
|
|
futures=weakref.ref(futures),
|
|
schedule_task=schedule_task,
|
|
submit=self.submit,
|
|
loop=loop,
|
|
),
|
|
},
|
|
)
|
|
self.commit(t, None)
|
|
except Exception as exc:
|
|
self.commit(t, exc)
|
|
if reraise and futures:
|
|
# will be re-raised after futures are done
|
|
fut: asyncio.Future = loop.create_future()
|
|
fut.set_exception(exc)
|
|
futures.done.add(fut)
|
|
elif reraise:
|
|
if tb := exc.__traceback__:
|
|
while tb.tb_next is not None and any(
|
|
tb.tb_frame.f_code.co_filename.endswith(name)
|
|
for name in EXCLUDED_FRAME_FNAMES
|
|
):
|
|
tb = tb.tb_next
|
|
exc.__traceback__ = tb
|
|
raise
|
|
if not futures: # maybe `t` scheduled another task
|
|
return
|
|
else:
|
|
tasks = () # don't reschedule this task
|
|
# add waiter task if requested
|
|
if get_waiter is not None:
|
|
futures[get_waiter()] = None
|
|
# schedule tasks
|
|
for t in tasks:
|
|
fut = cast(
|
|
asyncio.Future,
|
|
self.submit()( # type: ignore[misc]
|
|
arun_with_retry,
|
|
t,
|
|
retry_policy,
|
|
stream=self.use_astream,
|
|
configurable={
|
|
CONFIG_KEY_CALL: partial(
|
|
_acall,
|
|
weakref.ref(t),
|
|
retry_policy=retry_policy,
|
|
stream=self.use_astream,
|
|
futures=weakref.ref(futures),
|
|
schedule_task=schedule_task,
|
|
submit=self.submit,
|
|
loop=loop,
|
|
),
|
|
},
|
|
__name__=t.name,
|
|
__cancel_on_exit__=True,
|
|
__reraise_on_exit__=reraise,
|
|
),
|
|
)
|
|
futures[fut] = t
|
|
# execute tasks, and wait for one to fail or all to finish.
|
|
# each task is independent from all other concurrent tasks
|
|
# yield updates/debug output as each task finishes
|
|
end_time = timeout + loop.time() if timeout else None
|
|
while len(futures) > (1 if get_waiter is not None else 0):
|
|
done, inflight = await asyncio.wait(
|
|
futures,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
timeout=(max(0, end_time - loop.time()) if end_time else None),
|
|
)
|
|
if not done:
|
|
break # timed out
|
|
for fut in done:
|
|
task = futures.pop(fut)
|
|
if task is None:
|
|
# waiter task finished, schedule another
|
|
if inflight and get_waiter is not None:
|
|
futures[get_waiter()] = None
|
|
else:
|
|
# remove references to loop vars
|
|
del fut, task
|
|
# maybe stop other tasks
|
|
if _should_stop_others(done):
|
|
break
|
|
# give control back to the caller
|
|
yield
|
|
# wait for done callbacks
|
|
await asyncio.wait_for(
|
|
futures.event.wait(),
|
|
timeout=(max(0, end_time - loop.time()) if end_time else None),
|
|
)
|
|
# give control back to the caller
|
|
yield
|
|
# cancel waiter task
|
|
for fut in futures:
|
|
fut.cancel()
|
|
# panic on failure or timeout
|
|
try:
|
|
_panic_or_proceed(
|
|
futures.done.union(f for f, t in futures.items() if t is not None),
|
|
timeout_exc_cls=asyncio.TimeoutError,
|
|
panic=reraise,
|
|
)
|
|
except Exception as exc:
|
|
if tb := exc.__traceback__:
|
|
while tb.tb_next is not None and any(
|
|
tb.tb_frame.f_code.co_filename.endswith(name)
|
|
for name in EXCLUDED_FRAME_FNAMES
|
|
):
|
|
tb = tb.tb_next
|
|
exc.__traceback__ = tb
|
|
raise
|
|
|
|
def commit(
|
|
self,
|
|
task: PregelExecutableTask,
|
|
exception: BaseException | None,
|
|
) -> None:
|
|
if isinstance(exception, asyncio.CancelledError):
|
|
# for cancelled tasks, also save error in task,
|
|
# so loop can finish super-step
|
|
task.writes.append((ERROR, exception))
|
|
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
|
elif exception:
|
|
if isinstance(exception, GraphInterrupt):
|
|
# save interrupt to checkpointer
|
|
if exception.args[0]:
|
|
writes = [(INTERRUPT, exception.args[0])]
|
|
if resumes := [w for w in task.writes if w[0] == RESUME]:
|
|
writes.extend(resumes)
|
|
self.put_writes()(task.id, writes) # type: ignore[misc]
|
|
elif isinstance(exception, GraphBubbleUp):
|
|
# exception will be raised in _panic_or_proceed
|
|
pass
|
|
else:
|
|
# save error to checkpointer
|
|
task.writes.append((ERROR, exception))
|
|
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
|
else:
|
|
if self.node_finished and (
|
|
task.config is None or TAG_HIDDEN not in task.config.get("tags", [])
|
|
):
|
|
self.node_finished(task.name)
|
|
if not task.writes:
|
|
# add no writes marker
|
|
task.writes.append((NO_WRITES, None))
|
|
# save task writes to checkpointer
|
|
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
|
|
|
|
|
def _should_stop_others(
|
|
done: set[F],
|
|
) -> bool:
|
|
"""Check if any task failed, if so, cancel all other tasks.
|
|
GraphInterrupts are not considered failures."""
|
|
for fut in done:
|
|
if fut.cancelled():
|
|
continue
|
|
elif exc := fut.exception():
|
|
if not isinstance(exc, GraphBubbleUp) and fut not in SKIP_RERAISE_SET:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _exception(
|
|
fut: concurrent.futures.Future[Any] | asyncio.Future[Any],
|
|
) -> BaseException | None:
|
|
"""Return the exception from a future, without raising CancelledError."""
|
|
if fut.cancelled():
|
|
if isinstance(fut, asyncio.Future):
|
|
return asyncio.CancelledError()
|
|
else:
|
|
return concurrent.futures.CancelledError()
|
|
else:
|
|
return fut.exception()
|
|
|
|
|
|
def _panic_or_proceed(
|
|
futs: set[concurrent.futures.Future] | set[asyncio.Future],
|
|
*,
|
|
timeout_exc_cls: type[Exception] = TimeoutError,
|
|
panic: bool = True,
|
|
) -> None:
|
|
"""Cancel remaining tasks if any failed, re-raise exception if panic is True."""
|
|
done: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
|
|
inflight: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
|
|
for fut in futs:
|
|
if fut.cancelled():
|
|
continue
|
|
elif fut.done():
|
|
done.add(fut)
|
|
else:
|
|
inflight.add(fut)
|
|
interrupts: list[GraphInterrupt] = []
|
|
while done:
|
|
# if any task failed
|
|
fut = done.pop()
|
|
if exc := _exception(fut):
|
|
# cancel all pending tasks
|
|
while inflight:
|
|
inflight.pop().cancel()
|
|
# raise the exception
|
|
if panic:
|
|
if isinstance(exc, GraphInterrupt):
|
|
# collect interrupts
|
|
interrupts.append(exc)
|
|
elif fut not in SKIP_RERAISE_SET:
|
|
raise exc
|
|
# raise combined interrupts
|
|
if interrupts:
|
|
raise GraphInterrupt(tuple(i for exc in interrupts for i in exc.args[0]))
|
|
if inflight:
|
|
# if we got here means we timed out
|
|
while inflight:
|
|
# cancel all pending tasks
|
|
inflight.pop().cancel()
|
|
# raise timeout error
|
|
raise timeout_exc_cls("Timed out")
|
|
|
|
|
|
def _call(
|
|
task: weakref.ref[PregelExecutableTask],
|
|
func: Callable[[Any], Awaitable[Any] | Any],
|
|
input: Any,
|
|
*,
|
|
retry_policy: Sequence[RetryPolicy] | None = None,
|
|
cache_policy: CachePolicy | None = None,
|
|
callbacks: Callbacks = None,
|
|
futures: weakref.ref[FuturesDict],
|
|
schedule_task: Callable[
|
|
[PregelExecutableTask, int, Call | None], PregelExecutableTask | None
|
|
],
|
|
submit: weakref.ref[Submit],
|
|
) -> concurrent.futures.Future[Any]:
|
|
if inspect.iscoroutinefunction(func):
|
|
raise RuntimeError("In an sync context async tasks cannot be called")
|
|
|
|
fut: concurrent.futures.Future | None = None
|
|
# schedule PUSH tasks, collect futures
|
|
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
|
|
# schedule the next task, if the callback returns one
|
|
if next_task := schedule_task(
|
|
task(), # type: ignore[arg-type]
|
|
scratchpad.call_counter(),
|
|
Call(
|
|
func,
|
|
input,
|
|
retry_policy=retry_policy,
|
|
cache_policy=cache_policy,
|
|
callbacks=callbacks,
|
|
),
|
|
):
|
|
if fut := next(
|
|
(
|
|
f
|
|
for f, t in futures().items() # type: ignore[union-attr]
|
|
if t is not None and t == next_task.id
|
|
),
|
|
None,
|
|
):
|
|
# if the parent task was retried,
|
|
# the next task might already be running
|
|
pass
|
|
elif next_task.writes:
|
|
# if it already ran, return the result
|
|
fut = concurrent.futures.Future()
|
|
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
|
|
if ret is not MISSING:
|
|
fut.set_result(ret)
|
|
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
|
|
fut.set_exception(
|
|
exc if isinstance(exc, BaseException) else Exception(exc)
|
|
)
|
|
else:
|
|
fut.set_result(None)
|
|
else:
|
|
# schedule the next task
|
|
fut = submit()( # type: ignore[misc]
|
|
run_with_retry,
|
|
next_task,
|
|
retry_policy,
|
|
configurable={
|
|
CONFIG_KEY_CALL: partial(
|
|
_call,
|
|
weakref.ref(next_task),
|
|
futures=futures,
|
|
retry_policy=retry_policy,
|
|
callbacks=callbacks,
|
|
schedule_task=schedule_task,
|
|
submit=submit,
|
|
),
|
|
},
|
|
__reraise_on_exit__=False,
|
|
# starting a new task in the next tick ensures
|
|
# updates from this tick are committed/streamed first
|
|
__next_tick__=True,
|
|
)
|
|
# exceptions for call() tasks are raised into the parent task
|
|
# so we should not re-raise at the end of the tick
|
|
SKIP_RERAISE_SET.add(fut)
|
|
futures()[fut] = next_task # type: ignore[index]
|
|
fut = cast(asyncio.Future | concurrent.futures.Future, fut)
|
|
# return a chained future to ensure commit() callback is called
|
|
# before the returned future is resolved, to ensure stream order etc
|
|
return chain_future(fut, concurrent.futures.Future())
|
|
|
|
|
|
def _acall(
|
|
task: weakref.ref[PregelExecutableTask],
|
|
func: Callable[[Any], Awaitable[Any] | Any],
|
|
input: Any,
|
|
*,
|
|
retry_policy: Sequence[RetryPolicy] | None = None,
|
|
cache_policy: CachePolicy | None = None,
|
|
callbacks: Callbacks = None,
|
|
# injected dependencies
|
|
futures: weakref.ref[FuturesDict],
|
|
schedule_task: Callable[
|
|
[PregelExecutableTask, int, Call | None],
|
|
Awaitable[PregelExecutableTask | None],
|
|
],
|
|
submit: weakref.ref[Submit],
|
|
loop: asyncio.AbstractEventLoop,
|
|
stream: bool = False,
|
|
) -> asyncio.Future[Any] | concurrent.futures.Future[Any]:
|
|
# return a chained future to ensure commit() callback is called
|
|
# before the returned future is resolved, to ensure stream order etc
|
|
try:
|
|
in_async = asyncio.current_task() is not None
|
|
except RuntimeError:
|
|
in_async = False
|
|
# if in async context return an async future, otherwise return a sync future
|
|
if in_async:
|
|
fut: asyncio.Future[Any] | concurrent.futures.Future[Any] = asyncio.Future(
|
|
loop=loop
|
|
)
|
|
else:
|
|
fut = concurrent.futures.Future()
|
|
# schedule the next task
|
|
run_coroutine_threadsafe(
|
|
_acall_impl(
|
|
fut,
|
|
task,
|
|
func,
|
|
input,
|
|
retry_policy=retry_policy,
|
|
cache_policy=cache_policy,
|
|
callbacks=callbacks,
|
|
futures=futures,
|
|
schedule_task=schedule_task,
|
|
submit=submit,
|
|
loop=loop,
|
|
stream=stream,
|
|
),
|
|
loop,
|
|
lazy=False,
|
|
)
|
|
return fut
|
|
|
|
|
|
async def _acall_impl(
|
|
destination: asyncio.Future[Any] | concurrent.futures.Future[Any],
|
|
task: weakref.ref[PregelExecutableTask],
|
|
func: Callable[[Any], Awaitable[Any] | Any],
|
|
input: Any,
|
|
*,
|
|
retry_policy: Sequence[RetryPolicy] | None = None,
|
|
cache_policy: CachePolicy | None = None,
|
|
callbacks: Callbacks = None,
|
|
# injected dependencies
|
|
futures: weakref.ref[FuturesDict[asyncio.Future, asyncio.Event]],
|
|
schedule_task: Callable[
|
|
[PregelExecutableTask, int, Call | None],
|
|
Awaitable[PregelExecutableTask | None],
|
|
],
|
|
submit: weakref.ref[Submit],
|
|
loop: asyncio.AbstractEventLoop,
|
|
stream: bool = False,
|
|
) -> None:
|
|
try:
|
|
fut: asyncio.Future | None = None
|
|
# schedule PUSH tasks, collect futures
|
|
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
|
|
# schedule the next task, if the callback returns one
|
|
if next_task := await schedule_task(
|
|
task(), # type: ignore[arg-type]
|
|
scratchpad.call_counter(),
|
|
Call(
|
|
func,
|
|
input,
|
|
retry_policy=retry_policy,
|
|
cache_policy=cache_policy,
|
|
callbacks=callbacks,
|
|
),
|
|
):
|
|
if fut := next(
|
|
(
|
|
f
|
|
for f, t in futures().items() # type: ignore[union-attr]
|
|
if t is not None and t == next_task.id
|
|
),
|
|
None,
|
|
):
|
|
# if the parent task was retried,
|
|
# the next task might already be running
|
|
pass
|
|
elif next_task.writes:
|
|
# if it already ran, return the result
|
|
fut = asyncio.Future(loop=loop)
|
|
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
|
|
if ret is not MISSING:
|
|
fut.set_result(ret)
|
|
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
|
|
fut.set_exception(
|
|
exc if isinstance(exc, BaseException) else Exception(exc)
|
|
)
|
|
else:
|
|
fut.set_result(None)
|
|
futures()[fut] = next_task # type: ignore[index]
|
|
else:
|
|
# schedule the next task
|
|
fut = cast(
|
|
asyncio.Future,
|
|
submit()( # type: ignore[misc]
|
|
arun_with_retry,
|
|
next_task,
|
|
retry_policy,
|
|
stream=stream,
|
|
configurable={
|
|
CONFIG_KEY_CALL: partial(
|
|
_acall,
|
|
weakref.ref(next_task),
|
|
stream=stream,
|
|
futures=futures,
|
|
schedule_task=schedule_task,
|
|
submit=submit,
|
|
loop=loop,
|
|
),
|
|
},
|
|
__name__=next_task.name,
|
|
__cancel_on_exit__=True,
|
|
__reraise_on_exit__=False,
|
|
# starting a new task in the next tick ensures
|
|
# updates from this tick are committed/streamed first
|
|
__next_tick__=True,
|
|
),
|
|
)
|
|
# exceptions for call() tasks are raised into the parent task
|
|
# so we should not re-raise at the end of the tick
|
|
SKIP_RERAISE_SET.add(fut)
|
|
futures()[fut] = next_task # type: ignore[index]
|
|
if fut is not None:
|
|
chain_future(fut, destination)
|
|
else:
|
|
destination.set_exception(RuntimeError("Task not scheduled"))
|
|
except Exception as exc:
|
|
destination.set_exception(exc)
|