group-wbl/.venv/lib/python3.13/site-packages/langchain/agents/middleware/_retry.py
2026-01-09 09:48:03 +08:00

124 lines
3.7 KiB
Python

"""Shared retry utilities for agent middleware.
This module contains common constants, utilities, and logic used by both
model and tool retry middleware implementations.
"""
from __future__ import annotations
import random
from collections.abc import Callable
from typing import Literal
# Type aliases
RetryOn = tuple[type[Exception], ...] | Callable[[Exception], bool]
"""Type for specifying which exceptions to retry on.
Can be either:
- A tuple of exception types to retry on (based on `isinstance` checks)
- A callable that takes an exception and returns `True` if it should be retried
"""
OnFailure = Literal["error", "continue"] | Callable[[Exception], str]
"""Type for specifying failure handling behavior.
Can be either:
- A literal action string (`'error'` or `'continue'`)
- `'error'`: Re-raise the exception, stopping agent execution.
- `'continue'`: Inject a message with the error details, allowing the agent to continue.
For tool retries, a `ToolMessage` with the error details will be injected.
For model retries, an `AIMessage` with the error details will be returned.
- A callable that takes an exception and returns a string for error message content
"""
def validate_retry_params(
max_retries: int,
initial_delay: float,
max_delay: float,
backoff_factor: float,
) -> None:
"""Validate retry parameters.
Args:
max_retries: Maximum number of retry attempts.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
backoff_factor: Multiplier for exponential backoff.
Raises:
ValueError: If any parameter is invalid (negative values).
"""
if max_retries < 0:
msg = "max_retries must be >= 0"
raise ValueError(msg)
if initial_delay < 0:
msg = "initial_delay must be >= 0"
raise ValueError(msg)
if max_delay < 0:
msg = "max_delay must be >= 0"
raise ValueError(msg)
if backoff_factor < 0:
msg = "backoff_factor must be >= 0"
raise ValueError(msg)
def should_retry_exception(
exc: Exception,
retry_on: RetryOn,
) -> bool:
"""Check if an exception should trigger a retry.
Args:
exc: The exception that occurred.
retry_on: Either a tuple of exception types to retry on, or a callable
that takes an exception and returns `True` if it should be retried.
Returns:
`True` if the exception should be retried, `False` otherwise.
"""
if callable(retry_on):
return retry_on(exc)
return isinstance(exc, retry_on)
def calculate_delay(
retry_number: int,
*,
backoff_factor: float,
initial_delay: float,
max_delay: float,
jitter: bool,
) -> float:
"""Calculate delay for a retry attempt with exponential backoff and optional jitter.
Args:
retry_number: The retry attempt number (0-indexed).
backoff_factor: Multiplier for exponential backoff.
Set to `0.0` for constant delay.
initial_delay: Initial delay in seconds before first retry.
max_delay: Maximum delay in seconds between retries.
Caps exponential backoff growth.
jitter: Whether to add random jitter to delay to avoid thundering herd.
Returns:
Delay in seconds before next retry.
"""
if backoff_factor == 0.0:
delay = initial_delay
else:
delay = initial_delay * (backoff_factor**retry_number)
# Cap at max_delay
delay = min(delay, max_delay)
if jitter and delay > 0:
jitter_amount = delay * 0.25 # ±25% jitter
delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
# Ensure delay is not negative after jitter
delay = max(0, delay)
return delay