451 lines
15 KiB
Python
451 lines
15 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import itertools
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
import uuid
|
||
|
|
from collections.abc import Iterable
|
||
|
|
from io import BufferedReader
|
||
|
|
from typing import Literal, Optional, Union, cast
|
||
|
|
|
||
|
|
from langsmith import schemas as ls_schemas
|
||
|
|
from langsmith._internal import _orjson
|
||
|
|
from langsmith._internal._compressed_traces import CompressedTraces
|
||
|
|
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
|
||
|
|
from langsmith._internal._serde import dumps_json as _dumps_json
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class SerializedRunOperation:
|
||
|
|
operation: Literal["post", "patch"]
|
||
|
|
id: uuid.UUID
|
||
|
|
trace_id: uuid.UUID
|
||
|
|
|
||
|
|
# this is the whole object, minus the other fields which
|
||
|
|
# are popped (inputs/outputs/events/attachments)
|
||
|
|
_none: bytes
|
||
|
|
|
||
|
|
inputs: Optional[bytes]
|
||
|
|
outputs: Optional[bytes]
|
||
|
|
events: Optional[bytes]
|
||
|
|
extra: Optional[bytes]
|
||
|
|
error: Optional[bytes]
|
||
|
|
serialized: Optional[bytes]
|
||
|
|
attachments: Optional[ls_schemas.Attachments]
|
||
|
|
|
||
|
|
__slots__ = (
|
||
|
|
"operation",
|
||
|
|
"id",
|
||
|
|
"trace_id",
|
||
|
|
"_none",
|
||
|
|
"inputs",
|
||
|
|
"outputs",
|
||
|
|
"events",
|
||
|
|
"extra",
|
||
|
|
"error",
|
||
|
|
"serialized",
|
||
|
|
"attachments",
|
||
|
|
)
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
operation: Literal["post", "patch"],
|
||
|
|
id: uuid.UUID,
|
||
|
|
trace_id: uuid.UUID,
|
||
|
|
_none: bytes,
|
||
|
|
inputs: Optional[bytes] = None,
|
||
|
|
outputs: Optional[bytes] = None,
|
||
|
|
events: Optional[bytes] = None,
|
||
|
|
extra: Optional[bytes] = None,
|
||
|
|
error: Optional[bytes] = None,
|
||
|
|
serialized: Optional[bytes] = None,
|
||
|
|
attachments: Optional[ls_schemas.Attachments] = None,
|
||
|
|
) -> None:
|
||
|
|
self.operation = operation
|
||
|
|
self.id = id
|
||
|
|
self.trace_id = trace_id
|
||
|
|
self._none = _none
|
||
|
|
self.inputs = inputs
|
||
|
|
self.outputs = outputs
|
||
|
|
self.events = events
|
||
|
|
self.extra = extra
|
||
|
|
self.error = error
|
||
|
|
self.serialized = serialized
|
||
|
|
self.attachments = attachments
|
||
|
|
|
||
|
|
def calculate_serialized_size(self) -> int:
|
||
|
|
"""Calculate actual serialized size of this operation."""
|
||
|
|
size = 0
|
||
|
|
if self._none:
|
||
|
|
size += len(self._none)
|
||
|
|
if self.inputs:
|
||
|
|
size += len(self.inputs)
|
||
|
|
if self.outputs:
|
||
|
|
size += len(self.outputs)
|
||
|
|
if self.events:
|
||
|
|
size += len(self.events)
|
||
|
|
if self.extra:
|
||
|
|
size += len(self.extra)
|
||
|
|
if self.error:
|
||
|
|
size += len(self.error)
|
||
|
|
if self.serialized:
|
||
|
|
size += len(self.serialized)
|
||
|
|
if self.attachments:
|
||
|
|
for content_type, data_or_path in self.attachments.values():
|
||
|
|
if isinstance(data_or_path, bytes):
|
||
|
|
size += len(data_or_path)
|
||
|
|
return size
|
||
|
|
|
||
|
|
def deserialize_run_info(self) -> dict:
|
||
|
|
"""Deserialize the main run info (_none and extra, error and serialized)."""
|
||
|
|
run_info = _orjson.loads(self._none)
|
||
|
|
if self.extra is not None:
|
||
|
|
run_info["extra"] = _orjson.loads(self.extra)
|
||
|
|
|
||
|
|
if self.error is not None:
|
||
|
|
run_info["error"] = _orjson.loads(self.error)
|
||
|
|
|
||
|
|
if self.serialized is not None:
|
||
|
|
run_info["serialized"] = _orjson.loads(self.serialized)
|
||
|
|
|
||
|
|
return run_info
|
||
|
|
|
||
|
|
def __eq__(self, other: object) -> bool:
|
||
|
|
return isinstance(other, SerializedRunOperation) and (
|
||
|
|
self.operation,
|
||
|
|
self.id,
|
||
|
|
self.trace_id,
|
||
|
|
self._none,
|
||
|
|
self.inputs,
|
||
|
|
self.outputs,
|
||
|
|
self.events,
|
||
|
|
self.extra,
|
||
|
|
self.error,
|
||
|
|
self.serialized,
|
||
|
|
self.attachments,
|
||
|
|
) == (
|
||
|
|
other.operation,
|
||
|
|
other.id,
|
||
|
|
other.trace_id,
|
||
|
|
other._none,
|
||
|
|
other.inputs,
|
||
|
|
other.outputs,
|
||
|
|
other.events,
|
||
|
|
other.extra,
|
||
|
|
other.error,
|
||
|
|
other.serialized,
|
||
|
|
other.attachments,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class SerializedFeedbackOperation:
|
||
|
|
id: uuid.UUID
|
||
|
|
trace_id: uuid.UUID
|
||
|
|
feedback: bytes
|
||
|
|
|
||
|
|
__slots__ = ("id", "trace_id", "feedback")
|
||
|
|
|
||
|
|
def __init__(self, id: uuid.UUID, trace_id: uuid.UUID, feedback: bytes) -> None:
|
||
|
|
self.id = id
|
||
|
|
self.trace_id = trace_id
|
||
|
|
self.feedback = feedback
|
||
|
|
|
||
|
|
def calculate_serialized_size(self) -> int:
|
||
|
|
"""Calculate actual serialized size of this operation."""
|
||
|
|
return len(self.feedback)
|
||
|
|
|
||
|
|
def __eq__(self, other: object) -> bool:
|
||
|
|
return isinstance(other, SerializedFeedbackOperation) and (
|
||
|
|
self.id,
|
||
|
|
self.trace_id,
|
||
|
|
self.feedback,
|
||
|
|
) == (other.id, other.trace_id, other.feedback)
|
||
|
|
|
||
|
|
|
||
|
|
def serialize_feedback_dict(
|
||
|
|
feedback: Union[ls_schemas.FeedbackCreate, dict],
|
||
|
|
) -> SerializedFeedbackOperation:
|
||
|
|
if hasattr(feedback, "model_dump") and callable(getattr(feedback, "model_dump")):
|
||
|
|
feedback_create: dict = feedback.model_dump() # type: ignore
|
||
|
|
else:
|
||
|
|
feedback_create = cast(dict, feedback)
|
||
|
|
if "id" not in feedback_create:
|
||
|
|
feedback_create["id"] = uuid.uuid4()
|
||
|
|
elif isinstance(feedback_create["id"], str):
|
||
|
|
feedback_create["id"] = uuid.UUID(feedback_create["id"])
|
||
|
|
if "trace_id" not in feedback_create:
|
||
|
|
feedback_create["trace_id"] = uuid.uuid4()
|
||
|
|
elif isinstance(feedback_create["trace_id"], str):
|
||
|
|
feedback_create["trace_id"] = uuid.UUID(feedback_create["trace_id"])
|
||
|
|
|
||
|
|
return SerializedFeedbackOperation(
|
||
|
|
id=feedback_create["id"],
|
||
|
|
trace_id=feedback_create["trace_id"],
|
||
|
|
feedback=_dumps_json(feedback_create),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def serialize_run_dict(
|
||
|
|
operation: Literal["post", "patch"], payload: dict
|
||
|
|
) -> SerializedRunOperation:
|
||
|
|
inputs = payload.pop("inputs", None)
|
||
|
|
outputs = payload.pop("outputs", None)
|
||
|
|
events = payload.pop("events", None)
|
||
|
|
extra = payload.pop("extra", None)
|
||
|
|
error = payload.pop("error", None)
|
||
|
|
serialized = payload.pop("serialized", None)
|
||
|
|
attachments = payload.pop("attachments", None)
|
||
|
|
return SerializedRunOperation(
|
||
|
|
operation=operation,
|
||
|
|
id=payload["id"],
|
||
|
|
trace_id=payload["trace_id"],
|
||
|
|
_none=_dumps_json(payload),
|
||
|
|
inputs=_dumps_json(inputs) if inputs is not None else None,
|
||
|
|
outputs=_dumps_json(outputs) if outputs is not None else None,
|
||
|
|
events=_dumps_json(events) if events is not None else None,
|
||
|
|
extra=_dumps_json(extra) if extra is not None else None,
|
||
|
|
error=_dumps_json(error) if error is not None else None,
|
||
|
|
serialized=_dumps_json(serialized) if serialized is not None else None,
|
||
|
|
attachments=attachments if attachments is not None else None,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def combine_serialized_queue_operations(
|
||
|
|
ops: list[Union[SerializedRunOperation, SerializedFeedbackOperation]],
|
||
|
|
) -> list[Union[SerializedRunOperation, SerializedFeedbackOperation]]:
|
||
|
|
create_ops_by_id = {
|
||
|
|
op.id: op
|
||
|
|
for op in ops
|
||
|
|
if isinstance(op, SerializedRunOperation) and op.operation == "post"
|
||
|
|
}
|
||
|
|
passthrough_ops: list[
|
||
|
|
Union[SerializedRunOperation, SerializedFeedbackOperation]
|
||
|
|
] = []
|
||
|
|
for op in ops:
|
||
|
|
if isinstance(op, SerializedRunOperation):
|
||
|
|
if op.operation == "post":
|
||
|
|
continue
|
||
|
|
|
||
|
|
# must be patch
|
||
|
|
|
||
|
|
create_op = create_ops_by_id.get(op.id)
|
||
|
|
if create_op is None:
|
||
|
|
passthrough_ops.append(op)
|
||
|
|
continue
|
||
|
|
|
||
|
|
if op._none is not None and op._none != create_op._none:
|
||
|
|
# TODO optimize this more - this would currently be slowest
|
||
|
|
# for large payloads
|
||
|
|
create_op_dict = _orjson.loads(create_op._none)
|
||
|
|
op_dict = {
|
||
|
|
k: v for k, v in _orjson.loads(op._none).items() if v is not None
|
||
|
|
}
|
||
|
|
create_op_dict.update(op_dict)
|
||
|
|
create_op._none = _orjson.dumps(create_op_dict)
|
||
|
|
|
||
|
|
if op.inputs is not None:
|
||
|
|
create_op.inputs = op.inputs
|
||
|
|
if op.outputs is not None:
|
||
|
|
create_op.outputs = op.outputs
|
||
|
|
if op.events is not None:
|
||
|
|
create_op.events = op.events
|
||
|
|
if op.extra is not None:
|
||
|
|
create_op.extra = op.extra
|
||
|
|
if op.error is not None:
|
||
|
|
create_op.error = op.error
|
||
|
|
if op.serialized is not None:
|
||
|
|
create_op.serialized = op.serialized
|
||
|
|
if op.attachments is not None:
|
||
|
|
if create_op.attachments is None:
|
||
|
|
create_op.attachments = {}
|
||
|
|
create_op.attachments.update(op.attachments)
|
||
|
|
else:
|
||
|
|
passthrough_ops.append(op)
|
||
|
|
return list(itertools.chain(create_ops_by_id.values(), passthrough_ops))
|
||
|
|
|
||
|
|
|
||
|
|
def serialized_feedback_operation_to_multipart_parts_and_context(
|
||
|
|
op: SerializedFeedbackOperation,
|
||
|
|
) -> MultipartPartsAndContext:
|
||
|
|
return MultipartPartsAndContext(
|
||
|
|
[
|
||
|
|
(
|
||
|
|
f"feedback.{op.id}",
|
||
|
|
(
|
||
|
|
None,
|
||
|
|
op.feedback,
|
||
|
|
"application/json",
|
||
|
|
{"Content-Length": str(len(op.feedback))},
|
||
|
|
),
|
||
|
|
)
|
||
|
|
],
|
||
|
|
f"trace={op.trace_id},id={op.id}",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def serialized_run_operation_to_multipart_parts_and_context(
|
||
|
|
op: SerializedRunOperation,
|
||
|
|
) -> tuple[MultipartPartsAndContext, dict[str, BufferedReader]]:
|
||
|
|
acc_parts: list[MultipartPart] = []
|
||
|
|
opened_files_dict: dict[str, BufferedReader] = {}
|
||
|
|
# this is main object, minus inputs/outputs/events/attachments
|
||
|
|
acc_parts.append(
|
||
|
|
(
|
||
|
|
f"{op.operation}.{op.id}",
|
||
|
|
(
|
||
|
|
None,
|
||
|
|
op._none,
|
||
|
|
"application/json",
|
||
|
|
{"Content-Length": str(len(op._none))},
|
||
|
|
),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
for key, value in (
|
||
|
|
("inputs", op.inputs),
|
||
|
|
("outputs", op.outputs),
|
||
|
|
("events", op.events),
|
||
|
|
("extra", op.extra),
|
||
|
|
("error", op.error),
|
||
|
|
("serialized", op.serialized),
|
||
|
|
):
|
||
|
|
if value is None:
|
||
|
|
continue
|
||
|
|
valb = value
|
||
|
|
acc_parts.append(
|
||
|
|
(
|
||
|
|
f"{op.operation}.{op.id}.{key}",
|
||
|
|
(
|
||
|
|
None,
|
||
|
|
valb,
|
||
|
|
"application/json",
|
||
|
|
{"Content-Length": str(len(valb))},
|
||
|
|
),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
if op.attachments:
|
||
|
|
for n, (content_type, data_or_path) in op.attachments.items():
|
||
|
|
if "." in n:
|
||
|
|
logger.warning(
|
||
|
|
f"Skipping logging of attachment '{n}' "
|
||
|
|
f"for run {op.id}:"
|
||
|
|
" Invalid attachment name. Attachment names must not contain"
|
||
|
|
" periods ('.'). Please rename the attachment and try again."
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
|
||
|
|
if isinstance(data_or_path, bytes):
|
||
|
|
acc_parts.append(
|
||
|
|
(
|
||
|
|
f"attachment.{op.id}.{n}",
|
||
|
|
(
|
||
|
|
None,
|
||
|
|
data_or_path,
|
||
|
|
content_type,
|
||
|
|
{"Content-Length": str(len(data_or_path))},
|
||
|
|
),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
file_size = os.path.getsize(data_or_path)
|
||
|
|
file = open(data_or_path, "rb")
|
||
|
|
except FileNotFoundError:
|
||
|
|
logger.warning(
|
||
|
|
"Attachment file not found for run %s: %s", op.id, data_or_path
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
opened_files_dict[str(data_or_path) + str(uuid.uuid4())] = file
|
||
|
|
acc_parts.append(
|
||
|
|
(
|
||
|
|
f"attachment.{op.id}.{n}",
|
||
|
|
(
|
||
|
|
None,
|
||
|
|
file,
|
||
|
|
f"{content_type}; length={file_size}",
|
||
|
|
{},
|
||
|
|
),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return (
|
||
|
|
MultipartPartsAndContext(acc_parts, f"trace={op.trace_id},id={op.id}"),
|
||
|
|
opened_files_dict,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def encode_multipart_parts_and_context(
|
||
|
|
parts_and_context: MultipartPartsAndContext,
|
||
|
|
boundary: str,
|
||
|
|
) -> Iterable[tuple[bytes, Union[bytes, BufferedReader]]]:
|
||
|
|
for part_name, (filename, data, content_type, headers) in parts_and_context.parts:
|
||
|
|
header_parts = [
|
||
|
|
f"--{boundary}\r\n",
|
||
|
|
f'Content-Disposition: form-data; name="{part_name}"',
|
||
|
|
]
|
||
|
|
|
||
|
|
if filename:
|
||
|
|
header_parts.append(f'; filename="{filename}"')
|
||
|
|
|
||
|
|
header_parts.extend(
|
||
|
|
[
|
||
|
|
f"\r\nContent-Type: {content_type}\r\n",
|
||
|
|
*[f"{k}: {v}\r\n" for k, v in headers.items()],
|
||
|
|
"\r\n",
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
yield ("".join(header_parts).encode(), data)
|
||
|
|
|
||
|
|
|
||
|
|
def compress_multipart_parts_and_context(
|
||
|
|
parts_and_context: MultipartPartsAndContext,
|
||
|
|
compressed_traces: CompressedTraces,
|
||
|
|
boundary: str,
|
||
|
|
) -> bool:
|
||
|
|
"""Compress multipart parts into the shared compressed buffer.
|
||
|
|
|
||
|
|
Returns True if the parts were enqueued into the compressed buffer, or False
|
||
|
|
if they were rejected because the configured in-memory size limit would be
|
||
|
|
exceeded.
|
||
|
|
"""
|
||
|
|
write = compressed_traces.compressor_writer.write
|
||
|
|
|
||
|
|
parts: list[tuple[bytes, bytes]] = []
|
||
|
|
op_uncompressed_size = 0
|
||
|
|
|
||
|
|
for headers, data in encode_multipart_parts_and_context(
|
||
|
|
parts_and_context, boundary
|
||
|
|
):
|
||
|
|
# Normalise to bytes
|
||
|
|
if not isinstance(data, (bytes, bytearray)):
|
||
|
|
data = (
|
||
|
|
data.read() if isinstance(data, BufferedReader) else str(data).encode()
|
||
|
|
)
|
||
|
|
|
||
|
|
parts.append((headers, data))
|
||
|
|
op_uncompressed_size += len(data)
|
||
|
|
|
||
|
|
max_bytes = getattr(compressed_traces, "max_uncompressed_size_bytes", None)
|
||
|
|
if max_bytes is not None and max_bytes > 0:
|
||
|
|
current_size = compressed_traces.uncompressed_size
|
||
|
|
if current_size > 0 and current_size + op_uncompressed_size > max_bytes:
|
||
|
|
logger.warning(
|
||
|
|
"Compressed traces queue size limit (%s bytes) exceeded. "
|
||
|
|
"Dropping trace data with context: %s. "
|
||
|
|
"Current queue size: %s bytes, attempted addition: %s bytes.",
|
||
|
|
max_bytes,
|
||
|
|
parts_and_context.context,
|
||
|
|
current_size,
|
||
|
|
op_uncompressed_size,
|
||
|
|
)
|
||
|
|
return False
|
||
|
|
|
||
|
|
for headers, data in parts:
|
||
|
|
write(headers)
|
||
|
|
compressed_traces.uncompressed_size += len(data)
|
||
|
|
write(data)
|
||
|
|
write(b"\r\n") # part terminator
|
||
|
|
|
||
|
|
compressed_traces._context.append(parts_and_context.context)
|
||
|
|
return True
|