564 lines
20 KiB
Python
564 lines
20 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import functools
|
||
|
|
import inspect
|
||
|
|
import warnings
|
||
|
|
from collections.abc import Awaitable, Callable, Sequence
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import (
|
||
|
|
Any,
|
||
|
|
Generic,
|
||
|
|
TypeVar,
|
||
|
|
cast,
|
||
|
|
get_args,
|
||
|
|
get_origin,
|
||
|
|
overload,
|
||
|
|
)
|
||
|
|
|
||
|
|
from langgraph.cache.base import BaseCache
|
||
|
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||
|
|
from langgraph.store.base import BaseStore
|
||
|
|
from typing_extensions import Unpack
|
||
|
|
|
||
|
|
from langgraph._internal._constants import CACHE_NS_WRITES, PREVIOUS
|
||
|
|
from langgraph._internal._typing import MISSING, DeprecatedKwargs
|
||
|
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
||
|
|
from langgraph.channels.last_value import LastValue
|
||
|
|
from langgraph.constants import END, START
|
||
|
|
from langgraph.pregel import Pregel
|
||
|
|
from langgraph.pregel._call import (
|
||
|
|
P,
|
||
|
|
SyncAsyncFuture,
|
||
|
|
T,
|
||
|
|
call,
|
||
|
|
get_runnable_for_entrypoint,
|
||
|
|
identifier,
|
||
|
|
)
|
||
|
|
from langgraph.pregel._read import PregelNode
|
||
|
|
from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
|
||
|
|
from langgraph.types import _DC_KWARGS, CachePolicy, RetryPolicy, StreamMode
|
||
|
|
from langgraph.typing import ContextT
|
||
|
|
from langgraph.warnings import LangGraphDeprecatedSinceV05, LangGraphDeprecatedSinceV10
|
||
|
|
|
||
|
|
__all__ = ("task", "entrypoint")
|
||
|
|
|
||
|
|
|
||
|
|
class _TaskFunction(Generic[P, T]):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
func: Callable[P, Awaitable[T]] | Callable[P, T],
|
||
|
|
*,
|
||
|
|
retry_policy: Sequence[RetryPolicy],
|
||
|
|
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
|
||
|
|
name: str | None = None,
|
||
|
|
) -> None:
|
||
|
|
if name is not None:
|
||
|
|
if hasattr(func, "__func__"):
|
||
|
|
# handle class methods
|
||
|
|
# NOTE: we're modifying the instance method to avoid modifying
|
||
|
|
# the original class method in case it's shared across multiple tasks
|
||
|
|
instance_method = functools.partial(func.__func__, func.__self__) # type: ignore [union-attr]
|
||
|
|
instance_method.__name__ = name # type: ignore [attr-defined]
|
||
|
|
func = instance_method
|
||
|
|
else:
|
||
|
|
# handle regular functions / partials / callable classes, etc.
|
||
|
|
func.__name__ = name
|
||
|
|
self.func = func
|
||
|
|
self.retry_policy = retry_policy
|
||
|
|
self.cache_policy = cache_policy
|
||
|
|
functools.update_wrapper(self, func)
|
||
|
|
|
||
|
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> SyncAsyncFuture[T]:
|
||
|
|
return call(
|
||
|
|
self.func,
|
||
|
|
retry_policy=self.retry_policy,
|
||
|
|
cache_policy=self.cache_policy,
|
||
|
|
*args,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def clear_cache(self, cache: BaseCache) -> None:
|
||
|
|
"""Clear the cache for this task."""
|
||
|
|
if self.cache_policy is not None:
|
||
|
|
cache.clear(((CACHE_NS_WRITES, identifier(self.func) or "__dynamic__"),))
|
||
|
|
|
||
|
|
async def aclear_cache(self, cache: BaseCache) -> None:
|
||
|
|
"""Clear the cache for this task."""
|
||
|
|
if self.cache_policy is not None:
|
||
|
|
await cache.aclear(
|
||
|
|
((CACHE_NS_WRITES, identifier(self.func) or "__dynamic__"),)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def task(
|
||
|
|
__func_or_none__: None = None,
|
||
|
|
*,
|
||
|
|
name: str | None = None,
|
||
|
|
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||
|
|
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
|
||
|
|
**kwargs: Unpack[DeprecatedKwargs],
|
||
|
|
) -> Callable[
|
||
|
|
[Callable[P, Awaitable[T]] | Callable[P, T]],
|
||
|
|
_TaskFunction[P, T],
|
||
|
|
]: ...
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def task(__func_or_none__: Callable[P, Awaitable[T]]) -> _TaskFunction[P, T]: ...
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def task(__func_or_none__: Callable[P, T]) -> _TaskFunction[P, T]: ...
|
||
|
|
|
||
|
|
|
||
|
|
def task(
|
||
|
|
__func_or_none__: Callable[P, Awaitable[T]] | Callable[P, T] | None = None,
|
||
|
|
*,
|
||
|
|
name: str | None = None,
|
||
|
|
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||
|
|
cache_policy: CachePolicy[Callable[P, str | bytes]] | None = None,
|
||
|
|
**kwargs: Unpack[DeprecatedKwargs],
|
||
|
|
) -> (
|
||
|
|
Callable[[Callable[P, Awaitable[T]] | Callable[P, T]], _TaskFunction[P, T]]
|
||
|
|
| _TaskFunction[P, T]
|
||
|
|
):
|
||
|
|
"""Define a LangGraph task using the `task` decorator.
|
||
|
|
|
||
|
|
!!! important "Requires python 3.11 or higher for async functions"
|
||
|
|
The `task` decorator supports both sync and async functions. To use async
|
||
|
|
functions, ensure that you are using Python 3.11 or higher.
|
||
|
|
|
||
|
|
Tasks can only be called from within an [`entrypoint`][langgraph.func.entrypoint] or
|
||
|
|
from within a `StateGraph`. A task can be called like a regular function with the
|
||
|
|
following differences:
|
||
|
|
|
||
|
|
- When a checkpointer is enabled, the function inputs and outputs must be serializable.
|
||
|
|
- The decorated function can only be called from within an entrypoint or `StateGraph`.
|
||
|
|
- Calling the function produces a future. This makes it easy to parallelize tasks.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: An optional name for the task. If not provided, the function name will be used.
|
||
|
|
retry_policy: An optional retry policy (or list of policies) to use for the task in case of a failure.
|
||
|
|
cache_policy: An optional cache policy to use for the task. This allows caching of the task results.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A callable function when used as a decorator.
|
||
|
|
|
||
|
|
Example: Sync Task
|
||
|
|
```python
|
||
|
|
from langgraph.func import entrypoint, task
|
||
|
|
|
||
|
|
|
||
|
|
@task
|
||
|
|
def add_one_task(a: int) -> int:
|
||
|
|
return a + 1
|
||
|
|
|
||
|
|
|
||
|
|
@entrypoint()
|
||
|
|
def add_one(numbers: list[int]) -> list[int]:
|
||
|
|
futures = [add_one_task(n) for n in numbers]
|
||
|
|
results = [f.result() for f in futures]
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
# Call the entrypoint
|
||
|
|
add_one.invoke([1, 2, 3]) # Returns [2, 3, 4]
|
||
|
|
```
|
||
|
|
|
||
|
|
Example: Async Task
|
||
|
|
```python
|
||
|
|
import asyncio
|
||
|
|
from langgraph.func import entrypoint, task
|
||
|
|
|
||
|
|
|
||
|
|
@task
|
||
|
|
async def add_one_task(a: int) -> int:
|
||
|
|
return a + 1
|
||
|
|
|
||
|
|
|
||
|
|
@entrypoint()
|
||
|
|
async def add_one(numbers: list[int]) -> list[int]:
|
||
|
|
futures = [add_one_task(n) for n in numbers]
|
||
|
|
return asyncio.gather(*futures)
|
||
|
|
|
||
|
|
|
||
|
|
# Call the entrypoint
|
||
|
|
await add_one.ainvoke([1, 2, 3]) # Returns [2, 3, 4]
|
||
|
|
```
|
||
|
|
"""
|
||
|
|
if (retry := kwargs.get("retry", MISSING)) is not MISSING:
|
||
|
|
warnings.warn(
|
||
|
|
"`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
|
||
|
|
category=LangGraphDeprecatedSinceV05,
|
||
|
|
stacklevel=2,
|
||
|
|
)
|
||
|
|
if retry_policy is None:
|
||
|
|
retry_policy = retry # type: ignore[assignment]
|
||
|
|
|
||
|
|
retry_policies: Sequence[RetryPolicy] = (
|
||
|
|
()
|
||
|
|
if retry_policy is None
|
||
|
|
else (retry_policy,)
|
||
|
|
if isinstance(retry_policy, RetryPolicy)
|
||
|
|
else retry_policy
|
||
|
|
)
|
||
|
|
|
||
|
|
def decorator(
|
||
|
|
func: Callable[P, Awaitable[T]] | Callable[P, T],
|
||
|
|
) -> Callable[P, SyncAsyncFuture[T]]:
|
||
|
|
return _TaskFunction(
|
||
|
|
func, retry_policy=retry_policies, cache_policy=cache_policy, name=name
|
||
|
|
)
|
||
|
|
|
||
|
|
if __func_or_none__ is not None:
|
||
|
|
return decorator(__func_or_none__)
|
||
|
|
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
|
||
|
|
R = TypeVar("R")
|
||
|
|
S = TypeVar("S")
|
||
|
|
|
||
|
|
|
||
|
|
# The decorator was wrapped in a class to support the `final` attribute.
|
||
|
|
# In this form, the `final` attribute should play nicely with IDE autocompletion,
|
||
|
|
# and type checking tools.
|
||
|
|
# In addition, we'll be able to surface this information in the API Reference.
|
||
|
|
class entrypoint(Generic[ContextT]):
|
||
|
|
"""Define a LangGraph workflow using the `entrypoint` decorator.
|
||
|
|
|
||
|
|
### Function signature
|
||
|
|
|
||
|
|
The decorated function must accept a **single parameter**, which serves as the input
|
||
|
|
to the function. This input parameter can be of any type. Use a dictionary
|
||
|
|
to pass **multiple parameters** to the function.
|
||
|
|
|
||
|
|
### Injectable parameters
|
||
|
|
|
||
|
|
The decorated function can request access to additional parameters
|
||
|
|
that will be injected automatically at run time. These parameters include:
|
||
|
|
|
||
|
|
| Parameter | Description |
|
||
|
|
|------------------|------------------------------------------------------------------------------------------------------|
|
||
|
|
| **`config`** | A configuration object (aka `RunnableConfig`) that holds run-time configuration values. |
|
||
|
|
| **`previous`** | The previous return value for the given thread (available only when a checkpointer is provided). |
|
||
|
|
| **`runtime`** | A `Runtime` object that contains information about the current run, including context, store, writer |
|
||
|
|
|
||
|
|
The entrypoint decorator can be applied to sync functions or async functions.
|
||
|
|
|
||
|
|
### State management
|
||
|
|
|
||
|
|
The **`previous`** parameter can be used to access the return value of the previous
|
||
|
|
invocation of the entrypoint on the same thread id. This value is only available
|
||
|
|
when a checkpointer is provided.
|
||
|
|
|
||
|
|
If you want **`previous`** to be different from the return value, you can use the
|
||
|
|
`entrypoint.final` object to return a value while saving a different value to the
|
||
|
|
checkpoint.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
checkpointer: Specify a checkpointer to create a workflow that can persist
|
||
|
|
its state across runs.
|
||
|
|
store: A generalized key-value store. Some implementations may support
|
||
|
|
semantic search capabilities through an optional `index` configuration.
|
||
|
|
cache: A cache to use for caching the results of the workflow.
|
||
|
|
context_schema: Specifies the schema for the context object that will be
|
||
|
|
passed to the workflow.
|
||
|
|
cache_policy: A cache policy to use for caching the results of the workflow.
|
||
|
|
retry_policy: A retry policy (or list of policies) to use for the workflow in case of a failure.
|
||
|
|
|
||
|
|
!!! warning "`config_schema` Deprecated"
|
||
|
|
The `config_schema` parameter is deprecated in v0.6.0 and support will be removed in v2.0.0.
|
||
|
|
Please use `context_schema` instead to specify the schema for run-scoped context.
|
||
|
|
|
||
|
|
|
||
|
|
Example: Using entrypoint and tasks
|
||
|
|
```python
|
||
|
|
import time
|
||
|
|
|
||
|
|
from langgraph.func import entrypoint, task
|
||
|
|
from langgraph.types import interrupt, Command
|
||
|
|
from langgraph.checkpoint.memory import InMemorySaver
|
||
|
|
|
||
|
|
@task
|
||
|
|
def compose_essay(topic: str) -> str:
|
||
|
|
time.sleep(1.0) # Simulate slow operation
|
||
|
|
return f"An essay about {topic}"
|
||
|
|
|
||
|
|
@entrypoint(checkpointer=InMemorySaver())
|
||
|
|
def review_workflow(topic: str) -> dict:
|
||
|
|
\"\"\"Manages the workflow for generating and reviewing an essay.
|
||
|
|
|
||
|
|
The workflow includes:
|
||
|
|
1. Generating an essay about the given topic.
|
||
|
|
2. Interrupting the workflow for human review of the generated essay.
|
||
|
|
|
||
|
|
Upon resuming the workflow, compose_essay task will not be re-executed
|
||
|
|
as its result is cached by the checkpointer.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
topic: The subject of the essay.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
dict: A dictionary containing the generated essay and the human review.
|
||
|
|
\"\"\"
|
||
|
|
essay_future = compose_essay(topic)
|
||
|
|
essay = essay_future.result()
|
||
|
|
human_review = interrupt({
|
||
|
|
\"question\": \"Please provide a review\",
|
||
|
|
\"essay\": essay
|
||
|
|
})
|
||
|
|
return {
|
||
|
|
\"essay\": essay,
|
||
|
|
\"review\": human_review,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Example configuration for the workflow
|
||
|
|
config = {
|
||
|
|
\"configurable\": {
|
||
|
|
\"thread_id\": \"some_thread\"
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Topic for the essay
|
||
|
|
topic = \"cats\"
|
||
|
|
|
||
|
|
# Stream the workflow to generate the essay and await human review
|
||
|
|
for result in review_workflow.stream(topic, config):
|
||
|
|
print(result)
|
||
|
|
|
||
|
|
# Example human review provided after the interrupt
|
||
|
|
human_review = \"This essay is great.\"
|
||
|
|
|
||
|
|
# Resume the workflow with the provided human review
|
||
|
|
for result in review_workflow.stream(Command(resume=human_review), config):
|
||
|
|
print(result)
|
||
|
|
```
|
||
|
|
|
||
|
|
Example: Accessing the previous return value
|
||
|
|
When a checkpointer is enabled the function can access the previous return value
|
||
|
|
of the previous invocation on the same thread id.
|
||
|
|
|
||
|
|
```python
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from langgraph.checkpoint.memory import MemorySaver
|
||
|
|
|
||
|
|
from langgraph.func import entrypoint
|
||
|
|
|
||
|
|
|
||
|
|
@entrypoint(checkpointer=InMemorySaver())
|
||
|
|
def my_workflow(input_data: str, previous: Optional[str] = None) -> str:
|
||
|
|
return "world"
|
||
|
|
|
||
|
|
|
||
|
|
config = {"configurable": {"thread_id": "some_thread"}}
|
||
|
|
my_workflow.invoke("hello", config)
|
||
|
|
```
|
||
|
|
|
||
|
|
Example: Using `entrypoint.final` to save a value
|
||
|
|
The `entrypoint.final` object allows you to return a value while saving
|
||
|
|
a different value to the checkpoint. This value will be accessible
|
||
|
|
in the next invocation of the entrypoint via the `previous` parameter, as
|
||
|
|
long as the same thread id is used.
|
||
|
|
|
||
|
|
```python
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from langgraph.checkpoint.memory import MemorySaver
|
||
|
|
|
||
|
|
from langgraph.func import entrypoint
|
||
|
|
|
||
|
|
|
||
|
|
@entrypoint(checkpointer=InMemorySaver())
|
||
|
|
def my_workflow(
|
||
|
|
number: int,
|
||
|
|
*,
|
||
|
|
previous: Any = None,
|
||
|
|
) -> entrypoint.final[int, int]:
|
||
|
|
previous = previous or 0
|
||
|
|
# This will return the previous value to the caller, saving
|
||
|
|
# 2 * number to the checkpoint, which will be used in the next invocation
|
||
|
|
# for the `previous` parameter.
|
||
|
|
return entrypoint.final(value=previous, save=2 * number)
|
||
|
|
|
||
|
|
|
||
|
|
config = {"configurable": {"thread_id": "some_thread"}}
|
||
|
|
|
||
|
|
my_workflow.invoke(3, config) # 0 (previous was None)
|
||
|
|
my_workflow.invoke(1, config) # 6 (previous was 3 * 2 from the previous invocation)
|
||
|
|
```
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
checkpointer: BaseCheckpointSaver | None = None,
|
||
|
|
store: BaseStore | None = None,
|
||
|
|
cache: BaseCache | None = None,
|
||
|
|
context_schema: type[ContextT] | None = None,
|
||
|
|
cache_policy: CachePolicy | None = None,
|
||
|
|
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
||
|
|
**kwargs: Unpack[DeprecatedKwargs],
|
||
|
|
) -> None:
|
||
|
|
"""Initialize the entrypoint decorator."""
|
||
|
|
if (config_schema := kwargs.get("config_schema", MISSING)) is not MISSING:
|
||
|
|
warnings.warn(
|
||
|
|
"`config_schema` is deprecated and will be removed. Please use `context_schema` instead.",
|
||
|
|
category=LangGraphDeprecatedSinceV10,
|
||
|
|
stacklevel=2,
|
||
|
|
)
|
||
|
|
if context_schema is None:
|
||
|
|
context_schema = cast(type[ContextT], config_schema)
|
||
|
|
|
||
|
|
if (retry := kwargs.get("retry", MISSING)) is not MISSING:
|
||
|
|
warnings.warn(
|
||
|
|
"`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
|
||
|
|
category=LangGraphDeprecatedSinceV05,
|
||
|
|
stacklevel=2,
|
||
|
|
)
|
||
|
|
if retry_policy is None:
|
||
|
|
retry_policy = cast("RetryPolicy | Sequence[RetryPolicy]", retry)
|
||
|
|
|
||
|
|
self.checkpointer = checkpointer
|
||
|
|
self.store = store
|
||
|
|
self.cache = cache
|
||
|
|
self.cache_policy = cache_policy
|
||
|
|
self.retry_policy = retry_policy
|
||
|
|
self.context_schema = context_schema
|
||
|
|
|
||
|
|
@dataclass(**_DC_KWARGS)
|
||
|
|
class final(Generic[R, S]):
|
||
|
|
"""A primitive that can be returned from an entrypoint.
|
||
|
|
|
||
|
|
This primitive allows to save a value to the checkpointer distinct from the
|
||
|
|
return value from the entrypoint.
|
||
|
|
|
||
|
|
Example: Decoupling the return value and the save value
|
||
|
|
```python
|
||
|
|
from langgraph.checkpoint.memory import InMemorySaver
|
||
|
|
from langgraph.func import entrypoint
|
||
|
|
|
||
|
|
|
||
|
|
@entrypoint(checkpointer=InMemorySaver())
|
||
|
|
def my_workflow(
|
||
|
|
number: int,
|
||
|
|
*,
|
||
|
|
previous: Any = None,
|
||
|
|
) -> entrypoint.final[int, int]:
|
||
|
|
previous = previous or 0
|
||
|
|
# This will return the previous value to the caller, saving
|
||
|
|
# 2 * number to the checkpoint, which will be used in the next invocation
|
||
|
|
# for the `previous` parameter.
|
||
|
|
return entrypoint.final(value=previous, save=2 * number)
|
||
|
|
|
||
|
|
|
||
|
|
config = {"configurable": {"thread_id": "1"}}
|
||
|
|
|
||
|
|
my_workflow.invoke(3, config) # 0 (previous was None)
|
||
|
|
my_workflow.invoke(1, config) # 6 (previous was 3 * 2 from the previous invocation)
|
||
|
|
```
|
||
|
|
"""
|
||
|
|
|
||
|
|
value: R
|
||
|
|
"""Value to return. A value will always be returned even if it is `None`."""
|
||
|
|
save: S
|
||
|
|
"""The value for the state for the next checkpoint.
|
||
|
|
|
||
|
|
A value will always be saved even if it is `None`.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __call__(self, func: Callable[..., Any]) -> Pregel:
|
||
|
|
"""Convert a function into a Pregel graph.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
func: The function to convert. Support both sync and async functions.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A Pregel graph.
|
||
|
|
"""
|
||
|
|
# wrap generators in a function that writes to StreamWriter
|
||
|
|
if inspect.isgeneratorfunction(func) or inspect.isasyncgenfunction(func):
|
||
|
|
raise NotImplementedError(
|
||
|
|
"Generators are not supported in the Functional API."
|
||
|
|
)
|
||
|
|
|
||
|
|
bound = get_runnable_for_entrypoint(func)
|
||
|
|
stream_mode: StreamMode = "updates"
|
||
|
|
|
||
|
|
# get input and output types
|
||
|
|
sig = inspect.signature(func)
|
||
|
|
first_parameter_name = next(iter(sig.parameters.keys()), None)
|
||
|
|
if not first_parameter_name:
|
||
|
|
raise ValueError("Entrypoint function must have at least one parameter")
|
||
|
|
input_type = (
|
||
|
|
sig.parameters[first_parameter_name].annotation
|
||
|
|
if sig.parameters[first_parameter_name].annotation
|
||
|
|
is not inspect.Signature.empty
|
||
|
|
else Any
|
||
|
|
)
|
||
|
|
|
||
|
|
def _pluck_return_value(value: Any) -> Any:
|
||
|
|
"""Extract the return_ value the entrypoint.final object or passthrough."""
|
||
|
|
return value.value if isinstance(value, entrypoint.final) else value
|
||
|
|
|
||
|
|
def _pluck_save_value(value: Any) -> Any:
|
||
|
|
"""Get save value from the entrypoint.final object or passthrough."""
|
||
|
|
return value.save if isinstance(value, entrypoint.final) else value
|
||
|
|
|
||
|
|
output_type, save_type = Any, Any
|
||
|
|
if sig.return_annotation is not inspect.Signature.empty:
|
||
|
|
# User does not parameterize entrypoint.final properly
|
||
|
|
if (
|
||
|
|
sig.return_annotation is entrypoint.final
|
||
|
|
): # Un-parameterized entrypoint.final
|
||
|
|
output_type = save_type = Any
|
||
|
|
else:
|
||
|
|
origin = get_origin(sig.return_annotation)
|
||
|
|
if origin is entrypoint.final:
|
||
|
|
type_annotations = get_args(sig.return_annotation)
|
||
|
|
if len(type_annotations) != 2:
|
||
|
|
raise TypeError(
|
||
|
|
"Please an annotation for both the return_ and "
|
||
|
|
"the save values."
|
||
|
|
"For example, `-> entrypoint.final[int, str]` would assign a "
|
||
|
|
"return_ a type of `int` and save the type `str`."
|
||
|
|
)
|
||
|
|
output_type, save_type = get_args(sig.return_annotation)
|
||
|
|
else:
|
||
|
|
output_type = save_type = sig.return_annotation
|
||
|
|
|
||
|
|
return Pregel(
|
||
|
|
nodes={
|
||
|
|
func.__name__: PregelNode(
|
||
|
|
bound=bound,
|
||
|
|
triggers=[START],
|
||
|
|
channels=START,
|
||
|
|
writers=[
|
||
|
|
ChannelWrite(
|
||
|
|
[
|
||
|
|
ChannelWriteEntry(END, mapper=_pluck_return_value),
|
||
|
|
ChannelWriteEntry(PREVIOUS, mapper=_pluck_save_value),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
},
|
||
|
|
channels={
|
||
|
|
START: EphemeralValue(input_type),
|
||
|
|
END: LastValue(output_type, END),
|
||
|
|
PREVIOUS: LastValue(save_type, PREVIOUS),
|
||
|
|
},
|
||
|
|
input_channels=START,
|
||
|
|
output_channels=END,
|
||
|
|
stream_channels=END,
|
||
|
|
stream_mode=stream_mode,
|
||
|
|
stream_eager=True,
|
||
|
|
checkpointer=self.checkpointer,
|
||
|
|
store=self.store,
|
||
|
|
cache=self.cache,
|
||
|
|
cache_policy=self.cache_policy,
|
||
|
|
retry_policy=self.retry_policy or (),
|
||
|
|
context_schema=self.context_schema, # type: ignore[arg-type]
|
||
|
|
)
|