80 lines
2.2 KiB
Python
80 lines
2.2 KiB
Python
"""Chain that runs an arbitrary python function."""
|
|
|
|
import functools
|
|
import logging
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForChainRun,
|
|
CallbackManagerForChainRun,
|
|
)
|
|
from pydantic import Field
|
|
from typing_extensions import override
|
|
|
|
from langchain_classic.chains.base import Chain
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TransformChain(Chain):
|
|
"""Chain that transforms the chain output.
|
|
|
|
Example:
|
|
```python
|
|
from langchain_classic.chains import TransformChain
|
|
transform_chain = TransformChain(input_variables=["text"],
|
|
output_variables["entities"], transform=func())
|
|
|
|
```
|
|
"""
|
|
|
|
input_variables: list[str]
|
|
"""The keys expected by the transform's input dictionary."""
|
|
output_variables: list[str]
|
|
"""The keys returned by the transform's output dictionary."""
|
|
transform_cb: Callable[[dict[str, str]], dict[str, str]] = Field(alias="transform")
|
|
"""The transform function."""
|
|
atransform_cb: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]] | None = Field(
|
|
None, alias="atransform"
|
|
)
|
|
"""The async coroutine transform function."""
|
|
|
|
@staticmethod
|
|
@functools.lru_cache
|
|
def _log_once(msg: str) -> None:
|
|
"""Log a message once."""
|
|
logger.warning(msg)
|
|
|
|
@property
|
|
def input_keys(self) -> list[str]:
|
|
"""Expect input keys."""
|
|
return self.input_variables
|
|
|
|
@property
|
|
def output_keys(self) -> list[str]:
|
|
"""Return output keys."""
|
|
return self.output_variables
|
|
|
|
@override
|
|
def _call(
|
|
self,
|
|
inputs: dict[str, str],
|
|
run_manager: CallbackManagerForChainRun | None = None,
|
|
) -> dict[str, str]:
|
|
return self.transform_cb(inputs)
|
|
|
|
@override
|
|
async def _acall(
|
|
self,
|
|
inputs: dict[str, Any],
|
|
run_manager: AsyncCallbackManagerForChainRun | None = None,
|
|
) -> dict[str, Any]:
|
|
if self.atransform_cb is not None:
|
|
return await self.atransform_cb(inputs)
|
|
self._log_once(
|
|
"TransformChain's atransform is not provided, falling"
|
|
" back to synchronous transform",
|
|
)
|
|
return self.transform_cb(inputs)
|