group-wbl/.venv/lib/python3.13/site-packages/chromadb/api/base_http_client.py

146 lines
5.1 KiB
Python
Raw Normal View History

2026-01-09 09:48:03 +08:00
from typing import Any, Dict, Mapping, Optional, TypeVar
from urllib.parse import quote, urlparse, urlunparse
import logging
import orjson as json
import httpx
import chromadb.errors as errors
from chromadb.config import Component, Settings, System
logger = logging.getLogger(__name__)
# inherits from Component so that it can create an init function to use system
# this way it can build limits from the settings in System
class BaseHTTPClient(Component):
_settings: Settings
pre_flight_checks: Any = None
DEFAULT_KEEPALIVE_SECS: float = 40.0
def __init__(self, system: System):
super().__init__(system)
self._settings = system.settings
keepalive_setting = self._settings.chroma_http_keepalive_secs
self.keepalive_secs: Optional[float] = (
keepalive_setting
if keepalive_setting is not None
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
)
self._http_limits = self._build_limits()
def _build_limits(self) -> httpx.Limits:
limit_kwargs: Dict[str, Any] = {}
if self.keepalive_secs is not None:
limit_kwargs["keepalive_expiry"] = self.keepalive_secs
max_connections = self._settings.chroma_http_max_connections
if max_connections is not None:
limit_kwargs["max_connections"] = max_connections
max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
if max_keepalive_connections is not None:
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections
return httpx.Limits(**limit_kwargs)
@property
def http_limits(self) -> httpx.Limits:
return self._http_limits
@staticmethod
def _validate_host(host: str) -> None:
parsed = urlparse(host)
if "/" in host and parsed.scheme not in {"http", "https"}:
raise ValueError(
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
)
if "/" in host and (not host.startswith("http")):
raise ValueError(
"Invalid URL. "
"Seems that you are trying to pass URL as a host but without \
specifying the protocol. "
"Please add http:// or https:// to the host."
)
@staticmethod
def resolve_url(
chroma_server_host: str,
chroma_server_ssl_enabled: Optional[bool] = False,
default_api_path: Optional[str] = "",
chroma_server_http_port: Optional[int] = 8000,
) -> str:
_skip_port = False
_chroma_server_host = chroma_server_host
BaseHTTPClient._validate_host(_chroma_server_host)
if _chroma_server_host.startswith("http"):
logger.debug("Skipping port as the user is passing a full URL")
_skip_port = True
parsed = urlparse(_chroma_server_host)
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
port = (
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
)
path = parsed.path or default_api_path
if not path or path == net_loc:
path = default_api_path if default_api_path else ""
if not path.endswith(default_api_path or ""):
path = path + default_api_path if default_api_path else ""
full_url = urlunparse(
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
)
return full_url
# requests removes None values from the built query string, but httpx includes it as an empty value
T = TypeVar("T", bound=Dict[Any, Any])
@staticmethod
def _clean_params(params: T) -> T:
"""Remove None values from provided dict."""
return {k: v for k, v in params.items() if v is not None} # type: ignore
@staticmethod
def _raise_chroma_error(resp: httpx.Response) -> None:
"""Raises an error if the response is not ok, using a ChromaError if possible."""
try:
resp.raise_for_status()
return
except httpx.HTTPStatusError:
pass
chroma_error = None
try:
body = json.loads(resp.text)
if "error" in body:
if body["error"] in errors.error_types:
chroma_error = errors.error_types[body["error"]](body["message"])
trace_id = resp.headers.get("chroma-trace-id")
if trace_id:
chroma_error.trace_id = trace_id
except BaseException:
pass
if chroma_error:
raise chroma_error
try:
resp.raise_for_status()
except httpx.HTTPStatusError:
trace_id = resp.headers.get("chroma-trace-id")
if trace_id:
raise Exception(f"{resp.text} (trace ID: {trace_id})")
raise (Exception(resp.text))
def get_request_headers(self) -> Mapping[str, str]:
"""Return headers used for HTTP requests."""
return {}
def get_api_url(self) -> str:
"""Return the API URL for this client."""
return ""