761 lines
28 KiB
Python
761 lines
28 KiB
Python
"""Middleware that exposes a persistent shell tool to agents."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import contextlib
|
||
import logging
|
||
import os
|
||
import queue
|
||
import signal
|
||
import subprocess
|
||
import tempfile
|
||
import threading
|
||
import time
|
||
import uuid
|
||
import weakref
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||
|
||
from langchain_core.messages import ToolMessage
|
||
from langchain_core.tools.base import ToolException
|
||
from langgraph.channels.untracked_value import UntrackedValue
|
||
from pydantic import BaseModel, model_validator
|
||
from pydantic.json_schema import SkipJsonSchema
|
||
from typing_extensions import NotRequired, override
|
||
|
||
from langchain.agents.middleware._execution import (
|
||
SHELL_TEMP_PREFIX,
|
||
BaseExecutionPolicy,
|
||
CodexSandboxExecutionPolicy,
|
||
DockerExecutionPolicy,
|
||
HostExecutionPolicy,
|
||
)
|
||
from langchain.agents.middleware._redaction import (
|
||
PIIDetectionError,
|
||
PIIMatch,
|
||
RedactionRule,
|
||
ResolvedRedactionRule,
|
||
)
|
||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
|
||
from langchain.tools import ToolRuntime, tool
|
||
|
||
if TYPE_CHECKING:
|
||
from collections.abc import Mapping, Sequence
|
||
|
||
from langgraph.runtime import Runtime
|
||
|
||
|
||
LOGGER = logging.getLogger(__name__)
|
||
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
|
||
|
||
DEFAULT_TOOL_DESCRIPTION = (
|
||
"Execute a shell command inside a persistent session. Before running a command, "
|
||
"confirm the working directory is correct (e.g., inspect with `ls` or `pwd`) and ensure "
|
||
"any parent directories exist. Prefer absolute paths and quote paths containing spaces, "
|
||
'such as `cd "/path/with spaces"`. Chain multiple commands with `&&` or `;` instead of '
|
||
"embedding newlines. Avoid unnecessary `cd` usage unless explicitly required so the "
|
||
"session remains stable. Outputs may be truncated when they become very large, and long "
|
||
"running commands will be terminated once their configured timeout elapses."
|
||
)
|
||
SHELL_TOOL_NAME = "shell"
|
||
|
||
|
||
def _cleanup_resources(
|
||
session: ShellSession, tempdir: tempfile.TemporaryDirectory[str] | None, timeout: float
|
||
) -> None:
|
||
with contextlib.suppress(Exception):
|
||
session.stop(timeout)
|
||
if tempdir is not None:
|
||
with contextlib.suppress(Exception):
|
||
tempdir.cleanup()
|
||
|
||
|
||
@dataclass
|
||
class _SessionResources:
|
||
"""Container for per-run shell resources."""
|
||
|
||
session: ShellSession
|
||
tempdir: tempfile.TemporaryDirectory[str] | None
|
||
policy: BaseExecutionPolicy
|
||
finalizer: weakref.finalize = field(init=False, repr=False)
|
||
|
||
def __post_init__(self) -> None:
|
||
self.finalizer = weakref.finalize(
|
||
self,
|
||
_cleanup_resources,
|
||
self.session,
|
||
self.tempdir,
|
||
self.policy.termination_timeout,
|
||
)
|
||
|
||
|
||
class ShellToolState(AgentState):
|
||
"""Agent state extension for tracking shell session resources."""
|
||
|
||
shell_session_resources: NotRequired[
|
||
Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr]
|
||
]
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class CommandExecutionResult:
|
||
"""Structured result from command execution."""
|
||
|
||
output: str
|
||
exit_code: int | None
|
||
timed_out: bool
|
||
truncated_by_lines: bool
|
||
truncated_by_bytes: bool
|
||
total_lines: int
|
||
total_bytes: int
|
||
|
||
|
||
class ShellSession:
|
||
"""Persistent shell session that supports sequential command execution."""
|
||
|
||
def __init__(
|
||
self,
|
||
workspace: Path,
|
||
policy: BaseExecutionPolicy,
|
||
command: tuple[str, ...],
|
||
environment: Mapping[str, str],
|
||
) -> None:
|
||
self._workspace = workspace
|
||
self._policy = policy
|
||
self._command = command
|
||
self._environment = dict(environment)
|
||
self._process: subprocess.Popen[str] | None = None
|
||
self._stdin: Any = None
|
||
self._queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
|
||
self._lock = threading.Lock()
|
||
self._stdout_thread: threading.Thread | None = None
|
||
self._stderr_thread: threading.Thread | None = None
|
||
self._terminated = False
|
||
|
||
def start(self) -> None:
|
||
"""Start the shell subprocess and reader threads."""
|
||
if self._process and self._process.poll() is None:
|
||
return
|
||
|
||
self._process = self._policy.spawn(
|
||
workspace=self._workspace,
|
||
env=self._environment,
|
||
command=self._command,
|
||
)
|
||
if (
|
||
self._process.stdin is None
|
||
or self._process.stdout is None
|
||
or self._process.stderr is None
|
||
):
|
||
msg = "Failed to initialize shell session pipes."
|
||
raise RuntimeError(msg)
|
||
|
||
self._stdin = self._process.stdin
|
||
self._terminated = False
|
||
self._queue = queue.Queue()
|
||
|
||
self._stdout_thread = threading.Thread(
|
||
target=self._enqueue_stream,
|
||
args=(self._process.stdout, "stdout"),
|
||
daemon=True,
|
||
)
|
||
self._stderr_thread = threading.Thread(
|
||
target=self._enqueue_stream,
|
||
args=(self._process.stderr, "stderr"),
|
||
daemon=True,
|
||
)
|
||
self._stdout_thread.start()
|
||
self._stderr_thread.start()
|
||
|
||
def restart(self) -> None:
|
||
"""Restart the shell process."""
|
||
self.stop(self._policy.termination_timeout)
|
||
self.start()
|
||
|
||
def stop(self, timeout: float) -> None:
|
||
"""Stop the shell subprocess."""
|
||
if not self._process:
|
||
return
|
||
|
||
if self._process.poll() is None and not self._terminated:
|
||
try:
|
||
self._stdin.write("exit\n")
|
||
self._stdin.flush()
|
||
except (BrokenPipeError, OSError):
|
||
LOGGER.debug(
|
||
"Failed to write exit command; terminating shell session.",
|
||
exc_info=True,
|
||
)
|
||
|
||
try:
|
||
if self._process.wait(timeout=timeout) is None:
|
||
self._kill_process()
|
||
except subprocess.TimeoutExpired:
|
||
self._kill_process()
|
||
finally:
|
||
self._terminated = True
|
||
with contextlib.suppress(Exception):
|
||
self._stdin.close()
|
||
self._process = None
|
||
|
||
def execute(self, command: str, *, timeout: float) -> CommandExecutionResult:
|
||
"""Execute a command in the persistent shell."""
|
||
if not self._process or self._process.poll() is not None:
|
||
msg = "Shell session is not running."
|
||
raise RuntimeError(msg)
|
||
|
||
marker = f"{_DONE_MARKER_PREFIX}{uuid.uuid4().hex}"
|
||
deadline = time.monotonic() + timeout
|
||
|
||
with self._lock:
|
||
self._drain_queue()
|
||
payload = command if command.endswith("\n") else f"{command}\n"
|
||
self._stdin.write(payload)
|
||
self._stdin.write(f"printf '{marker} %s\\n' $?\n")
|
||
self._stdin.flush()
|
||
|
||
return self._collect_output(marker, deadline, timeout)
|
||
|
||
def _collect_output(
|
||
self,
|
||
marker: str,
|
||
deadline: float,
|
||
timeout: float,
|
||
) -> CommandExecutionResult:
|
||
collected: list[str] = []
|
||
total_lines = 0
|
||
total_bytes = 0
|
||
truncated_by_lines = False
|
||
truncated_by_bytes = False
|
||
exit_code: int | None = None
|
||
timed_out = False
|
||
|
||
while True:
|
||
remaining = deadline - time.monotonic()
|
||
if remaining <= 0:
|
||
timed_out = True
|
||
break
|
||
try:
|
||
source, data = self._queue.get(timeout=remaining)
|
||
except queue.Empty:
|
||
timed_out = True
|
||
break
|
||
|
||
if data is None:
|
||
continue
|
||
|
||
if source == "stdout" and data.startswith(marker):
|
||
_, _, status = data.partition(" ")
|
||
exit_code = self._safe_int(status.strip())
|
||
# Drain any remaining stderr that may have arrived concurrently.
|
||
# The stderr reader thread runs independently, so output might
|
||
# still be in flight when the stdout marker arrives.
|
||
self._drain_remaining_stderr(collected, deadline)
|
||
break
|
||
|
||
total_lines += 1
|
||
encoded = data.encode("utf-8", "replace")
|
||
total_bytes += len(encoded)
|
||
|
||
if total_lines > self._policy.max_output_lines:
|
||
truncated_by_lines = True
|
||
continue
|
||
|
||
if (
|
||
self._policy.max_output_bytes is not None
|
||
and total_bytes > self._policy.max_output_bytes
|
||
):
|
||
truncated_by_bytes = True
|
||
continue
|
||
|
||
if source == "stderr":
|
||
stripped = data.rstrip("\n")
|
||
collected.append(f"[stderr] {stripped}")
|
||
if data.endswith("\n"):
|
||
collected.append("\n")
|
||
else:
|
||
collected.append(data)
|
||
|
||
if timed_out:
|
||
LOGGER.warning(
|
||
"Command timed out after %.2f seconds; restarting shell session.",
|
||
timeout,
|
||
)
|
||
self.restart()
|
||
return CommandExecutionResult(
|
||
output="",
|
||
exit_code=None,
|
||
timed_out=True,
|
||
truncated_by_lines=truncated_by_lines,
|
||
truncated_by_bytes=truncated_by_bytes,
|
||
total_lines=total_lines,
|
||
total_bytes=total_bytes,
|
||
)
|
||
|
||
output = "".join(collected)
|
||
return CommandExecutionResult(
|
||
output=output,
|
||
exit_code=exit_code,
|
||
timed_out=False,
|
||
truncated_by_lines=truncated_by_lines,
|
||
truncated_by_bytes=truncated_by_bytes,
|
||
total_lines=total_lines,
|
||
total_bytes=total_bytes,
|
||
)
|
||
|
||
def _kill_process(self) -> None:
|
||
if not self._process:
|
||
return
|
||
|
||
if hasattr(os, "killpg"):
|
||
with contextlib.suppress(ProcessLookupError):
|
||
os.killpg(os.getpgid(self._process.pid), signal.SIGKILL)
|
||
else: # pragma: no cover
|
||
with contextlib.suppress(ProcessLookupError):
|
||
self._process.kill()
|
||
|
||
def _enqueue_stream(self, stream: Any, label: str) -> None:
|
||
for line in iter(stream.readline, ""):
|
||
self._queue.put((label, line))
|
||
self._queue.put((label, None))
|
||
|
||
def _drain_queue(self) -> None:
|
||
while True:
|
||
try:
|
||
self._queue.get_nowait()
|
||
except queue.Empty:
|
||
break
|
||
|
||
def _drain_remaining_stderr(
|
||
self, collected: list[str], deadline: float, drain_timeout: float = 0.05
|
||
) -> None:
|
||
"""Drain any stderr output that arrived concurrently with the done marker.
|
||
|
||
The stdout and stderr reader threads run independently. When a command writes to
|
||
stderr just before exiting, the stderr output may still be in transit when the
|
||
done marker arrives on stdout. This method briefly polls the queue to capture
|
||
such output.
|
||
|
||
Args:
|
||
collected: The list to append collected stderr lines to.
|
||
deadline: The original command deadline (used as an upper bound).
|
||
drain_timeout: Maximum time to wait for additional stderr output.
|
||
"""
|
||
drain_deadline = min(time.monotonic() + drain_timeout, deadline)
|
||
while True:
|
||
remaining = drain_deadline - time.monotonic()
|
||
if remaining <= 0:
|
||
break
|
||
try:
|
||
source, data = self._queue.get(timeout=remaining)
|
||
except queue.Empty:
|
||
break
|
||
if data is None or source != "stderr":
|
||
continue
|
||
stripped = data.rstrip("\n")
|
||
collected.append(f"[stderr] {stripped}")
|
||
if data.endswith("\n"):
|
||
collected.append("\n")
|
||
|
||
@staticmethod
|
||
def _safe_int(value: str) -> int | None:
|
||
with contextlib.suppress(ValueError):
|
||
return int(value)
|
||
return None
|
||
|
||
|
||
class _ShellToolInput(BaseModel):
|
||
"""Input schema for the persistent shell tool."""
|
||
|
||
command: str | None = None
|
||
"""The shell command to execute."""
|
||
|
||
restart: bool | None = None
|
||
"""Whether to restart the shell session."""
|
||
|
||
runtime: Annotated[Any, SkipJsonSchema()] = None
|
||
"""The runtime for the shell tool.
|
||
|
||
Included as a workaround at the moment bc args_schema doesn't work with
|
||
injected ToolRuntime.
|
||
"""
|
||
|
||
@model_validator(mode="after")
|
||
def validate_payload(self) -> _ShellToolInput:
|
||
if self.command is None and not self.restart:
|
||
msg = "Shell tool requires either 'command' or 'restart'."
|
||
raise ValueError(msg)
|
||
if self.command is not None and self.restart:
|
||
msg = "Specify only one of 'command' or 'restart'."
|
||
raise ValueError(msg)
|
||
return self
|
||
|
||
|
||
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
||
"""Middleware that registers a persistent shell tool for agents.
|
||
|
||
The middleware exposes a single long-lived shell session. Use the execution policy
|
||
to match your deployment's security posture:
|
||
|
||
* `HostExecutionPolicy` – full host access; best for trusted environments where the
|
||
agent already runs inside a container or VM that provides isolation.
|
||
* `CodexSandboxExecutionPolicy` – reuses the Codex CLI sandbox for additional
|
||
syscall/filesystem restrictions when the CLI is available.
|
||
* `DockerExecutionPolicy` – launches a separate Docker container for each agent run,
|
||
providing harder isolation, optional read-only root filesystems, and user
|
||
remapping.
|
||
|
||
When no policy is provided the middleware defaults to `HostExecutionPolicy`.
|
||
"""
|
||
|
||
state_schema = ShellToolState
|
||
|
||
def __init__(
|
||
self,
|
||
workspace_root: str | Path | None = None,
|
||
*,
|
||
startup_commands: tuple[str, ...] | list[str] | str | None = None,
|
||
shutdown_commands: tuple[str, ...] | list[str] | str | None = None,
|
||
execution_policy: BaseExecutionPolicy | None = None,
|
||
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
|
||
tool_description: str | None = None,
|
||
tool_name: str = SHELL_TOOL_NAME,
|
||
shell_command: Sequence[str] | str | None = None,
|
||
env: Mapping[str, Any] | None = None,
|
||
) -> None:
|
||
"""Initialize an instance of `ShellToolMiddleware`.
|
||
|
||
Args:
|
||
workspace_root: Base directory for the shell session.
|
||
|
||
If omitted, a temporary directory is created when the agent starts and
|
||
removed when it ends.
|
||
startup_commands: Optional commands executed sequentially after the session
|
||
starts.
|
||
shutdown_commands: Optional commands executed before the session shuts down.
|
||
execution_policy: Execution policy controlling timeouts, output limits, and
|
||
resource configuration.
|
||
|
||
Defaults to `HostExecutionPolicy` for native execution.
|
||
redaction_rules: Optional redaction rules to sanitize command output before
|
||
returning it to the model.
|
||
tool_description: Optional override for the registered shell tool
|
||
description.
|
||
tool_name: Name for the registered shell tool.
|
||
|
||
Defaults to `"shell"`.
|
||
shell_command: Optional shell executable (string) or argument sequence used
|
||
to launch the persistent session.
|
||
|
||
Defaults to an implementation-defined bash command.
|
||
env: Optional environment variables to supply to the shell session.
|
||
|
||
Values are coerced to strings before command execution. If omitted, the
|
||
session inherits the parent process environment.
|
||
"""
|
||
super().__init__()
|
||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||
self._tool_name = tool_name
|
||
self._shell_command = self._normalize_shell_command(shell_command)
|
||
self._environment = self._normalize_env(env)
|
||
if execution_policy is not None:
|
||
self._execution_policy = execution_policy
|
||
else:
|
||
self._execution_policy = HostExecutionPolicy()
|
||
rules = redaction_rules or ()
|
||
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
|
||
rule.resolve() for rule in rules
|
||
)
|
||
self._startup_commands = self._normalize_commands(startup_commands)
|
||
self._shutdown_commands = self._normalize_commands(shutdown_commands)
|
||
|
||
# Create a proper tool that executes directly (no interception needed)
|
||
description = tool_description or DEFAULT_TOOL_DESCRIPTION
|
||
|
||
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
|
||
def shell_tool(
|
||
*,
|
||
runtime: ToolRuntime[None, ShellToolState],
|
||
command: str | None = None,
|
||
restart: bool = False,
|
||
) -> ToolMessage | str:
|
||
resources = self._get_or_create_resources(runtime.state)
|
||
return self._run_shell_tool(
|
||
resources,
|
||
{"command": command, "restart": restart},
|
||
tool_call_id=runtime.tool_call_id,
|
||
)
|
||
|
||
self._shell_tool = shell_tool
|
||
self.tools = [self._shell_tool]
|
||
|
||
@staticmethod
|
||
def _normalize_commands(
|
||
commands: tuple[str, ...] | list[str] | str | None,
|
||
) -> tuple[str, ...]:
|
||
if commands is None:
|
||
return ()
|
||
if isinstance(commands, str):
|
||
return (commands,)
|
||
return tuple(commands)
|
||
|
||
@staticmethod
|
||
def _normalize_shell_command(
|
||
shell_command: Sequence[str] | str | None,
|
||
) -> tuple[str, ...]:
|
||
if shell_command is None:
|
||
return ("/bin/bash",)
|
||
normalized = (shell_command,) if isinstance(shell_command, str) else tuple(shell_command)
|
||
if not normalized:
|
||
msg = "Shell command must contain at least one argument."
|
||
raise ValueError(msg)
|
||
return normalized
|
||
|
||
@staticmethod
|
||
def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
|
||
if env is None:
|
||
return None
|
||
normalized: dict[str, str] = {}
|
||
for key, value in env.items():
|
||
if not isinstance(key, str):
|
||
msg = "Environment variable names must be strings."
|
||
raise TypeError(msg)
|
||
normalized[key] = str(value)
|
||
return normalized
|
||
|
||
@override
|
||
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||
"""Start the shell session and run startup commands."""
|
||
resources = self._get_or_create_resources(state)
|
||
return {"shell_session_resources": resources}
|
||
|
||
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
||
"""Async start the shell session and run startup commands."""
|
||
return self.before_agent(state, runtime)
|
||
|
||
@override
|
||
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
||
"""Run shutdown commands and release resources when an agent completes."""
|
||
resources = state.get("shell_session_resources")
|
||
if not isinstance(resources, _SessionResources):
|
||
# Resources were never created, nothing to clean up
|
||
return
|
||
try:
|
||
self._run_shutdown_commands(resources.session)
|
||
finally:
|
||
resources.finalizer()
|
||
|
||
async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
|
||
"""Async run shutdown commands and release resources when an agent completes."""
|
||
return self.after_agent(state, runtime)
|
||
|
||
def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources:
|
||
"""Get existing resources from state or create new ones if they don't exist.
|
||
|
||
This method enables resumability by checking if resources already exist in the state
|
||
(e.g., after an interrupt), and only creating new resources if they're not present.
|
||
|
||
Args:
|
||
state: The agent state which may contain shell session resources.
|
||
|
||
Returns:
|
||
Session resources, either retrieved from state or newly created.
|
||
"""
|
||
resources = state.get("shell_session_resources")
|
||
if isinstance(resources, _SessionResources):
|
||
return resources
|
||
|
||
new_resources = self._create_resources()
|
||
# Cast needed to make state dict-like for mutation
|
||
cast("dict[str, Any]", state)["shell_session_resources"] = new_resources
|
||
return new_resources
|
||
|
||
def _create_resources(self) -> _SessionResources:
|
||
workspace = self._workspace_root
|
||
tempdir: tempfile.TemporaryDirectory[str] | None = None
|
||
if workspace is None:
|
||
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
|
||
workspace_path = Path(tempdir.name)
|
||
else:
|
||
workspace_path = workspace
|
||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
session = ShellSession(
|
||
workspace_path,
|
||
self._execution_policy,
|
||
self._shell_command,
|
||
self._environment or {},
|
||
)
|
||
try:
|
||
session.start()
|
||
LOGGER.info("Started shell session in %s", workspace_path)
|
||
self._run_startup_commands(session)
|
||
except BaseException:
|
||
LOGGER.exception("Starting shell session failed; cleaning up resources.")
|
||
session.stop(self._execution_policy.termination_timeout)
|
||
if tempdir is not None:
|
||
tempdir.cleanup()
|
||
raise
|
||
|
||
return _SessionResources(session=session, tempdir=tempdir, policy=self._execution_policy)
|
||
|
||
def _run_startup_commands(self, session: ShellSession) -> None:
|
||
if not self._startup_commands:
|
||
return
|
||
for command in self._startup_commands:
|
||
result = session.execute(command, timeout=self._execution_policy.startup_timeout)
|
||
if result.timed_out or (result.exit_code not in (0, None)):
|
||
msg = f"Startup command '{command}' failed with exit code {result.exit_code}"
|
||
raise RuntimeError(msg)
|
||
|
||
def _run_shutdown_commands(self, session: ShellSession) -> None:
|
||
if not self._shutdown_commands:
|
||
return
|
||
for command in self._shutdown_commands:
|
||
try:
|
||
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
||
if result.timed_out:
|
||
LOGGER.warning("Shutdown command '%s' timed out.", command)
|
||
elif result.exit_code not in (0, None):
|
||
LOGGER.warning(
|
||
"Shutdown command '%s' exited with %s.", command, result.exit_code
|
||
)
|
||
except (RuntimeError, ToolException, OSError) as exc:
|
||
LOGGER.warning(
|
||
"Failed to run shutdown command '%s': %s", command, exc, exc_info=True
|
||
)
|
||
|
||
def _apply_redactions(self, content: str) -> tuple[str, dict[str, list[PIIMatch]]]:
|
||
"""Apply configured redaction rules to command output."""
|
||
matches_by_type: dict[str, list[PIIMatch]] = {}
|
||
updated = content
|
||
for rule in self._redaction_rules:
|
||
updated, matches = rule.apply(updated)
|
||
if matches:
|
||
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
|
||
return updated, matches_by_type
|
||
|
||
def _run_shell_tool(
|
||
self,
|
||
resources: _SessionResources,
|
||
payload: dict[str, Any],
|
||
*,
|
||
tool_call_id: str | None,
|
||
) -> Any:
|
||
session = resources.session
|
||
|
||
if payload.get("restart"):
|
||
LOGGER.info("Restarting shell session on request.")
|
||
try:
|
||
session.restart()
|
||
self._run_startup_commands(session)
|
||
except BaseException as err:
|
||
LOGGER.exception("Restarting shell session failed; session remains unavailable.")
|
||
msg = "Failed to restart shell session."
|
||
raise ToolException(msg) from err
|
||
message = "Shell session restarted."
|
||
return self._format_tool_message(message, tool_call_id, status="success")
|
||
|
||
command = payload.get("command")
|
||
if not command or not isinstance(command, str):
|
||
msg = "Shell tool expects a 'command' string when restart is not requested."
|
||
raise ToolException(msg)
|
||
|
||
LOGGER.info("Executing shell command: %s", command)
|
||
result = session.execute(command, timeout=self._execution_policy.command_timeout)
|
||
|
||
if result.timed_out:
|
||
timeout_seconds = self._execution_policy.command_timeout
|
||
message = f"Error: Command timed out after {timeout_seconds:.1f} seconds."
|
||
return self._format_tool_message(
|
||
message,
|
||
tool_call_id,
|
||
status="error",
|
||
artifact={
|
||
"timed_out": True,
|
||
"exit_code": None,
|
||
},
|
||
)
|
||
|
||
try:
|
||
sanitized_output, matches = self._apply_redactions(result.output)
|
||
except PIIDetectionError as error:
|
||
LOGGER.warning("Blocking command output due to detected %s.", error.pii_type)
|
||
message = f"Output blocked: detected {error.pii_type}."
|
||
return self._format_tool_message(
|
||
message,
|
||
tool_call_id,
|
||
status="error",
|
||
artifact={
|
||
"timed_out": False,
|
||
"exit_code": result.exit_code,
|
||
"matches": {error.pii_type: error.matches},
|
||
},
|
||
)
|
||
|
||
sanitized_output = sanitized_output or "<no output>"
|
||
if result.truncated_by_lines:
|
||
sanitized_output = (
|
||
f"{sanitized_output.rstrip()}\n\n"
|
||
f"... Output truncated at {self._execution_policy.max_output_lines} lines "
|
||
f"(observed {result.total_lines})."
|
||
)
|
||
if result.truncated_by_bytes and self._execution_policy.max_output_bytes is not None:
|
||
sanitized_output = (
|
||
f"{sanitized_output.rstrip()}\n\n"
|
||
f"... Output truncated at {self._execution_policy.max_output_bytes} bytes "
|
||
f"(observed {result.total_bytes})."
|
||
)
|
||
|
||
if result.exit_code not in (0, None):
|
||
sanitized_output = f"{sanitized_output.rstrip()}\n\nExit code: {result.exit_code}"
|
||
final_status: Literal["success", "error"] = "error"
|
||
else:
|
||
final_status = "success"
|
||
|
||
artifact = {
|
||
"timed_out": False,
|
||
"exit_code": result.exit_code,
|
||
"truncated_by_lines": result.truncated_by_lines,
|
||
"truncated_by_bytes": result.truncated_by_bytes,
|
||
"total_lines": result.total_lines,
|
||
"total_bytes": result.total_bytes,
|
||
"redaction_matches": matches,
|
||
}
|
||
|
||
return self._format_tool_message(
|
||
sanitized_output,
|
||
tool_call_id,
|
||
status=final_status,
|
||
artifact=artifact,
|
||
)
|
||
|
||
def _format_tool_message(
|
||
self,
|
||
content: str,
|
||
tool_call_id: str | None,
|
||
*,
|
||
status: Literal["success", "error"],
|
||
artifact: dict[str, Any] | None = None,
|
||
) -> ToolMessage | str:
|
||
artifact = artifact or {}
|
||
if tool_call_id is None:
|
||
return content
|
||
return ToolMessage(
|
||
content=content,
|
||
tool_call_id=tool_call_id,
|
||
name=self._tool_name,
|
||
status=status,
|
||
artifact=artifact,
|
||
)
|
||
|
||
|
||
__all__ = [
|
||
"CodexSandboxExecutionPolicy",
|
||
"DockerExecutionPolicy",
|
||
"HostExecutionPolicy",
|
||
"RedactionRule",
|
||
"ShellToolMiddleware",
|
||
]
|