group-wbl/.venv/lib/python3.13/site-packages/langchain_classic/runnables/openai_functions.py

55 lines
1.9 KiB
Python
Raw Normal View History

2026-01-09 09:12:25 +08:00
from collections.abc import Callable, Mapping
from operator import itemgetter
from typing import Any
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.runnables import RouterRunnable, Runnable
from langchain_core.runnables.base import RunnableBindingBase
from typing_extensions import TypedDict
class OpenAIFunction(TypedDict):
"""A function description for `ChatOpenAI`."""
name: str
"""The name of the function."""
description: str
"""The description of the function."""
parameters: dict
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]): # type: ignore[no-redef]
"""A runnable that routes to the selected function."""
functions: list[OpenAIFunction] | None
def __init__(
self,
runnables: Mapping[
str,
Runnable[dict, Any] | Callable[[dict], Any],
],
functions: list[OpenAIFunction] | None = None,
):
"""Initialize the `OpenAIFunctionsRouter`.
Args:
runnables: A mapping of function names to runnables.
functions: Optional list of functions to check against the runnables.
"""
if functions is not None:
if len(functions) != len(runnables):
msg = "The number of functions does not match the number of runnables."
raise ValueError(msg)
if not all(func["name"] in runnables for func in functions):
msg = "One or more function names are not found in runnables."
raise ValueError(msg)
router = (
JsonOutputFunctionsParser(args_only=False)
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
| RouterRunnable(runnables)
)
super().__init__(bound=router, kwargs={}, functions=functions)