group-wbl/.venv/lib/python3.13/site-packages/chromadb/execution/executor/distributed.py
2026-01-09 09:12:25 +08:00

243 lines
8.8 KiB
Python

import threading
import random
from typing import Callable, Dict, List, Optional, TypeVar
import grpc
from overrides import overrides
from chromadb.api.types import GetResult, Metadata, QueryResult
from chromadb.config import System
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.proto import convert
from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from tenacity import (
RetryCallState,
Retrying,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception,
)
from opentelemetry.trace import Span
def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map."""
if not metadata:
return None
result = {}
for k, v in metadata.items():
if not k.startswith("chroma:"):
result[k] = v
if len(result) == 0:
return None
return result
def _uri(metadata: Optional[Metadata]) -> Optional[str]:
"""Retrieve the uri (if any) from a Metadata map"""
if metadata and "chroma:uri" in metadata:
return str(metadata["chroma:uri"])
return None
# Type variables for input and output types of the round-robin retry function
I = TypeVar("I") # noqa: E741
O = TypeVar("O") # noqa: E741
class DistributedExecutor(Executor):
_mtx: threading.Lock
_grpc_stub_pool: Dict[str, QueryExecutorStub]
_manager: DistributedSegmentManager
_request_timeout_seconds: int
_query_replication_factor: int
def __init__(self, system: System):
super().__init__(system)
self._mtx = threading.Lock()
self._grpc_stub_pool = {}
self._manager = self.require(DistributedSegmentManager)
self._request_timeout_seconds = system.settings.require(
"chroma_query_request_timeout_seconds"
)
self._query_replication_factor = system.settings.require(
"chroma_query_replication_factor"
)
def _round_robin_retry(self, funcs: List[Callable[[I], O]], args: I) -> O:
"""
Retry a list of functions in a round-robin fashion until one of them succeeds.
funcs: List of functions to retry
args: Arguments to pass to each function
"""
attempt_count = 0
sleep_span: Optional[Span] = None
def before_sleep(_: RetryCallState) -> None:
# HACK(hammadb) 1/14/2024 - this is a hack to avoid the fact that tracer is not yet available and there are boot order issues
# This should really use our component system to get the tracer. Since our grpc utils use this pattern
# we are copying it here. This should be removed once we have a better way to get the tracer
from chromadb.telemetry.opentelemetry import tracer
nonlocal sleep_span
if tracer is not None:
sleep_span = tracer.start_span("Waiting to retry RPC")
for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(0.1, jitter=0.1),
reraise=True,
retry=retry_if_exception(
lambda x: isinstance(x, grpc.RpcError)
and x.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
),
before_sleep=before_sleep,
):
if sleep_span is not None:
sleep_span.end()
sleep_span = None
with attempt:
return funcs[attempt_count % len(funcs)](args)
attempt_count += 1
# NOTE(hammadb) because Retrying() will always either return or raise an exception, this line should never be reached
raise Exception("Unreachable code error - should never reach here")
@overrides
def count(self, plan: CountPlan) -> int:
endpoints = self._get_grpc_endpoints(plan.scan)
count_funcs = [self._get_stub(endpoint).Count for endpoint in endpoints]
count_result = self._round_robin_retry(
count_funcs, convert.to_proto_count_plan(plan)
)
return convert.from_proto_count_result(count_result)
@overrides
def get(self, plan: GetPlan) -> GetResult:
endpoints = self._get_grpc_endpoints(plan.scan)
get_funcs = [self._get_stub(endpoint).Get for endpoint in endpoints]
get_result = self._round_robin_retry(get_funcs, convert.to_proto_get_plan(plan))
records = convert.from_proto_get_result(get_result)
ids = [record["id"] for record in records]
embeddings = (
[record["embedding"] for record in records]
if plan.projection.embedding
else None
)
documents = (
[record["document"] for record in records]
if plan.projection.document
else None
)
uris = (
[_uri(record["metadata"]) for record in records]
if plan.projection.uri
else None
)
metadatas = (
[_clean_metadata(record["metadata"]) for record in records]
if plan.projection.metadata
else None
)
# TODO: Fix typing
return GetResult(
ids=ids,
embeddings=embeddings, # type: ignore[typeddict-item]
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
included=plan.projection.included,
)
@overrides
def knn(self, plan: KNNPlan) -> QueryResult:
endpoints = self._get_grpc_endpoints(plan.scan)
knn_funcs = [self._get_stub(endpoint).KNN for endpoint in endpoints]
knn_result = self._round_robin_retry(knn_funcs, convert.to_proto_knn_plan(plan))
results = convert.from_proto_knn_batch_result(knn_result)
ids = [[record["record"]["id"] for record in records] for records in results]
embeddings = (
[
[record["record"]["embedding"] for record in records]
for records in results
]
if plan.projection.embedding
else None
)
documents = (
[
[record["record"]["document"] for record in records]
for records in results
]
if plan.projection.document
else None
)
uris = (
[
[_uri(record["record"]["metadata"]) for record in records]
for records in results
]
if plan.projection.uri
else None
)
metadatas = (
[
[_clean_metadata(record["record"]["metadata"]) for record in records]
for records in results
]
if plan.projection.metadata
else None
)
distances = (
[[record["distance"] for record in records] for records in results]
if plan.projection.rank
else None
)
# TODO: Fix typing
return QueryResult(
ids=ids,
embeddings=embeddings, # type: ignore[typeddict-item]
documents=documents, # type: ignore[typeddict-item]
uris=uris, # type: ignore[typeddict-item]
data=None,
metadatas=metadatas, # type: ignore[typeddict-item]
distances=distances, # type: ignore[typeddict-item]
included=plan.projection.included,
)
def _get_grpc_endpoints(self, scan: Scan) -> List[str]:
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
grpc_urls = self._manager.get_endpoints(
scan.record, self._query_replication_factor
)
# Shuffle the grpc urls to distribute the load evenly
random.shuffle(grpc_urls)
return grpc_urls
def _get_stub(self, grpc_url: str) -> QueryExecutorStub:
with self._mtx:
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(
grpc_url,
options=[
("grpc.max_concurrent_streams", 1000),
("grpc.max_receive_message_length", 32000000), # 32 MB
],
)
interceptors = [OtelInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)
return self._grpc_stub_pool[grpc_url]