397 lines
14 KiB
Python
397 lines
14 KiB
Python
|
|
"""Tool retry middleware for agents."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import time
|
||
|
|
import warnings
|
||
|
|
from typing import TYPE_CHECKING
|
||
|
|
|
||
|
|
from langchain_core.messages import ToolMessage
|
||
|
|
|
||
|
|
from langchain.agents.middleware._retry import (
|
||
|
|
OnFailure,
|
||
|
|
RetryOn,
|
||
|
|
calculate_delay,
|
||
|
|
should_retry_exception,
|
||
|
|
validate_retry_params,
|
||
|
|
)
|
||
|
|
from langchain.agents.middleware.types import AgentMiddleware
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from collections.abc import Awaitable, Callable
|
||
|
|
|
||
|
|
from langgraph.types import Command
|
||
|
|
|
||
|
|
from langchain.agents.middleware.types import ToolCallRequest
|
||
|
|
from langchain.tools import BaseTool
|
||
|
|
|
||
|
|
|
||
|
|
class ToolRetryMiddleware(AgentMiddleware):
|
||
|
|
"""Middleware that automatically retries failed tool calls with configurable backoff.
|
||
|
|
|
||
|
|
Supports retrying on specific exceptions and exponential backoff.
|
||
|
|
|
||
|
|
Examples:
|
||
|
|
!!! example "Basic usage with default settings (2 retries, exponential backoff)"
|
||
|
|
|
||
|
|
```python
|
||
|
|
from langchain.agents import create_agent
|
||
|
|
from langchain.agents.middleware import ToolRetryMiddleware
|
||
|
|
|
||
|
|
agent = create_agent(model, tools=[search_tool], middleware=[ToolRetryMiddleware()])
|
||
|
|
```
|
||
|
|
|
||
|
|
!!! example "Retry specific exceptions only"
|
||
|
|
|
||
|
|
```python
|
||
|
|
from requests.exceptions import RequestException, Timeout
|
||
|
|
|
||
|
|
retry = ToolRetryMiddleware(
|
||
|
|
max_retries=4,
|
||
|
|
retry_on=(RequestException, Timeout),
|
||
|
|
backoff_factor=1.5,
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
!!! example "Custom exception filtering"
|
||
|
|
|
||
|
|
```python
|
||
|
|
from requests.exceptions import HTTPError
|
||
|
|
|
||
|
|
|
||
|
|
def should_retry(exc: Exception) -> bool:
|
||
|
|
# Only retry on 5xx errors
|
||
|
|
if isinstance(exc, HTTPError):
|
||
|
|
return 500 <= exc.status_code < 600
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
retry = ToolRetryMiddleware(
|
||
|
|
max_retries=3,
|
||
|
|
retry_on=should_retry,
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
!!! example "Apply to specific tools with custom error handling"
|
||
|
|
|
||
|
|
```python
|
||
|
|
def format_error(exc: Exception) -> str:
|
||
|
|
return "Database temporarily unavailable. Please try again later."
|
||
|
|
|
||
|
|
|
||
|
|
retry = ToolRetryMiddleware(
|
||
|
|
max_retries=4,
|
||
|
|
tools=["search_database"],
|
||
|
|
on_failure=format_error,
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
!!! example "Apply to specific tools using `BaseTool` instances"
|
||
|
|
|
||
|
|
```python
|
||
|
|
from langchain_core.tools import tool
|
||
|
|
|
||
|
|
|
||
|
|
@tool
|
||
|
|
def search_database(query: str) -> str:
|
||
|
|
'''Search the database.'''
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
retry = ToolRetryMiddleware(
|
||
|
|
max_retries=4,
|
||
|
|
tools=[search_database], # Pass BaseTool instance
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
!!! example "Constant backoff (no exponential growth)"
|
||
|
|
|
||
|
|
```python
|
||
|
|
retry = ToolRetryMiddleware(
|
||
|
|
max_retries=5,
|
||
|
|
backoff_factor=0.0, # No exponential growth
|
||
|
|
initial_delay=2.0, # Always wait 2 seconds
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
!!! example "Raise exception on failure"
|
||
|
|
|
||
|
|
```python
|
||
|
|
retry = ToolRetryMiddleware(
|
||
|
|
max_retries=2,
|
||
|
|
on_failure="error", # Re-raise exception instead of returning message
|
||
|
|
)
|
||
|
|
```
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
max_retries: int = 2,
|
||
|
|
tools: list[BaseTool | str] | None = None,
|
||
|
|
retry_on: RetryOn = (Exception,),
|
||
|
|
on_failure: OnFailure = "continue",
|
||
|
|
backoff_factor: float = 2.0,
|
||
|
|
initial_delay: float = 1.0,
|
||
|
|
max_delay: float = 60.0,
|
||
|
|
jitter: bool = True,
|
||
|
|
) -> None:
|
||
|
|
"""Initialize `ToolRetryMiddleware`.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
max_retries: Maximum number of retry attempts after the initial call.
|
||
|
|
|
||
|
|
Must be `>= 0`.
|
||
|
|
tools: Optional list of tools or tool names to apply retry logic to.
|
||
|
|
|
||
|
|
Can be a list of `BaseTool` instances or tool name strings.
|
||
|
|
|
||
|
|
If `None`, applies to all tools.
|
||
|
|
retry_on: Either a tuple of exception types to retry on, or a callable
|
||
|
|
that takes an exception and returns `True` if it should be retried.
|
||
|
|
|
||
|
|
Default is to retry on all exceptions.
|
||
|
|
on_failure: Behavior when all retries are exhausted.
|
||
|
|
|
||
|
|
Options:
|
||
|
|
|
||
|
|
- `'continue'`: Return a `ToolMessage` with error details,
|
||
|
|
allowing the LLM to handle the failure and potentially recover.
|
||
|
|
- `'error'`: Re-raise the exception, stopping agent execution.
|
||
|
|
- **Custom callable:** Function that takes the exception and returns a
|
||
|
|
string for the `ToolMessage` content, allowing custom error
|
||
|
|
formatting.
|
||
|
|
|
||
|
|
**Deprecated values** (for backwards compatibility):
|
||
|
|
|
||
|
|
- `'return_message'`: Use `'continue'` instead.
|
||
|
|
- `'raise'`: Use `'error'` instead.
|
||
|
|
backoff_factor: Multiplier for exponential backoff.
|
||
|
|
|
||
|
|
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
|
||
|
|
seconds.
|
||
|
|
|
||
|
|
Set to `0.0` for constant delay.
|
||
|
|
initial_delay: Initial delay in seconds before first retry.
|
||
|
|
max_delay: Maximum delay in seconds between retries.
|
||
|
|
|
||
|
|
Caps exponential backoff growth.
|
||
|
|
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If `max_retries < 0` or delays are negative.
|
||
|
|
"""
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
# Validate parameters
|
||
|
|
validate_retry_params(max_retries, initial_delay, max_delay, backoff_factor)
|
||
|
|
|
||
|
|
# Handle backwards compatibility for deprecated on_failure values
|
||
|
|
if on_failure == "raise": # type: ignore[comparison-overlap]
|
||
|
|
msg = (
|
||
|
|
"on_failure='raise' is deprecated and will be removed in a future version. "
|
||
|
|
"Use on_failure='error' instead."
|
||
|
|
)
|
||
|
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||
|
|
on_failure = "error"
|
||
|
|
elif on_failure == "return_message": # type: ignore[comparison-overlap]
|
||
|
|
msg = (
|
||
|
|
"on_failure='return_message' is deprecated and will be removed "
|
||
|
|
"in a future version. Use on_failure='continue' instead."
|
||
|
|
)
|
||
|
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||
|
|
on_failure = "continue"
|
||
|
|
|
||
|
|
self.max_retries = max_retries
|
||
|
|
|
||
|
|
# Extract tool names from BaseTool instances or strings
|
||
|
|
self._tool_filter: list[str] | None
|
||
|
|
if tools is not None:
|
||
|
|
self._tool_filter = [tool.name if not isinstance(tool, str) else tool for tool in tools]
|
||
|
|
else:
|
||
|
|
self._tool_filter = None
|
||
|
|
|
||
|
|
self.tools = [] # No additional tools registered by this middleware
|
||
|
|
self.retry_on = retry_on
|
||
|
|
self.on_failure = on_failure
|
||
|
|
self.backoff_factor = backoff_factor
|
||
|
|
self.initial_delay = initial_delay
|
||
|
|
self.max_delay = max_delay
|
||
|
|
self.jitter = jitter
|
||
|
|
|
||
|
|
def _should_retry_tool(self, tool_name: str) -> bool:
|
||
|
|
"""Check if retry logic should apply to this tool.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tool_name: Name of the tool being called.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`True` if retry logic should apply, `False` otherwise.
|
||
|
|
"""
|
||
|
|
if self._tool_filter is None:
|
||
|
|
return True
|
||
|
|
return tool_name in self._tool_filter
|
||
|
|
|
||
|
|
def _format_failure_message(self, tool_name: str, exc: Exception, attempts_made: int) -> str:
|
||
|
|
"""Format the failure message when retries are exhausted.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tool_name: Name of the tool that failed.
|
||
|
|
exc: The exception that caused the failure.
|
||
|
|
attempts_made: Number of attempts actually made.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Formatted error message string.
|
||
|
|
"""
|
||
|
|
exc_type = type(exc).__name__
|
||
|
|
exc_msg = str(exc)
|
||
|
|
attempt_word = "attempt" if attempts_made == 1 else "attempts"
|
||
|
|
return (
|
||
|
|
f"Tool '{tool_name}' failed after {attempts_made} {attempt_word} "
|
||
|
|
f"with {exc_type}: {exc_msg}. Please try again."
|
||
|
|
)
|
||
|
|
|
||
|
|
def _handle_failure(
|
||
|
|
self, tool_name: str, tool_call_id: str | None, exc: Exception, attempts_made: int
|
||
|
|
) -> ToolMessage:
|
||
|
|
"""Handle failure when all retries are exhausted.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tool_name: Name of the tool that failed.
|
||
|
|
tool_call_id: ID of the tool call (may be `None`).
|
||
|
|
exc: The exception that caused the failure.
|
||
|
|
attempts_made: Number of attempts actually made.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`ToolMessage` with error details.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
Exception: If `on_failure` is `'error'`, re-raises the exception.
|
||
|
|
"""
|
||
|
|
if self.on_failure == "error":
|
||
|
|
raise exc
|
||
|
|
|
||
|
|
if callable(self.on_failure):
|
||
|
|
content = self.on_failure(exc)
|
||
|
|
else:
|
||
|
|
content = self._format_failure_message(tool_name, exc, attempts_made)
|
||
|
|
|
||
|
|
return ToolMessage(
|
||
|
|
content=content,
|
||
|
|
tool_call_id=tool_call_id,
|
||
|
|
name=tool_name,
|
||
|
|
status="error",
|
||
|
|
)
|
||
|
|
|
||
|
|
def wrap_tool_call(
|
||
|
|
self,
|
||
|
|
request: ToolCallRequest,
|
||
|
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||
|
|
) -> ToolMessage | Command:
|
||
|
|
"""Intercept tool execution and retry on failure.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Tool call request with call dict, `BaseTool`, state, and runtime.
|
||
|
|
handler: Callable to execute the tool (can be called multiple times).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`ToolMessage` or `Command` (the final result).
|
||
|
|
"""
|
||
|
|
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
||
|
|
|
||
|
|
# Check if retry should apply to this tool
|
||
|
|
if not self._should_retry_tool(tool_name):
|
||
|
|
return handler(request)
|
||
|
|
|
||
|
|
tool_call_id = request.tool_call["id"]
|
||
|
|
|
||
|
|
# Initial attempt + retries
|
||
|
|
for attempt in range(self.max_retries + 1):
|
||
|
|
try:
|
||
|
|
return handler(request)
|
||
|
|
except Exception as exc:
|
||
|
|
attempts_made = attempt + 1 # attempt is 0-indexed
|
||
|
|
|
||
|
|
# Check if we should retry this exception
|
||
|
|
if not should_retry_exception(exc, self.retry_on):
|
||
|
|
# Exception is not retryable, handle failure immediately
|
||
|
|
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||
|
|
|
||
|
|
# Check if we have more retries left
|
||
|
|
if attempt < self.max_retries:
|
||
|
|
# Calculate and apply backoff delay
|
||
|
|
delay = calculate_delay(
|
||
|
|
attempt,
|
||
|
|
backoff_factor=self.backoff_factor,
|
||
|
|
initial_delay=self.initial_delay,
|
||
|
|
max_delay=self.max_delay,
|
||
|
|
jitter=self.jitter,
|
||
|
|
)
|
||
|
|
if delay > 0:
|
||
|
|
time.sleep(delay)
|
||
|
|
# Continue to next retry
|
||
|
|
else:
|
||
|
|
# No more retries, handle failure
|
||
|
|
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||
|
|
|
||
|
|
# Unreachable: loop always returns via handler success or _handle_failure
|
||
|
|
msg = "Unexpected: retry loop completed without returning"
|
||
|
|
raise RuntimeError(msg)
|
||
|
|
|
||
|
|
async def awrap_tool_call(
|
||
|
|
self,
|
||
|
|
request: ToolCallRequest,
|
||
|
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||
|
|
) -> ToolMessage | Command:
|
||
|
|
"""Intercept and control async tool execution with retry logic.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Tool call request with call `dict`, `BaseTool`, state, and runtime.
|
||
|
|
handler: Async callable to execute the tool and returns `ToolMessage` or
|
||
|
|
`Command`.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
`ToolMessage` or `Command` (the final result).
|
||
|
|
"""
|
||
|
|
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
||
|
|
|
||
|
|
# Check if retry should apply to this tool
|
||
|
|
if not self._should_retry_tool(tool_name):
|
||
|
|
return await handler(request)
|
||
|
|
|
||
|
|
tool_call_id = request.tool_call["id"]
|
||
|
|
|
||
|
|
# Initial attempt + retries
|
||
|
|
for attempt in range(self.max_retries + 1):
|
||
|
|
try:
|
||
|
|
return await handler(request)
|
||
|
|
except Exception as exc:
|
||
|
|
attempts_made = attempt + 1 # attempt is 0-indexed
|
||
|
|
|
||
|
|
# Check if we should retry this exception
|
||
|
|
if not should_retry_exception(exc, self.retry_on):
|
||
|
|
# Exception is not retryable, handle failure immediately
|
||
|
|
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||
|
|
|
||
|
|
# Check if we have more retries left
|
||
|
|
if attempt < self.max_retries:
|
||
|
|
# Calculate and apply backoff delay
|
||
|
|
delay = calculate_delay(
|
||
|
|
attempt,
|
||
|
|
backoff_factor=self.backoff_factor,
|
||
|
|
initial_delay=self.initial_delay,
|
||
|
|
max_delay=self.max_delay,
|
||
|
|
jitter=self.jitter,
|
||
|
|
)
|
||
|
|
if delay > 0:
|
||
|
|
await asyncio.sleep(delay)
|
||
|
|
# Continue to next retry
|
||
|
|
else:
|
||
|
|
# No more retries, handle failure
|
||
|
|
return self._handle_failure(tool_name, tool_call_id, exc, attempts_made)
|
||
|
|
|
||
|
|
# Unreachable: loop always returns via handler success or _handle_failure
|
||
|
|
msg = "Unexpected: retry loop completed without returning"
|
||
|
|
raise RuntimeError(msg)
|