group-wbl/.venv/lib/python3.13/site-packages/langchain_classic/chat_models/base.py

1048 lines
38 KiB
Python
Raw Permalink Normal View History

2026-01-09 09:48:03 +08:00
from __future__ import annotations
import warnings
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from importlib import util
from typing import Any, Literal, TypeAlias, cast, overload
from langchain_core.language_models import (
BaseChatModel,
LanguageModelInput,
SimpleChatModel,
)
from langchain_core.language_models.chat_models import (
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
from pydantic import BaseModel
from typing_extensions import override
__all__ = [
# For backwards compatibility
"BaseChatModel",
"SimpleChatModel",
"agenerate_from_stream",
"generate_from_stream",
"init_chat_model",
]
@overload
def init_chat_model(
model: str,
*,
model_provider: str | None = None,
configurable_fields: None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> BaseChatModel: ...
@overload
def init_chat_model(
model: None = None,
*,
model_provider: str | None = None,
configurable_fields: None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@overload
def init_chat_model(
model: str | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
def init_chat_model(
model: str | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> BaseChatModel | _ConfigurableModel:
"""Initialize a chat model from any supported provider using a unified interface.
**Two main use cases:**
1. **Fixed model** specify the model upfront and get back a ready-to-use chat
model.
2. **Configurable model** choose to specify parameters (including model name) at
runtime via `config`. Makes it easy to switch between models/providers without
changing your code
!!! note
Requires the integration package for the chosen model provider to be installed.
See the `model_provider` parameter below for specific package names
(e.g., `pip install langchain-openai`).
Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers)
for supported model parameters to use as `**kwargs`.
Args:
model: The name or ID of the model, e.g. `'o3-mini'`, `'claude-sonnet-4-5-20250929'`.
You can also specify model and model provider in a single argument using
`'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
Will attempt to infer `model_provider` from model if not specified.
The following providers will be inferred based on these model prefixes:
- `gpt-...` | `o1...` | `o3...` -> `openai`
- `claude...` -> `anthropic`
- `amazon...` -> `bedrock`
- `gemini...` -> `google_vertexai`
- `command...` -> `cohere`
- `accounts/fireworks...` -> `fireworks`
- `mistral...` -> `mistralai`
- `deepseek...` -> `deepseek`
- `grok...` -> `xai`
- `sonar...` -> `perplexity`
model_provider: The model provider if not specified as part of the model arg
(see above).
Supported `model_provider` values and the corresponding integration package
are:
- `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `anthropic` -> [`langchain-anthropic`](https://docs.langchain.com/oss/python/integrations/providers/anthropic)
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `azure_ai` -> [`langchain-azure-ai`](https://docs.langchain.com/oss/python/integrations/providers/microsoft)
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `bedrock_converse` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
- `fireworks` -> [`langchain-fireworks`](https://docs.langchain.com/oss/python/integrations/providers/fireworks)
- `together` -> [`langchain-together`](https://docs.langchain.com/oss/python/integrations/providers/together)
- `mistralai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
- `groq` -> [`langchain-groq`](https://docs.langchain.com/oss/python/integrations/providers/groq)
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
configurable_fields: Which model parameters are configurable at runtime:
- `None`: No configurable fields (i.e., a fixed model).
- `'any'`: All fields are configurable. **See security note below.**
- `list[str] | Tuple[str, ...]`: Specified fields are configurable.
Fields are assumed to have `config_prefix` stripped if a `config_prefix` is
specified.
If `model` is specified, then defaults to `None`.
If `model` is not specified, then defaults to `("model", "model_provider")`.
!!! warning "Security note"
Setting `configurable_fields="any"` means fields like `api_key`,
`base_url`, etc., can be altered at runtime, potentially redirecting
model requests to a different service/user.
Make sure that if you're accepting untrusted configurations that you
enumerate the `configurable_fields=(...)` explicitly.
config_prefix: Optional prefix for configuration keys.
Useful when you have multiple configurable models in the same application.
If `'config_prefix'` is a non-empty string then `model` will be configurable
at runtime via the `config["configurable"]["{config_prefix}_{param}"]` keys.
See examples below.
If `'config_prefix'` is an empty string then model will be configurable via
`config["configurable"]["{param}"]`.
**kwargs: Additional model-specific keyword args to pass to the underlying
chat model's `__init__` method. Common parameters include:
- `temperature`: Model temperature for controlling randomness.
- `max_tokens`: Maximum number of output tokens.
- `timeout`: Maximum time (in seconds) to wait for a response.
- `max_retries`: Maximum number of retry attempts for failed requests.
- `base_url`: Custom API endpoint URL.
- `rate_limiter`: A
[`BaseRateLimiter`][langchain_core.rate_limiters.BaseRateLimiter]
instance to control request rate.
Refer to the specific model provider's
[integration reference](https://reference.langchain.com/python/integrations/)
for all available parameters.
Returns:
A [`BaseChatModel`][langchain_core.language_models.BaseChatModel] corresponding
to the `model_name` and `model_provider` specified if configurability is
inferred to be `False`. If configurable, a chat model emulator that
initializes the underlying model at runtime once a config is passed in.
Raises:
ValueError: If `model_provider` cannot be inferred or isn't supported.
ImportError: If the model provider integration package is not installed.
???+ example "Initialize a non-configurable model"
```python
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
from langchain_classic.chat_models import init_chat_model
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
claude_sonnet = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=0)
gemini_2-5_flash = init_chat_model(
"google_vertexai:gemini-2.5-flash", temperature=0
)
o3_mini.invoke("what's your name")
claude_sonnet.invoke("what's your name")
gemini_2-5_flash.invoke("what's your name")
```
??? example "Partially configurable model with no default"
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain_classic.chat_models import init_chat_model
# (We don't need to specify configurable=True if a model isn't specified.)
configurable_model = init_chat_model(temperature=0)
configurable_model.invoke(
"what's your name", config={"configurable": {"model": "gpt-4o"}}
)
# Use GPT-4o to generate the response
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
)
```
??? example "Fully configurable model with a default"
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain_classic.chat_models import init_chat_model
configurable_model_with_default = init_chat_model(
"openai:gpt-4o",
configurable_fields="any", # This allows us to configure other params like temperature, max_tokens, etc at runtime.
config_prefix="foo",
temperature=0,
)
configurable_model_with_default.invoke("what's your name")
# GPT-4o response with temperature 0 (as set in default)
configurable_model_with_default.invoke(
"what's your name",
config={
"configurable": {
"foo_model": "anthropic:claude-sonnet-4-5-20250929",
"foo_temperature": 0.6,
}
},
)
# Override default to use Sonnet 4.5 with temperature 0.6 to generate response
```
??? example "Bind tools to a configurable model"
You can call any chat model declarative methods on a configurable model in the
same way that you would with a normal model:
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain_classic.chat_models import init_chat_model
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
configurable_model = init_chat_model(
"gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
)
configurable_model_with_tools = configurable_model.bind_tools(
[
GetWeather,
GetPopulation,
]
)
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
# Use GPT-4o
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?",
config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
)
# Use Sonnet 4.5
```
!!! warning "Behavior changed in `langchain` 0.2.8"
Support for `configurable_fields` and `config_prefix` added.
!!! warning "Behavior changed in `langchain` 0.2.12"
Support for Ollama via langchain-ollama package added
(`langchain_ollama.ChatOllama`). Previously,
the now-deprecated langchain-community version of Ollama was imported
(`langchain_community.chat_models.ChatOllama`).
Support for AWS Bedrock models via the Converse API added
(`model_provider="bedrock_converse"`).
!!! warning "Behavior changed in `langchain` 0.3.5"
Out of beta.
!!! warning "Behavior changed in `langchain` 0.3.19"
Support for Deepseek, IBM, Nvidia, and xAI models added.
""" # noqa: E501
if not model and not configurable_fields:
configurable_fields = ("model", "model_provider")
config_prefix = config_prefix or ""
if config_prefix and not configurable_fields:
warnings.warn(
f"{config_prefix=} has been set but no fields are configurable. Set "
f"`configurable_fields=(...)` to specify the model params that are "
f"configurable.",
stacklevel=2,
)
if not configurable_fields:
return _init_chat_model_helper(
cast("str", model),
model_provider=model_provider,
**kwargs,
)
if model:
kwargs["model"] = model
if model_provider:
kwargs["model_provider"] = model_provider
return _ConfigurableModel(
default_config=kwargs,
config_prefix=config_prefix,
configurable_fields=configurable_fields,
)
def _init_chat_model_helper(
model: str,
*,
model_provider: str | None = None,
**kwargs: Any,
) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider)
if model_provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
if model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
if model_provider == "azure_ai":
_check_pkg("langchain_azure_ai")
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
return AzureAIChatCompletionsModel(model=model, **kwargs)
if model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
if model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
if model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
if model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
if model_provider == "ollama":
try:
_check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama
except ImportError:
# For backwards compatibility
try:
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
_check_pkg("langchain_ollama")
return ChatOllama(model=model, **kwargs)
if model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
if model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
if model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
# TODO: update to use model= once ChatBedrock supports
return ChatBedrock(model_id=model, **kwargs)
if model_provider == "bedrock_converse":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs)
if model_provider == "google_anthropic_vertex":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs)
if model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs)
if model_provider == "nvidia":
_check_pkg("langchain_nvidia_ai_endpoints")
from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(model=model, **kwargs)
if model_provider == "ibm":
_check_pkg("langchain_ibm")
from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs)
if model_provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs)
if model_provider == "perplexity":
_check_pkg("langchain_perplexity")
from langchain_perplexity import ChatPerplexity
return ChatPerplexity(model=model, **kwargs)
if model_provider == "upstage":
_check_pkg("langchain_upstage")
from langchain_upstage import ChatUpstage
return ChatUpstage(model=model, **kwargs)
supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = (
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
)
raise ValueError(msg)
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"azure_ai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
"bedrock_converse",
"google_anthropic_vertex",
"deepseek",
"ibm",
"xai",
"perplexity",
"upstage",
}
def _attempt_infer_model_provider(model_name: str) -> str | None:
"""Attempt to infer model provider from model name.
Args:
model_name: The name of the model to infer provider for.
Returns:
The inferred provider name, or `None` if no provider could be inferred.
"""
model_lower = model_name.lower()
# OpenAI models (including newer models and aliases)
if any(
model_lower.startswith(pre)
for pre in (
"gpt-",
"o1",
"o3",
"chatgpt",
"text-davinci",
)
):
return "openai"
# Anthropic models
if model_lower.startswith("claude"):
return "anthropic"
# Cohere models
if model_lower.startswith("command"):
return "cohere"
# Fireworks models
if model_name.startswith("accounts/fireworks"):
return "fireworks"
# Google models
if model_lower.startswith("gemini"):
return "google_vertexai"
# AWS Bedrock models
if model_name.startswith("amazon.") or model_lower.startswith(
("anthropic.", "meta.")
):
return "bedrock"
# Mistral models
if model_lower.startswith(("mistral", "mixtral")):
return "mistralai"
# DeepSeek models
if model_lower.startswith("deepseek"):
return "deepseek"
# xAI models
if model_lower.startswith("grok"):
return "xai"
# Perplexity models
if model_lower.startswith("sonar"):
return "perplexity"
# Upstage models
if model_lower.startswith("solar"):
return "upstage"
return None
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
"""Parse model name and provider, inferring provider if necessary."""
if not model_provider and ":" in model:
prefix, suffix = model.split(":", 1)
if prefix in _SUPPORTED_PROVIDERS:
model_provider = prefix
model = suffix
else:
inferred = _attempt_infer_model_provider(prefix)
if inferred:
model_provider = inferred
model = suffix
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
supported_list = ", ".join(sorted(_SUPPORTED_PROVIDERS))
msg = (
f"Unable to infer model provider for {model=}. "
f"Please specify 'model_provider' directly.\n\n"
f"Supported providers: {supported_list}\n\n"
f"For help with specific providers, see: "
f"https://docs.langchain.com/oss/python/integrations/providers"
)
raise ValueError(msg)
# Normalize provider name
model_provider = model_provider.replace("-", "_").lower()
return model, model_provider
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
msg = (
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
)
raise ImportError(msg)
_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def __init__(
self,
*,
default_config: dict | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
config_prefix: str = "",
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
) -> None:
self._default_config: dict = default_config or {}
self._configurable_fields: Literal["any"] | list[str] = (
configurable_fields
if configurable_fields == "any"
else list(configurable_fields)
)
self._config_prefix = (
config_prefix + "_"
if config_prefix and not config_prefix.endswith("_")
else config_prefix
)
self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
queued_declarative_operations,
)
def __getattr__(self, name: str) -> Any:
if name in _DECLARATIVE_METHODS:
# Declarative operations that cannot be applied until after an actual model
# object is instantiated. So instead of returning the actual operation,
# we record the operation and its arguments in a queue. This queue is
# then applied in order whenever we actually instantiate the model (in
# self._model()).
def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
queued_declarative_operations = list(
self._queued_declarative_operations,
)
queued_declarative_operations.append((name, args, kwargs))
return _ConfigurableModel(
default_config=dict(self._default_config),
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
return queue
if self._default_config and (model := self._model()) and hasattr(model, name):
return getattr(model, name)
msg = f"{name} is not a BaseChatModel attribute"
if self._default_config:
msg += " and is not implemented on the default model"
msg += "."
raise AttributeError(msg)
def _model(self, config: RunnableConfig | None = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)}
model = _init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs)
return model
def _model_params(self, config: RunnableConfig | None) -> dict:
config = ensure_config(config)
model_params = {
k.removeprefix(self._config_prefix): v
for k, v in config.get("configurable", {}).items()
if k.startswith(self._config_prefix)
}
if self._configurable_fields != "any":
model_params = {
k: v for k, v in model_params.items() if k in self._configurable_fields
}
return model_params
def with_config(
self,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> _ConfigurableModel:
"""Bind config to a `Runnable`, returning a new `Runnable`."""
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
model_params = self._model_params(config)
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
remaining_config["configurable"] = {
k: v
for k, v in config.get("configurable", {}).items()
if k.removeprefix(self._config_prefix) not in model_params
}
queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config:
queued_declarative_operations.append(
(
"with_config",
(),
{"config": remaining_config},
),
)
return _ConfigurableModel(
default_config={**self._default_config, **model_params},
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
@property
@override
def InputType(self) -> TypeAlias:
"""Get the input type for this `Runnable`."""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
@override
def invoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return self._model(config).invoke(input, config=config, **kwargs)
@override
async def ainvoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs)
@override
def stream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs)
@override
async def astream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).astream(input, config=config, **kwargs):
yield x
def batch(
self,
inputs: list[LanguageModelInput],
config: RunnableConfig | list[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return self._model(config).batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return super().batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch(
self,
inputs: list[LanguageModelInput],
config: RunnableConfig | list[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return await self._model(config).abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return await super().abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
def batch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> Iterator[tuple[int, Any | Exception]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
yield from super().batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> AsyncIterator[tuple[int, Any]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
async for x in self._model(
cast("RunnableConfig", config),
).abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
async for x in super().abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
@override
def transform(
self,
input: Iterator[LanguageModelInput],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).transform(input, config=config, **kwargs)
@override
async def atransform(
self,
input: AsyncIterator[LanguageModelInput],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).atransform(input, config=config, **kwargs):
yield x
@overload
def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: Literal[True] = True,
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]: ...
@overload
def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: Literal[False],
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLog]: ...
@override
async def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: bool = True,
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
input,
config=config,
diff=diff,
with_streamed_output_list=with_streamed_output_list,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
@override
async def astream_events(
self,
input: Any,
config: RunnableConfig | None = None,
*,
version: Literal["v1", "v2"] = "v2",
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
async for x in self._model(config).astream_events(
input,
config=config,
version=version,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
# Explicitly added to satisfy downstream linters.
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters.
def with_structured_output(
self,
schema: dict | type[BaseModel],
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict | BaseModel]:
return self.__getattr__("with_structured_output")(schema, **kwargs)