group-wbl/.venv/lib/python3.13/site-packages/langchain_classic/chains/transform.py

80 lines
2.2 KiB
Python
Raw Permalink Normal View History

2026-01-09 09:48:03 +08:00
"""Chain that runs an arbitrary python function."""
import functools
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from pydantic import Field
from typing_extensions import override
from langchain_classic.chains.base import Chain
logger = logging.getLogger(__name__)
class TransformChain(Chain):
"""Chain that transforms the chain output.
Example:
```python
from langchain_classic.chains import TransformChain
transform_chain = TransformChain(input_variables=["text"],
output_variables["entities"], transform=func())
```
"""
input_variables: list[str]
"""The keys expected by the transform's input dictionary."""
output_variables: list[str]
"""The keys returned by the transform's output dictionary."""
transform_cb: Callable[[dict[str, str]], dict[str, str]] = Field(alias="transform")
"""The transform function."""
atransform_cb: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]] | None = Field(
None, alias="atransform"
)
"""The async coroutine transform function."""
@staticmethod
@functools.lru_cache
def _log_once(msg: str) -> None:
"""Log a message once."""
logger.warning(msg)
@property
def input_keys(self) -> list[str]:
"""Expect input keys."""
return self.input_variables
@property
def output_keys(self) -> list[str]:
"""Return output keys."""
return self.output_variables
@override
def _call(
self,
inputs: dict[str, str],
run_manager: CallbackManagerForChainRun | None = None,
) -> dict[str, str]:
return self.transform_cb(inputs)
@override
async def _acall(
self,
inputs: dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun | None = None,
) -> dict[str, Any]:
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform",
)
return self.transform_cb(inputs)