"""Context editing middleware. Mirrors Anthropic's context editing capabilities by clearing older tool results once the conversation grows beyond a configurable token threshold. The implementation is intentionally model-agnostic so it can be used with any LangChain chat model. """ from __future__ import annotations from collections.abc import Awaitable, Callable, Iterable, Sequence from copy import deepcopy from dataclasses import dataclass from typing import Literal from langchain_core.messages import ( AIMessage, AnyMessage, BaseMessage, ToolMessage, ) from langchain_core.messages.utils import count_tokens_approximately from typing_extensions import Protocol from langchain.agents.middleware.types import ( AgentMiddleware, ModelCallResult, ModelRequest, ModelResponse, ) DEFAULT_TOOL_PLACEHOLDER = "[cleared]" TokenCounter = Callable[ [Sequence[BaseMessage]], int, ] class ContextEdit(Protocol): """Protocol describing a context editing strategy.""" def apply( self, messages: list[AnyMessage], *, count_tokens: TokenCounter, ) -> None: """Apply an edit to the message list in place.""" ... @dataclass(slots=True) class ClearToolUsesEdit(ContextEdit): """Configuration for clearing tool outputs when token limits are exceeded.""" trigger: int = 100_000 """Token count that triggers the edit.""" clear_at_least: int = 0 """Minimum number of tokens to reclaim when the edit runs.""" keep: int = 3 """Number of most recent tool results that must be preserved.""" clear_tool_inputs: bool = False """Whether to clear the originating tool call parameters on the AI message.""" exclude_tools: Sequence[str] = () """List of tool names to exclude from clearing.""" placeholder: str = DEFAULT_TOOL_PLACEHOLDER """Placeholder text inserted for cleared tool outputs.""" def apply( self, messages: list[AnyMessage], *, count_tokens: TokenCounter, ) -> None: """Apply the clear-tool-uses strategy.""" tokens = count_tokens(messages) if tokens <= self.trigger: return candidates = [ (idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage) ] if self.keep >= len(candidates): candidates = [] elif self.keep: candidates = candidates[: -self.keep] cleared_tokens = 0 excluded_tools = set(self.exclude_tools) for idx, tool_message in candidates: if tool_message.response_metadata.get("context_editing", {}).get("cleared"): continue ai_message = next( (m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)), None ) if ai_message is None: continue tool_call = next( ( call for call in ai_message.tool_calls if call.get("id") == tool_message.tool_call_id ), None, ) if tool_call is None: continue if (tool_message.name or tool_call["name"]) in excluded_tools: continue messages[idx] = tool_message.model_copy( update={ "artifact": None, "content": self.placeholder, "response_metadata": { **tool_message.response_metadata, "context_editing": { "cleared": True, "strategy": "clear_tool_uses", }, }, } ) if self.clear_tool_inputs: messages[messages.index(ai_message)] = self._build_cleared_tool_input_message( ai_message, tool_message.tool_call_id, ) if self.clear_at_least > 0: new_token_count = count_tokens(messages) cleared_tokens = max(0, tokens - new_token_count) if cleared_tokens >= self.clear_at_least: break return def _build_cleared_tool_input_message( self, message: AIMessage, tool_call_id: str, ) -> AIMessage: updated_tool_calls = [] cleared_any = False for tool_call in message.tool_calls: updated_call = dict(tool_call) if updated_call.get("id") == tool_call_id: updated_call["args"] = {} cleared_any = True updated_tool_calls.append(updated_call) metadata = dict(getattr(message, "response_metadata", {})) context_entry = dict(metadata.get("context_editing", {})) if cleared_any: cleared_ids = set(context_entry.get("cleared_tool_inputs", [])) cleared_ids.add(tool_call_id) context_entry["cleared_tool_inputs"] = sorted(cleared_ids) metadata["context_editing"] = context_entry return message.model_copy( update={ "tool_calls": updated_tool_calls, "response_metadata": metadata, } ) class ContextEditingMiddleware(AgentMiddleware): """Automatically prune tool results to manage context size. The middleware applies a sequence of edits when the total input token count exceeds configured thresholds. Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's `clear_tool_uses_20250919` behavior [(read more)](https://platform.claude.com/docs/en/agents-and-tools/tool-use/memory-tool). """ edits: list[ContextEdit] token_count_method: Literal["approximate", "model"] def __init__( self, *, edits: Iterable[ContextEdit] | None = None, token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107 ) -> None: """Initialize an instance of context editing middleware. Args: edits: Sequence of edit strategies to apply. Defaults to a single `ClearToolUsesEdit` mirroring Anthropic defaults. token_count_method: Whether to use approximate token counting (faster, less accurate) or exact counting implemented by the chat model (potentially slower, more accurate). """ super().__init__() self.edits = list(edits or (ClearToolUsesEdit(),)) self.token_count_method = token_count_method def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: """Apply context edits before invoking the model via handler.""" if not request.messages: return handler(request) if self.token_count_method == "approximate": # noqa: S105 def count_tokens(messages: Sequence[BaseMessage]) -> int: return count_tokens_approximately(messages) else: system_msg = [request.system_message] if request.system_message else [] def count_tokens(messages: Sequence[BaseMessage]) -> int: return request.model.get_num_tokens_from_messages( system_msg + list(messages), request.tools ) edited_messages = deepcopy(list(request.messages)) for edit in self.edits: edit.apply(edited_messages, count_tokens=count_tokens) return handler(request.override(messages=edited_messages)) async def awrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: """Apply context edits before invoking the model via handler (async version).""" if not request.messages: return await handler(request) if self.token_count_method == "approximate": # noqa: S105 def count_tokens(messages: Sequence[BaseMessage]) -> int: return count_tokens_approximately(messages) else: system_msg = [request.system_message] if request.system_message else [] def count_tokens(messages: Sequence[BaseMessage]) -> int: return request.model.get_num_tokens_from_messages( system_msg + list(messages), request.tools ) edited_messages = deepcopy(list(request.messages)) for edit in self.edits: edit.apply(edited_messages, count_tokens=count_tokens) return await handler(request.override(messages=edited_messages)) __all__ = [ "ClearToolUsesEdit", "ContextEditingMiddleware", ]