509 lines
16 KiB
Python
509 lines
16 KiB
Python
|
|
"""Prompt caching module for LangSmith SDK.
|
||
|
|
|
||
|
|
This module provides thread-safe LRU caches with background refresh
|
||
|
|
for prompt caching. Includes both sync and async implementations.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import threading
|
||
|
|
import time
|
||
|
|
from abc import ABC
|
||
|
|
from collections import OrderedDict
|
||
|
|
from collections.abc import Awaitable
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
pass
|
||
|
|
|
||
|
|
logger = logging.getLogger("langsmith.cache")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class CacheEntry:
|
||
|
|
"""A single cache entry with metadata for TTL tracking."""
|
||
|
|
|
||
|
|
value: Any # The cached value (e.g., PromptCommit)
|
||
|
|
created_at: float # time.time() when entry was created/refreshed
|
||
|
|
|
||
|
|
def is_stale(self, ttl_seconds: Optional[float]) -> bool:
|
||
|
|
"""Check if entry is past its TTL (needs refresh)."""
|
||
|
|
if ttl_seconds is None:
|
||
|
|
return False # Infinite TTL, never stale
|
||
|
|
return (time.time() - self.created_at) > ttl_seconds
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class CacheMetrics:
|
||
|
|
"""Cache performance metrics."""
|
||
|
|
|
||
|
|
hits: int = 0
|
||
|
|
misses: int = 0
|
||
|
|
refreshes: int = 0
|
||
|
|
refresh_errors: int = 0
|
||
|
|
|
||
|
|
@property
|
||
|
|
def total_requests(self) -> int:
|
||
|
|
"""Total cache requests (hits + misses)."""
|
||
|
|
return self.hits + self.misses
|
||
|
|
|
||
|
|
@property
|
||
|
|
def hit_rate(self) -> float:
|
||
|
|
"""Cache hit rate (0.0 to 1.0)."""
|
||
|
|
total = self.total_requests
|
||
|
|
return self.hits / total if total > 0 else 0.0
|
||
|
|
|
||
|
|
|
||
|
|
class _BasePromptCache(ABC):
|
||
|
|
"""Base class for prompt caches with shared LRU logic.
|
||
|
|
|
||
|
|
Provides thread-safe in-memory LRU cache operations.
|
||
|
|
Subclasses implement the background refresh mechanism.
|
||
|
|
"""
|
||
|
|
|
||
|
|
__slots__ = [
|
||
|
|
"_cache",
|
||
|
|
"_lock",
|
||
|
|
"_max_size",
|
||
|
|
"_ttl_seconds",
|
||
|
|
"_refresh_interval",
|
||
|
|
"_metrics",
|
||
|
|
]
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
max_size: int = 100,
|
||
|
|
ttl_seconds: Optional[float] = 3600.0,
|
||
|
|
refresh_interval_seconds: float = 60.0,
|
||
|
|
) -> None:
|
||
|
|
"""Initialize the base cache.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
max_size: Maximum entries in cache (LRU eviction when exceeded).
|
||
|
|
ttl_seconds: Time before entry is considered stale. Set to None for
|
||
|
|
infinite TTL (entries never expire, no background refresh).
|
||
|
|
refresh_interval_seconds: How often to check for stale entries.
|
||
|
|
"""
|
||
|
|
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||
|
|
self._lock = threading.RLock()
|
||
|
|
self._max_size = max_size
|
||
|
|
self._ttl_seconds = ttl_seconds
|
||
|
|
self._refresh_interval = refresh_interval_seconds
|
||
|
|
self._metrics = CacheMetrics()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def metrics(self) -> CacheMetrics:
|
||
|
|
"""Get cache performance metrics."""
|
||
|
|
return self._metrics
|
||
|
|
|
||
|
|
def reset_metrics(self) -> None:
|
||
|
|
"""Reset all metrics to zero."""
|
||
|
|
self._metrics = CacheMetrics()
|
||
|
|
|
||
|
|
def get(self, key: str) -> Optional[Any]:
|
||
|
|
"""Get a value from cache.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
key: The cache key (prompt identifier like "owner/name:hash").
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The cached value or None if not found.
|
||
|
|
Stale entries are still returned (background refresh handles updates).
|
||
|
|
"""
|
||
|
|
with self._lock:
|
||
|
|
if key not in self._cache:
|
||
|
|
self._metrics.misses += 1
|
||
|
|
return None
|
||
|
|
|
||
|
|
entry = self._cache[key]
|
||
|
|
|
||
|
|
# Move to end for LRU
|
||
|
|
self._cache.move_to_end(key)
|
||
|
|
|
||
|
|
self._metrics.hits += 1
|
||
|
|
return entry.value
|
||
|
|
|
||
|
|
def set(self, key: str, value: Any) -> None:
|
||
|
|
"""Set a value in the cache.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
key: The cache key (prompt identifier).
|
||
|
|
value: The value to cache.
|
||
|
|
"""
|
||
|
|
with self._lock:
|
||
|
|
now = time.time()
|
||
|
|
entry = CacheEntry(value=value, created_at=now)
|
||
|
|
|
||
|
|
# Check if we need to evict
|
||
|
|
if key not in self._cache and len(self._cache) >= self._max_size:
|
||
|
|
# Evict oldest (first item in OrderedDict)
|
||
|
|
oldest_key = next(iter(self._cache))
|
||
|
|
self._cache.pop(oldest_key)
|
||
|
|
logger.debug(f"Evicted oldest cache entry: {oldest_key}")
|
||
|
|
|
||
|
|
self._cache[key] = entry
|
||
|
|
self._cache.move_to_end(key)
|
||
|
|
|
||
|
|
def invalidate(self, key: str) -> None:
|
||
|
|
"""Remove a specific entry from cache.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
key: The cache key to invalidate.
|
||
|
|
"""
|
||
|
|
with self._lock:
|
||
|
|
self._cache.pop(key, None)
|
||
|
|
|
||
|
|
def clear(self) -> None:
|
||
|
|
"""Clear all cache entries from memory."""
|
||
|
|
with self._lock:
|
||
|
|
self._cache.clear()
|
||
|
|
|
||
|
|
def _get_stale_keys(self) -> list[str]:
|
||
|
|
"""Get list of stale cache keys (thread-safe)."""
|
||
|
|
with self._lock:
|
||
|
|
return [
|
||
|
|
key
|
||
|
|
for key, entry in self._cache.items()
|
||
|
|
if entry.is_stale(self._ttl_seconds)
|
||
|
|
]
|
||
|
|
|
||
|
|
def dump(self, path: Union[str, Path]) -> None:
|
||
|
|
"""Dump cache contents to a JSON file for offline use.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
path: Path to the output JSON file.
|
||
|
|
"""
|
||
|
|
from langsmith import schemas as ls_schemas
|
||
|
|
|
||
|
|
path = Path(path)
|
||
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
with self._lock:
|
||
|
|
entries = {}
|
||
|
|
for key, entry in self._cache.items():
|
||
|
|
# Serialize PromptCommit using Pydantic
|
||
|
|
if isinstance(entry.value, ls_schemas.PromptCommit):
|
||
|
|
# Handle both pydantic v1 and v2
|
||
|
|
if hasattr(entry.value, "model_dump"):
|
||
|
|
value_data = entry.value.model_dump(mode="json")
|
||
|
|
else:
|
||
|
|
value_data = entry.value.dict()
|
||
|
|
else:
|
||
|
|
# Fallback for other types
|
||
|
|
value_data = entry.value
|
||
|
|
|
||
|
|
entries[key] = value_data
|
||
|
|
|
||
|
|
data = {"entries": entries}
|
||
|
|
|
||
|
|
# Atomic write: write to temp file then rename
|
||
|
|
temp_path = path.with_suffix(".tmp")
|
||
|
|
try:
|
||
|
|
with open(temp_path, "w") as f:
|
||
|
|
json.dump(data, f, indent=2)
|
||
|
|
temp_path.replace(path)
|
||
|
|
logger.debug(f"Dumped {len(entries)} cache entries to {path}")
|
||
|
|
except Exception as e:
|
||
|
|
# Clean up temp file on failure
|
||
|
|
if temp_path.exists():
|
||
|
|
temp_path.unlink()
|
||
|
|
raise e
|
||
|
|
|
||
|
|
def load(self, path: Union[str, Path]) -> int:
|
||
|
|
"""Load cache contents from a JSON file.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
path: Path to the JSON file to load.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Number of entries loaded.
|
||
|
|
|
||
|
|
Loaded entries get a fresh TTL starting from load time.
|
||
|
|
If the file doesn't exist or is corrupted, returns 0.
|
||
|
|
"""
|
||
|
|
from langsmith import schemas as ls_schemas
|
||
|
|
|
||
|
|
path = Path(path)
|
||
|
|
|
||
|
|
if not path.exists():
|
||
|
|
logger.debug(f"Cache file not found: {path}")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
try:
|
||
|
|
with open(path) as f:
|
||
|
|
data = json.load(f)
|
||
|
|
except (json.JSONDecodeError, OSError) as e:
|
||
|
|
logger.warning(f"Failed to load cache file {path}: {e}")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
entries = data.get("entries", {})
|
||
|
|
loaded = 0
|
||
|
|
now = time.time()
|
||
|
|
|
||
|
|
with self._lock:
|
||
|
|
for key, value_data in entries.items():
|
||
|
|
if len(self._cache) >= self._max_size:
|
||
|
|
logger.debug(f"Reached max cache size, stopping load at {loaded}")
|
||
|
|
break
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Deserialize PromptCommit using Pydantic (v1 and v2 compatible)
|
||
|
|
if hasattr(ls_schemas.PromptCommit, "model_validate"):
|
||
|
|
value = ls_schemas.PromptCommit.model_validate(value_data)
|
||
|
|
else:
|
||
|
|
value = ls_schemas.PromptCommit.parse_obj(value_data)
|
||
|
|
|
||
|
|
# Fresh TTL from load time
|
||
|
|
entry = CacheEntry(value=value, created_at=now)
|
||
|
|
self._cache[key] = entry
|
||
|
|
loaded += 1
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"Failed to load cache entry {key}: {e}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
logger.debug(f"Loaded {loaded} cache entries from {path}")
|
||
|
|
return loaded
|
||
|
|
|
||
|
|
|
||
|
|
class PromptCache(_BasePromptCache):
|
||
|
|
"""Thread-safe LRU cache with background thread refresh.
|
||
|
|
|
||
|
|
For use with the synchronous Client.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- In-memory LRU cache with configurable max size
|
||
|
|
- Background thread for refreshing stale entries
|
||
|
|
- Stale-while-revalidate: returns stale data while refresh happens
|
||
|
|
- Thread-safe for concurrent access
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> def fetch_prompt(key: str) -> PromptCommit:
|
||
|
|
... return client._fetch_prompt_from_api(key)
|
||
|
|
>>> cache = PromptCache(
|
||
|
|
... max_size=100,
|
||
|
|
... ttl_seconds=3600,
|
||
|
|
... fetch_func=fetch_prompt,
|
||
|
|
... )
|
||
|
|
>>> cache.set("my-prompt:latest", prompt_commit)
|
||
|
|
>>> cached = cache.get("my-prompt:latest")
|
||
|
|
>>> cache.shutdown()
|
||
|
|
"""
|
||
|
|
|
||
|
|
__slots__ = ["_fetch_func", "_shutdown_event", "_refresh_thread"]
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
max_size: int = 100,
|
||
|
|
ttl_seconds: Optional[float] = 3600.0,
|
||
|
|
refresh_interval_seconds: float = 60.0,
|
||
|
|
fetch_func: Optional[Callable[[str], Any]] = None,
|
||
|
|
) -> None:
|
||
|
|
"""Initialize the sync prompt cache.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
max_size: Maximum entries in cache (LRU eviction when exceeded).
|
||
|
|
ttl_seconds: Time before entry is considered stale. Set to None for
|
||
|
|
infinite TTL (offline mode - entries never expire).
|
||
|
|
refresh_interval_seconds: How often to check for stale entries.
|
||
|
|
fetch_func: Callback to fetch fresh data for a cache key.
|
||
|
|
If provided, starts a background thread for refresh.
|
||
|
|
"""
|
||
|
|
super().__init__(
|
||
|
|
max_size=max_size,
|
||
|
|
ttl_seconds=ttl_seconds,
|
||
|
|
refresh_interval_seconds=refresh_interval_seconds,
|
||
|
|
)
|
||
|
|
self._fetch_func = fetch_func
|
||
|
|
self._shutdown_event = threading.Event()
|
||
|
|
self._refresh_thread: Optional[threading.Thread] = None
|
||
|
|
|
||
|
|
# Start background refresh if fetch_func provided and TTL is set
|
||
|
|
# (no refresh needed for infinite TTL)
|
||
|
|
if self._fetch_func is not None and self._ttl_seconds is not None:
|
||
|
|
self._start_refresh_thread()
|
||
|
|
|
||
|
|
def shutdown(self) -> None:
|
||
|
|
"""Stop background refresh thread.
|
||
|
|
|
||
|
|
Should be called when the client is being cleaned up.
|
||
|
|
"""
|
||
|
|
self._shutdown_event.set()
|
||
|
|
if self._refresh_thread is not None:
|
||
|
|
self._refresh_thread.join(timeout=5.0)
|
||
|
|
self._refresh_thread = None
|
||
|
|
|
||
|
|
def _start_refresh_thread(self) -> None:
|
||
|
|
"""Start background thread for refreshing stale entries."""
|
||
|
|
self._refresh_thread = threading.Thread(
|
||
|
|
target=self._refresh_loop,
|
||
|
|
daemon=True,
|
||
|
|
name="PromptCache-refresh",
|
||
|
|
)
|
||
|
|
self._refresh_thread.start()
|
||
|
|
logger.debug("Started cache refresh thread")
|
||
|
|
|
||
|
|
def _refresh_loop(self) -> None:
|
||
|
|
"""Background loop to refresh stale entries."""
|
||
|
|
while not self._shutdown_event.wait(self._refresh_interval):
|
||
|
|
try:
|
||
|
|
self._refresh_stale_entries()
|
||
|
|
except Exception as e:
|
||
|
|
# Log but don't die - keep the refresh loop running
|
||
|
|
logger.exception(f"Unexpected error in cache refresh loop: {e}")
|
||
|
|
|
||
|
|
def _refresh_stale_entries(self) -> None:
|
||
|
|
"""Check for stale entries and refresh them."""
|
||
|
|
if self._fetch_func is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
stale_keys = self._get_stale_keys()
|
||
|
|
|
||
|
|
if not stale_keys:
|
||
|
|
return
|
||
|
|
|
||
|
|
logger.debug(f"Refreshing {len(stale_keys)} stale cache entries")
|
||
|
|
|
||
|
|
for key in stale_keys:
|
||
|
|
if self._shutdown_event.is_set():
|
||
|
|
break
|
||
|
|
try:
|
||
|
|
new_value = self._fetch_func(key)
|
||
|
|
self.set(key, new_value)
|
||
|
|
self._metrics.refreshes += 1
|
||
|
|
logger.debug(f"Refreshed cache entry: {key}")
|
||
|
|
except Exception as e:
|
||
|
|
# Keep stale data on refresh failure
|
||
|
|
self._metrics.refresh_errors += 1
|
||
|
|
logger.warning(f"Failed to refresh cache entry {key}: {e}")
|
||
|
|
|
||
|
|
|
||
|
|
class AsyncPromptCache(_BasePromptCache):
|
||
|
|
"""Thread-safe LRU cache with asyncio task refresh.
|
||
|
|
|
||
|
|
For use with the asynchronous AsyncClient.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- In-memory LRU cache with configurable max size
|
||
|
|
- Asyncio task for refreshing stale entries
|
||
|
|
- Stale-while-revalidate: returns stale data while refresh happens
|
||
|
|
- Thread-safe for concurrent access
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> async def fetch_prompt(key: str) -> PromptCommit:
|
||
|
|
... return await client._afetch_prompt_from_api(key)
|
||
|
|
>>> cache = AsyncPromptCache(
|
||
|
|
... max_size=100,
|
||
|
|
... ttl_seconds=3600,
|
||
|
|
... fetch_func=fetch_prompt,
|
||
|
|
... )
|
||
|
|
>>> await cache.start()
|
||
|
|
>>> cache.set("my-prompt:latest", prompt_commit)
|
||
|
|
>>> cached = cache.get("my-prompt:latest")
|
||
|
|
>>> await cache.stop()
|
||
|
|
"""
|
||
|
|
|
||
|
|
__slots__ = ["_fetch_func", "_refresh_task"]
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
max_size: int = 100,
|
||
|
|
ttl_seconds: Optional[float] = 3600.0,
|
||
|
|
refresh_interval_seconds: float = 60.0,
|
||
|
|
fetch_func: Optional[Callable[[str], Awaitable[Any]]] = None,
|
||
|
|
) -> None:
|
||
|
|
"""Initialize the async prompt cache.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
max_size: Maximum entries in cache (LRU eviction when exceeded).
|
||
|
|
ttl_seconds: Time before entry is considered stale. Set to None for
|
||
|
|
infinite TTL (offline mode - entries never expire).
|
||
|
|
refresh_interval_seconds: How often to check for stale entries.
|
||
|
|
fetch_func: Async callback to fetch fresh data for a cache key.
|
||
|
|
"""
|
||
|
|
super().__init__(
|
||
|
|
max_size=max_size,
|
||
|
|
ttl_seconds=ttl_seconds,
|
||
|
|
refresh_interval_seconds=refresh_interval_seconds,
|
||
|
|
)
|
||
|
|
self._fetch_func = fetch_func
|
||
|
|
self._refresh_task: Optional[asyncio.Task[None]] = None
|
||
|
|
|
||
|
|
async def start(self) -> None:
|
||
|
|
"""Start async background refresh loop.
|
||
|
|
|
||
|
|
Must be called from an async context. Creates an asyncio task that
|
||
|
|
periodically checks for stale entries and refreshes them.
|
||
|
|
Does nothing if ttl_seconds is None (infinite TTL mode).
|
||
|
|
"""
|
||
|
|
if self._fetch_func is None or self._ttl_seconds is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
if self._refresh_task is not None:
|
||
|
|
# Already running
|
||
|
|
return
|
||
|
|
|
||
|
|
self._refresh_task = asyncio.create_task(
|
||
|
|
self._refresh_loop(),
|
||
|
|
name="AsyncPromptCache-refresh",
|
||
|
|
)
|
||
|
|
logger.debug("Started async cache refresh task")
|
||
|
|
|
||
|
|
async def stop(self) -> None:
|
||
|
|
"""Stop async background refresh loop.
|
||
|
|
|
||
|
|
Cancels the refresh task and waits for it to complete.
|
||
|
|
"""
|
||
|
|
if self._refresh_task is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
self._refresh_task.cancel()
|
||
|
|
try:
|
||
|
|
await self._refresh_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
self._refresh_task = None
|
||
|
|
logger.debug("Stopped async cache refresh task")
|
||
|
|
|
||
|
|
async def _refresh_loop(self) -> None:
|
||
|
|
"""Async background loop to refresh stale entries."""
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
await asyncio.sleep(self._refresh_interval)
|
||
|
|
await self._refresh_stale_entries()
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
# Log but don't die - keep the refresh loop running
|
||
|
|
logger.exception(f"Unexpected error in async cache refresh loop: {e}")
|
||
|
|
|
||
|
|
async def _refresh_stale_entries(self) -> None:
|
||
|
|
"""Check for stale entries and refresh them asynchronously."""
|
||
|
|
if self._fetch_func is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
stale_keys = self._get_stale_keys()
|
||
|
|
|
||
|
|
if not stale_keys:
|
||
|
|
return
|
||
|
|
|
||
|
|
logger.debug(f"Async refreshing {len(stale_keys)} stale cache entries")
|
||
|
|
|
||
|
|
for key in stale_keys:
|
||
|
|
try:
|
||
|
|
new_value = await self._fetch_func(key)
|
||
|
|
self.set(key, new_value)
|
||
|
|
self._metrics.refreshes += 1
|
||
|
|
logger.debug(f"Async refreshed cache entry: {key}")
|
||
|
|
except Exception as e:
|
||
|
|
# Keep stale data on refresh failure
|
||
|
|
self._metrics.refresh_errors += 1
|
||
|
|
logger.warning(f"Failed to async refresh cache entry {key}: {e}")
|