915 lines
31 KiB
Python
915 lines
31 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import enum
|
||
|
|
import inspect
|
||
|
|
import sys
|
||
|
|
import warnings
|
||
|
|
from collections.abc import (
|
||
|
|
AsyncIterator,
|
||
|
|
Awaitable,
|
||
|
|
Callable,
|
||
|
|
Coroutine,
|
||
|
|
Generator,
|
||
|
|
Iterator,
|
||
|
|
Sequence,
|
||
|
|
)
|
||
|
|
from contextlib import AsyncExitStack, contextmanager
|
||
|
|
from contextvars import Context, Token, copy_context
|
||
|
|
from functools import partial, wraps
|
||
|
|
from typing import (
|
||
|
|
Any,
|
||
|
|
Optional,
|
||
|
|
Protocol,
|
||
|
|
TypeGuard,
|
||
|
|
cast,
|
||
|
|
)
|
||
|
|
|
||
|
|
from langchain_core.runnables.base import (
|
||
|
|
Runnable,
|
||
|
|
RunnableConfig,
|
||
|
|
RunnableLambda,
|
||
|
|
RunnableParallel,
|
||
|
|
RunnableSequence,
|
||
|
|
)
|
||
|
|
from langchain_core.runnables.base import (
|
||
|
|
RunnableLike as LCRunnableLike,
|
||
|
|
)
|
||
|
|
from langchain_core.runnables.config import (
|
||
|
|
run_in_executor,
|
||
|
|
var_child_runnable_config,
|
||
|
|
)
|
||
|
|
from langchain_core.runnables.utils import Input, Output
|
||
|
|
from langchain_core.tracers.langchain import LangChainTracer
|
||
|
|
from langgraph.store.base import BaseStore
|
||
|
|
|
||
|
|
from langgraph._internal._config import (
|
||
|
|
ensure_config,
|
||
|
|
get_async_callback_manager_for_config,
|
||
|
|
get_callback_manager_for_config,
|
||
|
|
patch_config,
|
||
|
|
)
|
||
|
|
from langgraph._internal._constants import (
|
||
|
|
CONF,
|
||
|
|
CONFIG_KEY_RUNTIME,
|
||
|
|
)
|
||
|
|
from langgraph._internal._typing import MISSING
|
||
|
|
from langgraph.types import StreamWriter
|
||
|
|
|
||
|
|
try:
|
||
|
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||
|
|
except ImportError:
|
||
|
|
_StreamingCallbackHandler = None # type: ignore
|
||
|
|
|
||
|
|
|
||
|
|
def _set_config_context(
|
||
|
|
config: RunnableConfig, run: Any = None
|
||
|
|
) -> Token[RunnableConfig | None]:
|
||
|
|
"""Set the child Runnable config + tracing context.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
config: The config to set.
|
||
|
|
"""
|
||
|
|
config_token = var_child_runnable_config.set(config)
|
||
|
|
if run is not None:
|
||
|
|
from langsmith.run_helpers import _set_tracing_context
|
||
|
|
|
||
|
|
_set_tracing_context({"parent": run})
|
||
|
|
return config_token
|
||
|
|
|
||
|
|
|
||
|
|
def _unset_config_context(token: Token[RunnableConfig | None], run: Any = None) -> None:
|
||
|
|
"""Set the child Runnable config + tracing context.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token: The config token to reset.
|
||
|
|
"""
|
||
|
|
var_child_runnable_config.reset(token)
|
||
|
|
if run is not None:
|
||
|
|
from langsmith.run_helpers import _set_tracing_context
|
||
|
|
|
||
|
|
_set_tracing_context(
|
||
|
|
{
|
||
|
|
"parent": None,
|
||
|
|
"project_name": None,
|
||
|
|
"tags": None,
|
||
|
|
"metadata": None,
|
||
|
|
"enabled": None,
|
||
|
|
"client": None,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@contextmanager
|
||
|
|
def set_config_context(
|
||
|
|
config: RunnableConfig, run: Any = None
|
||
|
|
) -> Generator[Context, None, None]:
|
||
|
|
"""Set the child Runnable config + tracing context.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
config: The config to set.
|
||
|
|
"""
|
||
|
|
ctx = copy_context()
|
||
|
|
config_token = ctx.run(_set_config_context, config, run)
|
||
|
|
try:
|
||
|
|
yield ctx
|
||
|
|
finally:
|
||
|
|
ctx.run(_unset_config_context, config_token, run)
|
||
|
|
|
||
|
|
|
||
|
|
# Before Python 3.11 native StrEnum is not available
|
||
|
|
class StrEnum(str, enum.Enum):
|
||
|
|
"""A string enum."""
|
||
|
|
|
||
|
|
|
||
|
|
# Special type to denote any type is accepted
|
||
|
|
ANY_TYPE = object()
|
||
|
|
|
||
|
|
ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)
|
||
|
|
|
||
|
|
# List of keyword arguments that can be injected into nodes / tasks / tools at runtime.
|
||
|
|
# A named argument may appear multiple times if it appears with distinct types.
|
||
|
|
KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
|
||
|
|
(
|
||
|
|
"config",
|
||
|
|
(
|
||
|
|
RunnableConfig,
|
||
|
|
"RunnableConfig",
|
||
|
|
Optional[RunnableConfig], # noqa: UP045
|
||
|
|
"Optional[RunnableConfig]",
|
||
|
|
inspect.Parameter.empty,
|
||
|
|
),
|
||
|
|
# for now, use config directly, eventually, will pop off of Runtime
|
||
|
|
"N/A",
|
||
|
|
inspect.Parameter.empty,
|
||
|
|
),
|
||
|
|
(
|
||
|
|
"writer",
|
||
|
|
(StreamWriter, "StreamWriter", inspect.Parameter.empty),
|
||
|
|
"stream_writer",
|
||
|
|
lambda _: None,
|
||
|
|
),
|
||
|
|
(
|
||
|
|
"store",
|
||
|
|
(
|
||
|
|
BaseStore,
|
||
|
|
"BaseStore",
|
||
|
|
inspect.Parameter.empty,
|
||
|
|
),
|
||
|
|
"store",
|
||
|
|
inspect.Parameter.empty,
|
||
|
|
),
|
||
|
|
(
|
||
|
|
"store",
|
||
|
|
(
|
||
|
|
Optional[BaseStore], # noqa: UP045
|
||
|
|
"Optional[BaseStore]",
|
||
|
|
),
|
||
|
|
"store",
|
||
|
|
None,
|
||
|
|
),
|
||
|
|
(
|
||
|
|
"previous",
|
||
|
|
(ANY_TYPE,),
|
||
|
|
"previous",
|
||
|
|
inspect.Parameter.empty,
|
||
|
|
),
|
||
|
|
(
|
||
|
|
"runtime",
|
||
|
|
(ANY_TYPE,),
|
||
|
|
# we never hit this block, we just inject runtime directly
|
||
|
|
"N/A",
|
||
|
|
inspect.Parameter.empty,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
"""List of kwargs that can be passed to functions, and their corresponding
|
||
|
|
config keys, default values and type annotations.
|
||
|
|
|
||
|
|
Used to configure keyword arguments that can be injected at runtime
|
||
|
|
from the `Runtime` object as kwargs to `invoke`, `ainvoke`, `stream` and `astream`.
|
||
|
|
|
||
|
|
For a keyword to be injected from the config object, the function signature
|
||
|
|
must contain a kwarg with the same name and a matching type annotation.
|
||
|
|
|
||
|
|
Each tuple contains:
|
||
|
|
- the name of the kwarg in the function signature
|
||
|
|
- the type annotation(s) for the kwarg
|
||
|
|
- the `Runtime` attribute for fetching the value (N/A if not applicable)
|
||
|
|
|
||
|
|
This is fully internal and should be further refactored to use `get_type_hints`
|
||
|
|
to resolve forward references and optional types formatted like BaseStore | None.
|
||
|
|
"""
|
||
|
|
|
||
|
|
VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
|
||
|
|
|
||
|
|
|
||
|
|
class _RunnableWithWriter(Protocol[Input, Output]):
|
||
|
|
def __call__(self, state: Input, *, writer: StreamWriter) -> Output: ...
|
||
|
|
|
||
|
|
|
||
|
|
class _RunnableWithStore(Protocol[Input, Output]):
|
||
|
|
def __call__(self, state: Input, *, store: BaseStore) -> Output: ...
|
||
|
|
|
||
|
|
|
||
|
|
class _RunnableWithWriterStore(Protocol[Input, Output]):
|
||
|
|
def __call__(
|
||
|
|
self, state: Input, *, writer: StreamWriter, store: BaseStore
|
||
|
|
) -> Output: ...
|
||
|
|
|
||
|
|
|
||
|
|
class _RunnableWithConfigWriter(Protocol[Input, Output]):
|
||
|
|
def __call__(
|
||
|
|
self, state: Input, *, config: RunnableConfig, writer: StreamWriter
|
||
|
|
) -> Output: ...
|
||
|
|
|
||
|
|
|
||
|
|
class _RunnableWithConfigStore(Protocol[Input, Output]):
|
||
|
|
def __call__(
|
||
|
|
self, state: Input, *, config: RunnableConfig, store: BaseStore
|
||
|
|
) -> Output: ...
|
||
|
|
|
||
|
|
|
||
|
|
class _RunnableWithConfigWriterStore(Protocol[Input, Output]):
|
||
|
|
def __call__(
|
||
|
|
self,
|
||
|
|
state: Input,
|
||
|
|
*,
|
||
|
|
config: RunnableConfig,
|
||
|
|
writer: StreamWriter,
|
||
|
|
store: BaseStore,
|
||
|
|
) -> Output: ...
|
||
|
|
|
||
|
|
|
||
|
|
RunnableLike = (
|
||
|
|
LCRunnableLike
|
||
|
|
| _RunnableWithWriter[Input, Output]
|
||
|
|
| _RunnableWithStore[Input, Output]
|
||
|
|
| _RunnableWithWriterStore[Input, Output]
|
||
|
|
| _RunnableWithConfigWriter[Input, Output]
|
||
|
|
| _RunnableWithConfigStore[Input, Output]
|
||
|
|
| _RunnableWithConfigWriterStore[Input, Output]
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class RunnableCallable(Runnable):
|
||
|
|
"""A much simpler version of RunnableLambda that requires sync and async functions."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
func: Callable[..., Any | Runnable] | None,
|
||
|
|
afunc: Callable[..., Awaitable[Any | Runnable]] | None = None,
|
||
|
|
*,
|
||
|
|
name: str | None = None,
|
||
|
|
tags: Sequence[str] | None = None,
|
||
|
|
trace: bool = True,
|
||
|
|
recurse: bool = True,
|
||
|
|
explode_args: bool = False,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> None:
|
||
|
|
self.name = name
|
||
|
|
if self.name is None:
|
||
|
|
if func:
|
||
|
|
try:
|
||
|
|
if func.__name__ != "<lambda>":
|
||
|
|
self.name = func.__name__
|
||
|
|
except AttributeError:
|
||
|
|
pass
|
||
|
|
elif afunc:
|
||
|
|
try:
|
||
|
|
self.name = afunc.__name__
|
||
|
|
except AttributeError:
|
||
|
|
pass
|
||
|
|
self.func = func
|
||
|
|
self.afunc = afunc
|
||
|
|
self.tags = tags
|
||
|
|
self.kwargs = kwargs
|
||
|
|
self.trace = trace
|
||
|
|
self.recurse = recurse
|
||
|
|
self.explode_args = explode_args
|
||
|
|
# check signature
|
||
|
|
if func is None and afunc is None:
|
||
|
|
raise ValueError("At least one of func or afunc must be provided.")
|
||
|
|
|
||
|
|
self.func_accepts: dict[str, tuple[str, Any]] = {}
|
||
|
|
params = inspect.signature(cast(Callable, func or afunc)).parameters
|
||
|
|
|
||
|
|
for kw, typ, runtime_key, default in KWARGS_CONFIG_KEYS:
|
||
|
|
p = params.get(kw)
|
||
|
|
|
||
|
|
if p is None or p.kind not in VALID_KINDS:
|
||
|
|
# If parameter is not found or is not a valid kind, skip
|
||
|
|
continue
|
||
|
|
|
||
|
|
if typ != (ANY_TYPE,) and p.annotation not in typ:
|
||
|
|
# A specific type is required, but the function annotation does
|
||
|
|
# not match the expected type.
|
||
|
|
|
||
|
|
# If this is a config parameter with incorrect typing, emit a warning
|
||
|
|
# because we used to support any type but are moving towards more correct typing
|
||
|
|
if kw == "config" and p.annotation != inspect.Parameter.empty:
|
||
|
|
warnings.warn(
|
||
|
|
f"The 'config' parameter should be typed as 'RunnableConfig' or "
|
||
|
|
f"'RunnableConfig | None', not '{p.annotation}'. ",
|
||
|
|
UserWarning,
|
||
|
|
stacklevel=4,
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
|
||
|
|
# If the kwarg is accepted by the function, store the key / runtime attribute to inject
|
||
|
|
self.func_accepts[kw] = (runtime_key, default)
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
repr_args = {
|
||
|
|
k: v
|
||
|
|
for k, v in self.__dict__.items()
|
||
|
|
if k not in {"name", "func", "afunc", "config", "kwargs", "trace"}
|
||
|
|
}
|
||
|
|
return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})"
|
||
|
|
|
||
|
|
def invoke(
|
||
|
|
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
|
||
|
|
) -> Any:
|
||
|
|
if self.func is None:
|
||
|
|
raise TypeError(
|
||
|
|
f'No synchronous function provided to "{self.name}".'
|
||
|
|
"\nEither initialize with a synchronous function or invoke"
|
||
|
|
" via the async API (ainvoke, astream, etc.)"
|
||
|
|
)
|
||
|
|
if config is None:
|
||
|
|
config = ensure_config()
|
||
|
|
if self.explode_args:
|
||
|
|
args, _kwargs = input
|
||
|
|
kwargs = {**self.kwargs, **_kwargs, **kwargs}
|
||
|
|
else:
|
||
|
|
args = (input,)
|
||
|
|
kwargs = {**self.kwargs, **kwargs}
|
||
|
|
|
||
|
|
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
|
||
|
|
|
||
|
|
for kw, (runtime_key, default) in self.func_accepts.items():
|
||
|
|
# If the kwarg is already set, use the set value
|
||
|
|
if kw in kwargs:
|
||
|
|
continue
|
||
|
|
|
||
|
|
kw_value: Any = MISSING
|
||
|
|
if kw == "config":
|
||
|
|
kw_value = config
|
||
|
|
elif runtime:
|
||
|
|
if kw == "runtime":
|
||
|
|
kw_value = runtime
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
kw_value = getattr(runtime, runtime_key)
|
||
|
|
except AttributeError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
if kw_value is MISSING:
|
||
|
|
if default is inspect.Parameter.empty:
|
||
|
|
raise ValueError(
|
||
|
|
f"Missing required config key '{runtime_key}' for '{self.name}'."
|
||
|
|
)
|
||
|
|
kw_value = default
|
||
|
|
kwargs[kw] = kw_value
|
||
|
|
|
||
|
|
if self.trace:
|
||
|
|
callback_manager = get_callback_manager_for_config(config, self.tags)
|
||
|
|
run_manager = callback_manager.on_chain_start(
|
||
|
|
None,
|
||
|
|
input,
|
||
|
|
name=config.get("run_name") or self.get_name(),
|
||
|
|
run_id=config.pop("run_id", None),
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||
|
|
# get the run
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, LangChainTracer):
|
||
|
|
run = h.run_map.get(str(run_manager.run_id))
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
run = None
|
||
|
|
# run in context
|
||
|
|
with set_config_context(child_config, run) as context:
|
||
|
|
ret = context.run(self.func, *args, **kwargs)
|
||
|
|
except BaseException as e:
|
||
|
|
run_manager.on_chain_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
run_manager.on_chain_end(ret)
|
||
|
|
else:
|
||
|
|
ret = self.func(*args, **kwargs)
|
||
|
|
if self.recurse and isinstance(ret, Runnable):
|
||
|
|
return ret.invoke(input, config)
|
||
|
|
return ret
|
||
|
|
|
||
|
|
async def ainvoke(
|
||
|
|
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
|
||
|
|
) -> Any:
|
||
|
|
if not self.afunc:
|
||
|
|
return self.invoke(input, config)
|
||
|
|
if config is None:
|
||
|
|
config = ensure_config()
|
||
|
|
if self.explode_args:
|
||
|
|
args, _kwargs = input
|
||
|
|
kwargs = {**self.kwargs, **_kwargs, **kwargs}
|
||
|
|
else:
|
||
|
|
args = (input,)
|
||
|
|
kwargs = {**self.kwargs, **kwargs}
|
||
|
|
|
||
|
|
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
|
||
|
|
|
||
|
|
for kw, (runtime_key, default) in self.func_accepts.items():
|
||
|
|
# If the kwarg has already been set, use the set value
|
||
|
|
if kw in kwargs:
|
||
|
|
continue
|
||
|
|
|
||
|
|
kw_value: Any = MISSING
|
||
|
|
if kw == "config":
|
||
|
|
kw_value = config
|
||
|
|
elif runtime:
|
||
|
|
if kw == "runtime":
|
||
|
|
kw_value = runtime
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
kw_value = getattr(runtime, runtime_key)
|
||
|
|
except AttributeError:
|
||
|
|
pass
|
||
|
|
if kw_value is MISSING:
|
||
|
|
if default is inspect.Parameter.empty:
|
||
|
|
raise ValueError(
|
||
|
|
f"Missing required config key '{runtime_key}' for '{self.name}'."
|
||
|
|
)
|
||
|
|
kw_value = default
|
||
|
|
kwargs[kw] = kw_value
|
||
|
|
|
||
|
|
if self.trace:
|
||
|
|
callback_manager = get_async_callback_manager_for_config(config, self.tags)
|
||
|
|
run_manager = await callback_manager.on_chain_start(
|
||
|
|
None,
|
||
|
|
input,
|
||
|
|
name=config.get("run_name") or self.name,
|
||
|
|
run_id=config.pop("run_id", None),
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||
|
|
coro = cast(Coroutine[None, None, Any], self.afunc(*args, **kwargs))
|
||
|
|
if ASYNCIO_ACCEPTS_CONTEXT:
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, LangChainTracer):
|
||
|
|
run = h.run_map.get(str(run_manager.run_id))
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
run = None
|
||
|
|
with set_config_context(child_config, run) as context:
|
||
|
|
ret = await asyncio.create_task(coro, context=context)
|
||
|
|
else:
|
||
|
|
ret = await coro
|
||
|
|
except BaseException as e:
|
||
|
|
await run_manager.on_chain_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
await run_manager.on_chain_end(ret)
|
||
|
|
else:
|
||
|
|
ret = await self.afunc(*args, **kwargs)
|
||
|
|
if self.recurse and isinstance(ret, Runnable):
|
||
|
|
return await ret.ainvoke(input, config)
|
||
|
|
return ret
|
||
|
|
|
||
|
|
|
||
|
|
def is_async_callable(
|
||
|
|
func: Any,
|
||
|
|
) -> TypeGuard[Callable[..., Awaitable]]:
|
||
|
|
"""Check if a function is async."""
|
||
|
|
return (
|
||
|
|
inspect.iscoroutinefunction(func)
|
||
|
|
or hasattr(func, "__call__")
|
||
|
|
and inspect.iscoroutinefunction(func.__call__)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def is_async_generator(
|
||
|
|
func: Any,
|
||
|
|
) -> TypeGuard[Callable[..., AsyncIterator]]:
|
||
|
|
"""Check if a function is an async generator."""
|
||
|
|
return (
|
||
|
|
inspect.isasyncgenfunction(func)
|
||
|
|
or hasattr(func, "__call__")
|
||
|
|
and inspect.isasyncgenfunction(func.__call__)
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def coerce_to_runnable(
|
||
|
|
thing: RunnableLike, *, name: str | None, trace: bool
|
||
|
|
) -> Runnable:
|
||
|
|
"""Coerce a runnable-like object into a Runnable.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
thing: A runnable-like object.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A Runnable.
|
||
|
|
"""
|
||
|
|
if isinstance(thing, Runnable):
|
||
|
|
return thing
|
||
|
|
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
|
||
|
|
return RunnableLambda(thing, name=name)
|
||
|
|
elif callable(thing):
|
||
|
|
if is_async_callable(thing):
|
||
|
|
return RunnableCallable(None, thing, name=name, trace=trace)
|
||
|
|
else:
|
||
|
|
return RunnableCallable(
|
||
|
|
thing,
|
||
|
|
wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type]
|
||
|
|
name=name,
|
||
|
|
trace=trace,
|
||
|
|
)
|
||
|
|
elif isinstance(thing, dict):
|
||
|
|
return RunnableParallel(thing)
|
||
|
|
else:
|
||
|
|
raise TypeError(
|
||
|
|
f"Expected a Runnable, callable or dict."
|
||
|
|
f"Instead got an unsupported type: {type(thing)}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class RunnableSeq(Runnable):
|
||
|
|
"""Sequence of `Runnable`, where the output of each is the input of the next.
|
||
|
|
|
||
|
|
`RunnableSeq` is a simpler version of `RunnableSequence` that is internal to
|
||
|
|
LangGraph.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*steps: RunnableLike,
|
||
|
|
name: str | None = None,
|
||
|
|
trace_inputs: Callable[[Any], Any] | None = None,
|
||
|
|
) -> None:
|
||
|
|
"""Create a new RunnableSeq.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
steps: The steps to include in the sequence.
|
||
|
|
name: The name of the `Runnable`.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If the sequence has less than 2 steps.
|
||
|
|
"""
|
||
|
|
steps_flat: list[Runnable] = []
|
||
|
|
for step in steps:
|
||
|
|
if isinstance(step, RunnableSequence):
|
||
|
|
steps_flat.extend(step.steps)
|
||
|
|
elif isinstance(step, RunnableSeq):
|
||
|
|
steps_flat.extend(step.steps)
|
||
|
|
else:
|
||
|
|
steps_flat.append(coerce_to_runnable(step, name=None, trace=True))
|
||
|
|
if len(steps_flat) < 2:
|
||
|
|
raise ValueError(
|
||
|
|
f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}"
|
||
|
|
)
|
||
|
|
self.steps = steps_flat
|
||
|
|
self.name = name
|
||
|
|
self.trace_inputs = trace_inputs
|
||
|
|
|
||
|
|
def __or__(
|
||
|
|
self,
|
||
|
|
other: Any,
|
||
|
|
) -> Runnable:
|
||
|
|
if isinstance(other, RunnableSequence):
|
||
|
|
return RunnableSeq(
|
||
|
|
*self.steps,
|
||
|
|
other.first,
|
||
|
|
*other.middle,
|
||
|
|
other.last,
|
||
|
|
name=self.name or other.name,
|
||
|
|
)
|
||
|
|
elif isinstance(other, RunnableSeq):
|
||
|
|
return RunnableSeq(
|
||
|
|
*self.steps,
|
||
|
|
*other.steps,
|
||
|
|
name=self.name or other.name,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
return RunnableSeq(
|
||
|
|
*self.steps,
|
||
|
|
coerce_to_runnable(other, name=None, trace=True),
|
||
|
|
name=self.name,
|
||
|
|
)
|
||
|
|
|
||
|
|
def __ror__(
|
||
|
|
self,
|
||
|
|
other: Any,
|
||
|
|
) -> Runnable:
|
||
|
|
if isinstance(other, RunnableSequence):
|
||
|
|
return RunnableSequence(
|
||
|
|
other.first,
|
||
|
|
*other.middle,
|
||
|
|
other.last,
|
||
|
|
*self.steps,
|
||
|
|
name=other.name or self.name,
|
||
|
|
)
|
||
|
|
elif isinstance(other, RunnableSeq):
|
||
|
|
return RunnableSeq(
|
||
|
|
*other.steps,
|
||
|
|
*self.steps,
|
||
|
|
name=other.name or self.name,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
return RunnableSequence(
|
||
|
|
coerce_to_runnable(other, name=None, trace=True),
|
||
|
|
*self.steps,
|
||
|
|
name=self.name,
|
||
|
|
)
|
||
|
|
|
||
|
|
def invoke(
|
||
|
|
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
|
||
|
|
) -> Any:
|
||
|
|
if config is None:
|
||
|
|
config = ensure_config()
|
||
|
|
# setup callbacks and context
|
||
|
|
callback_manager = get_callback_manager_for_config(config)
|
||
|
|
# start the root run
|
||
|
|
run_manager = callback_manager.on_chain_start(
|
||
|
|
None,
|
||
|
|
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||
|
|
name=config.get("run_name") or self.get_name(),
|
||
|
|
run_id=config.pop("run_id", None),
|
||
|
|
)
|
||
|
|
# invoke all steps in sequence
|
||
|
|
try:
|
||
|
|
for i, step in enumerate(self.steps):
|
||
|
|
# mark each step as a child run
|
||
|
|
config = patch_config(
|
||
|
|
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||
|
|
)
|
||
|
|
# 1st step is the actual node,
|
||
|
|
# others are writers which don't need to be run in context
|
||
|
|
if i == 0:
|
||
|
|
# get the run object
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, LangChainTracer):
|
||
|
|
run = h.run_map.get(str(run_manager.run_id))
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
run = None
|
||
|
|
# run in context
|
||
|
|
with set_config_context(config, run) as context:
|
||
|
|
input = context.run(step.invoke, input, config, **kwargs)
|
||
|
|
else:
|
||
|
|
input = step.invoke(input, config)
|
||
|
|
# finish the root run
|
||
|
|
except BaseException as e:
|
||
|
|
run_manager.on_chain_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
run_manager.on_chain_end(input)
|
||
|
|
return input
|
||
|
|
|
||
|
|
async def ainvoke(
|
||
|
|
self,
|
||
|
|
input: Input,
|
||
|
|
config: RunnableConfig | None = None,
|
||
|
|
**kwargs: Any | None,
|
||
|
|
) -> Any:
|
||
|
|
if config is None:
|
||
|
|
config = ensure_config()
|
||
|
|
# setup callbacks
|
||
|
|
callback_manager = get_async_callback_manager_for_config(config)
|
||
|
|
# start the root run
|
||
|
|
run_manager = await callback_manager.on_chain_start(
|
||
|
|
None,
|
||
|
|
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||
|
|
name=config.get("run_name") or self.get_name(),
|
||
|
|
run_id=config.pop("run_id", None),
|
||
|
|
)
|
||
|
|
|
||
|
|
# invoke all steps in sequence
|
||
|
|
try:
|
||
|
|
for i, step in enumerate(self.steps):
|
||
|
|
# mark each step as a child run
|
||
|
|
config = patch_config(
|
||
|
|
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
|
||
|
|
)
|
||
|
|
# 1st step is the actual node,
|
||
|
|
# others are writers which don't need to be run in context
|
||
|
|
if i == 0:
|
||
|
|
if ASYNCIO_ACCEPTS_CONTEXT:
|
||
|
|
# get the run object
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, LangChainTracer):
|
||
|
|
run = h.run_map.get(str(run_manager.run_id))
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
run = None
|
||
|
|
# run in context
|
||
|
|
with set_config_context(config, run) as context:
|
||
|
|
input = await asyncio.create_task(
|
||
|
|
step.ainvoke(input, config, **kwargs), context=context
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
input = await step.ainvoke(input, config, **kwargs)
|
||
|
|
else:
|
||
|
|
input = await step.ainvoke(input, config)
|
||
|
|
# finish the root run
|
||
|
|
except BaseException as e:
|
||
|
|
await run_manager.on_chain_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
await run_manager.on_chain_end(input)
|
||
|
|
return input
|
||
|
|
|
||
|
|
def stream(
|
||
|
|
self,
|
||
|
|
input: Input,
|
||
|
|
config: RunnableConfig | None = None,
|
||
|
|
**kwargs: Any | None,
|
||
|
|
) -> Iterator[Any]:
|
||
|
|
if config is None:
|
||
|
|
config = ensure_config()
|
||
|
|
# setup callbacks
|
||
|
|
callback_manager = get_callback_manager_for_config(config)
|
||
|
|
# start the root run
|
||
|
|
run_manager = callback_manager.on_chain_start(
|
||
|
|
None,
|
||
|
|
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||
|
|
name=config.get("run_name") or self.get_name(),
|
||
|
|
run_id=config.pop("run_id", None),
|
||
|
|
)
|
||
|
|
# get the run object
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, LangChainTracer):
|
||
|
|
run = h.run_map.get(str(run_manager.run_id))
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
run = None
|
||
|
|
# create first step config
|
||
|
|
config = patch_config(
|
||
|
|
config,
|
||
|
|
callbacks=run_manager.get_child(f"seq:step:{1}"),
|
||
|
|
)
|
||
|
|
# run all in context
|
||
|
|
with set_config_context(config, run) as context:
|
||
|
|
try:
|
||
|
|
# stream the last steps
|
||
|
|
# transform the input stream of each step with the next
|
||
|
|
# steps that don't natively support transforming an input stream will
|
||
|
|
# buffer input in memory until all available, and then start emitting output
|
||
|
|
for idx, step in enumerate(self.steps):
|
||
|
|
if idx == 0:
|
||
|
|
iterator = step.stream(input, config, **kwargs)
|
||
|
|
else:
|
||
|
|
config = patch_config(
|
||
|
|
config,
|
||
|
|
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
|
||
|
|
)
|
||
|
|
iterator = step.transform(iterator, config)
|
||
|
|
# populates streamed_output in astream_log() output if needed
|
||
|
|
if _StreamingCallbackHandler is not None:
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, _StreamingCallbackHandler):
|
||
|
|
iterator = h.tap_output_iter(run_manager.run_id, iterator)
|
||
|
|
# consume into final output
|
||
|
|
output = context.run(_consume_iter, iterator)
|
||
|
|
# sequence doesn't emit output, yield to mark as generator
|
||
|
|
yield
|
||
|
|
except BaseException as e:
|
||
|
|
run_manager.on_chain_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
run_manager.on_chain_end(output)
|
||
|
|
|
||
|
|
async def astream(
|
||
|
|
self,
|
||
|
|
input: Input,
|
||
|
|
config: RunnableConfig | None = None,
|
||
|
|
**kwargs: Any | None,
|
||
|
|
) -> AsyncIterator[Any]:
|
||
|
|
if config is None:
|
||
|
|
config = ensure_config()
|
||
|
|
# setup callbacks
|
||
|
|
callback_manager = get_async_callback_manager_for_config(config)
|
||
|
|
# start the root run
|
||
|
|
run_manager = await callback_manager.on_chain_start(
|
||
|
|
None,
|
||
|
|
self.trace_inputs(input) if self.trace_inputs is not None else input,
|
||
|
|
name=config.get("run_name") or self.get_name(),
|
||
|
|
run_id=config.pop("run_id", None),
|
||
|
|
)
|
||
|
|
# stream the last steps
|
||
|
|
# transform the input stream of each step with the next
|
||
|
|
# steps that don't natively support transforming an input stream will
|
||
|
|
# buffer input in memory until all available, and then start emitting output
|
||
|
|
if ASYNCIO_ACCEPTS_CONTEXT:
|
||
|
|
# get the run object
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, LangChainTracer):
|
||
|
|
run = h.run_map.get(str(run_manager.run_id))
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
run = None
|
||
|
|
# create first step config
|
||
|
|
config = patch_config(
|
||
|
|
config,
|
||
|
|
callbacks=run_manager.get_child(f"seq:step:{1}"),
|
||
|
|
)
|
||
|
|
# run all in context
|
||
|
|
with set_config_context(config, run) as context:
|
||
|
|
try:
|
||
|
|
async with AsyncExitStack() as stack:
|
||
|
|
for idx, step in enumerate(self.steps):
|
||
|
|
if idx == 0:
|
||
|
|
aiterator = step.astream(input, config, **kwargs)
|
||
|
|
else:
|
||
|
|
config = patch_config(
|
||
|
|
config,
|
||
|
|
callbacks=run_manager.get_child(
|
||
|
|
f"seq:step:{idx + 1}"
|
||
|
|
),
|
||
|
|
)
|
||
|
|
aiterator = step.atransform(aiterator, config)
|
||
|
|
if hasattr(aiterator, "aclose"):
|
||
|
|
stack.push_async_callback(aiterator.aclose)
|
||
|
|
# populates streamed_output in astream_log() output if needed
|
||
|
|
if _StreamingCallbackHandler is not None:
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, _StreamingCallbackHandler):
|
||
|
|
aiterator = h.tap_output_aiter(
|
||
|
|
run_manager.run_id, aiterator
|
||
|
|
)
|
||
|
|
# consume into final output
|
||
|
|
output = await asyncio.create_task(
|
||
|
|
_consume_aiter(aiterator), context=context
|
||
|
|
)
|
||
|
|
# sequence doesn't emit output, yield to mark as generator
|
||
|
|
yield
|
||
|
|
except BaseException as e:
|
||
|
|
await run_manager.on_chain_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
await run_manager.on_chain_end(output)
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
async with AsyncExitStack() as stack:
|
||
|
|
for idx, step in enumerate(self.steps):
|
||
|
|
config = patch_config(
|
||
|
|
config,
|
||
|
|
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
|
||
|
|
)
|
||
|
|
if idx == 0:
|
||
|
|
aiterator = step.astream(input, config, **kwargs)
|
||
|
|
else:
|
||
|
|
aiterator = step.atransform(aiterator, config)
|
||
|
|
if hasattr(aiterator, "aclose"):
|
||
|
|
stack.push_async_callback(aiterator.aclose)
|
||
|
|
# populates streamed_output in astream_log() output if needed
|
||
|
|
if _StreamingCallbackHandler is not None:
|
||
|
|
for h in run_manager.handlers:
|
||
|
|
if isinstance(h, _StreamingCallbackHandler):
|
||
|
|
aiterator = h.tap_output_aiter(
|
||
|
|
run_manager.run_id, aiterator
|
||
|
|
)
|
||
|
|
# consume into final output
|
||
|
|
output = await _consume_aiter(aiterator)
|
||
|
|
# sequence doesn't emit output, yield to mark as generator
|
||
|
|
yield
|
||
|
|
except BaseException as e:
|
||
|
|
await run_manager.on_chain_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
await run_manager.on_chain_end(output)
|
||
|
|
|
||
|
|
|
||
|
|
def _consume_iter(it: Iterator[Any]) -> Any:
|
||
|
|
"""Consume an iterator."""
|
||
|
|
output: Any = None
|
||
|
|
add_supported = False
|
||
|
|
for chunk in it:
|
||
|
|
# collect final output
|
||
|
|
if output is None:
|
||
|
|
output = chunk
|
||
|
|
elif add_supported:
|
||
|
|
try:
|
||
|
|
output = output + chunk
|
||
|
|
except TypeError:
|
||
|
|
output = chunk
|
||
|
|
add_supported = False
|
||
|
|
else:
|
||
|
|
output = chunk
|
||
|
|
return output
|
||
|
|
|
||
|
|
|
||
|
|
async def _consume_aiter(it: AsyncIterator[Any]) -> Any:
|
||
|
|
"""Consume an async iterator."""
|
||
|
|
output: Any = None
|
||
|
|
add_supported = False
|
||
|
|
async for chunk in it:
|
||
|
|
# collect final output
|
||
|
|
if add_supported:
|
||
|
|
try:
|
||
|
|
output = output + chunk
|
||
|
|
except TypeError:
|
||
|
|
output = chunk
|
||
|
|
add_supported = False
|
||
|
|
else:
|
||
|
|
output = chunk
|
||
|
|
return output
|