from __future__ import annotations import asyncio import enum import inspect import sys import warnings from collections.abc import ( AsyncIterator, Awaitable, Callable, Coroutine, Generator, Iterator, Sequence, ) from contextlib import AsyncExitStack, contextmanager from contextvars import Context, Token, copy_context from functools import partial, wraps from typing import ( Any, Optional, Protocol, TypeGuard, cast, ) from langchain_core.runnables.base import ( Runnable, RunnableConfig, RunnableLambda, RunnableParallel, RunnableSequence, ) from langchain_core.runnables.base import ( RunnableLike as LCRunnableLike, ) from langchain_core.runnables.config import ( run_in_executor, var_child_runnable_config, ) from langchain_core.runnables.utils import Input, Output from langchain_core.tracers.langchain import LangChainTracer from langgraph.store.base import BaseStore from langgraph._internal._config import ( ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, patch_config, ) from langgraph._internal._constants import ( CONF, CONFIG_KEY_RUNTIME, ) from langgraph._internal._typing import MISSING from langgraph.types import StreamWriter try: from langchain_core.tracers._streaming import _StreamingCallbackHandler except ImportError: _StreamingCallbackHandler = None # type: ignore def _set_config_context( config: RunnableConfig, run: Any = None ) -> Token[RunnableConfig | None]: """Set the child Runnable config + tracing context. Args: config: The config to set. """ config_token = var_child_runnable_config.set(config) if run is not None: from langsmith.run_helpers import _set_tracing_context _set_tracing_context({"parent": run}) return config_token def _unset_config_context(token: Token[RunnableConfig | None], run: Any = None) -> None: """Set the child Runnable config + tracing context. Args: token: The config token to reset. """ var_child_runnable_config.reset(token) if run is not None: from langsmith.run_helpers import _set_tracing_context _set_tracing_context( { "parent": None, "project_name": None, "tags": None, "metadata": None, "enabled": None, "client": None, } ) @contextmanager def set_config_context( config: RunnableConfig, run: Any = None ) -> Generator[Context, None, None]: """Set the child Runnable config + tracing context. Args: config: The config to set. """ ctx = copy_context() config_token = ctx.run(_set_config_context, config, run) try: yield ctx finally: ctx.run(_unset_config_context, config_token, run) # Before Python 3.11 native StrEnum is not available class StrEnum(str, enum.Enum): """A string enum.""" # Special type to denote any type is accepted ANY_TYPE = object() ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11) # List of keyword arguments that can be injected into nodes / tasks / tools at runtime. # A named argument may appear multiple times if it appears with distinct types. KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = ( ( "config", ( RunnableConfig, "RunnableConfig", Optional[RunnableConfig], # noqa: UP045 "Optional[RunnableConfig]", inspect.Parameter.empty, ), # for now, use config directly, eventually, will pop off of Runtime "N/A", inspect.Parameter.empty, ), ( "writer", (StreamWriter, "StreamWriter", inspect.Parameter.empty), "stream_writer", lambda _: None, ), ( "store", ( BaseStore, "BaseStore", inspect.Parameter.empty, ), "store", inspect.Parameter.empty, ), ( "store", ( Optional[BaseStore], # noqa: UP045 "Optional[BaseStore]", ), "store", None, ), ( "previous", (ANY_TYPE,), "previous", inspect.Parameter.empty, ), ( "runtime", (ANY_TYPE,), # we never hit this block, we just inject runtime directly "N/A", inspect.Parameter.empty, ), ) """List of kwargs that can be passed to functions, and their corresponding config keys, default values and type annotations. Used to configure keyword arguments that can be injected at runtime from the `Runtime` object as kwargs to `invoke`, `ainvoke`, `stream` and `astream`. For a keyword to be injected from the config object, the function signature must contain a kwarg with the same name and a matching type annotation. Each tuple contains: - the name of the kwarg in the function signature - the type annotation(s) for the kwarg - the `Runtime` attribute for fetching the value (N/A if not applicable) This is fully internal and should be further refactored to use `get_type_hints` to resolve forward references and optional types formatted like BaseStore | None. """ VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) class _RunnableWithWriter(Protocol[Input, Output]): def __call__(self, state: Input, *, writer: StreamWriter) -> Output: ... class _RunnableWithStore(Protocol[Input, Output]): def __call__(self, state: Input, *, store: BaseStore) -> Output: ... class _RunnableWithWriterStore(Protocol[Input, Output]): def __call__( self, state: Input, *, writer: StreamWriter, store: BaseStore ) -> Output: ... class _RunnableWithConfigWriter(Protocol[Input, Output]): def __call__( self, state: Input, *, config: RunnableConfig, writer: StreamWriter ) -> Output: ... class _RunnableWithConfigStore(Protocol[Input, Output]): def __call__( self, state: Input, *, config: RunnableConfig, store: BaseStore ) -> Output: ... class _RunnableWithConfigWriterStore(Protocol[Input, Output]): def __call__( self, state: Input, *, config: RunnableConfig, writer: StreamWriter, store: BaseStore, ) -> Output: ... RunnableLike = ( LCRunnableLike | _RunnableWithWriter[Input, Output] | _RunnableWithStore[Input, Output] | _RunnableWithWriterStore[Input, Output] | _RunnableWithConfigWriter[Input, Output] | _RunnableWithConfigStore[Input, Output] | _RunnableWithConfigWriterStore[Input, Output] ) class RunnableCallable(Runnable): """A much simpler version of RunnableLambda that requires sync and async functions.""" def __init__( self, func: Callable[..., Any | Runnable] | None, afunc: Callable[..., Awaitable[Any | Runnable]] | None = None, *, name: str | None = None, tags: Sequence[str] | None = None, trace: bool = True, recurse: bool = True, explode_args: bool = False, **kwargs: Any, ) -> None: self.name = name if self.name is None: if func: try: if func.__name__ != "": self.name = func.__name__ except AttributeError: pass elif afunc: try: self.name = afunc.__name__ except AttributeError: pass self.func = func self.afunc = afunc self.tags = tags self.kwargs = kwargs self.trace = trace self.recurse = recurse self.explode_args = explode_args # check signature if func is None and afunc is None: raise ValueError("At least one of func or afunc must be provided.") self.func_accepts: dict[str, tuple[str, Any]] = {} params = inspect.signature(cast(Callable, func or afunc)).parameters for kw, typ, runtime_key, default in KWARGS_CONFIG_KEYS: p = params.get(kw) if p is None or p.kind not in VALID_KINDS: # If parameter is not found or is not a valid kind, skip continue if typ != (ANY_TYPE,) and p.annotation not in typ: # A specific type is required, but the function annotation does # not match the expected type. # If this is a config parameter with incorrect typing, emit a warning # because we used to support any type but are moving towards more correct typing if kw == "config" and p.annotation != inspect.Parameter.empty: warnings.warn( f"The 'config' parameter should be typed as 'RunnableConfig' or " f"'RunnableConfig | None', not '{p.annotation}'. ", UserWarning, stacklevel=4, ) continue # If the kwarg is accepted by the function, store the key / runtime attribute to inject self.func_accepts[kw] = (runtime_key, default) def __repr__(self) -> str: repr_args = { k: v for k, v in self.__dict__.items() if k not in {"name", "func", "afunc", "config", "kwargs", "trace"} } return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})" def invoke( self, input: Any, config: RunnableConfig | None = None, **kwargs: Any ) -> Any: if self.func is None: raise TypeError( f'No synchronous function provided to "{self.name}".' "\nEither initialize with a synchronous function or invoke" " via the async API (ainvoke, astream, etc.)" ) if config is None: config = ensure_config() if self.explode_args: args, _kwargs = input kwargs = {**self.kwargs, **_kwargs, **kwargs} else: args = (input,) kwargs = {**self.kwargs, **kwargs} runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME) for kw, (runtime_key, default) in self.func_accepts.items(): # If the kwarg is already set, use the set value if kw in kwargs: continue kw_value: Any = MISSING if kw == "config": kw_value = config elif runtime: if kw == "runtime": kw_value = runtime else: try: kw_value = getattr(runtime, runtime_key) except AttributeError: pass if kw_value is MISSING: if default is inspect.Parameter.empty: raise ValueError( f"Missing required config key '{runtime_key}' for '{self.name}'." ) kw_value = default kwargs[kw] = kw_value if self.trace: callback_manager = get_callback_manager_for_config(config, self.tags) run_manager = callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) # get the run for h in run_manager.handlers: if isinstance(h, LangChainTracer): run = h.run_map.get(str(run_manager.run_id)) break else: run = None # run in context with set_config_context(child_config, run) as context: ret = context.run(self.func, *args, **kwargs) except BaseException as e: run_manager.on_chain_error(e) raise else: run_manager.on_chain_end(ret) else: ret = self.func(*args, **kwargs) if self.recurse and isinstance(ret, Runnable): return ret.invoke(input, config) return ret async def ainvoke( self, input: Any, config: RunnableConfig | None = None, **kwargs: Any ) -> Any: if not self.afunc: return self.invoke(input, config) if config is None: config = ensure_config() if self.explode_args: args, _kwargs = input kwargs = {**self.kwargs, **_kwargs, **kwargs} else: args = (input,) kwargs = {**self.kwargs, **kwargs} runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME) for kw, (runtime_key, default) in self.func_accepts.items(): # If the kwarg has already been set, use the set value if kw in kwargs: continue kw_value: Any = MISSING if kw == "config": kw_value = config elif runtime: if kw == "runtime": kw_value = runtime else: try: kw_value = getattr(runtime, runtime_key) except AttributeError: pass if kw_value is MISSING: if default is inspect.Parameter.empty: raise ValueError( f"Missing required config key '{runtime_key}' for '{self.name}'." ) kw_value = default kwargs[kw] = kw_value if self.trace: callback_manager = get_async_callback_manager_for_config(config, self.tags) run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name") or self.name, run_id=config.pop("run_id", None), ) try: child_config = patch_config(config, callbacks=run_manager.get_child()) coro = cast(Coroutine[None, None, Any], self.afunc(*args, **kwargs)) if ASYNCIO_ACCEPTS_CONTEXT: for h in run_manager.handlers: if isinstance(h, LangChainTracer): run = h.run_map.get(str(run_manager.run_id)) break else: run = None with set_config_context(child_config, run) as context: ret = await asyncio.create_task(coro, context=context) else: ret = await coro except BaseException as e: await run_manager.on_chain_error(e) raise else: await run_manager.on_chain_end(ret) else: ret = await self.afunc(*args, **kwargs) if self.recurse and isinstance(ret, Runnable): return await ret.ainvoke(input, config) return ret def is_async_callable( func: Any, ) -> TypeGuard[Callable[..., Awaitable]]: """Check if a function is async.""" return ( inspect.iscoroutinefunction(func) or hasattr(func, "__call__") and inspect.iscoroutinefunction(func.__call__) ) def is_async_generator( func: Any, ) -> TypeGuard[Callable[..., AsyncIterator]]: """Check if a function is an async generator.""" return ( inspect.isasyncgenfunction(func) or hasattr(func, "__call__") and inspect.isasyncgenfunction(func.__call__) ) def coerce_to_runnable( thing: RunnableLike, *, name: str | None, trace: bool ) -> Runnable: """Coerce a runnable-like object into a Runnable. Args: thing: A runnable-like object. Returns: A Runnable. """ if isinstance(thing, Runnable): return thing elif is_async_generator(thing) or inspect.isgeneratorfunction(thing): return RunnableLambda(thing, name=name) elif callable(thing): if is_async_callable(thing): return RunnableCallable(None, thing, name=name, trace=trace) else: return RunnableCallable( thing, wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type] name=name, trace=trace, ) elif isinstance(thing, dict): return RunnableParallel(thing) else: raise TypeError( f"Expected a Runnable, callable or dict." f"Instead got an unsupported type: {type(thing)}" ) class RunnableSeq(Runnable): """Sequence of `Runnable`, where the output of each is the input of the next. `RunnableSeq` is a simpler version of `RunnableSequence` that is internal to LangGraph. """ def __init__( self, *steps: RunnableLike, name: str | None = None, trace_inputs: Callable[[Any], Any] | None = None, ) -> None: """Create a new RunnableSeq. Args: steps: The steps to include in the sequence. name: The name of the `Runnable`. Raises: ValueError: If the sequence has less than 2 steps. """ steps_flat: list[Runnable] = [] for step in steps: if isinstance(step, RunnableSequence): steps_flat.extend(step.steps) elif isinstance(step, RunnableSeq): steps_flat.extend(step.steps) else: steps_flat.append(coerce_to_runnable(step, name=None, trace=True)) if len(steps_flat) < 2: raise ValueError( f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}" ) self.steps = steps_flat self.name = name self.trace_inputs = trace_inputs def __or__( self, other: Any, ) -> Runnable: if isinstance(other, RunnableSequence): return RunnableSeq( *self.steps, other.first, *other.middle, other.last, name=self.name or other.name, ) elif isinstance(other, RunnableSeq): return RunnableSeq( *self.steps, *other.steps, name=self.name or other.name, ) else: return RunnableSeq( *self.steps, coerce_to_runnable(other, name=None, trace=True), name=self.name, ) def __ror__( self, other: Any, ) -> Runnable: if isinstance(other, RunnableSequence): return RunnableSequence( other.first, *other.middle, other.last, *self.steps, name=other.name or self.name, ) elif isinstance(other, RunnableSeq): return RunnableSeq( *other.steps, *self.steps, name=other.name or self.name, ) else: return RunnableSequence( coerce_to_runnable(other, name=None, trace=True), *self.steps, name=self.name, ) def invoke( self, input: Input, config: RunnableConfig | None = None, **kwargs: Any ) -> Any: if config is None: config = ensure_config() # setup callbacks and context callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( None, self.trace_inputs(input) if self.trace_inputs is not None else input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) # invoke all steps in sequence try: for i, step in enumerate(self.steps): # mark each step as a child run config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") ) # 1st step is the actual node, # others are writers which don't need to be run in context if i == 0: # get the run object for h in run_manager.handlers: if isinstance(h, LangChainTracer): run = h.run_map.get(str(run_manager.run_id)) break else: run = None # run in context with set_config_context(config, run) as context: input = context.run(step.invoke, input, config, **kwargs) else: input = step.invoke(input, config) # finish the root run except BaseException as e: run_manager.on_chain_error(e) raise else: run_manager.on_chain_end(input) return input async def ainvoke( self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None, ) -> Any: if config is None: config = ensure_config() # setup callbacks callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( None, self.trace_inputs(input) if self.trace_inputs is not None else input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) # invoke all steps in sequence try: for i, step in enumerate(self.steps): # mark each step as a child run config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") ) # 1st step is the actual node, # others are writers which don't need to be run in context if i == 0: if ASYNCIO_ACCEPTS_CONTEXT: # get the run object for h in run_manager.handlers: if isinstance(h, LangChainTracer): run = h.run_map.get(str(run_manager.run_id)) break else: run = None # run in context with set_config_context(config, run) as context: input = await asyncio.create_task( step.ainvoke(input, config, **kwargs), context=context ) else: input = await step.ainvoke(input, config, **kwargs) else: input = await step.ainvoke(input, config) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) raise else: await run_manager.on_chain_end(input) return input def stream( self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None, ) -> Iterator[Any]: if config is None: config = ensure_config() # setup callbacks callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( None, self.trace_inputs(input) if self.trace_inputs is not None else input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) # get the run object for h in run_manager.handlers: if isinstance(h, LangChainTracer): run = h.run_map.get(str(run_manager.run_id)) break else: run = None # create first step config config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{1}"), ) # run all in context with set_config_context(config, run) as context: try: # stream the last steps # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output for idx, step in enumerate(self.steps): if idx == 0: iterator = step.stream(input, config, **kwargs) else: config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}"), ) iterator = step.transform(iterator, config) # populates streamed_output in astream_log() output if needed if _StreamingCallbackHandler is not None: for h in run_manager.handlers: if isinstance(h, _StreamingCallbackHandler): iterator = h.tap_output_iter(run_manager.run_id, iterator) # consume into final output output = context.run(_consume_iter, iterator) # sequence doesn't emit output, yield to mark as generator yield except BaseException as e: run_manager.on_chain_error(e) raise else: run_manager.on_chain_end(output) async def astream( self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None, ) -> AsyncIterator[Any]: if config is None: config = ensure_config() # setup callbacks callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( None, self.trace_inputs(input) if self.trace_inputs is not None else input, name=config.get("run_name") or self.get_name(), run_id=config.pop("run_id", None), ) # stream the last steps # transform the input stream of each step with the next # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output if ASYNCIO_ACCEPTS_CONTEXT: # get the run object for h in run_manager.handlers: if isinstance(h, LangChainTracer): run = h.run_map.get(str(run_manager.run_id)) break else: run = None # create first step config config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{1}"), ) # run all in context with set_config_context(config, run) as context: try: async with AsyncExitStack() as stack: for idx, step in enumerate(self.steps): if idx == 0: aiterator = step.astream(input, config, **kwargs) else: config = patch_config( config, callbacks=run_manager.get_child( f"seq:step:{idx + 1}" ), ) aiterator = step.atransform(aiterator, config) if hasattr(aiterator, "aclose"): stack.push_async_callback(aiterator.aclose) # populates streamed_output in astream_log() output if needed if _StreamingCallbackHandler is not None: for h in run_manager.handlers: if isinstance(h, _StreamingCallbackHandler): aiterator = h.tap_output_aiter( run_manager.run_id, aiterator ) # consume into final output output = await asyncio.create_task( _consume_aiter(aiterator), context=context ) # sequence doesn't emit output, yield to mark as generator yield except BaseException as e: await run_manager.on_chain_error(e) raise else: await run_manager.on_chain_end(output) else: try: async with AsyncExitStack() as stack: for idx, step in enumerate(self.steps): config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}"), ) if idx == 0: aiterator = step.astream(input, config, **kwargs) else: aiterator = step.atransform(aiterator, config) if hasattr(aiterator, "aclose"): stack.push_async_callback(aiterator.aclose) # populates streamed_output in astream_log() output if needed if _StreamingCallbackHandler is not None: for h in run_manager.handlers: if isinstance(h, _StreamingCallbackHandler): aiterator = h.tap_output_aiter( run_manager.run_id, aiterator ) # consume into final output output = await _consume_aiter(aiterator) # sequence doesn't emit output, yield to mark as generator yield except BaseException as e: await run_manager.on_chain_error(e) raise else: await run_manager.on_chain_end(output) def _consume_iter(it: Iterator[Any]) -> Any: """Consume an iterator.""" output: Any = None add_supported = False for chunk in it: # collect final output if output is None: output = chunk elif add_supported: try: output = output + chunk except TypeError: output = chunk add_supported = False else: output = chunk return output async def _consume_aiter(it: AsyncIterator[Any]) -> Any: """Consume an async iterator.""" output: Any = None add_supported = False async for chunk in it: # collect final output if add_supported: try: output = output + chunk except TypeError: output = chunk add_supported = False else: output = chunk return output