group-wbl/.venv/lib/python3.13/site-packages/chromadb/test/property/test_add.py
2026-01-09 09:48:03 +08:00

525 lines
17 KiB
Python

import uuid
from random import randint
from typing import cast, List, Any, Dict
import hypothesis
import numpy as np
import pytest
import hypothesis.strategies as st
from hypothesis import given, settings
from chromadb.api import ClientAPI
from chromadb.api.types import Embeddings, Metadatas
from chromadb.test.conftest import (
NOT_CLUSTER_ONLY,
override_hypothesis_profile,
create_isolated_database,
)
import chromadb.test.property.strategies as strategies
import chromadb.test.property.invariants as invariants
from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase
from chromadb.utils.batch_utils import create_batches
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")
@given(
collection=collection_st,
record_set=strategies.recordsets(collection_st, min_size=1, max_size=5),
)
@settings(
deadline=None,
parent=override_hypothesis_profile(
normal=hypothesis.settings(max_examples=500),
fast=hypothesis.settings(max_examples=200),
),
max_examples=2,
)
def test_add_miniscule(
client: ClientAPI,
collection: strategies.Collection,
record_set: strategies.RecordSet,
) -> None:
if (
client.get_settings().chroma_api_impl
== "chromadb.api.async_fastapi.AsyncFastAPI"
):
pytest.skip(
"TODO @jai, come back and debug why CI runners fail with async + sync"
)
_test_add(client, collection, record_set, True, always_compact=True)
# Hypothesis tends to generate smaller values so we explicitly segregate the
# the tests into tiers, Small, Medium. Hypothesis struggles to generate large
# record sets so we explicitly create a large record set without using Hypothesis
@given(
collection=collection_st,
record_set=strategies.recordsets(collection_st, min_size=1, max_size=500),
should_compact=st.booleans(),
)
@settings(
deadline=None,
parent=override_hypothesis_profile(
normal=hypothesis.settings(max_examples=500),
fast=hypothesis.settings(max_examples=200),
),
)
def test_add_small(
client: ClientAPI,
collection: strategies.Collection,
record_set: strategies.RecordSet,
should_compact: bool,
) -> None:
if (
client.get_settings().chroma_api_impl
== "chromadb.api.async_fastapi.AsyncFastAPI"
):
pytest.skip(
"TODO @jai, come back and debug why CI runners fail with async + sync"
)
_test_add(client, collection, record_set, should_compact)
@given(
collection=collection_st,
record_set=strategies.recordsets(
collection_st,
min_size=250,
max_size=500,
num_unique_metadata=5,
min_metadata_size=1,
max_metadata_size=5,
),
should_compact=st.booleans(),
)
@settings(
deadline=None,
parent=override_hypothesis_profile(
normal=hypothesis.settings(max_examples=10),
fast=hypothesis.settings(max_examples=5),
),
suppress_health_check=[
hypothesis.HealthCheck.too_slow,
hypothesis.HealthCheck.data_too_large,
hypothesis.HealthCheck.large_base_example,
hypothesis.HealthCheck.function_scoped_fixture,
],
)
def test_add_medium(
client: ClientAPI,
collection: strategies.Collection,
record_set: strategies.RecordSet,
should_compact: bool,
) -> None:
if (
client.get_settings().chroma_api_impl
== "chromadb.api.async_fastapi.AsyncFastAPI"
):
pytest.skip(
"TODO @jai, come back and debug why CI runners fail with async + sync"
)
# Cluster tests transmit their results over grpc, which has a payload limit
# This breaks the ann_accuracy invariant by default, since
# the vector reader returns a payload of dataset size. So we need to batch
# the queries in the ann_accuracy invariant
_test_add(client, collection, record_set, should_compact, batch_ann_accuracy=True)
def _test_add(
client: ClientAPI,
collection: strategies.Collection,
record_set: strategies.RecordSet,
should_compact: bool,
batch_ann_accuracy: bool = False,
always_compact: bool = False,
) -> None:
create_isolated_database(client)
# TODO: Generative embedding functions
coll = client.create_collection(
name=collection.name,
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
configuration=collection.collection_config,
)
initial_version = cast(int, coll.get_model()["version"])
normalized_record_set = invariants.wrap_all(record_set)
# TODO: The type of add() is incorrect as it does not allow for metadatas
# like [{"a": 1}, None, {"a": 3}]
for batch in create_batches(
api=client,
ids=cast(List[str], record_set["ids"]),
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=cast(Metadatas, record_set["metadatas"]),
documents=cast(List[str], record_set["documents"]),
):
coll.add(*batch)
# Only wait for compaction if the size of the collection is
# some minimal size
if (
not NOT_CLUSTER_ONLY
and should_compact
and (len(normalized_record_set["ids"]) > 10 or always_compact)
):
# Wait for the model to be updated
wait_for_version_increase(client, collection.name, initial_version)
invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))
n_results = max(1, (len(normalized_record_set["ids"]) // 10))
if batch_ann_accuracy:
batch_size = 10
for i in range(0, len(normalized_record_set["ids"]), batch_size):
invariants.ann_accuracy(
coll,
cast(strategies.RecordSet, normalized_record_set),
n_results=n_results,
embedding_function=collection.embedding_function,
query_indices=list(
range(i, min(i + batch_size, len(normalized_record_set["ids"])))
),
)
else:
invariants.ann_accuracy(
coll,
cast(strategies.RecordSet, normalized_record_set),
n_results=n_results,
embedding_function=collection.embedding_function,
)
# Hypothesis struggles to generate large record sets so we explicitly create
# a large record set
def create_large_recordset(
min_size: int = 45000,
max_size: int = 50000,
) -> strategies.RecordSet:
size = randint(min_size, max_size)
ids = [str(uuid.uuid4()) for _ in range(size)]
metadatas = [{"some_key": f"{i}"} for i in range(size)]
documents = [f"Document {i}" for i in range(size)]
embeddings = [[1, 2, 3] for _ in range(size)]
record_set: Dict[str, List[Any]] = {
"ids": ids,
"embeddings": cast(Embeddings, embeddings),
"metadatas": metadatas,
"documents": documents,
}
return cast(strategies.RecordSet, record_set)
@given(collection=collection_st, should_compact=st.booleans())
@settings(deadline=None, max_examples=5)
def test_add_large(
client: ClientAPI, collection: strategies.Collection, should_compact: bool
) -> None:
create_isolated_database(client)
if (
client.get_settings().chroma_api_impl
== "chromadb.api.async_fastapi.AsyncFastAPI"
):
pytest.skip(
"TODO @jai, come back and debug why CI runners fail with async + sync"
)
record_set = create_large_recordset(
min_size=10000,
max_size=50000,
)
coll = client.create_collection(
name=collection.name,
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)
normalized_record_set = invariants.wrap_all(record_set)
initial_version = cast(int, coll.get_model()["version"])
for batch in create_batches(
api=client,
ids=cast(List[str], record_set["ids"]),
embeddings=cast(Embeddings, record_set["embeddings"]),
metadatas=cast(Metadatas, record_set["metadatas"]),
documents=cast(List[str], record_set["documents"]),
):
coll.add(*batch)
if (
not NOT_CLUSTER_ONLY
and should_compact
and len(normalized_record_set["ids"]) > 10
):
# Wait for the model to be updated, since the record set is larger, add some additional time
wait_for_version_increase(
client, collection.name, initial_version, additional_time=240
)
invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))
@given(collection=collection_st)
@settings(deadline=None, max_examples=1)
def test_add_large_exceeding(
client: ClientAPI, collection: strategies.Collection
) -> None:
create_isolated_database(client)
if (
client.get_settings().chroma_api_impl
== "chromadb.api.async_fastapi.AsyncFastAPI"
):
pytest.skip(
"TODO @jai, come back and debug why CI runners fail with async + sync"
)
record_set = create_large_recordset(
min_size=client.get_max_batch_size(),
max_size=client.get_max_batch_size()
+ 100, # Exceed the max batch size by 100 records
)
coll = client.create_collection(
name=collection.name,
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)
with pytest.raises(Exception) as e:
coll.add(**record_set) # type: ignore[arg-type]
assert "batch size" in str(e.value)
# TODO: This test fails right now because the ids are not sorted by the input order
@pytest.mark.xfail(
reason="This is expected to fail right now. We should change the API to sort the \
ids by input order."
)
def test_out_of_order_ids(client: ClientAPI) -> None:
if (
client.get_settings().chroma_api_impl
== "chromadb.api.async_fastapi.AsyncFastAPI"
):
pytest.skip(
"TODO @jai, come back and debug why CI runners fail with async + sync"
)
ooo_ids = [
"40",
"05",
"8",
"6",
"10",
"01",
"00",
"3",
"04",
"20",
"02",
"9",
"30",
"11",
"13",
"2",
"0",
"7",
"06",
"5",
"50",
"12",
"03",
"4",
"1",
]
coll = client.create_collection(
"test",
embedding_function=lambda input: [[1, 2, 3] for _ in input], # type: ignore
)
embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
coll.add(ids=ooo_ids, embeddings=embeddings)
get_ids = coll.get(ids=ooo_ids)["ids"]
assert get_ids == ooo_ids
def test_add_partial(client: ClientAPI) -> None:
"""Tests adding a record set with some of the fields set to None."""
create_isolated_database(client)
if (
client.get_settings().chroma_api_impl
== "chromadb.api.async_fastapi.AsyncFastAPI"
):
pytest.skip(
"TODO @jai, come back and debug why CI runners fail with async + sync"
)
coll = client.create_collection("test")
# TODO: We need to clean up the api types to support this typing
coll.add(
ids=["1", "2", "3"],
# All embeddings must be provided, or else None - no partial lists allowed
embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]], # type: ignore
# Metadatas can always be partial
metadatas=[{"a": 1}, None, {"a": 3}], # type: ignore
# Documents are optional if embeddings are provided
documents=["a", "b", None], # type: ignore
)
results = coll.get()
assert results["ids"] == ["1", "2", "3"]
assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
assert results["documents"] == ["a", "b", None]
@pytest.mark.skipif(
NOT_CLUSTER_ONLY,
reason="GroupBy is only supported in distributed mode",
)
def test_search_group_by(client: ClientAPI) -> None:
"""Test GroupBy with single key, multiple keys, and multiple ranking keys."""
from chromadb.execution.expression.operator import GroupBy, MinK, Key
from chromadb.execution.expression.plan import Search
from chromadb.execution.expression import Knn
create_isolated_database(client)
coll = client.create_collection(name="test_group_by")
# Test data: 12 records across 3 categories and 2 years
# Embeddings are designed so science docs are closest to query [1,0,0,0]
ids = [
"sci_2023_1",
"sci_2023_2",
"sci_2024_1",
"sci_2024_2",
"tech_2023_1",
"tech_2023_2",
"tech_2024_1",
"tech_2024_2",
"arts_2023_1",
"arts_2023_2",
"arts_2024_1",
"arts_2024_2",
]
embeddings = cast(
Embeddings,
[
# Science - closest to [1,0,0,0]
[1.0, 0.0, 0.0, 0.0], # sci_2023_1: score ~0.0
[0.9, 0.1, 0.0, 0.0], # sci_2023_2: score ~0.141
[0.8, 0.2, 0.0, 0.0], # sci_2024_1: score ~0.283
[0.7, 0.3, 0.0, 0.0], # sci_2024_2: score ~0.424
# Tech - farther from [1,0,0,0]
[0.0, 1.0, 0.0, 0.0], # tech_2023_1: score ~1.414
[0.0, 0.9, 0.1, 0.0], # tech_2023_2: score ~1.345
[0.0, 0.8, 0.2, 0.0], # tech_2024_1: score ~1.281
[0.0, 0.7, 0.3, 0.0], # tech_2024_2: score ~1.221
# Arts - farther from [1,0,0,0]
[0.0, 0.0, 1.0, 0.0], # arts_2023_1: score ~1.414
[0.0, 0.0, 0.9, 0.1], # arts_2023_2: score ~1.345
[0.0, 0.0, 0.8, 0.2], # arts_2024_1: score ~1.281
[0.0, 0.0, 0.7, 0.3], # arts_2024_2: score ~1.221
],
)
metadatas: Metadatas = [
{"category": "science", "year": 2023, "priority": 1},
{"category": "science", "year": 2023, "priority": 2},
{"category": "science", "year": 2024, "priority": 1},
{"category": "science", "year": 2024, "priority": 3},
{"category": "tech", "year": 2023, "priority": 2},
{"category": "tech", "year": 2023, "priority": 1},
{"category": "tech", "year": 2024, "priority": 1},
{"category": "tech", "year": 2024, "priority": 2},
{"category": "arts", "year": 2023, "priority": 3},
{"category": "arts", "year": 2023, "priority": 1},
{"category": "arts", "year": 2024, "priority": 2},
{"category": "arts", "year": 2024, "priority": 1},
]
documents = [f"doc_{id}" for id in ids]
coll.add(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
)
query = [1.0, 0.0, 0.0, 0.0]
# Test 1: Single key grouping - top 2 per category by score
# Expected: 2 best from each category (science, tech, arts)
# - science: sci_2023_1 (0.0), sci_2023_2 (0.141)
# - tech: tech_2024_2 (1.221), tech_2024_1 (1.281)
# - arts: arts_2024_2 (1.221), arts_2024_1 (1.281)
results1 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(GroupBy(keys=Key("category"), aggregate=MinK(keys=Key.SCORE, k=2)))
.limit(12)
)
assert results1["ids"] is not None
result1_ids = results1["ids"][0]
assert len(result1_ids) == 6
expected1 = {
"sci_2023_1",
"sci_2023_2",
"tech_2024_2",
"tech_2024_1",
"arts_2024_2",
"arts_2024_1",
}
assert set(result1_ids) == expected1
# Test 2: Multiple key grouping - top 1 per (category, year) combination
# 6 groups: (science,2023), (science,2024), (tech,2023), (tech,2024), (arts,2023), (arts,2024)
results2 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=[Key("category"), Key("year")],
aggregate=MinK(keys=Key.SCORE, k=1),
)
)
.limit(12)
)
assert results2["ids"] is not None
result2_ids = results2["ids"][0]
assert len(result2_ids) == 6
expected2 = {
"sci_2023_1",
"sci_2024_1",
"tech_2023_2",
"tech_2024_2",
"arts_2023_2",
"arts_2024_2",
}
assert set(result2_ids) == expected2
# Test 3: Multiple ranking keys - priority first, then score as tiebreaker
# Top 2 per category, sorted by priority (ascending), then score (ascending)
results3 = coll.search(
Search()
.rank(Knn(query=query, limit=12))
.group_by(
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=2),
)
)
.limit(12)
)
assert results3["ids"] is not None
result3_ids = results3["ids"][0]
assert len(result3_ids) == 6
expected3 = {
"sci_2023_1",
"sci_2024_1",
"tech_2024_1",
"tech_2023_2",
"arts_2024_2",
"arts_2023_2",
}
assert set(result3_ids) == expected3