136 lines
3.9 KiB
Python
136 lines
3.9 KiB
Python
"""Model fallback middleware for agents."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from langchain.agents.middleware.types import (
|
|
AgentMiddleware,
|
|
ModelCallResult,
|
|
ModelRequest,
|
|
ModelResponse,
|
|
)
|
|
from langchain.chat_models import init_chat_model
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Awaitable, Callable
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
|
|
class ModelFallbackMiddleware(AgentMiddleware):
|
|
"""Automatic fallback to alternative models on errors.
|
|
|
|
Retries failed model calls with alternative models in sequence until
|
|
success or all models exhausted. Primary model specified in `create_agent`.
|
|
|
|
Example:
|
|
```python
|
|
from langchain.agents.middleware.model_fallback import ModelFallbackMiddleware
|
|
from langchain.agents import create_agent
|
|
|
|
fallback = ModelFallbackMiddleware(
|
|
"openai:gpt-4o-mini", # Try first on error
|
|
"anthropic:claude-sonnet-4-5-20250929", # Then this
|
|
)
|
|
|
|
agent = create_agent(
|
|
model="openai:gpt-4o", # Primary model
|
|
middleware=[fallback],
|
|
)
|
|
|
|
# If primary fails: tries gpt-4o-mini, then claude-sonnet-4-5-20250929
|
|
result = await agent.invoke({"messages": [HumanMessage("Hello")]})
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
first_model: str | BaseChatModel,
|
|
*additional_models: str | BaseChatModel,
|
|
) -> None:
|
|
"""Initialize model fallback middleware.
|
|
|
|
Args:
|
|
first_model: First fallback model (string name or instance).
|
|
*additional_models: Additional fallbacks in order.
|
|
"""
|
|
super().__init__()
|
|
|
|
# Initialize all fallback models
|
|
all_models = (first_model, *additional_models)
|
|
self.models: list[BaseChatModel] = []
|
|
for model in all_models:
|
|
if isinstance(model, str):
|
|
self.models.append(init_chat_model(model))
|
|
else:
|
|
self.models.append(model)
|
|
|
|
def wrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
) -> ModelCallResult:
|
|
"""Try fallback models in sequence on errors.
|
|
|
|
Args:
|
|
request: Initial model request.
|
|
handler: Callback to execute the model.
|
|
|
|
Returns:
|
|
AIMessage from successful model call.
|
|
|
|
Raises:
|
|
Exception: If all models fail, re-raises last exception.
|
|
"""
|
|
# Try primary model first
|
|
last_exception: Exception
|
|
try:
|
|
return handler(request)
|
|
except Exception as e:
|
|
last_exception = e
|
|
|
|
# Try fallback models
|
|
for fallback_model in self.models:
|
|
try:
|
|
return handler(request.override(model=fallback_model))
|
|
except Exception as e:
|
|
last_exception = e
|
|
continue
|
|
|
|
raise last_exception
|
|
|
|
async def awrap_model_call(
|
|
self,
|
|
request: ModelRequest,
|
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
) -> ModelCallResult:
|
|
"""Try fallback models in sequence on errors (async version).
|
|
|
|
Args:
|
|
request: Initial model request.
|
|
handler: Async callback to execute the model.
|
|
|
|
Returns:
|
|
AIMessage from successful model call.
|
|
|
|
Raises:
|
|
Exception: If all models fail, re-raises last exception.
|
|
"""
|
|
# Try primary model first
|
|
last_exception: Exception
|
|
try:
|
|
return await handler(request)
|
|
except Exception as e:
|
|
last_exception = e
|
|
|
|
# Try fallback models
|
|
for fallback_model in self.models:
|
|
try:
|
|
return await handler(request.override(model=fallback_model))
|
|
except Exception as e:
|
|
last_exception = e
|
|
continue
|
|
|
|
raise last_exception
|