group-wbl/.venv/lib/python3.13/site-packages/chromadb/test/test_api.py
2026-01-09 09:48:03 +08:00

3435 lines
111 KiB
Python

# type: ignore
import os
import shutil
import sys
import tempfile
import traceback
from datetime import datetime, timedelta
from typing import Any
import httpx
import numpy as np
import pytest
import chromadb
import chromadb.server.fastapi
from chromadb.api.fastapi import FastAPI
from chromadb.api.types import (
Document,
EmbeddingFunction,
QueryResult,
TYPE_KEY,
SPARSE_VECTOR_TYPE_VALUE,
)
from chromadb.config import Settings
from chromadb.errors import (
ChromaError,
NotFoundError,
InvalidArgumentError,
)
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
@pytest.fixture
def persist_dir():
return tempfile.mkdtemp()
@pytest.fixture
def local_persist_api(persist_dir):
client = chromadb.Client(
Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
allow_reset=True,
is_persistent=True,
persist_directory=persist_dir,
),
)
yield client
client.clear_system_cache()
if os.path.exists(persist_dir):
shutil.rmtree(persist_dir, ignore_errors=True)
# https://docs.pytest.org/en/6.2.x/fixture.html#fixtures-can-be-requested-more-than-once-per-test-return-values-are-cached
@pytest.fixture
def local_persist_api_cache_bust(persist_dir):
client = chromadb.Client(
Settings(
chroma_api_impl="chromadb.api.segment.SegmentAPI",
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
allow_reset=True,
is_persistent=True,
persist_directory=persist_dir,
),
)
yield client
client.clear_system_cache()
if os.path.exists(persist_dir):
shutil.rmtree(persist_dir, ignore_errors=True)
def approx_equal(a, b, tolerance=1e-6) -> bool:
return abs(a - b) < tolerance
def vector_approx_equal(a, b, tolerance: float = 1e-6) -> bool:
if len(a) != len(b):
return False
return all([approx_equal(a, b, tolerance) for a, b in zip(a, b)])
@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist_index_loading(api_fixture, request):
client = request.getfixturevalue("local_persist_api")
client.reset()
collection = client.create_collection("test")
collection.add(ids="id1", documents="hello")
api2 = request.getfixturevalue("local_persist_api_cache_bust")
collection = api2.get_collection("test")
includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=["embeddings", "documents", "metadatas", "distances"],
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist_index_loading_embedding_function(api_fixture, request):
class TestEF(EmbeddingFunction[Document]):
def __call__(self, input):
return [np.array([1, 2, 3]) for _ in range(len(input))]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def name(self) -> str:
return "test"
def build_from_config(self, config: dict[str, Any]) -> None:
pass
def get_config(self) -> dict[str, Any]:
return {}
client = request.getfixturevalue("local_persist_api")
client.reset()
collection = client.create_collection("test", embedding_function=TestEF())
collection.add(ids="id1", documents="hello")
client2 = request.getfixturevalue("local_persist_api_cache_bust")
collection = client2.get_collection("test", embedding_function=TestEF())
includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=includes,
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist_index_get_or_create_embedding_function(api_fixture, request):
class TestEF(EmbeddingFunction[Document]):
def __call__(self, input):
return [np.array([1, 2, 3]) for _ in range(len(input))]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def name(self) -> str:
return "test"
def build_from_config(self, config: dict[str, Any]) -> None:
pass
def get_config(self) -> dict[str, Any]:
return {}
api = request.getfixturevalue("local_persist_api")
api.reset()
collection = api.get_or_create_collection("test", embedding_function=TestEF())
collection.add(ids="id1", documents="hello")
api2 = request.getfixturevalue("local_persist_api_cache_bust")
collection = api2.get_or_create_collection("test", embedding_function=TestEF())
includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=includes,
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
assert nn["ids"] == [["id1"]]
assert nn["embeddings"][0][0].tolist() == [1, 2, 3]
assert nn["documents"] == [["hello"]]
assert nn["distances"] == [[0]]
@pytest.mark.parametrize("api_fixture", [local_persist_api])
def test_persist(api_fixture, request):
client = request.getfixturevalue(api_fixture.__name__)
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
client = request.getfixturevalue(api_fixture.__name__)
collection = client.get_collection("testspace")
assert collection.count() == 2
client.delete_collection("testspace")
client = request.getfixturevalue(api_fixture.__name__)
assert client.list_collections() == []
def test_heartbeat(client):
heartbeat_ns = client.heartbeat()
assert isinstance(heartbeat_ns, int)
heartbeat_s = heartbeat_ns // 10**9
heartbeat = datetime.fromtimestamp(heartbeat_s)
assert heartbeat > datetime.now() - timedelta(seconds=10)
def test_max_batch_size(client):
batch_size = client.get_max_batch_size()
assert batch_size > 0
def test_supports_base64_encoding(client):
if not isinstance(client, FastAPI):
pytest.skip("Not a FastAPI instance")
client.reset()
supports_base64_encoding = client.supports_base64_encoding()
assert supports_base64_encoding is True
def test_supports_base64_encoding_legacy(client):
if not isinstance(client, FastAPI):
pytest.skip("Not a FastAPI instance")
client.reset()
# legacy server does not give back supports_base64_encoding
client.pre_flight_checks = {
"max_batch_size": 100,
}
assert client.supports_base64_encoding() is False
assert client.get_max_batch_size() == 100
def test_pre_flight_checks(client):
if not isinstance(client, FastAPI):
pytest.skip("Not a FastAPI instance")
resp = httpx.get(f"{client._api_url}/pre-flight-checks")
assert resp.status_code == 200
assert resp.json() is not None
assert "max_batch_size" in resp.json().keys()
assert "supports_base64_encoding" in resp.json().keys()
batch_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com/1", "https://example.com/2"],
}
def test_add(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
def test_collection_add_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.add(**batch_records)
def test_get_or_create(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
with pytest.raises(Exception):
collection = client.create_collection("testspace")
collection = client.get_or_create_collection("testspace")
assert collection.count() == 2
minimal_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com/1", "https://example.com/2"],
}
def test_add_minimal(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**minimal_records)
assert collection.count() == 2
def test_get_from_db(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
includes = ["embeddings", "documents", "metadatas"]
records = collection.get(include=includes)
for key in records.keys():
if (key in includes) or (key == "ids"):
assert len(records[key]) == 2
elif key == "included":
assert set(records[key]) == set(includes)
else:
assert records[key] is None
def test_collection_get_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.get()
def test_reset_db(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
client.reset()
assert len(client.list_collections()) == 0
def test_get_nearest_neighbors(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_embeddings=[1.1, 2.3, 3.2],
n_results=1,
include=includes,
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results=1,
include=includes,
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2], [0.1, 2.3, 4.5]],
n_results=1,
include=includes,
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 2
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
def test_delete(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
with pytest.raises(Exception):
collection.delete()
def test_delete_returns_none(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
assert collection.delete(ids=batch_records["ids"]) is None
def test_delete_with_index(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
collection.query(query_embeddings=[[1.1, 2.3, 3.2]], n_results=1)
def test_collection_delete_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.delete(ids=["id1"])
def test_count(client):
client.reset()
collection = client.create_collection("testspace")
assert collection.count() == 0
collection.add(**batch_records)
assert collection.count() == 2
def test_collection_count_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.count()
def test_modify(client):
client.reset()
collection = client.create_collection("testspace")
collection.modify(name="testspace2")
# collection name is modify
assert collection.name == "testspace2"
def test_collection_modify_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.modify(name="test2")
def test_modify_error_on_existing_name(client):
client.reset()
client.create_collection("testspace")
c2 = client.create_collection("testspace2")
with pytest.raises(Exception):
c2.modify(name="testspace")
def test_modify_warn_on_DF_change(client, caplog):
client.reset()
collection = client.create_collection("testspace")
with pytest.raises(Exception, match="not supported"):
collection.modify(metadata={"hnsw:space": "cosine"})
def test_metadata_cru(client):
client.reset()
metadata_a = {"a": 1, "b": 2}
# Test create metadata
collection = client.create_collection("testspace", metadata=metadata_a)
assert collection.metadata is not None
assert collection.metadata["a"] == 1
assert collection.metadata["b"] == 2
# Test get metadata
collection = client.get_collection("testspace")
assert collection.metadata is not None
assert collection.metadata["a"] == 1
assert collection.metadata["b"] == 2
# Test modify metadata
collection.modify(metadata={"a": 2, "c": 3})
assert collection.metadata["a"] == 2
assert collection.metadata["c"] == 3
assert "b" not in collection.metadata
# Test get after modify metadata
collection = client.get_collection("testspace")
assert collection.metadata is not None
assert collection.metadata["a"] == 2
assert collection.metadata["c"] == 3
assert "b" not in collection.metadata
# Test name exists get_or_create_metadata
collection = client.get_or_create_collection("testspace")
assert collection.metadata is not None
assert collection.metadata["a"] == 2
assert collection.metadata["c"] == 3
# Test name exists create metadata
collection = client.get_or_create_collection("testspace2")
assert collection.metadata is None
# Test list collections
collections = client.list_collections()
for collection in collections:
if collection.name == "testspace":
assert collection.metadata is not None
assert collection.metadata["a"] == 2
assert collection.metadata["c"] == 3
elif collection.name == "testspace2":
assert collection.metadata is None
def test_increment_index_on(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
includes = ["embeddings", "documents", "metadatas", "distances"]
# increment index
nn = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results=1,
include=includes,
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
def test_add_a_collection(client):
client.reset()
client.create_collection("testspace")
# get collection does not throw an error
collection = client.get_collection("testspace")
assert collection.name == "testspace"
# get collection should throw an error if collection does not exist
with pytest.raises(Exception):
collection = client.get_collection("testspace2")
def test_error_includes_trace_id(http_client):
http_client.reset()
with pytest.raises(ChromaError) as error:
http_client.get_collection("testspace2")
assert error.value.trace_id is not None
def test_list_collections(client):
client.reset()
client.create_collection("testspace")
client.create_collection("testspace2")
# get collection does not throw an error
collections = client.list_collections()
assert len(collections) == 2
def test_reset(client):
client.reset()
client.create_collection("testspace")
client.create_collection("testspace2")
# get collection does not throw an error
collections = client.list_collections()
assert len(collections) == 2
client.reset()
collections = client.list_collections()
assert len(collections) == 0
def test_peek(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**batch_records)
assert collection.count() == 2
# peek
peek = collection.peek()
print(peek)
for key in peek.keys():
if key in ["embeddings", "documents", "metadatas"] or key == "ids":
assert len(peek[key]) == 2
elif key == "included":
assert set(peek[key]) == set(["embeddings", "metadatas", "documents"])
else:
assert peek[key] is None
def test_collection_peek_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.peek()
def test_collection_query_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.query(query_texts=["test"])
def test_collection_update_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.update(ids=["id1"], documents=["test"])
# TEST METADATA AND METADATA FILTERING
# region
metadata_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["id1", "id2"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001},
{"int_value": 2},
],
}
def test_metadata_add_get_int_float(client):
client.reset()
collection = client.create_collection("test_int")
collection.add(**metadata_records)
items = collection.get(ids=["id1", "id2"])
assert items["metadatas"][0]["int_value"] == 1
assert items["metadatas"][0]["float_value"] == 1.001
assert items["metadatas"][1]["int_value"] == 2
assert isinstance(items["metadatas"][0]["int_value"], int)
assert isinstance(items["metadatas"][0]["float_value"], float)
def test_metadata_add_query_int_float(client):
client.reset()
collection = client.create_collection("test_int")
collection.add(**metadata_records)
items: QueryResult = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]], n_results=1
)
assert items["metadatas"] is not None
assert items["metadatas"][0][0]["int_value"] == 1
assert items["metadatas"][0][0]["float_value"] == 1.001
assert isinstance(items["metadatas"][0][0]["int_value"], int)
assert isinstance(items["metadatas"][0][0]["float_value"], float)
def test_metadata_get_where_string(client):
client.reset()
collection = client.create_collection("test_int")
collection.add(**metadata_records)
items = collection.get(where={"string_value": "one"})
assert items["metadatas"][0]["int_value"] == 1
assert items["metadatas"][0]["string_value"] == "one"
def test_metadata_get_where_int(client):
client.reset()
collection = client.create_collection("test_int")
collection.add(**metadata_records)
items = collection.get(where={"int_value": 1})
assert items["metadatas"][0]["int_value"] == 1
assert items["metadatas"][0]["string_value"] == "one"
def test_metadata_get_where_float(client):
client.reset()
collection = client.create_collection("test_int")
collection.add(**metadata_records)
items = collection.get(where={"float_value": 1.001})
assert items["metadatas"][0]["int_value"] == 1
assert items["metadatas"][0]["string_value"] == "one"
assert items["metadatas"][0]["float_value"] == 1.001
def test_metadata_update_get_int_float(client):
client.reset()
collection = client.create_collection("test_int")
collection.add(**metadata_records)
collection.update(
ids=["id1"],
metadatas=[{"int_value": 2, "string_value": "two", "float_value": 2.002}],
)
items = collection.get(ids=["id1"])
assert items["metadatas"][0]["int_value"] == 2
assert items["metadatas"][0]["string_value"] == "two"
assert items["metadatas"][0]["float_value"] == 2.002
bad_metadata_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["id1", "id2"],
"metadatas": [{"value": {"nested": "5"}}, {"value": [1, 2, 3]}],
}
def test_metadata_validation_add(client):
client.reset()
collection = client.create_collection("test_metadata_validation")
with pytest.raises(ValueError, match="metadata"):
collection.add(**bad_metadata_records)
def test_metadata_validation_update(client):
client.reset()
collection = client.create_collection("test_metadata_validation")
collection.add(**metadata_records)
with pytest.raises(ValueError, match="metadata"):
collection.update(ids=["id1"], metadatas={"value": {"nested": "5"}})
def test_where_validation_get(client):
client.reset()
collection = client.create_collection("test_where_validation")
with pytest.raises(ValueError, match="where"):
collection.get(where={"value": {"nested": "5"}})
def test_where_validation_query(client):
client.reset()
collection = client.create_collection("test_where_validation")
with pytest.raises(ValueError, match="where"):
collection.query(query_embeddings=[0, 0, 0], where={"value": {"nested": "5"}})
operator_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["id1", "id2"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001},
{"int_value": 2, "float_value": 2.002, "string_value": "two"},
],
}
def test_where_lt(client):
client.reset()
collection = client.create_collection("test_where_lt")
collection.add(**operator_records)
items = collection.get(where={"int_value": {"$lt": 2}})
assert len(items["metadatas"]) == 1
def test_where_lte(client):
client.reset()
collection = client.create_collection("test_where_lte")
collection.add(**operator_records)
items = collection.get(where={"int_value": {"$lte": 2.0}})
assert len(items["metadatas"]) == 2
def test_where_gt(client):
client.reset()
collection = client.create_collection("test_where_lte")
collection.add(**operator_records)
items = collection.get(where={"float_value": {"$gt": -1.4}})
assert len(items["metadatas"]) == 2
def test_where_gte(client):
client.reset()
collection = client.create_collection("test_where_lte")
collection.add(**operator_records)
items = collection.get(where={"float_value": {"$gte": 2.002}})
assert len(items["metadatas"]) == 1
def test_where_ne_string(client):
client.reset()
collection = client.create_collection("test_where_lte")
collection.add(**operator_records)
items = collection.get(where={"string_value": {"$ne": "two"}})
assert len(items["metadatas"]) == 1
def test_where_ne_eq_number(client):
client.reset()
collection = client.create_collection("test_where_lte")
collection.add(**operator_records)
items = collection.get(where={"int_value": {"$ne": 1}})
assert len(items["metadatas"]) == 1
items = collection.get(where={"float_value": {"$eq": 2.002}})
assert len(items["metadatas"]) == 1
def test_where_valid_operators(client):
client.reset()
collection = client.create_collection("test_where_valid_operators")
collection.add(**operator_records)
with pytest.raises(ValueError):
collection.get(where={"int_value": {"$invalid": 2}})
with pytest.raises(ValueError):
collection.get(where={"int_value": {"$lt": "2"}})
with pytest.raises(ValueError):
collection.get(where={"int_value": {"$lt": 2, "$gt": 1}})
# Test invalid $and, $or
with pytest.raises(ValueError):
collection.get(where={"$and": {"int_value": {"$lt": 2}}})
with pytest.raises(ValueError):
collection.get(
where={"int_value": {"$lt": 2}, "$or": {"int_value": {"$gt": 1}}}
)
with pytest.raises(ValueError):
collection.get(
where={"$gt": [{"int_value": {"$lt": 2}}, {"int_value": {"$gt": 1}}]}
)
with pytest.raises(ValueError):
collection.get(where={"$or": [{"int_value": {"$lt": 2}}]})
with pytest.raises(ValueError):
collection.get(where={"$or": []})
with pytest.raises(ValueError):
collection.get(where={"a": {"$contains": "test"}})
with pytest.raises(ValueError):
collection.get(
where={
"$or": [
{"a": {"$contains": "first"}}, # invalid
{"$contains": "second"}, # valid
]
}
)
# TODO: Define the dimensionality of these embeddingds in terms of the default record
bad_dimensionality_records = {
"embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]],
"ids": ["id1", "id2"],
}
bad_dimensionality_query = {
"query_embeddings": [[1.1, 2.3, 3.2, 4.5], [1.2, 2.24, 3.2, 4.5]],
}
bad_number_of_results_query = {
"query_embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"n_results": 100,
}
def test_dimensionality_validation_add(client):
client.reset()
collection = client.create_collection("test_dimensionality_validation")
collection.add(**minimal_records)
with pytest.raises(Exception) as e:
collection.add(**bad_dimensionality_records)
assert "dimension" in str(e.value)
def test_dimensionality_validation_query(client):
client.reset()
collection = client.create_collection("test_dimensionality_validation_query")
collection.add(**minimal_records)
with pytest.raises(Exception) as e:
collection.query(**bad_dimensionality_query)
assert "dimension" in str(e.value)
def test_query_document_valid_operators(client):
client.reset()
collection = client.create_collection("test_where_valid_operators")
collection.add(**operator_records)
with pytest.raises(ValueError, match="where document"):
collection.get(where_document={"$lt": {"$nested": 2}})
with pytest.raises(ValueError, match="where document"):
collection.query(query_embeddings=[0, 0, 0], where_document={"$contains": 2})
with pytest.raises(ValueError, match="where document"):
collection.get(where_document={"$contains": []})
# Test invalid $contains
with pytest.raises(ValueError, match="where document"):
collection.get(where_document={"$contains": {"text": "hello"}})
# Test invalid $not_contains
with pytest.raises(ValueError, match="where document"):
collection.get(where_document={"$not_contains": {"text": "hello"}})
# Test invalid $and, $or
with pytest.raises(ValueError):
collection.get(where_document={"$and": {"$unsupported": "doc"}})
with pytest.raises(ValueError):
collection.get(
where_document={"$or": [{"$unsupported": "doc"}, {"$unsupported": "doc"}]}
)
with pytest.raises(ValueError):
collection.get(where_document={"$or": [{"$contains": "doc"}]})
with pytest.raises(ValueError):
collection.get(where_document={"$or": []})
with pytest.raises(ValueError):
collection.get(
where_document={
"$or": [{"$and": [{"$contains": "doc"}]}, {"$contains": "doc"}]
}
)
contains_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"documents": ["this is doc1 and it's great!", "doc2 is also great!"],
"ids": ["id1", "id2"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001},
{"int_value": 2, "float_value": 2.002, "string_value": "two"},
],
}
def test_get_where_document(client):
client.reset()
collection = client.create_collection("test_get_where_document")
collection.add(**contains_records)
items = collection.get(where_document={"$contains": "doc1"})
assert len(items["metadatas"]) == 1
items = collection.get(where_document={"$contains": "great"})
assert len(items["metadatas"]) == 2
items = collection.get(where_document={"$contains": "bad"})
assert len(items["metadatas"]) == 0
def test_query_where_document(client):
client.reset()
collection = client.create_collection("test_query_where_document")
collection.add(**contains_records)
items = collection.query(
query_embeddings=[1, 0, 0], where_document={"$contains": "doc1"}, n_results=1
)
assert len(items["metadatas"][0]) == 1
items = collection.query(
query_embeddings=[0, 0, 0], where_document={"$contains": "great"}, n_results=2
)
assert len(items["metadatas"][0]) == 2
with pytest.raises(Exception) as e:
items = collection.query(
query_embeddings=[0, 0, 0], where_document={"$contains": "bad"}, n_results=1
)
assert "datapoints" in str(e.value)
def test_delete_where_document(client):
client.reset()
collection = client.create_collection("test_delete_where_document")
collection.add(**contains_records)
collection.delete(where_document={"$contains": "doc1"})
assert collection.count() == 1
collection.delete(where_document={"$contains": "bad"})
assert collection.count() == 1
collection.delete(where_document={"$contains": "great"})
assert collection.count() == 0
logical_operator_records = {
"embeddings": [
[1.1, 2.3, 3.2],
[1.2, 2.24, 3.2],
[1.3, 2.25, 3.2],
[1.4, 2.26, 3.2],
],
"ids": ["id1", "id2", "id3", "id4"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001, "is": "doc"},
{"int_value": 2, "float_value": 2.002, "string_value": "two", "is": "doc"},
{"int_value": 3, "float_value": 3.003, "string_value": "three", "is": "doc"},
{"int_value": 4, "float_value": 4.004, "string_value": "four", "is": "doc"},
],
"documents": [
"this document is first and great",
"this document is second and great",
"this document is third and great",
"this document is fourth and great",
],
}
def test_where_logical_operators(client):
client.reset()
collection = client.create_collection("test_logical_operators")
collection.add(**logical_operator_records)
items = collection.get(
where={
"$and": [
{"$or": [{"int_value": {"$gte": 3}}, {"float_value": {"$lt": 1.9}}]},
{"is": "doc"},
]
}
)
assert len(items["metadatas"]) == 3
items = collection.get(
where={
"$or": [
{
"$and": [
{"int_value": {"$eq": 3}},
{"string_value": {"$eq": "three"}},
]
},
{
"$and": [
{"int_value": {"$eq": 4}},
{"string_value": {"$eq": "four"}},
]
},
]
}
)
assert len(items["metadatas"]) == 2
items = collection.get(
where={
"$and": [
{
"$or": [
{"int_value": {"$eq": 1}},
{"string_value": {"$eq": "two"}},
]
},
{
"$or": [
{"int_value": {"$eq": 2}},
{"string_value": {"$eq": "one"}},
]
},
]
}
)
assert len(items["metadatas"]) == 2
def test_where_document_logical_operators(client):
client.reset()
collection = client.create_collection("test_document_logical_operators")
collection.add(**logical_operator_records)
items = collection.get(
where_document={
"$and": [
{"$contains": "first"},
{"$contains": "doc"},
]
}
)
assert len(items["metadatas"]) == 1
items = collection.get(
where_document={
"$or": [
{"$contains": "first"},
{"$contains": "second"},
]
}
)
assert len(items["metadatas"]) == 2
items = collection.get(
where_document={
"$or": [
{"$contains": "first"},
{"$contains": "second"},
]
},
where={
"int_value": {"$ne": 2},
},
)
assert len(items["metadatas"]) == 1
# endregion
records = {
"embeddings": [[0, 0, 0], [1.2, 2.24, 3.2]],
"ids": ["id1", "id2"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001},
{"int_value": 2},
],
"documents": ["this document is first", "this document is second"],
}
def test_query_include(client):
client.reset()
collection = client.create_collection("test_query_include")
collection.add(**records)
include = ["metadatas", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=include,
n_results=1,
)
assert items["embeddings"] is None
assert items["ids"][0][0] == "id1"
assert items["metadatas"][0][0]["int_value"] == 1
assert set(items["included"]) == set(include)
include = ["embeddings", "documents", "distances"]
items = collection.query(
query_embeddings=[0, 0, 0],
include=include,
n_results=1,
)
assert items["metadatas"] is None
assert items["ids"][0][0] == "id1"
assert set(items["included"]) == set(include)
items = collection.query(
query_embeddings=[[0, 0, 0], [1, 2, 1.2]],
include=[],
n_results=2,
)
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert items["distances"] is None
assert items["ids"][0][0] == "id1"
assert items["ids"][0][1] == "id2"
def test_get_include(client):
client.reset()
collection = client.create_collection("test_get_include")
collection.add(**records)
include = ["metadatas", "documents"]
items = collection.get(include=include, where={"int_value": 1})
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["metadatas"][0]["int_value"] == 1
assert items["documents"][0] == "this document is first"
assert set(items["included"]) == set(include)
include = ["embeddings", "documents"]
items = collection.get(include=include)
assert items["metadatas"] is None
assert items["ids"][0] == "id1"
assert approx_equal(items["embeddings"][1][0], 1.2)
assert set(items["included"]) == set(include)
items = collection.get(include=[])
assert items["documents"] is None
assert items["metadatas"] is None
assert items["embeddings"] is None
assert items["ids"][0] == "id1"
assert items["included"] == []
with pytest.raises(ValueError, match="include"):
items = collection.get(include=["metadatas", "undefined"])
with pytest.raises(ValueError, match="include"):
items = collection.get(include=None)
# make sure query results are returned in the right order
def test_query_order(client):
client.reset()
collection = client.create_collection("test_query_order")
collection.add(**records)
items = collection.query(
query_embeddings=[1.2, 2.24, 3.2],
include=["metadatas", "documents", "distances"],
n_results=2,
)
assert items["documents"][0][0] == "this document is second"
assert items["documents"][0][1] == "this document is first"
# test to make sure add, get, delete error on invalid id input
def test_invalid_id(client):
client.reset()
collection = client.create_collection("test_invalid_id")
# Add with non-string id
with pytest.raises(ValueError) as e:
collection.add(embeddings=[0, 0, 0], ids=[1], metadatas=[{}])
assert "ID" in str(e.value)
# Get with non-list id
with pytest.raises(ValueError) as e:
collection.get(ids=1)
assert "ID" in str(e.value)
# Delete with malformed ids
with pytest.raises(ValueError) as e:
collection.delete(ids=["valid", 0])
assert "ID" in str(e.value)
def test_index_params(client):
EPS = 1e-12
# first standard add
client.reset()
collection = client.create_collection(name="test_index_params")
collection.add(**records)
items = collection.query(
query_embeddings=[0.6, 1.12, 1.6],
n_results=1,
)
assert items["distances"][0][0] > 4
# cosine
client.reset()
collection = client.create_collection(
name="test_index_params",
metadata={"hnsw:space": "cosine", "hnsw:construction_ef": 20, "hnsw:M": 5},
)
collection.add(**records)
items = collection.query(
query_embeddings=[0.6, 1.12, 1.6],
n_results=1,
)
assert items["distances"][0][0] > 0 - EPS
assert items["distances"][0][0] < 1 + EPS
# ip
client.reset()
collection = client.create_collection(
name="test_index_params", metadata={"hnsw:space": "ip"}
)
collection.add(**records)
items = collection.query(
query_embeddings=[0.6, 1.12, 1.6],
n_results=1,
)
assert items["distances"][0][0] < -5
def test_invalid_index_params(client):
client.reset()
with pytest.raises(InvalidArgumentError):
collection = client.create_collection(
name="test_index_params", metadata={"hnsw:space": "foobar"}
)
collection.add(**records)
def test_persist_index_loading_params(client, request):
client = request.getfixturevalue("local_persist_api")
client.reset()
collection = client.create_collection(
"test",
metadata={"hnsw:space": "ip"},
)
collection.add(ids="id1", documents="hello")
api2 = request.getfixturevalue("local_persist_api_cache_bust")
collection = api2.get_collection(
"test",
)
assert collection.metadata["hnsw:space"] == "ip"
includes = ["embeddings", "documents", "metadatas", "distances"]
nn = collection.query(
query_texts="hello",
n_results=1,
include=includes,
)
for key in nn.keys():
if (key in includes) or (key == "ids"):
assert len(nn[key]) == 1
elif key == "included":
assert set(nn[key]) == set(includes)
else:
assert nn[key] is None
def test_add_large(client):
client.reset()
collection = client.create_collection("testspace")
# Test adding a large number of records
large_records = np.random.rand(2000, 512).astype(np.float32).tolist()
collection.add(
embeddings=large_records,
ids=[f"http://example.com/{i}" for i in range(len(large_records))],
)
assert collection.count() == len(large_records)
# test get_version
def test_get_version(client):
client.reset()
version = client.get_version()
# assert version matches the pattern x.y.z
import re
assert re.match(r"\d+\.\d+\.\d+", version)
# test delete_collection
def test_delete_collection(client):
client.reset()
collection = client.create_collection("test_delete_collection")
collection.add(**records)
assert len(client.list_collections()) == 1
client.delete_collection("test_delete_collection")
assert len(client.list_collections()) == 0
# test default embedding function
def test_default_embedding():
embedding_function = DefaultEmbeddingFunction()
docs = ["this is a test" for _ in range(64)]
embeddings = embedding_function(docs)
assert len(embeddings) == 64
def test_multiple_collections(client):
embeddings1 = np.random.rand(10, 512).astype(np.float32).tolist()
embeddings2 = np.random.rand(10, 512).astype(np.float32).tolist()
ids1 = [f"http://example.com/1/{i}" for i in range(len(embeddings1))]
ids2 = [f"http://example.com/2/{i}" for i in range(len(embeddings2))]
client.reset()
coll1 = client.create_collection("coll1")
coll1.add(embeddings=embeddings1, ids=ids1)
coll2 = client.create_collection("coll2")
coll2.add(embeddings=embeddings2, ids=ids2)
assert len(client.list_collections()) == 2
assert coll1.count() == len(embeddings1)
assert coll2.count() == len(embeddings2)
results1 = coll1.query(query_embeddings=embeddings1[0], n_results=1)
results2 = coll2.query(query_embeddings=embeddings2[0], n_results=1)
# progressively check the results are what we expect so we can debug when/if flakes happen
assert len(results1["ids"]) > 0
assert len(results2["ids"]) > 0
assert len(results1["ids"][0]) > 0
assert len(results2["ids"][0]) > 0
assert results1["ids"][0][0] == ids1[0]
assert results2["ids"][0][0] == ids2[0]
def test_update_query(client):
client.reset()
collection = client.create_collection("test_update_query")
collection.add(**records)
updated_records = {
"ids": [records["ids"][0]],
"embeddings": [[0.1, 0.2, 0.3]],
"documents": ["updated document"],
"metadatas": [{"foo": "bar"}],
}
collection.update(**updated_records)
# test query
results = collection.query(
query_embeddings=updated_records["embeddings"],
n_results=1,
include=["embeddings", "documents", "metadatas"],
)
assert len(results["ids"][0]) == 1
assert results["ids"][0][0] == updated_records["ids"][0]
assert results["documents"][0][0] == updated_records["documents"][0]
assert results["metadatas"][0][0]["foo"] == "bar"
assert vector_approx_equal(
results["embeddings"][0][0], updated_records["embeddings"][0]
)
def test_get_nearest_neighbors_where_n_results_more_than_element(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**records)
includes = ["embeddings", "documents", "metadatas", "distances"]
results = collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results=5,
include=includes,
)
for key in results.keys():
if key in includes or key == "ids":
assert len(results[key][0]) == 2
elif key == "included":
assert set(results[key]) == set(includes)
else:
assert results[key] is None
def test_invalid_n_results_param(client):
client.reset()
collection = client.create_collection("testspace")
collection.add(**records)
with pytest.raises(TypeError) as exc:
collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results=-1,
include=["embeddings", "documents", "metadatas", "distances"],
)
assert "Number of requested results -1, cannot be negative, or zero." in str(
exc.value
)
assert exc.type == TypeError
with pytest.raises(ValueError) as exc:
collection.query(
query_embeddings=[[1.1, 2.3, 3.2]],
n_results="one",
include=["embeddings", "documents", "metadatas", "distances"],
)
assert "int" in str(exc.value)
assert exc.type == ValueError
initial_records = {
"embeddings": [[0, 0, 0], [1.2, 2.24, 3.2], [2.2, 3.24, 4.2]],
"ids": ["id1", "id2", "id3"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001},
{"int_value": 2},
{"string_value": "three"},
],
"documents": [
"this document is first",
"this document is second",
"this document is third",
],
}
new_records = {
"embeddings": [[3.0, 3.0, 1.1], [3.2, 4.24, 5.2]],
"ids": ["id1", "id4"],
"metadatas": [
{"int_value": 1, "string_value": "one_of_one", "float_value": 1.001},
{"int_value": 4},
],
"documents": [
"this document is even more first",
"this document is new and fourth",
],
}
def test_upsert(client):
client.reset()
collection = client.create_collection("test")
collection.add(**initial_records)
assert collection.count() == 3
collection.upsert(**new_records)
assert collection.count() == 4
get_result = collection.get(
include=["embeddings", "metadatas", "documents"], ids=new_records["ids"][0]
)
assert vector_approx_equal(
get_result["embeddings"][0], new_records["embeddings"][0]
)
assert get_result["metadatas"][0] == new_records["metadatas"][0]
assert get_result["documents"][0] == new_records["documents"][0]
query_result = collection.query(
query_embeddings=get_result["embeddings"],
n_results=1,
include=["embeddings", "metadatas", "documents"],
)
assert vector_approx_equal(
query_result["embeddings"][0][0], new_records["embeddings"][0]
)
assert query_result["metadatas"][0][0] == new_records["metadatas"][0]
assert query_result["documents"][0][0] == new_records["documents"][0]
collection.delete(ids=initial_records["ids"][2])
collection.upsert(
ids=initial_records["ids"][2],
embeddings=[[1.1, 0.99, 2.21]],
metadatas=[{"string_value": "a new string value"}],
)
assert collection.count() == 4
get_result = collection.get(
include=["embeddings", "metadatas", "documents"], ids=["id3"]
)
assert vector_approx_equal(get_result["embeddings"][0], [1.1, 0.99, 2.21])
assert get_result["metadatas"][0] == {"string_value": "a new string value"}
assert get_result["documents"][0] is None
def test_collection_upsert_with_invalid_collection_throws(client):
client.reset()
collection = client.create_collection("test")
client.delete_collection("test")
with pytest.raises(NotFoundError, match=r"Collection .* does not exist"):
collection.upsert(**initial_records)
# test to make sure add, query, update, upsert error on invalid embeddings input
def test_invalid_embeddings(client):
client.reset()
collection = client.create_collection("test_invalid_embeddings")
# Add with string embeddings
invalid_records = {
"embeddings": [["0", "0", "0"], ["1.2", "2.24", "3.2"]],
"ids": ["id1", "id2"],
}
with pytest.raises(ValueError) as e:
collection.add(**invalid_records)
assert "embedding" in str(e.value)
# Query with invalid embeddings
with pytest.raises(ValueError) as e:
collection.query(
query_embeddings=[["1.1", "2.3", "3.2"]],
n_results=1,
)
assert "embedding" in str(e.value)
# Update with invalid embeddings
invalid_records = {
"embeddings": [[[0], [0], [0]], [[1.2], [2.24], [3.2]]],
"ids": ["id1", "id2"],
}
with pytest.raises(ValueError) as e:
collection.update(**invalid_records)
assert "embedding" in str(e.value)
# Upsert with invalid embeddings
invalid_records = {
"embeddings": [[[1.1, 2.3, 3.2]], [[1.2, 2.24, 3.2]]],
"ids": ["id1", "id2"],
}
with pytest.raises(ValueError) as e:
collection.upsert(**invalid_records)
assert "embedding" in str(e.value)
# test to make sure update shows exception for bad dimensionality
def test_dimensionality_exception_update(client):
client.reset()
collection = client.create_collection("test_dimensionality_update_exception")
collection.add(**minimal_records)
with pytest.raises(Exception) as e:
collection.update(**bad_dimensionality_records)
assert "dimension" in str(e.value)
# test to make sure upsert shows exception for bad dimensionality
def test_dimensionality_exception_upsert(client):
client.reset()
collection = client.create_collection("test_dimensionality_upsert_exception")
collection.add(**minimal_records)
with pytest.raises(Exception) as e:
collection.upsert(**bad_dimensionality_records)
assert "dimension" in str(e.value)
# this may be flaky on windows, so we rerun it
@pytest.mark.flaky(reruns=3, condition=sys.platform.startswith("win32"))
def test_ssl_self_signed(client_ssl):
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"):
pytest.skip("Skipping test for integration test")
client_ssl.heartbeat()
# this may be flaky on windows, so we rerun it
@pytest.mark.flaky(reruns=3, condition=sys.platform.startswith("win32"))
def test_ssl_self_signed_without_ssl_verify(client_ssl):
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"):
pytest.skip("Skipping test for integration test")
client_ssl.heartbeat()
_port = client_ssl._server._settings.chroma_server_http_port
with pytest.raises(ValueError) as e:
chromadb.HttpClient(ssl=True, port=_port)
stack_trace = traceback.format_exception(
type(e.value), e.value, e.value.__traceback__
)
client_ssl.clear_system_cache()
assert "CERTIFICATE_VERIFY_FAILED" in "".join(stack_trace)
def test_query_id_filtering_small_dataset(client):
client.reset()
collection = client.create_collection("test_query_id_filtering_small")
num_vectors = 100
dim = 512
small_records = np.random.rand(100, 512).astype(np.float32).tolist()
ids = [f"{i}" for i in range(num_vectors)]
collection.add(
embeddings=small_records,
ids=ids,
)
query_ids = [f"{i}" for i in range(0, num_vectors, 10)]
query_embedding = np.random.rand(dim).astype(np.float32).tolist()
results = collection.query(
query_embeddings=query_embedding,
ids=query_ids,
n_results=num_vectors,
include=[],
)
all_returned_ids = [item for sublist in results["ids"] for item in sublist]
assert all(id in query_ids for id in all_returned_ids)
def test_query_id_filtering_medium_dataset(client):
client.reset()
collection = client.create_collection("test_query_id_filtering_medium")
num_vectors = 1000
dim = 512
medium_records = np.random.rand(num_vectors, dim).astype(np.float32).tolist()
ids = [f"{i}" for i in range(num_vectors)]
collection.add(
embeddings=medium_records,
ids=ids,
)
query_ids = [f"{i}" for i in range(0, num_vectors, 10)]
query_embedding = np.random.rand(dim).astype(np.float32).tolist()
results = collection.query(
query_embeddings=query_embedding,
ids=query_ids,
n_results=num_vectors,
include=[],
)
all_returned_ids = [item for sublist in results["ids"] for item in sublist]
assert all(id in query_ids for id in all_returned_ids)
multi_query_embeddings = [
np.random.rand(dim).astype(np.float32).tolist() for _ in range(3)
]
multi_results = collection.query(
query_embeddings=multi_query_embeddings,
ids=query_ids,
n_results=10,
include=[],
)
for result_set in multi_results["ids"]:
assert all(id in query_ids for id in result_set)
def test_query_id_filtering_e2e(client):
client.reset()
collection = client.create_collection("test_query_id_filtering_e2e")
dim = 512
num_vectors = 100
embeddings = np.random.rand(num_vectors, dim).astype(np.float32).tolist()
ids = [f"{i}" for i in range(num_vectors)]
metadatas = [{"index": i} for i in range(num_vectors)]
collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
)
ids_to_delete = [f"{i}" for i in range(10, 30)]
collection.delete(ids=ids_to_delete)
# modify some existing ids, and add some new ones to check query returns updated metadata
ids_to_upsert_existing = [f"{i}" for i in range(30, 50)]
new_num_vectors = num_vectors + 20
ids_to_upsert_new = [f"{i}" for i in range(num_vectors, new_num_vectors)]
upsert_embeddings = (
np.random.rand(len(ids_to_upsert_existing) + len(ids_to_upsert_new), dim)
.astype(np.float32)
.tolist()
)
upsert_metadatas = [
{"index": i, "upserted": True} for i in range(len(upsert_embeddings))
]
collection.upsert(
embeddings=upsert_embeddings,
ids=ids_to_upsert_existing + ids_to_upsert_new,
metadatas=upsert_metadatas,
)
valid_query_ids = (
[f"{i}" for i in range(5, 10)] # subset of existing ids
+ [f"{i}" for i in range(35, 45)] # subset of existing, but upserted
+ [
f"{i}" for i in range(num_vectors + 5, num_vectors + 15)
] # subset of new upserted ids
)
includes = ["metadatas"]
query_embedding = np.random.rand(dim).astype(np.float32).tolist()
results = collection.query(
query_embeddings=query_embedding,
ids=valid_query_ids,
n_results=new_num_vectors,
include=includes,
)
all_returned_ids = [item for sublist in results["ids"] for item in sublist]
assert all(id in valid_query_ids for id in all_returned_ids)
for result_index, id_list in enumerate(results["ids"]):
for item_index, item_id in enumerate(id_list):
if item_id in ids_to_upsert_existing or item_id in ids_to_upsert_new:
# checks if metadata correctly has upserted flag
assert results["metadatas"][result_index][item_index]["upserted"]
upserted_id = ids_to_upsert_existing[0]
# test single id filtering
results = collection.query(
query_embeddings=query_embedding,
ids=upserted_id,
n_results=1,
include=includes,
)
assert results["metadatas"][0][0]["upserted"]
deleted_id = ids_to_delete[0]
# test deleted id filter raises
with pytest.raises(Exception) as error:
collection.query(
query_embeddings=query_embedding,
ids=deleted_id,
n_results=1,
include=includes,
)
assert "Error finding id" in str(error.value)
def test_validate_sparse_vector():
"""Test SparseVector validation in __post_init__."""
from chromadb.base_types import SparseVector
# Test 1: Valid sparse vector - should not raise
SparseVector(indices=[0, 2, 5], values=[0.1, 0.5, 0.9])
# Test 2: Valid sparse vector with empty lists - should not raise
SparseVector(indices=[], values=[])
# Test 4: Invalid - indices not a list
with pytest.raises(ValueError, match="Expected SparseVector indices to be a list"):
SparseVector(indices="not_a_list", values=[0.1, 0.2]) # type: ignore
# Test 5: Invalid - values not a list
with pytest.raises(ValueError, match="Expected SparseVector values to be a list"):
SparseVector(indices=[0, 1], values="not_a_list") # type: ignore
# Test 6: Invalid - mismatched lengths
with pytest.raises(
ValueError, match="indices and values must have the same length"
):
SparseVector(indices=[0, 1, 2], values=[0.1, 0.2])
# Test 7: Invalid - non-integer index
with pytest.raises(ValueError, match="SparseVector indices must be integers"):
SparseVector(indices=[0, "not_int", 2], values=[0.1, 0.2, 0.3]) # type: ignore
# Test 8: Invalid - negative index
with pytest.raises(ValueError, match="SparseVector indices must be non-negative"):
SparseVector(indices=[0, -1, 2], values=[0.1, 0.2, 0.3])
# Test 9: Invalid - non-numeric value
with pytest.raises(ValueError, match="SparseVector values must be numbers"):
SparseVector(indices=[0, 1, 2], values=[0.1, "not_number", 0.3]) # type: ignore
# Test 10: Invalid - float indices (not integers)
with pytest.raises(ValueError, match="SparseVector indices must be integers"):
SparseVector(indices=[0.0, 1.0, 2.0], values=[0.1, 0.2, 0.3]) # type: ignore
# Test 11: Valid - integer values (not just floats)
SparseVector(indices=[0, 1, 2], values=[1, 2, 3])
# Test 12: Valid - mixed int and float values
SparseVector(indices=[0, 1, 2], values=[1, 2.5, 3])
# Test 13: Valid - large indices
SparseVector(indices=[100, 1000, 10000], values=[0.1, 0.2, 0.3])
# Test 14: Invalid - None as value
with pytest.raises(ValueError, match="SparseVector values must be numbers"):
SparseVector(indices=[0, 1], values=[0.1, None]) # type: ignore
# Test 15: Invalid - None as index
with pytest.raises(ValueError, match="SparseVector indices must be integers"):
SparseVector(indices=[0, None], values=[0.1, 0.2]) # type: ignore
# Test 16: Valid - single element
SparseVector(indices=[42], values=[3.14])
# Test 17: Boolean values are actually valid (bool is subclass of int in Python)
SparseVector(indices=[0, 1], values=[True, False]) # True=1, False=0
# Test 18: Invalid - unsorted indices
with pytest.raises(
ValueError, match="indices must be sorted in strictly ascending order"
):
SparseVector(indices=[0, 2, 1], values=[0.1, 0.2, 0.3])
# Test 19: Invalid - duplicate indices (not strictly ascending)
with pytest.raises(
ValueError, match="indices must be sorted in strictly ascending order"
):
SparseVector(indices=[0, 1, 1, 2], values=[0.1, 0.2, 0.3, 0.4])
# Test 20: Invalid - descending order
with pytest.raises(
ValueError, match="indices must be sorted in strictly ascending order"
):
SparseVector(indices=[5, 3, 1], values=[0.5, 0.3, 0.1])
def test_sparse_vector_in_metadata_validation():
"""Test that sparse vectors are properly validated in metadata."""
from chromadb.api.types import validate_metadata
from chromadb.base_types import SparseVector
# Test 1: Valid metadata with sparse vectors
sparse_vector_1 = SparseVector(indices=[0, 2, 5], values=[0.1, 0.5, 0.9])
sparse_vector_2 = SparseVector(indices=[1, 3, 4], values=[0.2, 0.4, 0.6])
metadata_1 = {
"text": "document 1",
"sparse_embedding": sparse_vector_1,
"score": 0.5,
}
metadata_2 = {
"text": "document 2",
"sparse_embedding": sparse_vector_2,
"score": 0.8,
}
validate_metadata(metadata_1)
validate_metadata(metadata_2)
# Test 2: Valid metadata with empty sparse vector
metadata_empty = {
"text": "empty sparse",
"sparse_vec": SparseVector(indices=[], values=[]),
}
validate_metadata(metadata_empty)
# Test 3: Invalid sparse vector in metadata (construction fails)
with pytest.raises(
ValueError, match="indices and values must have the same length"
):
invalid_metadata = {
"text": "invalid",
"sparse_embedding": SparseVector(indices=[0, 1], values=[0.1]),
}
# Test 4: Invalid dict in metadata (not a SparseVector dataclass)
invalid_metadata_2 = {
"text": "missing indices",
"sparse_embedding": {"values": [0.1, 0.2]},
}
with pytest.raises(
ValueError,
match="Expected metadata value to be a str, int, float, bool, SparseVector, or None",
):
validate_metadata(invalid_metadata_2)
# Test 5: Invalid sparse vector - negative index (construction fails)
with pytest.raises(ValueError, match="SparseVector indices must be non-negative"):
invalid_metadata_3 = {
"text": "negative index",
"sparse_embedding": SparseVector(
indices=[0, -1, 2], values=[0.1, 0.2, 0.3]
),
}
# Test 6: Invalid sparse vector - non-numeric value (construction fails)
with pytest.raises(ValueError, match="SparseVector values must be numbers"):
invalid_metadata_4 = {
"text": "non-numeric value",
"sparse_embedding": SparseVector(
indices=[0, 1], values=[0.1, "not_a_number"]
), # type: ignore
}
# Test 7: Multiple sparse vectors in metadata
metadata_multiple = {
"text": "multiple sparse vectors",
"sparse_1": SparseVector(indices=[0, 1], values=[0.1, 0.2]),
"sparse_2": SparseVector(indices=[2, 3, 4], values=[0.3, 0.4, 0.5]),
"regular_field": 42,
}
validate_metadata(metadata_multiple)
# Test 8: Regular dict (not SparseVector) should be rejected
metadata_nested = {
"config": "some_config",
"sparse_vector": {"indices": [0, 1, 2], "values": [1.0, 2.0, 3.0]},
}
with pytest.raises(
ValueError,
match="Expected metadata value to be a str, int, float, bool, SparseVector, or None",
):
validate_metadata(metadata_nested)
# Test 9: Large sparse vector
large_sparse = SparseVector(
indices=list(range(1000)),
values=[float(i) * 0.001 for i in range(1000)],
)
metadata_large = {"text": "large sparse", "large_sparse_vec": large_sparse}
validate_metadata(metadata_large)
def test_sparse_vector_dict_format_normalization():
"""Test that dict-format sparse vectors are normalized to SparseVector instances."""
from chromadb.api.types import normalize_metadata, validate_metadata
from chromadb.base_types import SparseVector
# Test 1: Dict format with #type='sparse_vector' should be converted
metadata_dict_format = {
"text": "test document",
"sparse": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [0, 2, 5],
"values": [1.0, 2.0, 3.0],
},
}
normalized = normalize_metadata(metadata_dict_format)
assert isinstance(normalized["sparse"], SparseVector)
assert normalized["sparse"].indices == [0, 2, 5]
assert normalized["sparse"].values == [1.0, 2.0, 3.0]
# Should pass validation after normalization
validate_metadata(normalized)
# Test 2: SparseVector instance should pass through unchanged
sparse_instance = SparseVector(indices=[1, 3, 4], values=[0.5, 1.5, 2.5])
metadata_instance_format = {
"text": "test document",
"sparse": sparse_instance,
}
normalized2 = normalize_metadata(metadata_instance_format)
assert normalized2["sparse"] is sparse_instance # Same object
validate_metadata(normalized2)
# Test 3: Dict format with unsorted indices should be rejected during normalization
metadata_unsorted = {
"text": "unsorted",
"sparse": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [5, 0, 2],
"values": [3.0, 1.0, 2.0],
},
}
with pytest.raises(
ValueError, match="indices must be sorted in strictly ascending order"
):
normalize_metadata(metadata_unsorted)
# Test 4: Dict format with duplicate indices should be rejected
metadata_duplicates = {
"text": "duplicates",
"sparse": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [0, 2, 2],
"values": [1.0, 2.0, 3.0],
},
}
with pytest.raises(
ValueError, match="indices must be sorted in strictly ascending order"
):
normalize_metadata(metadata_duplicates)
# Test 5: Dict format with negative indices should be rejected
metadata_negative = {
"text": "negative",
"sparse": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [-1, 0, 2],
"values": [1.0, 2.0, 3.0],
},
}
with pytest.raises(ValueError, match="indices must be non-negative"):
normalize_metadata(metadata_negative)
# Test 6: Dict format with length mismatch should be rejected
metadata_mismatch = {
"text": "mismatch",
"sparse": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [0, 2],
"values": [1.0, 2.0, 3.0],
},
}
with pytest.raises(
ValueError, match="indices and values must have the same length"
):
normalize_metadata(metadata_mismatch)
# Test 7: Regular dict without #type should not be converted
metadata_regular_dict = {
"text": "regular",
"config": {"key": "value"},
}
normalized3 = normalize_metadata(metadata_regular_dict)
assert isinstance(normalized3["config"], dict)
assert normalized3["config"]["key"] == "value"
# Test 8: Empty sparse vector in dict format
metadata_empty = {
"text": "empty",
"sparse": {TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE, "indices": [], "values": []},
}
normalized4 = normalize_metadata(metadata_empty)
assert isinstance(normalized4["sparse"], SparseVector)
assert normalized4["sparse"].indices == []
assert normalized4["sparse"].values == []
# Test 9: Multiple sparse vectors in dict format
metadata_multiple = {
"sparse1": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [0, 1],
"values": [1.0, 2.0],
},
"sparse2": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [2, 3],
"values": [3.0, 4.0],
},
"regular": 42,
}
normalized5 = normalize_metadata(metadata_multiple)
assert isinstance(normalized5["sparse1"], SparseVector)
assert isinstance(normalized5["sparse2"], SparseVector)
assert normalized5["regular"] == 42
def test_sparse_vector_dict_format_in_record_set():
"""Test that dict-format sparse vectors work in normalize_insert_record_set."""
from chromadb.api.types import (
normalize_insert_record_set,
validate_insert_record_set,
)
from chromadb.base_types import SparseVector
# Test 1: Mix of dict format and SparseVector instances
record_set = normalize_insert_record_set(
ids=["doc1", "doc2", "doc3"],
embeddings=None,
metadatas=[
{
"text": "test1",
"sparse": {
TYPE_KEY: SPARSE_VECTOR_TYPE_VALUE,
"indices": [0, 2],
"values": [1.0, 2.0],
},
},
{
"text": "test2",
"sparse": SparseVector(indices=[1, 3], values=[1.5, 2.5]),
},
{"text": "test3"}, # No sparse vector
],
documents=["doc one", "doc two", "doc three"],
)
# Both should be converted to SparseVector instances
assert isinstance(record_set["metadatas"][0]["sparse"], SparseVector)
assert isinstance(record_set["metadatas"][1]["sparse"], SparseVector)
assert "sparse" not in record_set["metadatas"][2]
# Validation should pass
validate_insert_record_set(record_set)
# Test 2: Verify values are correct after normalization
assert record_set["metadatas"][0]["sparse"].indices == [0, 2]
assert record_set["metadatas"][0]["sparse"].values == [1.0, 2.0]
assert record_set["metadatas"][1]["sparse"].indices == [1, 3]
assert record_set["metadatas"][1]["sparse"].values == [1.5, 2.5]
def test_search_result_rows() -> None:
"""Test the SearchResult.rows() method for converting column-major to row-major format."""
from chromadb.api.types import SearchResult
# Test 1: Basic single payload with all fields
result = SearchResult(
{
"ids": [["id1", "id2", "id3"]],
"documents": [["doc1", "doc2", "doc3"]],
"embeddings": [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]],
"metadatas": [[{"key": "a"}, {"key": "b"}, {"key": "c"}]],
"scores": [[0.9, 0.8, 0.7]],
"select": [["document", "score", "metadata"]],
}
)
rows = result.rows()
assert len(rows) == 1 # One payload
assert len(rows[0]) == 3 # Three results
# Check first row
assert rows[0][0]["id"] == "id1"
assert rows[0][0]["document"] == "doc1"
assert rows[0][0]["embedding"] == [1.0, 2.0]
assert rows[0][0]["metadata"] == {"key": "a"}
assert rows[0][0]["score"] == 0.9
# Check all rows have all fields
for row in rows[0]:
assert "id" in row
assert "document" in row
assert "embedding" in row
assert "metadata" in row
assert "score" in row
# Test 2: Multiple payloads
result = SearchResult(
{
"ids": [["a1", "a2"], ["b1", "b2", "b3"]],
"documents": [["doc_a1", "doc_a2"], ["doc_b1", "doc_b2", "doc_b3"]],
"embeddings": [
None,
[[1.0], [2.0], [3.0]],
], # First payload has no embeddings
"metadatas": [[{"x": 1}, {"x": 2}], None], # Second payload has no metadata
"scores": [[0.5, 0.4], [0.9, 0.8, 0.7]],
"select": [["document", "score"], ["embedding", "score"]],
}
)
rows = result.rows()
assert len(rows) == 2 # Two payloads
assert len(rows[0]) == 2 # First payload has 2 results
assert len(rows[1]) == 3 # Second payload has 3 results
# First payload - has docs, metadata, scores but no embeddings
assert rows[0][0] == {
"id": "a1",
"document": "doc_a1",
"metadata": {"x": 1},
"score": 0.5,
}
assert rows[0][1] == {
"id": "a2",
"document": "doc_a2",
"metadata": {"x": 2},
"score": 0.4,
}
# Second payload - has docs, embeddings, scores but no metadata
assert rows[1][0] == {
"id": "b1",
"document": "doc_b1",
"embedding": [1.0],
"score": 0.9,
}
assert rows[1][1] == {
"id": "b2",
"document": "doc_b2",
"embedding": [2.0],
"score": 0.8,
}
assert rows[1][2] == {
"id": "b3",
"document": "doc_b3",
"embedding": [3.0],
"score": 0.7,
}
# Test 3: Empty result
result = SearchResult(
{
"ids": [],
"documents": [],
"embeddings": [],
"metadatas": [],
"scores": [],
"select": [],
}
)
rows = result.rows()
assert rows == []
# Test 4: Sparse data with None values in lists
result = SearchResult(
{
"ids": [["id1", "id2", "id3"]],
"documents": [[None, "doc2", None]], # Sparse documents
"embeddings": None, # No embeddings at all
"metadatas": [[{"a": 1}, None, {"c": 3}]], # Sparse metadata
"scores": [[0.9, None, 0.7]], # Sparse scores
"select": [["document", "metadata", "score"]],
}
)
rows = result.rows()
assert len(rows) == 1
assert len(rows[0]) == 3
# First row - only has metadata and score
assert rows[0][0] == {"id": "id1", "metadata": {"a": 1}, "score": 0.9}
# Second row - only has document
assert rows[0][1] == {"id": "id2", "document": "doc2"}
# Third row - has metadata and score
assert rows[0][2] == {"id": "id3", "metadata": {"c": 3}, "score": 0.7}
# Test 5: Only IDs (minimal result)
result = SearchResult(
{
"ids": [["id1", "id2"]],
"documents": None,
"embeddings": None,
"metadatas": None,
"scores": None,
"select": [[]],
}
)
rows = result.rows()
assert len(rows) == 1
assert len(rows[0]) == 2
assert rows[0][0] == {"id": "id1"}
assert rows[0][1] == {"id": "id2"}
# Test 6: SearchResult works as dict (backward compatibility)
result = SearchResult(
{
"ids": [["test"]],
"documents": [["test doc"]],
"metadatas": [[{"test": True}]],
"embeddings": [[[0.1, 0.2]]],
"scores": [[0.99]],
"select": [["all"]],
}
)
# Should work as dict
assert result["ids"] == [["test"]]
assert result.get("documents") == [["test doc"]]
assert "metadatas" in result
assert len(result) == 6 # Should have 6 keys
# Should also have rows() method
rows = result.rows()
assert len(rows[0]) == 1
assert rows[0][0]["id"] == "test"
print("All SearchResult.rows() tests passed!")
def test_rrf_to_dict() -> None:
"""Test the Rrf (Reciprocal Rank Fusion) to_dict conversion."""
# Note: In these tests, "sparse_embedding" is just an example metadata field name.
# Users can store any data in metadata fields and reference them by name (without # prefix).
# The "#embedding" key refers to the special main embedding field.
import pytest
from chromadb.execution.expression.operator import Rrf, Knn, Val
# Test 1: Basic RRF with two KNN rankings (equal weight)
rrf = Rrf(
[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
]
)
result = rrf.to_dict()
# RRF formula: -sum(weight_i / (k + rank_i))
# With default k=60 and equal weights (1.0 each)
# Expected: -(1.0/(60 + knn1) + 1.0/(60 + knn2))
expected = {
"$mul": [
{"$val": -1},
{
"$sum": [
{
"$div": {
"left": {"$val": 1.0},
"right": {
"$sum": [
{"$val": 60},
{
"$knn": {
"query": [0.1, 0.2],
"key": "#embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
{
"$div": {
"left": {"$val": 1.0},
"right": {
"$sum": [
{"$val": 60},
{
"$knn": {
"query": [0.3, 0.4],
"key": "sparse_embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
]
},
]
}
assert result == expected
# Test 2: RRF with custom weights and k
rrf_weighted = Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
],
weights=[2.0, 1.0], # Dense is 2x more important
k=100,
)
result_weighted = rrf_weighted.to_dict()
# Expected: -(2.0/(100 + knn1) + 1.0/(100 + knn2))
expected_weighted = {
"$mul": [
{"$val": -1},
{
"$sum": [
{
"$div": {
"left": {"$val": 2.0},
"right": {
"$sum": [
{"$val": 100},
{
"$knn": {
"query": [0.1, 0.2],
"key": "#embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
{
"$div": {
"left": {"$val": 1.0},
"right": {
"$sum": [
{"$val": 100},
{
"$knn": {
"query": [0.3, 0.4],
"key": "sparse_embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
]
},
]
}
assert result_weighted == expected_weighted
# Test 3: RRF with three rankings
rrf_three = Rrf(
[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
Val(5.0), # Can also include constant rank
]
)
result_three = rrf_three.to_dict()
# Verify it has three terms in the sum
assert "$mul" in result_three
assert "$sum" in result_three["$mul"][1]
terms = result_three["$mul"][1]["$sum"]
assert len(terms) == 3 # Three ranking strategies
# Test 4: Error case - mismatched weights
with pytest.raises(
ValueError, match="Number of weights .* must match number of ranks"
):
rrf_bad = Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], return_rank=True),
],
weights=[1.0], # Only one weight for two ranks
)
rrf_bad.to_dict()
# Test 5: Error case - negative weights
with pytest.raises(ValueError, match="All weights must be non-negative"):
rrf_negative = Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], return_rank=True),
],
weights=[1.0, -1.0], # Negative weight
)
rrf_negative.to_dict()
# Test 6: Error case - empty ranks list
with pytest.raises(ValueError, match="RRF requires at least one rank"):
rrf_empty = Rrf([])
rrf_empty.to_dict() # Validation happens in to_dict()
# Test 7: Error case - negative k value
with pytest.raises(ValueError, match="k must be positive"):
rrf_neg_k = Rrf([Val(1.0)], k=-5)
rrf_neg_k.to_dict() # Validation happens in to_dict()
# Test 8: Error case - zero k value
with pytest.raises(ValueError, match="k must be positive"):
rrf_zero_k = Rrf([Val(1.0)], k=0)
rrf_zero_k.to_dict() # Validation happens in to_dict()
# Test 9: Normalize flag with weights
rrf_normalized = Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], key="sparse_embedding", return_rank=True),
],
weights=[3.0, 1.0], # Will be normalized to [0.75, 0.25]
normalize=True,
k=100,
)
result_normalized = rrf_normalized.to_dict()
# Expected: -(0.75/(100 + knn1) + 0.25/(100 + knn2))
expected_normalized = {
"$mul": [
{"$val": -1},
{
"$sum": [
{
"$div": {
"left": {"$val": 0.75},
"right": {
"$sum": [
{"$val": 100},
{
"$knn": {
"query": [0.1, 0.2],
"key": "#embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
{
"$div": {
"left": {"$val": 0.25},
"right": {
"$sum": [
{"$val": 100},
{
"$knn": {
"query": [0.3, 0.4],
"key": "sparse_embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
]
},
]
}
assert result_normalized == expected_normalized
# Test 10: Normalize flag without weights (should work with defaults)
rrf_normalize_defaults = Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], return_rank=True),
],
normalize=True, # Will normalize [1.0, 1.0] to [0.5, 0.5]
)
result_defaults = rrf_normalize_defaults.to_dict()
# Both weights should be 0.5 after normalization
expected_defaults = {
"$mul": [
{"$val": -1},
{
"$sum": [
{
"$div": {
"left": {"$val": 0.5},
"right": {
"$sum": [
{"$val": 60}, # Default k=60
{
"$knn": {
"query": [0.1, 0.2],
"key": "#embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
{
"$div": {
"left": {"$val": 0.5},
"right": {
"$sum": [
{"$val": 60},
{
"$knn": {
"query": [0.3, 0.4],
"key": "#embedding",
"limit": 16,
"return_rank": True,
}
},
]
},
}
},
]
},
]
}
assert result_defaults == expected_defaults
# Test 11: Error case - normalize with all zero weights
with pytest.raises(ValueError, match="Sum of weights must be positive"):
rrf_zero_weights = Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=[0.3, 0.4], return_rank=True),
],
weights=[0.0, 0.0],
normalize=True,
)
rrf_zero_weights.to_dict()
print("All RRF tests passed!")
def test_group_by_serialization() -> None:
"""Test GroupBy, MinK, and MaxK serialization and deserialization."""
import pytest
from chromadb.execution.expression.operator import (
GroupBy,
MinK,
MaxK,
Key,
Aggregate,
)
# to_dict with OneOrMany keys
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
assert group_by.to_dict() == {
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
}
# to_dict with multiple keys and MaxK
group_by = GroupBy(
keys=[Key("year"), Key("category")],
aggregate=MaxK(keys=[Key.SCORE, Key("priority")], k=5),
)
assert group_by.to_dict() == {
"keys": ["year", "category"],
"aggregate": {"$max_k": {"keys": ["#score", "priority"], "k": 5}},
}
# Round-trip
original = GroupBy(keys=[Key("category")], aggregate=MinK(keys=[Key.SCORE], k=3))
assert GroupBy.from_dict(original.to_dict()).to_dict() == original.to_dict()
# Empty GroupBy serializes to {} and from_dict({}) returns default GroupBy
empty_group_by = GroupBy()
assert empty_group_by.to_dict() == {}
assert GroupBy.from_dict({}).to_dict() == {}
# Error cases
with pytest.raises(ValueError, match="requires 'keys' field"):
GroupBy.from_dict({"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}})
with pytest.raises(ValueError, match="requires 'aggregate' field"):
GroupBy.from_dict({"keys": ["category"]})
with pytest.raises(ValueError, match="keys cannot be empty"):
GroupBy.from_dict(
{"keys": [], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
)
with pytest.raises(ValueError, match="Unknown aggregate operator"):
Aggregate.from_dict({"$unknown": {"keys": ["#score"], "k": 3}})
# Expression API Tests - Testing dict support and from_dict methods
class TestSearchDictSupport:
"""Test Search class dict input support."""
def test_search_with_dict_where(self):
"""Test Search accepts dict for where parameter."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import Where
# Simple equality
search = Search(where={"status": "active"})
assert search._where is not None
assert isinstance(search._where, Where)
# Complex where with operators
search = Search(where={"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]})
assert search._where is not None
def test_search_with_dict_rank(self):
"""Test Search accepts dict for rank parameter."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import Rank
# KNN ranking
search = Search(rank={"$knn": {"query": [0.1, 0.2]}})
assert search._rank is not None
assert isinstance(search._rank, Rank)
# Val ranking
search = Search(rank={"$val": 0.5})
assert search._rank is not None
def test_search_with_dict_limit(self):
"""Test Search accepts dict and int for limit parameter."""
from chromadb.execution.expression.plan import Search
# Dict limit
search = Search(limit={"limit": 10, "offset": 5})
assert search._limit.limit == 10
assert search._limit.offset == 5
# Int limit (creates Limit with offset=0)
search = Search(limit=10)
assert search._limit.limit == 10
assert search._limit.offset == 0
def test_search_with_dict_select(self):
"""Test Search accepts dict, list, and set for select parameter."""
from chromadb.execution.expression.plan import Search
# Dict select
search = Search(select={"keys": ["#document", "#score"]})
assert search._select is not None
# List select
search = Search(select=["#document", "#metadata"])
assert search._select is not None
# Set select
search = Search(select={"#document", "#embedding"})
assert search._select is not None
def test_search_mixed_inputs(self):
"""Test Search with mixed expression and dict inputs."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import Key
search = Search(
where=Key("status") == "active", # Expression
rank={"$knn": {"query": [0.1, 0.2]}}, # Dict
limit=10, # Int
select=["#document"], # List
)
assert search._where is not None
assert search._rank is not None
assert search._limit.limit == 10
assert search._select is not None
def test_search_builder_methods_with_dicts(self):
"""Test Search builder methods accept dicts."""
from chromadb.execution.expression.plan import Search
search = Search().where({"status": "active"}).rank({"$val": 0.5})
assert search._where is not None
assert search._rank is not None
def test_search_invalid_inputs(self):
"""Test Search rejects invalid input types."""
import pytest
from chromadb.execution.expression.plan import Search
with pytest.raises(TypeError, match="where must be"):
Search(where="invalid")
with pytest.raises(TypeError, match="rank must be"):
Search(rank=0.5) # Primitive numbers not allowed
with pytest.raises(TypeError, match="limit must be"):
Search(limit="10")
with pytest.raises(TypeError, match="select must be"):
Search(select=123)
def test_search_with_group_by(self):
"""Test Search accepts group_by as dict, object, and builder method."""
import pytest
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import GroupBy, MinK, Key
# Dict input
search = Search(
group_by={
"keys": ["category"],
"aggregate": {"$min_k": {"keys": ["#score"], "k": 3}},
}
)
assert isinstance(search._group_by, GroupBy)
# Object input and builder method
group_by = GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
assert Search(group_by=group_by)._group_by is group_by
assert Search().group_by(group_by)._group_by.aggregate is not None
# Invalid inputs
with pytest.raises(TypeError, match="group_by must be"):
Search(group_by="invalid")
with pytest.raises(ValueError, match="requires 'aggregate' field"):
Search(group_by={"keys": ["category"]})
def test_search_group_by_serialization(self):
"""Test Search serializes group_by correctly."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import GroupBy, MinK, Key, Knn
# Without group_by - empty dict
search = Search().rank(Knn(query=[0.1, 0.2])).limit(10)
assert search.to_dict()["group_by"] == {}
# With group_by - has keys and aggregate
search = Search().group_by(
GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=3))
)
result = search.to_dict()["group_by"]
assert result["keys"] == ["category"]
assert result["aggregate"] == {"$min_k": {"keys": ["#score"], "k": 3}}
class TestWhereFromDict:
"""Test Where.from_dict() conversion."""
def test_simple_equality(self):
"""Test simple equality conversion."""
from chromadb.execution.expression.operator import Where, Eq
# Shorthand for equality
where = Where.from_dict({"status": "active"})
assert isinstance(where, Eq)
# Explicit $eq
where = Where.from_dict({"status": {"$eq": "active"}})
assert isinstance(where, Eq)
def test_comparison_operators(self):
"""Test comparison operator conversions."""
from chromadb.execution.expression.operator import Where, Ne, Gt, Gte, Lt, Lte
# $ne
where = Where.from_dict({"status": {"$ne": "inactive"}})
assert isinstance(where, Ne)
# $gt
where = Where.from_dict({"score": {"$gt": 0.5}})
assert isinstance(where, Gt)
# $gte
where = Where.from_dict({"score": {"$gte": 0.5}})
assert isinstance(where, Gte)
# $lt
where = Where.from_dict({"score": {"$lt": 1.0}})
assert isinstance(where, Lt)
# $lte
where = Where.from_dict({"score": {"$lte": 1.0}})
assert isinstance(where, Lte)
def test_membership_operators(self):
"""Test membership operator conversions."""
from chromadb.execution.expression.operator import Where, In, Nin
# $in
where = Where.from_dict({"status": {"$in": ["active", "pending"]}})
assert isinstance(where, In)
# $nin (not in)
where = Where.from_dict({"status": {"$nin": ["deleted", "archived"]}})
assert isinstance(where, Nin)
def test_string_operators(self):
"""Test string operator conversions."""
from chromadb.execution.expression.operator import (
Where,
Contains,
NotContains,
Regex,
NotRegex,
)
# $contains
where = Where.from_dict({"text": {"$contains": "hello"}})
assert isinstance(where, Contains)
# $not_contains
where = Where.from_dict({"text": {"$not_contains": "spam"}})
assert isinstance(where, NotContains)
# $regex
where = Where.from_dict({"text": {"$regex": "^test.*"}})
assert isinstance(where, Regex)
# $not_regex
where = Where.from_dict({"text": {"$not_regex": r"\d+"}})
assert isinstance(where, NotRegex)
def test_logical_operators(self):
"""Test logical operator conversions."""
from chromadb.execution.expression.operator import Where, And, Or
# $and
where = Where.from_dict(
{"$and": [{"status": "active"}, {"score": {"$gt": 0.5}}]}
)
assert isinstance(where, And)
# $or
where = Where.from_dict({"$or": [{"status": "active"}, {"status": "pending"}]})
assert isinstance(where, Or)
def test_nested_logical_operators(self):
"""Test nested logical operations."""
from chromadb.execution.expression.operator import Where, And
where = Where.from_dict(
{
"$and": [
{"$or": [{"status": "active"}, {"status": "pending"}]},
{"score": {"$gte": 0.5}},
]
}
)
assert isinstance(where, And)
def test_special_keys(self):
"""Test special key handling."""
from chromadb.execution.expression.operator import Where, In
# ID key
where = Where.from_dict({"#id": {"$in": ["id1", "id2"]}})
assert isinstance(where, In)
def test_invalid_where_dicts(self):
"""Test invalid Where dict inputs."""
import pytest
from chromadb.execution.expression.operator import Where
with pytest.raises(TypeError, match="Expected dict"):
Where.from_dict("not a dict")
with pytest.raises(ValueError, match="cannot be empty"):
Where.from_dict({})
with pytest.raises(ValueError, match="requires at least one condition"):
Where.from_dict({"$and": []})
class TestRankFromDict:
"""Test Rank.from_dict() conversion."""
def test_val_conversion(self):
"""Test Val conversion."""
from chromadb.execution.expression.operator import Rank, Val
rank = Rank.from_dict({"$val": 0.5})
assert isinstance(rank, Val)
assert rank.value == 0.5
def test_knn_conversion(self):
"""Test KNN conversion."""
import numpy as np
from chromadb.execution.expression.operator import Rank, Knn
# Basic KNN with defaults
rank = Rank.from_dict({"$knn": {"query": [0.1, 0.2]}})
assert isinstance(rank, Knn)
# Handle both list and numpy array cases
if isinstance(rank.query, np.ndarray):
# Use allclose for floating point comparison with dtype tolerance
assert np.allclose(rank.query, np.array([0.1, 0.2]))
else:
assert rank.query == [0.1, 0.2]
assert rank.key == "#embedding" # default
assert rank.limit == 16 # default
# KNN with custom parameters
rank = Rank.from_dict(
{
"$knn": {
"query": [0.1, 0.2],
"key": "sparse_embedding",
"limit": 256,
"return_rank": True,
}
}
)
assert rank.key == "sparse_embedding"
assert rank.limit == 256
assert rank.return_rank
def test_arithmetic_operators(self):
"""Test arithmetic operator conversions."""
from chromadb.execution.expression.operator import Rank, Sum, Sub, Mul, Div
# $sum
rank = Rank.from_dict({"$sum": [{"$val": 0.5}, {"$val": 0.3}]})
assert isinstance(rank, Sum)
# $sub
rank = Rank.from_dict({"$sub": {"left": {"$val": 1.0}, "right": {"$val": 0.3}}})
assert isinstance(rank, Sub)
# $mul
rank = Rank.from_dict({"$mul": [{"$val": 2.0}, {"$val": 0.5}]})
assert isinstance(rank, Mul)
# $div
rank = Rank.from_dict({"$div": {"left": {"$val": 1.0}, "right": {"$val": 2.0}}})
assert isinstance(rank, Div)
def test_math_functions(self):
"""Test math function conversions."""
from chromadb.execution.expression.operator import Rank, Abs, Exp, Log
# $abs
rank = Rank.from_dict({"$abs": {"$val": -0.5}})
assert isinstance(rank, Abs)
# $exp
rank = Rank.from_dict({"$exp": {"$val": 1.0}})
assert isinstance(rank, Exp)
# $log
rank = Rank.from_dict({"$log": {"$val": 2.0}})
assert isinstance(rank, Log)
def test_aggregation_functions(self):
"""Test min/max conversions."""
from chromadb.execution.expression.operator import Rank, Max, Min
# $max
rank = Rank.from_dict({"$max": [{"$val": 0.5}, {"$val": 0.8}]})
assert isinstance(rank, Max)
# $min
rank = Rank.from_dict({"$min": [{"$val": 0.5}, {"$val": 0.8}]})
assert isinstance(rank, Min)
def test_complex_rank_expression(self):
"""Test complex nested rank expressions."""
from chromadb.execution.expression.operator import Rank, Sum
rank = Rank.from_dict(
{
"$sum": [
{"$mul": [{"$knn": {"query": [0.1, 0.2]}}, {"$val": 0.8}]},
{"$mul": [{"$val": 0.5}, {"$val": 0.2}]},
]
}
)
assert isinstance(rank, Sum)
def test_invalid_rank_dicts(self):
"""Test invalid Rank dict inputs."""
import pytest
from chromadb.execution.expression.operator import Rank
with pytest.raises(TypeError, match="Expected dict"):
Rank.from_dict("not a dict")
with pytest.raises(ValueError, match="cannot be empty"):
Rank.from_dict({})
with pytest.raises(ValueError, match="exactly one operator"):
Rank.from_dict({"$val": 0.5, "$knn": {"query": [0.1]}})
with pytest.raises(TypeError, match="requires a number"):
Rank.from_dict({"$val": "not a number"})
class TestLimitFromDict:
"""Test Limit.from_dict() conversion."""
def test_limit_only(self):
"""Test limit without offset."""
from chromadb.execution.expression.operator import Limit
limit = Limit.from_dict({"limit": 20})
assert limit.limit == 20
assert limit.offset == 0 # default
def test_offset_only(self):
"""Test offset without limit."""
from chromadb.execution.expression.operator import Limit
limit = Limit.from_dict({"offset": 10})
assert limit.offset == 10
assert limit.limit is None
def test_limit_and_offset(self):
"""Test both limit and offset."""
from chromadb.execution.expression.operator import Limit
limit = Limit.from_dict({"limit": 20, "offset": 10})
assert limit.limit == 20
assert limit.offset == 10
def test_validation(self):
"""Test Limit validation."""
import pytest
from chromadb.execution.expression.operator import Limit
# Negative limit
with pytest.raises(ValueError, match="must be positive"):
Limit.from_dict({"limit": -1})
# Zero limit
with pytest.raises(ValueError, match="must be positive"):
Limit.from_dict({"limit": 0})
# Negative offset
with pytest.raises(ValueError, match="must be non-negative"):
Limit.from_dict({"offset": -1})
def test_invalid_types(self):
"""Test type validation."""
import pytest
from chromadb.execution.expression.operator import Limit
with pytest.raises(TypeError, match="Expected dict"):
Limit.from_dict("not a dict")
with pytest.raises(TypeError, match="must be an integer"):
Limit.from_dict({"limit": "20"})
with pytest.raises(TypeError, match="must be an integer"):
Limit.from_dict({"offset": 10.5})
def test_unexpected_keys(self):
"""Test rejection of unexpected keys."""
import pytest
from chromadb.execution.expression.operator import Limit
with pytest.raises(ValueError, match="Unexpected keys"):
Limit.from_dict({"limit": 10, "invalid": "key"})
class TestSelectFromDict:
"""Test Select.from_dict() conversion."""
def test_special_keys(self):
"""Test special key conversion."""
from chromadb.execution.expression.operator import Select, Key
select = Select.from_dict(
{"keys": ["#document", "#embedding", "#metadata", "#score"]}
)
assert Key.DOCUMENT in select.keys
assert Key.EMBEDDING in select.keys
assert Key.METADATA in select.keys
assert Key.SCORE in select.keys
def test_metadata_keys(self):
"""Test regular metadata field keys."""
from chromadb.execution.expression.operator import Select, Key
select = Select.from_dict({"keys": ["title", "author", "date"]})
assert Key("title") in select.keys
assert Key("author") in select.keys
assert Key("date") in select.keys
def test_mixed_keys(self):
"""Test mix of special and metadata keys."""
from chromadb.execution.expression.operator import Select, Key
select = Select.from_dict({"keys": ["#document", "title", "#score"]})
assert Key.DOCUMENT in select.keys
assert Key("title") in select.keys
assert Key.SCORE in select.keys
def test_empty_keys(self):
"""Test empty keys list."""
from chromadb.execution.expression.operator import Select
select = Select.from_dict({"keys": []})
assert len(select.keys) == 0
def test_validation(self):
"""Test Select validation."""
import pytest
from chromadb.execution.expression.operator import Select
with pytest.raises(TypeError, match="Expected dict"):
Select.from_dict("not a dict")
with pytest.raises(TypeError, match="must be a list/tuple/set"):
Select.from_dict({"keys": "not a list"})
with pytest.raises(TypeError, match="must be a string"):
Select.from_dict({"keys": [123]})
def test_unexpected_keys(self):
"""Test rejection of unexpected keys."""
import pytest
from chromadb.execution.expression.operator import Select
with pytest.raises(ValueError, match="Unexpected keys"):
Select.from_dict({"keys": [], "invalid": "key"})
class TestRoundTripConversion:
"""Test that to_dict() and from_dict() round-trip correctly."""
def test_where_round_trip(self):
"""Test Where round-trip conversion."""
from chromadb.execution.expression.operator import Where, And, Key
original = And([Key("status") == "active", Key("score") > 0.5])
dict_form = original.to_dict()
restored = Where.from_dict(dict_form)
assert restored.to_dict() == dict_form
def test_rank_round_trip(self):
"""Test Rank round-trip conversion."""
import numpy as np
from chromadb.execution.expression.operator import Rank, Knn, Val
original = Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2
dict_form = original.to_dict()
restored = Rank.from_dict(dict_form)
restored_dict = restored.to_dict()
# Compare with float32 precision tolerance for KNN queries
# The normalize_embeddings function converts to float32, causing precision differences
def compare_dicts(d1, d2):
if isinstance(d1, dict) and isinstance(d2, dict):
if "$knn" in d1 and "$knn" in d2:
# Special handling for KNN queries
knn1, knn2 = d1["$knn"], d2["$knn"]
if "query" in knn1 and "query" in knn2:
# Compare queries with float32 precision
q1 = np.array(knn1["query"], dtype=np.float32)
q2 = np.array(knn2["query"], dtype=np.float32)
if not np.allclose(q1, q2):
return False
# Compare other fields exactly
for key in knn1:
if key != "query" and knn1[key] != knn2.get(key):
return False
return True
# Recursively compare other dict structures
if set(d1.keys()) != set(d2.keys()):
return False
for key in d1:
if not compare_dicts(d1[key], d2[key]):
return False
return True
elif isinstance(d1, list) and isinstance(d2, list):
if len(d1) != len(d2):
return False
return all(compare_dicts(a, b) for a, b in zip(d1, d2))
else:
return d1 == d2
assert compare_dicts(restored_dict, dict_form)
def test_limit_round_trip(self):
"""Test Limit round-trip conversion."""
from chromadb.execution.expression.operator import Limit
original = Limit(limit=20, offset=10)
dict_form = original.to_dict()
restored = Limit.from_dict(dict_form)
assert restored.to_dict() == dict_form
def test_select_round_trip(self):
"""Test Select round-trip conversion."""
from chromadb.execution.expression.operator import Select, Key
original = Select(keys={Key.DOCUMENT, Key("title"), Key.SCORE})
dict_form = original.to_dict()
restored = Select.from_dict(dict_form)
# Note: Set order might differ, so compare sets
assert set(restored.to_dict()["keys"]) == set(dict_form["keys"])
def test_search_round_trip(self):
"""Test Search round-trip through dict inputs."""
import numpy as np
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import Key, Knn, Limit, Select
original_search = Search(
where=Key("status") == "active",
rank=Knn(query=[0.1, 0.2]),
limit=Limit(limit=10),
select=Select(keys={Key.DOCUMENT}),
)
# Convert to dict
search_dict = original_search.to_dict()
# Create new Search from dicts
new_search = Search(
where=search_dict["filter"] if search_dict["filter"] else None,
rank=search_dict["rank"] if search_dict["rank"] else None,
limit=search_dict["limit"],
select=search_dict["select"],
)
# Get new dict
new_dict = new_search.to_dict()
# Compare with float32 tolerance for KNN queries
# Use the same comparison function as test_rank_round_trip
def compare_search_dicts(d1, d2):
if isinstance(d1, dict) and isinstance(d2, dict):
# Special handling for rank field with KNN
if "rank" in d1 and "rank" in d2:
rank1, rank2 = d1["rank"], d2["rank"]
if isinstance(rank1, dict) and isinstance(rank2, dict):
if "$knn" in rank1 and "$knn" in rank2:
knn1, knn2 = rank1["$knn"], rank2["$knn"]
if "query" in knn1 and "query" in knn2:
q1 = np.array(knn1["query"], dtype=np.float32)
q2 = np.array(knn2["query"], dtype=np.float32)
if not np.allclose(q1, q2):
return False
# Compare other KNN fields
for key in knn1:
if key != "query" and knn1[key] != knn2.get(key):
return False
# Compare other fields in the dict
for key in d1:
if key != "rank" and d1[key] != d2.get(key):
return False
return True
# Normal dict comparison
if set(d1.keys()) != set(d2.keys()):
return False
for key in d1:
if isinstance(d1[key], dict) and isinstance(d2[key], dict):
if not compare_search_dicts(d1[key], d2[key]):
return False
elif d1[key] != d2[key]:
return False
return True
else:
return d1 == d2
assert compare_search_dicts(new_dict, search_dict)
def test_search_round_trip_with_group_by(self):
"""Test Search round-trip with group_by."""
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression.operator import Key, GroupBy, MinK
original = Search(
where=Key("status") == "active",
group_by=GroupBy(
keys=[Key("category")],
aggregate=MinK(keys=[Key.SCORE], k=3),
),
)
# Verify group_by round-trip
search_dict = original.to_dict()
assert search_dict["group_by"]["keys"] == ["category"]
assert search_dict["group_by"]["aggregate"] == {
"$min_k": {"keys": ["#score"], "k": 3}
}
# Reconstruct and compare group_by
restored = Search(group_by=GroupBy.from_dict(search_dict["group_by"]))
assert restored.to_dict()["group_by"] == search_dict["group_by"]