"""Chain that runs an arbitrary python function.""" import functools import logging from collections.abc import Awaitable, Callable from typing import Any from langchain_core.callbacks import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) from pydantic import Field from typing_extensions import override from langchain_classic.chains.base import Chain logger = logging.getLogger(__name__) class TransformChain(Chain): """Chain that transforms the chain output. Example: ```python from langchain_classic.chains import TransformChain transform_chain = TransformChain(input_variables=["text"], output_variables["entities"], transform=func()) ``` """ input_variables: list[str] """The keys expected by the transform's input dictionary.""" output_variables: list[str] """The keys returned by the transform's output dictionary.""" transform_cb: Callable[[dict[str, str]], dict[str, str]] = Field(alias="transform") """The transform function.""" atransform_cb: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]] | None = Field( None, alias="atransform" ) """The async coroutine transform function.""" @staticmethod @functools.lru_cache def _log_once(msg: str) -> None: """Log a message once.""" logger.warning(msg) @property def input_keys(self) -> list[str]: """Expect input keys.""" return self.input_variables @property def output_keys(self) -> list[str]: """Return output keys.""" return self.output_variables @override def _call( self, inputs: dict[str, str], run_manager: CallbackManagerForChainRun | None = None, ) -> dict[str, str]: return self.transform_cb(inputs) @override async def _acall( self, inputs: dict[str, Any], run_manager: AsyncCallbackManagerForChainRun | None = None, ) -> dict[str, Any]: if self.atransform_cb is not None: return await self.atransform_cb(inputs) self._log_once( "TransformChain's atransform is not provided, falling" " back to synchronous transform", ) return self.transform_cb(inputs)