group-wbl/.venv/lib/python3.13/site-packages/langchain/agents/middleware/model_retry.py

301 lines
10 KiB
Python
Raw Permalink Normal View History

2026-01-09 09:48:03 +08:00
"""Model retry middleware for agents."""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING
from langchain_core.messages import AIMessage
from langchain.agents.middleware._retry import (
OnFailure,
RetryOn,
calculate_delay,
should_retry_exception,
validate_retry_params,
)
from langchain.agents.middleware.types import AgentMiddleware, ModelResponse
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain.agents.middleware.types import ModelRequest
class ModelRetryMiddleware(AgentMiddleware):
"""Middleware that automatically retries failed model calls with configurable backoff.
Supports retrying on specific exceptions and exponential backoff.
Examples:
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
```python
from langchain.agents import create_agent
from langchain.agents.middleware import ModelRetryMiddleware
agent = create_agent(model, tools=[search_tool], middleware=[ModelRetryMiddleware()])
```
!!! example "Retry specific exceptions only"
```python
from anthropic import RateLimitError
from openai import APITimeoutError
retry = ModelRetryMiddleware(
max_retries=4,
retry_on=(APITimeoutError, RateLimitError),
backoff_factor=1.5,
)
```
!!! example "Custom exception filtering"
```python
from anthropic import APIStatusError
def should_retry(exc: Exception) -> bool:
# Only retry on 5xx errors
if isinstance(exc, APIStatusError):
return 500 <= exc.status_code < 600
return False
retry = ModelRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
!!! example "Custom error handling"
```python
def format_error(exc: Exception) -> str:
return "Model temporarily unavailable. Please try again later."
retry = ModelRetryMiddleware(
max_retries=4,
on_failure=format_error,
)
```
!!! example "Constant backoff (no exponential growth)"
```python
retry = ModelRetryMiddleware(
max_retries=5,
backoff_factor=0.0, # No exponential growth
initial_delay=2.0, # Always wait 2 seconds
)
```
!!! example "Raise exception on failure"
```python
retry = ModelRetryMiddleware(
max_retries=2,
on_failure="error", # Re-raise exception instead of returning message
)
```
"""
def __init__(
self,
*,
max_retries: int = 2,
retry_on: RetryOn = (Exception,),
on_failure: OnFailure = "continue",
backoff_factor: float = 2.0,
initial_delay: float = 1.0,
max_delay: float = 60.0,
jitter: bool = True,
) -> None:
"""Initialize `ModelRetryMiddleware`.
Args:
max_retries: Maximum number of retry attempts after the initial call.
Must be `>= 0`.
retry_on: Either a tuple of exception types to retry on, or a callable
that takes an exception and returns `True` if it should be retried.
Default is to retry on all exceptions.
on_failure: Behavior when all retries are exhausted.
Options:
- `'continue'`: Return an `AIMessage` with error details,
allowing the agent to continue with an error response.
- `'error'`: Re-raise the exception, stopping agent execution.
- **Custom callable:** Function that takes the exception and returns a
string for the `AIMessage` content, allowing custom error
formatting.
backoff_factor: Multiplier for exponential backoff.
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
seconds.
Set to `0.0` for constant delay.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
Caps exponential backoff growth.
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
Raises:
ValueError: If `max_retries < 0` or delays are negative.
"""
super().__init__()
# Validate parameters
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
self.max_retries = max_retries
self.tools = [] # No additional tools registered by this middleware
self.retry_on = retry_on
self.on_failure = on_failure
self.backoff_factor = backoff_factor
self.initial_delay = initial_delay
self.max_delay = max_delay
self.jitter = jitter
def _format_failure_message(self, exc: Exception, attempts_made: int) -> AIMessage:
"""Format the failure message when retries are exhausted.
Args:
exc: The exception that caused the failure.
attempts_made: Number of attempts actually made.
Returns:
`AIMessage` with formatted error message.
"""
exc_type = type(exc).__name__
exc_msg = str(exc)
attempt_word = "attempt" if attempts_made == 1 else "attempts"
content = (
f"Model call failed after {attempts_made} {attempt_word} with {exc_type}: {exc_msg}"
)
return AIMessage(content=content)
def _handle_failure(self, exc: Exception, attempts_made: int) -> ModelResponse:
"""Handle failure when all retries are exhausted.
Args:
exc: The exception that caused the failure.
attempts_made: Number of attempts actually made.
Returns:
`ModelResponse` with error details.
Raises:
Exception: If `on_failure` is `'error'`, re-raises the exception.
"""
if self.on_failure == "error":
raise exc
if callable(self.on_failure):
content = self.on_failure(exc)
ai_msg = AIMessage(content=content)
else:
ai_msg = self._format_failure_message(exc, attempts_made)
return ModelResponse(result=[ai_msg])
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse | AIMessage:
"""Intercept model execution and retry on failure.
Args:
request: Model request with model, messages, state, and runtime.
handler: Callable to execute the model (can be called multiple times).
Returns:
`ModelResponse` or `AIMessage` (the final result).
"""
# Initial attempt + retries
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed
# Check if we should retry this exception
if not should_retry_exception(exc, self.retry_on):
# Exception is not retryable, handle failure immediately
return self._handle_failure(exc, attempts_made)
# Check if we have more retries left
if attempt < self.max_retries:
# Calculate and apply backoff delay
delay = calculate_delay(
attempt,
backoff_factor=self.backoff_factor,
initial_delay=self.initial_delay,
max_delay=self.max_delay,
jitter=self.jitter,
)
if delay > 0:
time.sleep(delay)
# Continue to next retry
else:
# No more retries, handle failure
return self._handle_failure(exc, attempts_made)
# Unreachable: loop always returns via handler success or _handle_failure
msg = "Unexpected: retry loop completed without returning"
raise RuntimeError(msg)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse | AIMessage:
"""Intercept and control async model execution with retry logic.
Args:
request: Model request with model, messages, state, and runtime.
handler: Async callable to execute the model and returns `ModelResponse`.
Returns:
`ModelResponse` or `AIMessage` (the final result).
"""
# Initial attempt + retries
for attempt in range(self.max_retries + 1):
try:
return await handler(request)
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed
# Check if we should retry this exception
if not should_retry_exception(exc, self.retry_on):
# Exception is not retryable, handle failure immediately
return self._handle_failure(exc, attempts_made)
# Check if we have more retries left
if attempt < self.max_retries:
# Calculate and apply backoff delay
delay = calculate_delay(
attempt,
backoff_factor=self.backoff_factor,
initial_delay=self.initial_delay,
max_delay=self.max_delay,
jitter=self.jitter,
)
if delay > 0:
await asyncio.sleep(delay)
# Continue to next retry
else:
# No more retries, handle failure
return self._handle_failure(exc, attempts_made)
# Unreachable: loop always returns via handler success or _handle_failure
msg = "Unexpected: retry loop completed without returning"
raise RuntimeError(msg)