154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
|
|
import asyncio
|
||
|
|
from typing import Any, Callable, Generator, cast, Dict, Tuple
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
import chromadb
|
||
|
|
from chromadb.config import Settings, System
|
||
|
|
from chromadb.api import ClientAPI
|
||
|
|
import chromadb.server.fastapi
|
||
|
|
from chromadb.api.fastapi import FastAPI
|
||
|
|
import pytest
|
||
|
|
import tempfile
|
||
|
|
import os
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def ephemeral_api() -> Generator[ClientAPI, None, None]:
|
||
|
|
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"):
|
||
|
|
pytest.skip("Integration test only")
|
||
|
|
client = chromadb.EphemeralClient()
|
||
|
|
yield client
|
||
|
|
client.clear_system_cache()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def persistent_api() -> Generator[ClientAPI, None, None]:
|
||
|
|
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"):
|
||
|
|
pytest.skip("Integration test only")
|
||
|
|
client = chromadb.PersistentClient(
|
||
|
|
path=tempfile.gettempdir() + "/test_server",
|
||
|
|
)
|
||
|
|
yield client
|
||
|
|
client.clear_system_cache()
|
||
|
|
|
||
|
|
|
||
|
|
HttpAPIFactory = Callable[..., ClientAPI]
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(params=["sync_client", "async_client"])
|
||
|
|
def http_api_factory(
|
||
|
|
request: pytest.FixtureRequest,
|
||
|
|
) -> Generator[HttpAPIFactory, None, None]:
|
||
|
|
if request.param == "sync_client":
|
||
|
|
with patch("chromadb.api.client.Client._validate_tenant_database"):
|
||
|
|
with patch("chromadb.api.client.Client.get_user_identity"):
|
||
|
|
yield chromadb.HttpClient
|
||
|
|
else:
|
||
|
|
with patch("chromadb.api.async_client.AsyncClient._validate_tenant_database"):
|
||
|
|
with patch("chromadb.api.async_client.AsyncClient.get_user_identity"):
|
||
|
|
|
||
|
|
def factory(*args: Any, **kwargs: Any) -> Any:
|
||
|
|
cls = asyncio.get_event_loop().run_until_complete(
|
||
|
|
chromadb.AsyncHttpClient(*args, **kwargs)
|
||
|
|
)
|
||
|
|
return cls
|
||
|
|
|
||
|
|
yield cast(HttpAPIFactory, factory)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture()
|
||
|
|
def http_api(http_api_factory: HttpAPIFactory) -> Generator[ClientAPI, None, None]:
|
||
|
|
if os.environ.get("CHROMA_SERVER_HTTP_PORT") is not None:
|
||
|
|
port = int(os.environ.get("CHROMA_SERVER_HTTP_PORT")) # type: ignore
|
||
|
|
client = http_api_factory(port=port)
|
||
|
|
else:
|
||
|
|
client = http_api_factory()
|
||
|
|
yield client
|
||
|
|
client.clear_system_cache()
|
||
|
|
|
||
|
|
|
||
|
|
def test_ephemeral_client(ephemeral_api: ClientAPI) -> None:
|
||
|
|
settings = ephemeral_api.get_settings()
|
||
|
|
assert settings.is_persistent is False
|
||
|
|
|
||
|
|
|
||
|
|
def test_persistent_client(persistent_api: ClientAPI) -> None:
|
||
|
|
settings = persistent_api.get_settings()
|
||
|
|
assert settings.is_persistent is True
|
||
|
|
|
||
|
|
|
||
|
|
def test_http_client(http_api: ClientAPI) -> None:
|
||
|
|
settings = http_api.get_settings()
|
||
|
|
assert (
|
||
|
|
settings.chroma_api_impl == "chromadb.api.fastapi.FastAPI"
|
||
|
|
or settings.chroma_api_impl == "chromadb.api.async_fastapi.AsyncFastAPI"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def test_http_client_with_inconsistent_host_settings(
|
||
|
|
http_api_factory: HttpAPIFactory,
|
||
|
|
) -> None:
|
||
|
|
try:
|
||
|
|
http_api_factory(settings=Settings(chroma_server_host="127.0.0.1"))
|
||
|
|
except ValueError as e:
|
||
|
|
assert (
|
||
|
|
str(e)
|
||
|
|
== "Chroma server host provided in settings[127.0.0.1] is different to the one provided in HttpClient: [localhost]"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def test_http_client_with_inconsistent_port_settings(
|
||
|
|
http_api_factory: HttpAPIFactory,
|
||
|
|
) -> None:
|
||
|
|
try:
|
||
|
|
http_api_factory(
|
||
|
|
port=8002,
|
||
|
|
settings=Settings(
|
||
|
|
chroma_server_http_port=8001,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
except ValueError as e:
|
||
|
|
assert (
|
||
|
|
str(e)
|
||
|
|
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]:
|
||
|
|
captured: Dict[str, Any] = {}
|
||
|
|
|
||
|
|
# takes any positional args to match httpx.Client
|
||
|
|
def factory(*_: Any, **kwargs: Any) -> Any:
|
||
|
|
captured.update(kwargs)
|
||
|
|
session = MagicMock()
|
||
|
|
session.headers = {}
|
||
|
|
return session
|
||
|
|
|
||
|
|
return factory, captured
|
||
|
|
|
||
|
|
|
||
|
|
def test_fastapi_uses_http_limits_from_settings() -> None:
|
||
|
|
settings = Settings(
|
||
|
|
chroma_api_impl="chromadb.api.fastapi.FastAPI",
|
||
|
|
chroma_server_host="localhost",
|
||
|
|
chroma_server_http_port=9000,
|
||
|
|
chroma_server_ssl_verify=True,
|
||
|
|
chroma_http_keepalive_secs=12.5,
|
||
|
|
chroma_http_max_connections=64,
|
||
|
|
chroma_http_max_keepalive_connections=16,
|
||
|
|
)
|
||
|
|
system = System(settings)
|
||
|
|
|
||
|
|
factory, captured = make_sync_client_factory()
|
||
|
|
|
||
|
|
with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]):
|
||
|
|
with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory):
|
||
|
|
api = FastAPI(system)
|
||
|
|
|
||
|
|
api.stop()
|
||
|
|
limits = captured["limits"]
|
||
|
|
assert limits.keepalive_expiry == 12.5
|
||
|
|
assert limits.max_connections == 64
|
||
|
|
assert limits.max_keepalive_connections == 16
|
||
|
|
assert captured["timeout"] is None
|
||
|
|
assert captured["verify"] is True
|