903 lines
28 KiB
Python
903 lines
28 KiB
Python
|
|
"""Generic utility functions."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import contextlib
|
||
|
|
import contextvars
|
||
|
|
import copy
|
||
|
|
import enum
|
||
|
|
import functools
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
import pathlib
|
||
|
|
import socket
|
||
|
|
import subprocess
|
||
|
|
import sys
|
||
|
|
import threading
|
||
|
|
import traceback
|
||
|
|
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
|
||
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||
|
|
from typing import (
|
||
|
|
Any,
|
||
|
|
Callable,
|
||
|
|
Literal,
|
||
|
|
Optional,
|
||
|
|
TypeVar,
|
||
|
|
Union,
|
||
|
|
cast,
|
||
|
|
overload,
|
||
|
|
)
|
||
|
|
from urllib import parse as urllib_parse
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
import requests
|
||
|
|
from typing_extensions import ParamSpec
|
||
|
|
from urllib3.util import Retry # type: ignore[import-untyped]
|
||
|
|
|
||
|
|
from langsmith import schemas as ls_schemas
|
||
|
|
|
||
|
|
_LOGGER = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithError(Exception):
|
||
|
|
"""An error occurred while communicating with the LangSmith API."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithAPIError(LangSmithError):
|
||
|
|
"""Internal server error while communicating with LangSmith."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithRequestTimeout(LangSmithError):
|
||
|
|
"""Client took too long to send request body."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithUserError(LangSmithError):
|
||
|
|
"""User error caused an exception when communicating with LangSmith."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithRateLimitError(LangSmithError):
|
||
|
|
"""You have exceeded the rate limit for the LangSmith API."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithAuthError(LangSmithError):
|
||
|
|
"""Couldn't authenticate with the LangSmith API."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithNotFoundError(LangSmithError):
|
||
|
|
"""Couldn't find the requested resource."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithConflictError(LangSmithError):
|
||
|
|
"""The resource already exists."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithConnectionError(LangSmithError):
|
||
|
|
"""Couldn't connect to the LangSmith API."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithExceptionGroup(LangSmithError):
|
||
|
|
"""Port of ExceptionGroup for Py < 3.11."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self, *args: Any, exceptions: Sequence[Exception], **kwargs: Any
|
||
|
|
) -> None:
|
||
|
|
"""Initialize."""
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
self.exceptions = exceptions
|
||
|
|
|
||
|
|
|
||
|
|
## Warning classes
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithWarning(UserWarning):
|
||
|
|
"""Base class for warnings."""
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithMissingAPIKeyWarning(LangSmithWarning):
|
||
|
|
"""Warning for missing API key."""
|
||
|
|
|
||
|
|
|
||
|
|
def tracing_is_enabled(ctx: Optional[dict] = None) -> Union[bool, Literal["local"]]:
|
||
|
|
"""Return True if tracing is enabled."""
|
||
|
|
# Access global fallbacks via context module to avoid stale references.
|
||
|
|
import langsmith._internal._context as _context
|
||
|
|
from langsmith.run_helpers import get_current_run_tree, get_tracing_context
|
||
|
|
|
||
|
|
tc = ctx or get_tracing_context()
|
||
|
|
# You can manually override the environment using context vars.
|
||
|
|
# Check that first.
|
||
|
|
# Doing this before checking the run tree lets us
|
||
|
|
# disable a branch within a trace.
|
||
|
|
if tc["enabled"] is not None:
|
||
|
|
return tc["enabled"]
|
||
|
|
# Next check if we're mid-trace
|
||
|
|
if get_current_run_tree():
|
||
|
|
return True
|
||
|
|
# If a global fallback was configured, use it next.
|
||
|
|
if _context._GLOBAL_TRACING_ENABLED is not None:
|
||
|
|
return _context._GLOBAL_TRACING_ENABLED
|
||
|
|
# Finally, check the global environment
|
||
|
|
var_result = get_env_var("TRACING_V2", default=get_env_var("TRACING", default=""))
|
||
|
|
return var_result == "true"
|
||
|
|
|
||
|
|
|
||
|
|
def test_tracking_is_disabled() -> bool:
|
||
|
|
"""Return True if testing is enabled."""
|
||
|
|
return get_env_var("TEST_TRACKING", default="") == "false"
|
||
|
|
|
||
|
|
|
||
|
|
def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
|
||
|
|
"""Validate specified keyword args are mutually exclusive."""
|
||
|
|
|
||
|
|
def decorator(func: Callable) -> Callable:
|
||
|
|
@functools.wraps(func)
|
||
|
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||
|
|
"""Validate exactly one arg in each group is not None."""
|
||
|
|
counts = [
|
||
|
|
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
|
||
|
|
for arg_group in arg_groups
|
||
|
|
]
|
||
|
|
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
|
||
|
|
if invalid_groups:
|
||
|
|
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
|
||
|
|
raise ValueError(
|
||
|
|
"Exactly one argument in each of the following"
|
||
|
|
" groups must be defined:"
|
||
|
|
f" {', '.join(invalid_group_names)}"
|
||
|
|
)
|
||
|
|
return func(*args, **kwargs)
|
||
|
|
|
||
|
|
return wrapper
|
||
|
|
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
|
||
|
|
def raise_for_status_with_text(
|
||
|
|
response: Union[requests.Response, httpx.Response],
|
||
|
|
) -> None:
|
||
|
|
"""Raise an error with the response text."""
|
||
|
|
try:
|
||
|
|
response.raise_for_status()
|
||
|
|
except requests.HTTPError as e:
|
||
|
|
raise requests.HTTPError(str(e), response.text) from e # type: ignore[call-arg]
|
||
|
|
except httpx.HTTPStatusError as e:
|
||
|
|
raise httpx.HTTPStatusError(
|
||
|
|
f"{str(e)}: {response.text}",
|
||
|
|
request=response.request, # type: ignore[arg-type]
|
||
|
|
response=response, # type: ignore[arg-type]
|
||
|
|
) from e
|
||
|
|
|
||
|
|
|
||
|
|
def get_enum_value(enu: Union[enum.Enum, str]) -> str:
|
||
|
|
"""Get the value of a string enum."""
|
||
|
|
if isinstance(enu, enum.Enum):
|
||
|
|
return enu.value
|
||
|
|
return enu
|
||
|
|
|
||
|
|
|
||
|
|
@functools.lru_cache(maxsize=1)
|
||
|
|
def log_once(level: int, message: str) -> None:
|
||
|
|
"""Log a message at the specified level, but only once."""
|
||
|
|
_LOGGER.log(level, message)
|
||
|
|
|
||
|
|
|
||
|
|
def _get_message_type(message: Mapping[str, Any]) -> str:
|
||
|
|
if not message:
|
||
|
|
raise ValueError("Message is empty.")
|
||
|
|
if "lc" in message:
|
||
|
|
if "id" not in message:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unexpected format for serialized message: {message}"
|
||
|
|
" Message does not have an id."
|
||
|
|
)
|
||
|
|
return message["id"][-1].replace("Message", "").lower()
|
||
|
|
else:
|
||
|
|
if "type" not in message:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unexpected format for stored message: {message}"
|
||
|
|
" Message does not have a type."
|
||
|
|
)
|
||
|
|
return message["type"]
|
||
|
|
|
||
|
|
|
||
|
|
def _get_message_fields(message: Mapping[str, Any]) -> Mapping[str, Any]:
|
||
|
|
if not message:
|
||
|
|
raise ValueError("Message is empty.")
|
||
|
|
if "lc" in message:
|
||
|
|
if "kwargs" not in message:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unexpected format for serialized message: {message}"
|
||
|
|
" Message does not have kwargs."
|
||
|
|
)
|
||
|
|
return message["kwargs"]
|
||
|
|
else:
|
||
|
|
if "data" not in message:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unexpected format for stored message: {message}"
|
||
|
|
" Message does not have data."
|
||
|
|
)
|
||
|
|
return message["data"]
|
||
|
|
|
||
|
|
|
||
|
|
def _convert_message(message: Mapping[str, Any]) -> dict[str, Any]:
|
||
|
|
"""Extract message from a message object."""
|
||
|
|
message_type = _get_message_type(message)
|
||
|
|
message_data = _get_message_fields(message)
|
||
|
|
return {"type": message_type, "data": message_data}
|
||
|
|
|
||
|
|
|
||
|
|
def get_messages_from_inputs(inputs: Mapping[str, Any]) -> list[dict[str, Any]]:
|
||
|
|
"""Extract messages from the given inputs dictionary.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
inputs: The inputs dictionary.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A list of dictionaries representing the extracted messages.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If no message(s) are found in the inputs dictionary.
|
||
|
|
"""
|
||
|
|
if "messages" in inputs:
|
||
|
|
return [_convert_message(message) for message in inputs["messages"]]
|
||
|
|
if "message" in inputs:
|
||
|
|
return [_convert_message(inputs["message"])]
|
||
|
|
raise ValueError(f"Could not find message(s) in run with inputs {inputs}.")
|
||
|
|
|
||
|
|
|
||
|
|
def get_message_generation_from_outputs(outputs: Mapping[str, Any]) -> dict[str, Any]:
|
||
|
|
"""Retrieve the message generation from the given outputs.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
outputs: The outputs dictionary.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The message generation.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If no generations are found or if multiple generations are present.
|
||
|
|
"""
|
||
|
|
if "generations" not in outputs:
|
||
|
|
raise ValueError(f"No generations found in in run with output: {outputs}.")
|
||
|
|
generations = outputs["generations"]
|
||
|
|
if len(generations) != 1:
|
||
|
|
raise ValueError(
|
||
|
|
"Chat examples expect exactly one generation."
|
||
|
|
f" Found {len(generations)} generations: {generations}."
|
||
|
|
)
|
||
|
|
first_generation = generations[0]
|
||
|
|
if "message" not in first_generation:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unexpected format for generation: {first_generation}."
|
||
|
|
" Generation does not have a message."
|
||
|
|
)
|
||
|
|
return _convert_message(first_generation["message"])
|
||
|
|
|
||
|
|
|
||
|
|
def get_prompt_from_inputs(inputs: Mapping[str, Any]) -> str:
|
||
|
|
"""Retrieve the prompt from the given inputs.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
inputs: The inputs dictionary.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: The prompt.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If the prompt is not found or if multiple prompts are present.
|
||
|
|
"""
|
||
|
|
if "prompt" in inputs:
|
||
|
|
return inputs["prompt"]
|
||
|
|
if "prompts" in inputs:
|
||
|
|
prompts = inputs["prompts"]
|
||
|
|
if len(prompts) == 1:
|
||
|
|
return prompts[0]
|
||
|
|
raise ValueError(
|
||
|
|
f"Multiple prompts in run with inputs {inputs}."
|
||
|
|
" Please create example manually."
|
||
|
|
)
|
||
|
|
raise ValueError(f"Could not find prompt in run with inputs {inputs}.")
|
||
|
|
|
||
|
|
|
||
|
|
def get_llm_generation_from_outputs(outputs: Mapping[str, Any]) -> str:
|
||
|
|
"""Get the LLM generation from the outputs."""
|
||
|
|
if "generations" not in outputs:
|
||
|
|
raise ValueError(f"No generations found in in run with output: {outputs}.")
|
||
|
|
generations = outputs["generations"]
|
||
|
|
if len(generations) != 1:
|
||
|
|
raise ValueError(f"Multiple generations in run: {generations}")
|
||
|
|
first_generation = generations[0]
|
||
|
|
if "text" not in first_generation:
|
||
|
|
raise ValueError(f"No text in generation: {first_generation}")
|
||
|
|
return first_generation["text"]
|
||
|
|
|
||
|
|
|
||
|
|
@functools.lru_cache(maxsize=1)
|
||
|
|
def get_docker_compose_command() -> list[str]:
|
||
|
|
"""Get the correct docker compose command for this system."""
|
||
|
|
try:
|
||
|
|
subprocess.check_call(
|
||
|
|
["docker", "compose", "--version"],
|
||
|
|
stdout=subprocess.DEVNULL,
|
||
|
|
stderr=subprocess.DEVNULL,
|
||
|
|
)
|
||
|
|
return ["docker", "compose"]
|
||
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||
|
|
try:
|
||
|
|
subprocess.check_call(
|
||
|
|
["docker-compose", "--version"],
|
||
|
|
stdout=subprocess.DEVNULL,
|
||
|
|
stderr=subprocess.DEVNULL,
|
||
|
|
)
|
||
|
|
return ["docker-compose"]
|
||
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||
|
|
raise ValueError(
|
||
|
|
"Neither 'docker compose' nor 'docker-compose'"
|
||
|
|
" commands are available. Please install the Docker"
|
||
|
|
" server following the instructions for your operating"
|
||
|
|
" system at https://docs.docker.com/engine/install/"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def convert_langchain_message(message: ls_schemas.BaseMessageLike) -> dict:
|
||
|
|
"""Convert a LangChain message to an example."""
|
||
|
|
converted: dict[str, Any] = {
|
||
|
|
"type": message.type,
|
||
|
|
"data": {"content": message.content},
|
||
|
|
}
|
||
|
|
# Check for presence of keys in additional_kwargs
|
||
|
|
if message.additional_kwargs and len(message.additional_kwargs) > 0:
|
||
|
|
converted["data"]["additional_kwargs"] = {**message.additional_kwargs}
|
||
|
|
return converted
|
||
|
|
|
||
|
|
|
||
|
|
def is_base_message_like(obj: object) -> bool:
|
||
|
|
"""Check if the given object is similar to `BaseMessage`.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
obj: The object to check.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
bool: True if the object is similar to `BaseMessage`, `False` otherwise.
|
||
|
|
"""
|
||
|
|
return all(
|
||
|
|
[
|
||
|
|
isinstance(getattr(obj, "content", None), str),
|
||
|
|
isinstance(getattr(obj, "additional_kwargs", None), dict),
|
||
|
|
hasattr(obj, "type") and isinstance(getattr(obj, "type"), str),
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def is_env_var_truish(value: Optional[str]) -> bool:
|
||
|
|
"""Check if the given environment variable is truish."""
|
||
|
|
if value is None:
|
||
|
|
return False
|
||
|
|
return is_truish(get_env_var(value))
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def get_env_var(
|
||
|
|
name: str,
|
||
|
|
default: str,
|
||
|
|
*,
|
||
|
|
namespaces: tuple = ("LANGSMITH", "LANGCHAIN"),
|
||
|
|
) -> str: ...
|
||
|
|
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def get_env_var(
|
||
|
|
name: str,
|
||
|
|
default: None = None,
|
||
|
|
*,
|
||
|
|
namespaces: tuple = ("LANGSMITH", "LANGCHAIN"),
|
||
|
|
) -> Optional[str]: ...
|
||
|
|
|
||
|
|
|
||
|
|
@functools.lru_cache(maxsize=100)
|
||
|
|
def get_env_var(
|
||
|
|
name: str,
|
||
|
|
default: Optional[str] = None,
|
||
|
|
*,
|
||
|
|
namespaces: tuple = ("LANGSMITH", "LANGCHAIN"),
|
||
|
|
) -> Optional[str]:
|
||
|
|
"""Retrieve an environment variable from a list of namespaces.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: The name of the environment variable.
|
||
|
|
default: The default value to return if the environment variable is not found.
|
||
|
|
namespaces: A tuple of namespaces to search for the environment variable.
|
||
|
|
|
||
|
|
Defaults to `('LANGSMITH', 'LANGCHAINs')`.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The value of the environment variable if found, otherwise the default value.
|
||
|
|
"""
|
||
|
|
names = [f"{namespace}_{name}" for namespace in namespaces]
|
||
|
|
for name in names:
|
||
|
|
value = os.environ.get(name)
|
||
|
|
if value is not None and value.strip() != "":
|
||
|
|
return value
|
||
|
|
return default
|
||
|
|
|
||
|
|
|
||
|
|
@functools.lru_cache(maxsize=1)
|
||
|
|
def get_tracer_project(return_default_value=True) -> Optional[str]:
|
||
|
|
"""Get the project name for a LangSmith tracer."""
|
||
|
|
return os.environ.get(
|
||
|
|
# Hosted LangServe projects get precedence over all other defaults.
|
||
|
|
# This is to make sure that we always use the associated project
|
||
|
|
# for a hosted langserve deployment even if the customer sets some
|
||
|
|
# other project name in their environment.
|
||
|
|
"HOSTED_LANGSERVE_PROJECT_NAME",
|
||
|
|
get_env_var(
|
||
|
|
"PROJECT",
|
||
|
|
# This is the legacy name for a LANGCHAIN_PROJECT, so it
|
||
|
|
# has lower precedence than LANGCHAIN_PROJECT
|
||
|
|
default=get_env_var(
|
||
|
|
"SESSION", default="default" if return_default_value else None
|
||
|
|
),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class FilterPoolFullWarning(logging.Filter):
|
||
|
|
"""Filter `urllib3` warnings logged when the connection pool isn't reused."""
|
||
|
|
|
||
|
|
def __init__(self, name: str = "", host: str = "") -> None:
|
||
|
|
"""Initialize the `FilterPoolFullWarning` filter.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: The name of the filter. Defaults to `""`.
|
||
|
|
host: The host to filter. Defaults to `""`.
|
||
|
|
"""
|
||
|
|
super().__init__(name)
|
||
|
|
self._host = host
|
||
|
|
|
||
|
|
def filter(self, record) -> bool:
|
||
|
|
"""urllib3.connectionpool:Connection pool is full, discarding connection: ..."""
|
||
|
|
msg = record.getMessage()
|
||
|
|
if "Connection pool is full, discarding connection" not in msg:
|
||
|
|
return True
|
||
|
|
return self._host not in msg
|
||
|
|
|
||
|
|
|
||
|
|
class FilterLangSmithRetry(logging.Filter):
|
||
|
|
"""Filter for retries from this lib."""
|
||
|
|
|
||
|
|
def filter(self, record) -> bool:
|
||
|
|
"""Filter retries from this library."""
|
||
|
|
# We re-raise/log manually.
|
||
|
|
msg = record.getMessage()
|
||
|
|
return "LangSmithRetry" not in msg
|
||
|
|
|
||
|
|
|
||
|
|
class LangSmithRetry(Retry):
|
||
|
|
"""Wrapper to filter logs with this name."""
|
||
|
|
|
||
|
|
|
||
|
|
_FILTER_LOCK = threading.RLock()
|
||
|
|
|
||
|
|
|
||
|
|
@contextlib.contextmanager
|
||
|
|
def filter_logs(
|
||
|
|
logger: logging.Logger, filters: Sequence[logging.Filter]
|
||
|
|
) -> Generator[None, None, None]:
|
||
|
|
"""Temporarily adds specified filters to a logger.
|
||
|
|
|
||
|
|
Parameters:
|
||
|
|
- logger: The logger to which the filters will be added.
|
||
|
|
- filters: A sequence of `logging.Filter` objects to be temporarily added
|
||
|
|
to the logger.
|
||
|
|
"""
|
||
|
|
with _FILTER_LOCK:
|
||
|
|
for filter in filters:
|
||
|
|
logger.addFilter(filter)
|
||
|
|
# Not actually perfectly thread-safe, but it's only log filters
|
||
|
|
try:
|
||
|
|
yield
|
||
|
|
finally:
|
||
|
|
with _FILTER_LOCK:
|
||
|
|
for filter in filters:
|
||
|
|
try:
|
||
|
|
logger.removeFilter(filter)
|
||
|
|
except BaseException:
|
||
|
|
_LOGGER.warning("Failed to remove filter")
|
||
|
|
|
||
|
|
|
||
|
|
def get_cache_dir(cache: Optional[str]) -> Optional[str]:
|
||
|
|
"""Get the testing cache directory.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
cache: The cache path.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The cache path if provided, otherwise the value from the `LANGSMITH_TEST_CACHE`
|
||
|
|
environment variable.
|
||
|
|
"""
|
||
|
|
if cache is not None:
|
||
|
|
return cache
|
||
|
|
return get_env_var("TEST_CACHE", default=None)
|
||
|
|
|
||
|
|
|
||
|
|
def filter_request_headers(
|
||
|
|
request: Any,
|
||
|
|
*,
|
||
|
|
ignore_hosts: Optional[Sequence[str]] = None,
|
||
|
|
allow_hosts: Optional[Sequence[str]] = None,
|
||
|
|
) -> Any:
|
||
|
|
"""Filter request headers based on `ignore_hosts` and `allow_hosts`."""
|
||
|
|
# Legacy behavior
|
||
|
|
if ignore_hosts and any(request.url.startswith(host) for host in ignore_hosts):
|
||
|
|
return None
|
||
|
|
|
||
|
|
if allow_hosts:
|
||
|
|
try:
|
||
|
|
parsed_url = urllib_parse.urlparse(request.url)
|
||
|
|
except Exception:
|
||
|
|
# If URL parsing fails, don't cache to be safe
|
||
|
|
return None
|
||
|
|
request_host = parsed_url.hostname or ""
|
||
|
|
# Check if request matches any allowed host
|
||
|
|
host_matches = any(
|
||
|
|
# Handle both full URLs (https://api.openai.com)
|
||
|
|
# and hostnames (api.openai.com)
|
||
|
|
(
|
||
|
|
request.url.startswith(host)
|
||
|
|
if host.startswith(("http://", "https://"))
|
||
|
|
else request_host == host or request_host.endswith(f".{host}")
|
||
|
|
)
|
||
|
|
for host in allow_hosts
|
||
|
|
)
|
||
|
|
if not host_matches:
|
||
|
|
return None
|
||
|
|
|
||
|
|
request.headers = {}
|
||
|
|
return request
|
||
|
|
|
||
|
|
|
||
|
|
@contextlib.contextmanager
|
||
|
|
def with_cache(
|
||
|
|
path: Union[str, pathlib.Path],
|
||
|
|
ignore_hosts: Optional[Sequence[str]] = None,
|
||
|
|
allow_hosts: Optional[Sequence[str]] = None,
|
||
|
|
) -> Generator[None, None, None]:
|
||
|
|
"""Use a cache for requests."""
|
||
|
|
try:
|
||
|
|
import vcr # type: ignore[import-untyped]
|
||
|
|
except ImportError:
|
||
|
|
raise ImportError(
|
||
|
|
"vcrpy is required to use caching. Install with:"
|
||
|
|
'pip install -U "langsmith[vcr]"'
|
||
|
|
)
|
||
|
|
# Fix concurrency issue in vcrpy's patching
|
||
|
|
from langsmith._internal import _patch as patch_urllib3
|
||
|
|
|
||
|
|
patch_urllib3.patch_urllib3()
|
||
|
|
|
||
|
|
cache_dir, cache_file = os.path.split(path)
|
||
|
|
|
||
|
|
ls_vcr = vcr.VCR(
|
||
|
|
serializer=(
|
||
|
|
"yaml"
|
||
|
|
if cache_file.endswith(".yaml") or cache_file.endswith(".yml")
|
||
|
|
else "json"
|
||
|
|
),
|
||
|
|
cassette_library_dir=cache_dir,
|
||
|
|
# Replay previous requests, record new ones
|
||
|
|
# TODO: Support other modes
|
||
|
|
record_mode="new_episodes",
|
||
|
|
match_on=["uri", "method", "path", "body"],
|
||
|
|
filter_headers=["authorization", "Set-Cookie"],
|
||
|
|
before_record_request=lambda request: filter_request_headers(
|
||
|
|
request, ignore_hosts=ignore_hosts, allow_hosts=allow_hosts
|
||
|
|
),
|
||
|
|
)
|
||
|
|
with ls_vcr.use_cassette(cache_file):
|
||
|
|
yield
|
||
|
|
|
||
|
|
|
||
|
|
@contextlib.contextmanager
|
||
|
|
def with_optional_cache(
|
||
|
|
path: Optional[Union[str, pathlib.Path]],
|
||
|
|
ignore_hosts: Optional[Sequence[str]] = None,
|
||
|
|
allow_hosts: Optional[Sequence[str]] = None,
|
||
|
|
) -> Generator[None, None, None]:
|
||
|
|
"""Use a cache for requests."""
|
||
|
|
if path is not None:
|
||
|
|
with with_cache(path, ignore_hosts, allow_hosts):
|
||
|
|
yield
|
||
|
|
else:
|
||
|
|
yield
|
||
|
|
|
||
|
|
|
||
|
|
def _format_exc() -> str:
|
||
|
|
# Used internally to format exceptions without cluttering the traceback
|
||
|
|
tb_lines = traceback.format_exception(*sys.exc_info())
|
||
|
|
filtered_lines = [line for line in tb_lines if "langsmith/" not in line]
|
||
|
|
return "".join(filtered_lines)
|
||
|
|
|
||
|
|
|
||
|
|
T = TypeVar("T")
|
||
|
|
|
||
|
|
|
||
|
|
def _middle_copy(
|
||
|
|
val: T, memo: dict[int, Any], max_depth: int = 4, _depth: int = 0
|
||
|
|
) -> T:
|
||
|
|
cls = type(val)
|
||
|
|
|
||
|
|
copier = getattr(cls, "__deepcopy__", None)
|
||
|
|
if copier is not None:
|
||
|
|
try:
|
||
|
|
return copier(memo)
|
||
|
|
except BaseException:
|
||
|
|
pass
|
||
|
|
if _depth >= max_depth:
|
||
|
|
return val
|
||
|
|
if isinstance(val, dict):
|
||
|
|
return { # type: ignore[return-value]
|
||
|
|
_middle_copy(k, memo, max_depth, _depth + 1): _middle_copy(
|
||
|
|
v, memo, max_depth, _depth + 1
|
||
|
|
)
|
||
|
|
for k, v in val.items()
|
||
|
|
}
|
||
|
|
if isinstance(val, list):
|
||
|
|
return [_middle_copy(item, memo, max_depth, _depth + 1) for item in val] # type: ignore[return-value]
|
||
|
|
if isinstance(val, tuple):
|
||
|
|
return tuple(_middle_copy(item, memo, max_depth, _depth + 1) for item in val) # type: ignore[return-value]
|
||
|
|
if isinstance(val, set):
|
||
|
|
return {_middle_copy(item, memo, max_depth, _depth + 1) for item in val} # type: ignore[return-value]
|
||
|
|
|
||
|
|
return val
|
||
|
|
|
||
|
|
|
||
|
|
def deepish_copy(val: T) -> T:
|
||
|
|
"""Deep copy a value with a compromise for uncopyable objects.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
val: The value to be deep copied.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The deep copied value.
|
||
|
|
"""
|
||
|
|
memo: dict[int, Any] = {}
|
||
|
|
try:
|
||
|
|
return copy.deepcopy(val, memo)
|
||
|
|
except BaseException as e:
|
||
|
|
# Generators, locks, etc. cannot be copied
|
||
|
|
# and raise a TypeError (mentioning pickling, since the dunder methods)
|
||
|
|
# are re-used for copying. We'll try to do a compromise and copy
|
||
|
|
# what we can
|
||
|
|
_LOGGER.debug("Failed to deepcopy input: %s", repr(e))
|
||
|
|
return _middle_copy(val, memo)
|
||
|
|
|
||
|
|
|
||
|
|
def is_version_greater_or_equal(current_version: str, target_version: str) -> bool:
|
||
|
|
"""Check if the current version is greater or equal to the target version."""
|
||
|
|
from packaging import version
|
||
|
|
|
||
|
|
current = version.parse(current_version)
|
||
|
|
target = version.parse(target_version)
|
||
|
|
return current >= target
|
||
|
|
|
||
|
|
|
||
|
|
def parse_prompt_identifier(identifier: str) -> tuple[str, str, str]:
|
||
|
|
"""Parse a string in the format of `owner/name:hash`, `name:hash`, `owner/name`, or `name`.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
identifier: The prompt identifier to parse.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A tuple containing `(owner, name, hash)`.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If the identifier doesn't match the expected formats.
|
||
|
|
""" # noqa: E501
|
||
|
|
if (
|
||
|
|
not identifier
|
||
|
|
or identifier.count("/") > 1
|
||
|
|
or identifier.startswith("/")
|
||
|
|
or identifier.endswith("/")
|
||
|
|
):
|
||
|
|
raise ValueError(f"Invalid identifier format: {identifier}")
|
||
|
|
|
||
|
|
parts = identifier.split(":", 1)
|
||
|
|
owner_name = parts[0]
|
||
|
|
commit = parts[1] if len(parts) > 1 else "latest"
|
||
|
|
|
||
|
|
if "/" in owner_name:
|
||
|
|
owner, name = owner_name.split("/", 1)
|
||
|
|
if not owner or not name:
|
||
|
|
raise ValueError(f"Invalid identifier format: {identifier}")
|
||
|
|
return owner, name, commit
|
||
|
|
else:
|
||
|
|
if not owner_name:
|
||
|
|
raise ValueError(f"Invalid identifier format: {identifier}")
|
||
|
|
return "-", owner_name, commit
|
||
|
|
|
||
|
|
|
||
|
|
P = ParamSpec("P")
|
||
|
|
|
||
|
|
|
||
|
|
class ContextThreadPoolExecutor(ThreadPoolExecutor):
|
||
|
|
"""ThreadPoolExecutor that copies the context to the child thread."""
|
||
|
|
|
||
|
|
def submit( # type: ignore[override]
|
||
|
|
self,
|
||
|
|
func: Callable[P, T],
|
||
|
|
*args: P.args,
|
||
|
|
**kwargs: P.kwargs,
|
||
|
|
) -> Future[T]:
|
||
|
|
"""Submit a function to the executor.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
func (Callable[..., T]): The function to submit.
|
||
|
|
*args (Any): The positional arguments to the function.
|
||
|
|
**kwargs (Any): The keyword arguments to the function.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Future[T]: The future for the function.
|
||
|
|
"""
|
||
|
|
return super().submit(
|
||
|
|
cast(
|
||
|
|
Callable[..., T],
|
||
|
|
functools.partial(
|
||
|
|
contextvars.copy_context().run, func, *args, **kwargs
|
||
|
|
),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
def map(
|
||
|
|
self,
|
||
|
|
fn: Callable[..., T],
|
||
|
|
*iterables: Iterable[Any],
|
||
|
|
timeout: Optional[float] = None,
|
||
|
|
chunksize: int = 1,
|
||
|
|
) -> Iterator[T]:
|
||
|
|
"""Return an iterator equivalent to stdlib map.
|
||
|
|
|
||
|
|
Each function will receive its own copy of the context from the parent thread.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
fn: A callable that will take as many arguments as there are
|
||
|
|
passed iterables.
|
||
|
|
timeout: The maximum number of seconds to wait. If None, then there
|
||
|
|
is no limit on the wait time.
|
||
|
|
chunksize: The size of the chunks the iterable will be broken into
|
||
|
|
before being passed to a child process. This argument is only
|
||
|
|
used by ProcessPoolExecutor; it is ignored by
|
||
|
|
ThreadPoolExecutor.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
An iterator equivalent to: map(func, *iterables) but the calls may
|
||
|
|
be evaluated out-of-order.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
TimeoutError: If the entire result iterator could not be generated
|
||
|
|
before the given timeout.
|
||
|
|
Exception: If fn(*args) raises for any values.
|
||
|
|
"""
|
||
|
|
contexts = [contextvars.copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type]
|
||
|
|
|
||
|
|
def _wrapped_fn(*args: Any) -> T:
|
||
|
|
return contexts.pop().run(fn, *args)
|
||
|
|
|
||
|
|
return super().map(
|
||
|
|
_wrapped_fn,
|
||
|
|
*iterables,
|
||
|
|
timeout=timeout,
|
||
|
|
chunksize=chunksize,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_api_url(api_url: Optional[str]) -> str:
|
||
|
|
"""Get the LangSmith API URL from the environment or the given value."""
|
||
|
|
_api_url = api_url or cast(
|
||
|
|
str,
|
||
|
|
get_env_var(
|
||
|
|
"ENDPOINT",
|
||
|
|
default="https://api.smith.langchain.com",
|
||
|
|
),
|
||
|
|
)
|
||
|
|
if not _api_url.strip():
|
||
|
|
raise LangSmithUserError("LangSmith API URL cannot be empty")
|
||
|
|
return _api_url.strip().strip('"').strip("'").rstrip("/")
|
||
|
|
|
||
|
|
|
||
|
|
def get_api_key(api_key: Optional[str]) -> Optional[str]:
|
||
|
|
"""Get the API key from the environment or the given value."""
|
||
|
|
api_key_ = api_key if api_key is not None else get_env_var("API_KEY", default=None)
|
||
|
|
if api_key_ is None or not api_key_.strip():
|
||
|
|
return None
|
||
|
|
return api_key_.strip().strip('"').strip("'")
|
||
|
|
|
||
|
|
|
||
|
|
def get_workspace_id(workspace_id: Optional[str]) -> Optional[str]:
|
||
|
|
"""Get workspace ID."""
|
||
|
|
workspace_id_ = (
|
||
|
|
workspace_id
|
||
|
|
if workspace_id is not None
|
||
|
|
else get_env_var("WORKSPACE_ID", default=None)
|
||
|
|
)
|
||
|
|
if workspace_id_ is None or not workspace_id_.strip():
|
||
|
|
return None
|
||
|
|
return workspace_id_.strip().strip('"').strip("'")
|
||
|
|
|
||
|
|
|
||
|
|
def _is_localhost(url: str) -> bool:
|
||
|
|
"""Check if the URL is localhost.
|
||
|
|
|
||
|
|
Parameters
|
||
|
|
----------
|
||
|
|
url : str
|
||
|
|
The URL to check.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
-------
|
||
|
|
bool
|
||
|
|
True if the URL is localhost, False otherwise.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
netloc = urllib_parse.urlsplit(url).netloc.split(":")[0]
|
||
|
|
ip = socket.gethostbyname(netloc)
|
||
|
|
return ip == "127.0.0.1" or ip.startswith("0.0.0.0") or ip.startswith("::")
|
||
|
|
except socket.gaierror:
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
@functools.lru_cache(maxsize=2)
|
||
|
|
def get_host_url(web_url: Optional[str], api_url: str):
|
||
|
|
"""Get the host URL based on the web URL or API URL."""
|
||
|
|
if web_url:
|
||
|
|
return web_url
|
||
|
|
parsed_url = urllib_parse.urlparse(api_url)
|
||
|
|
if _is_localhost(api_url):
|
||
|
|
link = "http://localhost"
|
||
|
|
elif str(parsed_url.path).endswith("/api"):
|
||
|
|
new_path = str(parsed_url.path).rsplit("/api", 1)[0]
|
||
|
|
link = urllib_parse.urlunparse(parsed_url._replace(path=new_path))
|
||
|
|
elif str(parsed_url.path).endswith("/api/v1"):
|
||
|
|
new_path = str(parsed_url.path).rsplit("/api/v1", 1)[0]
|
||
|
|
link = urllib_parse.urlunparse(parsed_url._replace(path=new_path))
|
||
|
|
elif str(parsed_url.netloc).startswith("eu."):
|
||
|
|
link = "https://eu.smith.langchain.com"
|
||
|
|
elif str(parsed_url.netloc).startswith("dev."):
|
||
|
|
link = "https://dev.smith.langchain.com"
|
||
|
|
elif str(parsed_url.netloc).startswith("beta."):
|
||
|
|
link = "https://beta.smith.langchain.com"
|
||
|
|
else:
|
||
|
|
link = "https://smith.langchain.com"
|
||
|
|
return link
|
||
|
|
|
||
|
|
|
||
|
|
def _get_function_name(fn: Callable, depth: int = 0) -> str:
|
||
|
|
if depth > 2 or not callable(fn):
|
||
|
|
return str(fn)
|
||
|
|
|
||
|
|
if hasattr(fn, "__name__"):
|
||
|
|
return fn.__name__
|
||
|
|
|
||
|
|
if isinstance(fn, functools.partial):
|
||
|
|
return _get_function_name(fn.func, depth + 1)
|
||
|
|
|
||
|
|
if hasattr(fn, "__call__"):
|
||
|
|
if hasattr(fn, "__class__") and hasattr(fn.__class__, "__name__"):
|
||
|
|
return fn.__class__.__name__
|
||
|
|
return _get_function_name(fn.__call__, depth + 1)
|
||
|
|
|
||
|
|
return str(fn)
|
||
|
|
|
||
|
|
|
||
|
|
def is_truish(val: Any) -> bool:
|
||
|
|
"""Check if the value is truish.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
val (Any): The value to check.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
bool: True if the value is truish, False otherwise.
|
||
|
|
"""
|
||
|
|
if isinstance(val, str):
|
||
|
|
return val.lower() == "true" or val == "1"
|
||
|
|
return bool(val)
|