group-wbl/.venv/lib/python3.13/site-packages/langsmith/run_trees.py
2026-01-09 09:48:03 +08:00

1149 lines
42 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Schemas for the LangSmith API."""
from __future__ import annotations
import contextvars
import functools
import json
import logging
import sys
import threading
import urllib.parse
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Any, Optional, Union, cast
from uuid import UUID
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import TypedDict
import langsmith._internal._context as _context
from langsmith import schemas as ls_schemas
from langsmith import utils
from langsmith._internal._uuid import uuid7, uuid7_deterministic
from langsmith.client import ID_TYPE, RUN_TYPE_T, Client, _dumps_json, _ensure_uuid
from langsmith.uuid import uuid7_from_datetime
logger = logging.getLogger(__name__)
class WriteReplica(TypedDict, total=False):
api_url: Optional[str]
api_key: Optional[str]
project_name: Optional[str]
updates: Optional[dict]
LANGSMITH_PREFIX = "langsmith-"
LANGSMITH_DOTTED_ORDER = sys.intern(f"{LANGSMITH_PREFIX}trace")
LANGSMITH_DOTTED_ORDER_BYTES = LANGSMITH_DOTTED_ORDER.encode("utf-8")
LANGSMITH_METADATA = sys.intern(f"{LANGSMITH_PREFIX}metadata")
LANGSMITH_TAGS = sys.intern(f"{LANGSMITH_PREFIX}tags")
LANGSMITH_PROJECT = sys.intern(f"{LANGSMITH_PREFIX}project")
LANGSMITH_REPLICAS = sys.intern(f"{LANGSMITH_PREFIX}replicas")
OVERRIDE_OUTPUTS = sys.intern("__omit_auto_outputs")
NOT_PROVIDED = cast(None, object())
_LOCK = threading.Lock()
# Context variables
_REPLICAS = contextvars.ContextVar[Optional[Sequence[WriteReplica]]](
"_REPLICAS", default=None
)
_DISTRIBUTED_PARENT_ID = contextvars.ContextVar[Optional[str]](
"_DISTRIBUTED_PARENT_ID", default=None
)
_SENTINEL = cast(None, object())
TIMESTAMP_LENGTH = 36
# Note, this is called directly by langchain. Do not remove.
def get_cached_client(**init_kwargs: Any) -> Client:
global _CLIENT
if _CLIENT is None:
with _LOCK:
if _CLIENT is None:
_CLIENT = Client(**init_kwargs)
return _CLIENT
def configure(
client: Optional[Client] = _SENTINEL,
enabled: Optional[bool] = _SENTINEL,
project_name: Optional[str] = _SENTINEL,
tags: Optional[list[str]] = _SENTINEL,
metadata: Optional[dict[str, Any]] = _SENTINEL,
):
"""Configure global LangSmith tracing context.
This function allows you to set global configuration options for LangSmith
tracing that will be applied to all subsequent traced operations. It modifies
context variables that control tracing behavior across your application.
Do this once at startup to configure the global settings in code.
If, instead, you wish to only configure tracing for a single invocation,
use the `tracing_context` context manager instead.
Args:
client: A LangSmith Client instance to use for all tracing operations.
If provided, this client will be used instead of creating new clients.
Pass `None` to explicitly clear the global client.
enabled: Whether tracing is enabled.
Can be:
- `True`: Enable tracing and send data to LangSmith
- `False`: Disable tracing completely
- `'local'`: Enable tracing but only store data locally
- `None`: Clear the setting (falls back to environment variables)
project_name: The LangSmith project name where traces will be sent.
This determines which project dashboard will display your traces.
Pass `None` to explicitly clear the project name.
tags: A list of tags to be applied to all traced runs.
Tags are useful for filtering and organizing runs in the LangSmith UI.
Pass `None` to explicitly clear all global tags.
metadata: A dictionary of metadata to attach to all traced runs.
Metadata can store any additional context about your runs.
Pass `None` to explicitly clear all global metadata.
Examples:
Basic configuration:
>>> import langsmith as ls
>>> # Enable tracing with a specific project
>>> ls.configure(enabled=True, project_name="my-project")
Set global trace masking:
>>> def hide_keys(data):
... if not data:
... return {}
... return {k: v for k, v in data.items() if k not in ["key1", "key2"]}
>>> ls.configure(
... client=ls.Client(
... hide_inputs=hide_keys,
... hide_outputs=hide_keys,
... )
... )
Adding global tags and metadata:
>>> ls.configure(
... tags=["production", "v1.0"],
... metadata={"environment": "prod", "version": "1.0.0"},
... )
Disabling tracing:
>>> ls.configure(enabled=False)
"""
global _CLIENT
with _LOCK:
if client is not _SENTINEL:
_CLIENT = client
if enabled is not _SENTINEL:
_context._TRACING_ENABLED.set(enabled)
_context._GLOBAL_TRACING_ENABLED = enabled
if project_name is not _SENTINEL:
_context._PROJECT_NAME.set(project_name)
_context._GLOBAL_PROJECT_NAME = project_name
if tags is not _SENTINEL:
_context._TAGS.set(tags)
_context._GLOBAL_TAGS = tags
if metadata is not _SENTINEL:
_context._METADATA.set(metadata)
_context._GLOBAL_METADATA = metadata
def validate_extracted_usage_metadata(
data: ls_schemas.ExtractedUsageMetadata,
) -> ls_schemas.ExtractedUsageMetadata:
"""Validate that the dict only contains allowed keys."""
allowed_keys = {
"input_tokens",
"output_tokens",
"total_tokens",
"input_token_details",
"output_token_details",
"input_cost",
"output_cost",
"total_cost",
"input_cost_details",
"output_cost_details",
}
extra_keys = set(data.keys()) - allowed_keys
if extra_keys:
raise ValueError(f"Unexpected keys in usage metadata: {extra_keys}")
return data # type: ignore
class RunTree(ls_schemas.RunBase):
"""Run Schema with back-references for posting runs."""
name: str
id: UUID = Field(default_factory=uuid7)
run_type: str = Field(default="chain")
start_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
# Note: no longer set.
parent_run: Optional[RunTree] = Field(default=None, exclude=True)
parent_dotted_order: Optional[str] = Field(default=None, exclude=True)
child_runs: list[RunTree] = Field(
default_factory=list,
exclude=True,
)
session_name: str = Field(
default_factory=lambda: utils.get_tracer_project() or "default",
alias="project_name",
)
session_id: Optional[UUID] = Field(default=None, alias="project_id")
extra: dict = Field(default_factory=dict)
tags: Optional[list[str]] = Field(default_factory=list)
events: list[dict] = Field(default_factory=list)
"""List of events associated with the run, like
start and end events."""
ls_client: Optional[Any] = Field(default=None, exclude=True)
dotted_order: str = Field(
default="", description="The order of the run in the tree."
)
trace_id: UUID = Field(default="", description="The trace id of the run.") # type: ignore
dangerously_allow_filesystem: Optional[bool] = Field(
default=False, description="Whether to allow filesystem access for attachments."
)
replicas: Optional[Sequence[WriteReplica]] = Field(
default=None,
description="Projects to replicate this run to with optional updates.",
)
model_config = ConfigDict(
arbitrary_types_allowed=True,
populate_by_name=True,
extra="ignore",
)
@model_validator(mode="before")
def infer_defaults(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Assign name to the run."""
if values.get("name") is None and values.get("serialized") is not None:
if "name" in values["serialized"]:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
if values.get("name") is None:
values["name"] = "Unnamed"
if "client" in values: # Handle user-constructed clients
values["ls_client"] = values.pop("client")
elif "_client" in values:
values["ls_client"] = values.pop("_client")
if not values.get("ls_client"):
values["ls_client"] = None
parent_run = values.pop("parent_run", None)
if parent_run is not None:
values["parent_run_id"] = parent_run.id
values["parent_dotted_order"] = parent_run.dotted_order
if "id" not in values:
# Generate UUID from start_time if available
if "start_time" in values and values["start_time"] is not None:
values["id"] = uuid7_from_datetime(values["start_time"])
else:
now = datetime.now(timezone.utc)
values["start_time"] = now
values["id"] = uuid7_from_datetime(now)
if "trace_id" not in values:
if parent_run is not None:
values["trace_id"] = parent_run.trace_id
else:
values["trace_id"] = values["id"]
cast(dict, values.setdefault("extra", {}))
if values.get("events") is None:
values["events"] = []
if values.get("tags") is None:
values["tags"] = []
if values.get("outputs") is None:
values["outputs"] = {}
if values.get("attachments") is None:
values["attachments"] = {}
if values.get("replicas") is None:
values["replicas"] = _REPLICAS.get()
values["replicas"] = _ensure_write_replicas(values["replicas"])
return values
@model_validator(mode="after")
def ensure_dotted_order(self) -> RunTree:
"""Ensure the dotted order of the run."""
current_dotted_order = self.dotted_order
if current_dotted_order and current_dotted_order.strip():
return self
current_dotted_order = _create_current_dotted_order(self.start_time, self.id)
parent_dotted_order = self.parent_dotted_order
if parent_dotted_order is not None:
self.dotted_order = parent_dotted_order + "." + current_dotted_order
else:
self.dotted_order = current_dotted_order
return self
@property
def client(self) -> Client:
"""Return the client."""
# Lazily load the client
# If you never use this for API calls, it will never be loaded
if self.ls_client is None:
self.ls_client = get_cached_client()
return self.ls_client
@property
def _client(self) -> Optional[Client]:
# For backwards compat
return self.ls_client
def __setattr__(self, name, value):
"""Set the `_client` specially."""
# For backwards compat
if name == "_client":
self.ls_client = value
else:
return super().__setattr__(name, value)
def set(
self,
*,
inputs: Optional[Mapping[str, Any]] = NOT_PROVIDED,
outputs: Optional[Mapping[str, Any]] = NOT_PROVIDED,
tags: Optional[Sequence[str]] = NOT_PROVIDED,
metadata: Optional[Mapping[str, Any]] = NOT_PROVIDED,
usage_metadata: Optional[ls_schemas.ExtractedUsageMetadata] = NOT_PROVIDED,
) -> None:
"""Set the inputs, outputs, tags, and metadata of the run.
If performed, this will override the default behavior of the
end() method to ignore new outputs (that would otherwise be added)
by the @traceable decorator.
If your LangChain or LangGraph versions are sufficiently up-to-date,
this will also override the default behavior of `LangChainTracer`.
Args:
inputs: The inputs to set.
outputs: The outputs to set.
tags: The tags to set.
metadata: The metadata to set.
usage_metadata: Usage information to set.
Returns:
None
"""
if tags is not NOT_PROVIDED:
self.tags = list(tags)
if metadata is not NOT_PROVIDED:
self.extra.setdefault("metadata", {}).update(metadata or {})
if inputs is not NOT_PROVIDED:
# Used by LangChain core to determine whether to
# re-upload the inputs upon run completion
self.extra["inputs_is_truthy"] = False
if inputs is None:
self.inputs = {}
else:
self.inputs = dict(inputs)
if outputs is not NOT_PROVIDED:
self.extra[OVERRIDE_OUTPUTS] = True
if outputs is None:
self.outputs = {}
else:
self.outputs = dict(outputs)
if usage_metadata is not NOT_PROVIDED:
self.extra.setdefault("metadata", {})["usage_metadata"] = (
validate_extracted_usage_metadata(usage_metadata)
)
def add_tags(self, tags: Union[Sequence[str], str]) -> None:
"""Add tags to the run."""
if isinstance(tags, str):
tags = [tags]
if self.tags is None:
self.tags = []
self.tags.extend(tags)
def add_metadata(self, metadata: dict[str, Any]) -> None:
"""Add metadata to the run."""
if self.extra is None:
self.extra = {}
metadata_: dict = cast(dict, self.extra).setdefault("metadata", {})
metadata_.update(metadata)
def add_outputs(self, outputs: dict[str, Any]) -> None:
"""Upsert the given outputs into the run.
Args:
outputs: A dictionary containing the outputs to be added.
"""
if self.outputs is None:
self.outputs = {}
self.outputs.update(outputs)
def add_inputs(self, inputs: dict[str, Any]) -> None:
"""Upsert the given inputs into the run.
Args:
inputs: A dictionary containing the inputs to be added.
"""
if self.inputs is None:
self.inputs = {}
self.inputs.update(inputs)
# Set to False so LangChain things it needs to
# re-upload inputs
self.extra["inputs_is_truthy"] = False
def add_event(
self,
events: Union[
ls_schemas.RunEvent,
Sequence[ls_schemas.RunEvent],
Sequence[dict],
dict,
str,
],
) -> None:
"""Add an event to the list of events.
Args:
events: The event(s) to be added. It can be a single event, a sequence
of events, a sequence of dictionaries, a dictionary, or a string.
Returns:
None
"""
if self.events is None:
self.events = []
if isinstance(events, dict):
self.events.append(events) # type: ignore[arg-type]
elif isinstance(events, str):
self.events.append(
{
"name": "event",
"time": datetime.now(timezone.utc).isoformat(),
"message": events,
}
)
else:
self.events.extend(events) # type: ignore[arg-type]
def end(
self,
*,
outputs: Optional[dict] = None,
error: Optional[str] = None,
end_time: Optional[datetime] = None,
events: Optional[Sequence[ls_schemas.RunEvent]] = None,
metadata: Optional[dict[str, Any]] = None,
) -> None:
"""Set the end time of the run and all child runs."""
self.end_time = end_time or datetime.now(timezone.utc)
# We've already 'set' the outputs, so ignore
# the ones that are automatically included
if not self.extra.get(OVERRIDE_OUTPUTS):
if outputs is not None:
if not self.outputs:
self.outputs = outputs
else:
self.outputs.update(outputs)
if error is not None:
self.error = error
if events is not None:
self.add_event(events)
if metadata is not None:
self.add_metadata(metadata)
def create_child(
self,
name: str,
run_type: RUN_TYPE_T = "chain",
*,
run_id: Optional[ID_TYPE] = None,
serialized: Optional[dict] = None,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
error: Optional[str] = None,
reference_example_id: Optional[UUID] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
tags: Optional[list[str]] = None,
extra: Optional[dict] = None,
attachments: Optional[ls_schemas.Attachments] = None,
) -> RunTree:
"""Add a child run to the run tree."""
# Ensure child start_time is never earlier than parent start_time
# to prevent timestamp ordering violations in dotted_order
if start_time is not None and self.start_time is not None:
if start_time < self.start_time:
logger.debug(
f"Adjusting child run '{name}' start_time from {start_time} "
f"to {self.start_time} to maintain timestamp ordering with "
f"parent '{self.name}'"
)
start_time = max(start_time, self.start_time)
serialized_ = serialized or {"name": name}
run = RunTree(
name=name,
id=_ensure_uuid(run_id),
serialized=serialized_,
inputs=inputs or {},
outputs=outputs or {},
error=error,
run_type=run_type,
reference_example_id=reference_example_id,
start_time=start_time or datetime.now(timezone.utc),
end_time=end_time,
extra=extra or {},
parent_run=self,
project_name=self.session_name,
replicas=self.replicas,
ls_client=self.ls_client,
tags=tags,
attachments=attachments or {}, # type: ignore
dangerously_allow_filesystem=self.dangerously_allow_filesystem,
)
return run
def _get_dicts_safe(self):
# Things like generators cannot be copied
self_dict = self.model_dump(
exclude={"child_runs", "inputs", "outputs"}, exclude_none=True
)
if self.inputs is not None:
# shallow copy. deep copying will occur in the client
inputs_ = {}
attachments = self_dict.get("attachments", {})
for k, v in self.inputs.items():
if isinstance(v, ls_schemas.Attachment):
attachments[k] = v
else:
inputs_[k] = v
self_dict["inputs"] = inputs_
if attachments:
self_dict["attachments"] = attachments
if self.outputs is not None:
# shallow copy; deep copying will occur in the client
self_dict["outputs"] = self.outputs.copy()
return self_dict
def _slice_parent_id(self, parent_id: str, run_dict: dict) -> None:
"""Slice the parent id from dotted order.
Additionally check if the current run is a child of the parent. If so, update
the parent_run_id to None, and set the trace id to the new root id after
parent_id.
"""
if dotted_order := run_dict.get("dotted_order"):
segs = dotted_order.split(".")
start_idx = None
parent_id = str(parent_id)
# TODO(angus): potentially use binary search to find the index
for idx, part in enumerate(segs):
seg_id = part[-TIMESTAMP_LENGTH:]
if str(seg_id) == parent_id:
start_idx = idx
break
if start_idx is not None:
# Trim segments to start after parent_id (exclusive)
trimmed_segs = segs[start_idx + 1 :]
# Rebuild dotted_order
run_dict["dotted_order"] = ".".join(trimmed_segs)
if trimmed_segs:
run_dict["trace_id"] = UUID(trimmed_segs[0][-TIMESTAMP_LENGTH:])
else:
run_dict["trace_id"] = run_dict["id"]
if str(run_dict.get("parent_run_id")) == parent_id:
# We've found the new root node.
run_dict.pop("parent_run_id", None)
def _remap_for_project(
self, project_name: str, updates: Optional[dict] = None
) -> dict:
"""Rewrites ids/dotted_order for a given project with optional updates."""
run_dict = self._get_dicts_safe()
if project_name == self.session_name:
return run_dict
if updates and updates.get("reroot", False):
distributed_parent_id = _DISTRIBUTED_PARENT_ID.get()
if distributed_parent_id:
self._slice_parent_id(distributed_parent_id, run_dict)
old_id = run_dict["id"]
new_id = uuid7_deterministic(UUID(str(old_id)), project_name)
# trace id
old_trace = run_dict.get("trace_id")
if old_trace:
new_trace = uuid7_deterministic(UUID(str(old_trace)), project_name)
else:
new_trace = None
# parent id
parent = run_dict.get("parent_run_id")
if parent:
new_parent = uuid7_deterministic(UUID(str(parent)), project_name)
else:
new_parent = None
# dotted order
if run_dict.get("dotted_order"):
segs = run_dict["dotted_order"].split(".")
rebuilt = []
for part in segs[:-1]:
seg_id = UUID(part[-TIMESTAMP_LENGTH:])
repl = uuid7_deterministic(seg_id, project_name)
rebuilt.append(part[:-TIMESTAMP_LENGTH] + str(repl))
rebuilt.append(segs[-1][:-TIMESTAMP_LENGTH] + str(new_id))
dotted = ".".join(rebuilt)
else:
dotted = None
dup = utils.deepish_copy(run_dict)
dup.update(
{
"id": new_id,
"trace_id": new_trace,
"parent_run_id": new_parent,
"dotted_order": dotted,
"session_name": project_name,
}
)
if updates:
dup.update(updates)
return dup
def post(self, exclude_child_runs: bool = True) -> None:
"""Post the run tree to the API asynchronously."""
if self.replicas:
for replica in self.replicas:
project_name = replica.get("project_name") or self.session_name
updates = replica.get("updates")
run_dict = self._remap_for_project(project_name, updates)
self.client.create_run(
**run_dict,
api_key=replica.get("api_key"),
api_url=replica.get("api_url"),
)
else:
kwargs = self._get_dicts_safe()
self.client.create_run(**kwargs)
if self.attachments:
keys = [str(name) for name in self.attachments]
self.events.append(
{
"name": "uploaded_attachment",
"time": datetime.now(timezone.utc).isoformat(),
"message": set(keys),
}
)
if not exclude_child_runs:
for child_run in self.child_runs:
child_run.post(exclude_child_runs=False)
def patch(self, *, exclude_inputs: bool = False) -> None:
"""Patch the run tree to the API in a background thread.
Args:
exclude_inputs: Whether to exclude inputs from the patch request.
"""
if not self.end_time:
self.end()
attachments = {
a: v for a, v in self.attachments.items() if isinstance(v, tuple)
}
try:
# Avoid loading the same attachment twice
if attachments:
uploaded = next(
(
ev
for ev in self.events
if ev.get("name") == "uploaded_attachment"
),
None,
)
if uploaded:
attachments = {
a: v
for a, v in attachments.items()
if a not in uploaded["message"]
}
except Exception as e:
logger.warning(f"Error filtering attachments to upload: {e}")
if self.replicas:
for replica in self.replicas:
project_name = replica.get("project_name") or self.session_name
updates = replica.get("updates")
run_dict = self._remap_for_project(project_name, updates)
self.client.update_run(
name=run_dict["name"],
run_id=run_dict["id"],
run_type=run_dict.get("run_type"),
start_time=run_dict.get("start_time"),
inputs=None if exclude_inputs else run_dict["inputs"],
outputs=run_dict["outputs"],
error=run_dict.get("error"),
parent_run_id=run_dict.get("parent_run_id"),
session_name=run_dict.get("session_name"),
reference_example_id=run_dict.get("reference_example_id"),
end_time=run_dict.get("end_time"),
dotted_order=run_dict.get("dotted_order"),
trace_id=run_dict.get("trace_id"),
events=run_dict.get("events"),
tags=run_dict.get("tags"),
extra=run_dict.get("extra"),
attachments=attachments,
api_key=replica.get("api_key"),
api_url=replica.get("api_url"),
)
else:
self.client.update_run(
name=self.name,
run_id=self.id,
run_type=cast(RUN_TYPE_T, self.run_type),
start_time=self.start_time,
inputs=(
None
if exclude_inputs
else (self.inputs.copy() if self.inputs else None)
),
outputs=self.outputs.copy() if self.outputs else None,
error=self.error,
parent_run_id=self.parent_run_id,
session_name=self.session_name,
reference_example_id=self.reference_example_id,
end_time=self.end_time,
dotted_order=self.dotted_order,
trace_id=self.trace_id,
events=self.events,
tags=self.tags,
extra=self.extra,
attachments=attachments,
)
def wait(self) -> None:
"""Wait for all `_futures` to complete."""
pass
def get_url(self) -> str:
"""Return the URL of the run."""
return self.client.get_run_url(run=self)
@classmethod
def from_dotted_order(
cls,
dotted_order: str,
**kwargs: Any,
) -> RunTree:
"""Create a new 'child' span from the provided dotted order.
Returns:
RunTree: The new span.
"""
headers = {
LANGSMITH_DOTTED_ORDER: dotted_order,
}
return cast(RunTree, cls.from_headers(headers, **kwargs)) # type: ignore[arg-type]
@classmethod
def from_runnable_config(
cls,
config: Optional[dict],
**kwargs: Any,
) -> Optional[RunTree]:
"""Create a new 'child' span from the provided runnable config.
Requires `langchain` to be installed.
Returns:
The new span or `None` if no parent span information is found.
"""
try:
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
CallbackManager,
)
from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_core.tracers.langchain import LangChainTracer
except ImportError as e:
raise ImportError(
"RunTree.from_runnable_config requires langchain-core to be installed. "
"You can install it with `pip install langchain-core`."
) from e
if config is None:
config_ = ensure_config(
cast(RunnableConfig, config) if isinstance(config, dict) else None
)
else:
config_ = cast(RunnableConfig, config)
if (
(cb := config_.get("callbacks"))
and isinstance(cb, (CallbackManager, AsyncCallbackManager))
and cb.parent_run_id
and (
tracer := next(
(t for t in cb.handlers if isinstance(t, LangChainTracer)),
None,
)
)
):
if (run := tracer.run_map.get(str(cb.parent_run_id))) and run.dotted_order:
dotted_order = run.dotted_order
kwargs["run_type"] = run.run_type
kwargs["inputs"] = run.inputs
kwargs["outputs"] = run.outputs
kwargs["start_time"] = run.start_time
kwargs["end_time"] = run.end_time
kwargs["tags"] = sorted(set(run.tags or [] + kwargs.get("tags", [])))
kwargs["name"] = run.name
extra_ = kwargs.setdefault("extra", {})
metadata_ = extra_.setdefault("metadata", {})
metadata_.update(run.metadata)
elif hasattr(tracer, "order_map") and cb.parent_run_id in tracer.order_map:
dotted_order = tracer.order_map[cb.parent_run_id][1]
else:
return None
kwargs["client"] = tracer.client
kwargs["project_name"] = tracer.project_name
return RunTree.from_dotted_order(dotted_order, **kwargs)
return None
@classmethod
def from_headers(
cls, headers: Mapping[Union[str, bytes], Union[str, bytes]], **kwargs: Any
) -> Optional[RunTree]:
"""Create a new 'parent' span from the provided headers.
Extracts parent span information from the headers and creates a new span.
Metadata and tags are extracted from the baggage header.
The dotted order and trace id are extracted from the trace header.
Returns:
The new span or `None` if no parent span information is found.
"""
init_args = kwargs.copy()
langsmith_trace = cast(Optional[str], headers.get(LANGSMITH_DOTTED_ORDER))
if not langsmith_trace:
langsmith_trace_bytes = cast(
Optional[bytes], headers.get(LANGSMITH_DOTTED_ORDER_BYTES)
)
if not langsmith_trace_bytes:
return # type: ignore[return-value]
langsmith_trace = langsmith_trace_bytes.decode("utf-8")
parent_dotted_order = langsmith_trace.strip()
parsed_dotted_order = _parse_dotted_order(parent_dotted_order)
trace_id = parsed_dotted_order[0][1]
init_args["trace_id"] = trace_id
init_args["id"] = parsed_dotted_order[-1][1]
init_args["dotted_order"] = parent_dotted_order
if len(parsed_dotted_order) >= 2:
# Has a parent
init_args["parent_run_id"] = parsed_dotted_order[-2][1]
# All placeholders. We assume the source process
# handles the life-cycle of the run.
init_args["start_time"] = init_args.get("start_time") or datetime.now(
timezone.utc
)
init_args["run_type"] = init_args.get("run_type") or "chain"
init_args["name"] = init_args.get("name") or "parent"
baggage = _Baggage.from_headers(headers)
if baggage.metadata or baggage.tags:
init_args["extra"] = init_args.setdefault("extra", {})
init_args["extra"]["metadata"] = init_args["extra"].setdefault(
"metadata", {}
)
metadata = {**baggage.metadata, **init_args["extra"]["metadata"]}
init_args["extra"]["metadata"] = metadata
tags = sorted(set(baggage.tags + init_args.get("tags", [])))
init_args["tags"] = tags
if baggage.project_name:
init_args["project_name"] = baggage.project_name
if baggage.replicas:
init_args["replicas"] = baggage.replicas
run_tree = RunTree(**init_args)
# Set the distributed parent ID to this run's ID for rerooting
_DISTRIBUTED_PARENT_ID.set(str(run_tree.id))
return run_tree
def to_headers(self) -> dict[str, str]:
"""Return the `RunTree` as a dictionary of headers."""
headers = {}
if self.trace_id:
headers[f"{LANGSMITH_DOTTED_ORDER}"] = self.dotted_order
baggage = _Baggage(
metadata=self.extra.get("metadata", {}),
tags=self.tags,
project_name=self.session_name,
replicas=self.replicas,
)
headers["baggage"] = baggage.to_header()
return headers
def __repr__(self):
"""Return a string representation of the `RunTree` object."""
return (
f"RunTree(id={self.id}, name='{self.name}', "
f"run_type='{self.run_type}', dotted_order='{self.dotted_order}')"
)
class _Baggage:
"""Baggage header information."""
def __init__(
self,
metadata: Optional[dict[str, str]] = None,
tags: Optional[list[str]] = None,
project_name: Optional[str] = None,
replicas: Optional[Sequence[WriteReplica]] = None,
):
"""Initialize the Baggage object."""
self.metadata = metadata or {}
self.tags = tags or []
self.project_name = project_name
self.replicas = replicas or []
@classmethod
def from_header(cls, header_value: Optional[str]) -> _Baggage:
"""Create a Baggage object from the given header value."""
if not header_value:
return cls()
metadata = {}
tags = []
project_name = None
replicas: Optional[list[WriteReplica]] = None
try:
for item in header_value.split(","):
key, value = item.split("=", 1)
if key == LANGSMITH_METADATA:
metadata = json.loads(urllib.parse.unquote(value))
elif key == LANGSMITH_TAGS:
tags = urllib.parse.unquote(value).split(",")
elif key == LANGSMITH_PROJECT:
project_name = urllib.parse.unquote(value)
elif key == LANGSMITH_REPLICAS:
replicas_data = json.loads(urllib.parse.unquote(value))
parsed_replicas: list[WriteReplica] = []
for replica_item in replicas_data:
if (
isinstance(replica_item, (tuple, list))
and len(replica_item) == 2
):
# Convert legacy format to WriteReplica
parsed_replicas.append(
WriteReplica(
api_url=None,
api_key=None,
project_name=str(replica_item[0]),
updates=replica_item[1],
)
)
elif isinstance(replica_item, dict):
# New WriteReplica format: preserve as dict
parsed_replicas.append(cast(WriteReplica, replica_item))
else:
logger.warning(
f"Unknown replica format in baggage: {replica_item}"
)
continue
replicas = parsed_replicas
except Exception as e:
logger.warning(f"Error parsing baggage header: {e}")
return cls(
metadata=metadata, tags=tags, project_name=project_name, replicas=replicas
)
@classmethod
def from_headers(cls, headers: Mapping[Union[str, bytes], Any]) -> _Baggage:
if "baggage" in headers:
return cls.from_header(headers["baggage"])
elif b"baggage" in headers:
return cls.from_header(cast(bytes, headers[b"baggage"]).decode("utf-8"))
else:
return cls.from_header(None)
def to_header(self) -> str:
"""Return the Baggage object as a header value."""
items = []
if self.metadata:
serialized_metadata = _dumps_json(self.metadata)
items.append(
f"{LANGSMITH_PREFIX}metadata={urllib.parse.quote(serialized_metadata)}"
)
if self.tags:
serialized_tags = ",".join(self.tags)
items.append(
f"{LANGSMITH_PREFIX}tags={urllib.parse.quote(serialized_tags)}"
)
if self.project_name:
items.append(
f"{LANGSMITH_PREFIX}project={urllib.parse.quote(self.project_name)}"
)
if self.replicas:
serialized_replicas = _dumps_json(self.replicas)
items.append(
f"{LANGSMITH_PREFIX}replicas={urllib.parse.quote(serialized_replicas)}"
)
return ",".join(items)
@functools.lru_cache(maxsize=1)
def _parse_write_replicas_from_env_var(env_var: Optional[str]) -> list[WriteReplica]:
"""Parse write replicas from LANGSMITH_RUNS_ENDPOINTS environment variable value.
Supports array format [{"api_url": "x", "api_key": "y"}] and object format
{"url": "key"}.
"""
if not env_var:
return []
try:
parsed = json.loads(env_var)
if isinstance(parsed, list):
replicas = []
for item in parsed:
if not isinstance(item, dict):
logger.warning(
f"Invalid item type in LANGSMITH_RUNS_ENDPOINTS: "
f"expected dict, got {type(item).__name__}"
)
continue
api_url = item.get("api_url")
api_key = item.get("api_key")
if not isinstance(api_url, str):
logger.warning(
f"Invalid api_url type in LANGSMITH_RUNS_ENDPOINTS: "
f"expected string, got {type(api_url).__name__}"
)
continue
if not isinstance(api_key, str):
logger.warning(
f"Invalid api_key type in LANGSMITH_RUNS_ENDPOINTS: "
f"expected string, got {type(api_key).__name__}"
)
continue
replicas.append(
WriteReplica(
api_url=api_url.rstrip("/"),
api_key=api_key,
project_name=None,
updates=None,
)
)
return replicas
elif isinstance(parsed, dict):
_check_endpoint_env_unset(parsed)
replicas = []
for url, key in parsed.items():
url = url.rstrip("/")
if isinstance(key, str):
replicas.append(
WriteReplica(
api_url=url,
api_key=key,
project_name=None,
updates=None,
)
)
else:
logger.warning(
f"Invalid value type in LANGSMITH_RUNS_ENDPOINTS for URL "
f"{url}: "
f"expected string, got {type(key).__name__}"
)
continue
return replicas
else:
logger.warning(
f"Invalid LANGSMITH_RUNS_ENDPOINTS must be valid JSON list of "
"objects with api_url and api_key properties, or object mapping "
f"url->apiKey, got {type(parsed).__name__}"
)
return []
except utils.LangSmithUserError:
raise
except Exception as e:
logger.warning(
"Invalid LANGSMITH_RUNS_ENDPOINTS must be valid JSON list of "
f"objects with api_url and api_key properties, or object mapping"
f" url->apiKey: {e}"
)
return []
def _get_write_replicas_from_env() -> list[WriteReplica]:
"""Get write replicas from LANGSMITH_RUNS_ENDPOINTS environment variable."""
env_var = utils.get_env_var("RUNS_ENDPOINTS")
return _parse_write_replicas_from_env_var(env_var)
def _check_endpoint_env_unset(parsed: dict[str, str]) -> None:
"""Check if endpoint environment variables conflict with runs endpoints."""
import os
if parsed and (os.getenv("LANGSMITH_ENDPOINT") or os.getenv("LANGCHAIN_ENDPOINT")):
raise utils.LangSmithUserError(
"You cannot provide both LANGSMITH_ENDPOINT / LANGCHAIN_ENDPOINT "
"and LANGSMITH_RUNS_ENDPOINTS."
)
def _ensure_write_replicas(
replicas: Optional[Sequence[WriteReplica]],
) -> list[WriteReplica]:
"""Convert replicas to WriteReplica format."""
if replicas is None:
return _get_write_replicas_from_env()
# All replicas should now be WriteReplica dicts
return list(replicas)
def _parse_dotted_order(dotted_order: str) -> list[tuple[datetime, UUID]]:
"""Parse the dotted order string."""
parts = dotted_order.split(".")
return [
(
datetime.strptime(part[:-TIMESTAMP_LENGTH], "%Y%m%dT%H%M%S%fZ"),
UUID(part[-TIMESTAMP_LENGTH:]),
)
for part in parts
]
_CLIENT: Optional[Client] = _context._GLOBAL_CLIENT
__all__ = ["RunTree", "RunTree"]
def _create_current_dotted_order(
start_time: Optional[datetime], run_id: Optional[UUID]
) -> str:
"""Create the current dotted order."""
st = start_time or datetime.now(timezone.utc)
id_ = run_id or uuid7_from_datetime(st)
return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_)