group-wbl/.venv/lib/python3.13/site-packages/chromadb/test/property/strategies.py
2026-01-09 09:48:03 +08:00

1101 lines
38 KiB
Python

import hashlib
import hypothesis
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast, Tuple
from typing_extensions import TypedDict
import uuid
import numpy as np
import numpy.typing as npt
import chromadb.api.types as types
import re
from hypothesis.strategies._internal.strategies import SearchStrategy
from chromadb.test.api.test_schema_e2e import (
SimpleEmbeddingFunction,
DeterministicSparseEmbeddingFunction,
)
from chromadb.test.conftest import NOT_CLUSTER_ONLY
from dataclasses import dataclass
from chromadb.api.types import (
Documents,
Embeddable,
EmbeddingFunction,
Embeddings,
Metadata,
Schema,
CollectionMetadata,
VectorIndexConfig,
SparseVectorIndexConfig,
StringInvertedIndexConfig,
IntInvertedIndexConfig,
FloatInvertedIndexConfig,
BoolInvertedIndexConfig,
HnswIndexConfig,
SpannIndexConfig,
Space,
)
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator
from chromadb.test.conftest import is_spann_disabled_mode
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
CreateSpannConfiguration,
CreateHNSWConfiguration,
)
from chromadb.utils.embedding_functions import (
register_embedding_function,
)
# Set the random seed for reproducibility
np.random.seed(0) # unnecessary, hypothesis does this for us
# See Hypothesis documentation for creating strategies at
# https://hypothesis.readthedocs.io/en/latest/data.html
# NOTE: Because these strategies are used in state machines, we need to
# work around an issue with state machines, in which strategies that frequently
# are marked as invalid (i.e. through the use of `assume` or `.filter`) can cause the
# state machine tests to fail with an hypothesis.errors.Unsatisfiable.
# Ultimately this is because the entire state machine is run as a single Hypothesis
# example, which ends up drawing from the same strategies an enormous number of times.
# Whenever a strategy marks itself as invalid, Hypothesis tries to start the entire
# state machine run over. See https://github.com/HypothesisWorks/hypothesis/issues/3618
# Because strategy generation is all interrelated, seemingly small changes (especially
# ones called early in a test) can have an outside effect. Generating lists with
# unique=True, or dictionaries with a min size seems especially bad.
# Please make changes to these strategies incrementally, testing to make sure they don't
# start generating unsatisfiable examples.
test_hnsw_config = {
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 128,
}
class RecordSet(TypedDict):
"""
A generated set of embeddings, ids, metadatas, and documents that
represent what a user would pass to the API.
"""
ids: Union[types.ID, List[types.ID]]
embeddings: Optional[Union[types.Embeddings, types.Embedding]]
metadatas: Optional[Union[List[Optional[types.Metadata]], types.Metadata]]
documents: Optional[Union[List[types.Document], types.Document]]
class NormalizedRecordSet(TypedDict):
"""
A RecordSet, with all fields normalized to lists.
"""
ids: List[types.ID]
embeddings: Optional[types.Embeddings]
metadatas: Optional[List[Optional[types.Metadata]]]
documents: Optional[List[types.Document]]
class StateMachineRecordSet(TypedDict):
"""
Represents the internal state of a state machine in hypothesis tests.
"""
ids: List[types.ID]
embeddings: types.Embeddings
metadatas: List[Optional[types.Metadata]]
documents: List[Optional[types.Document]]
class Record(TypedDict):
"""
A single generated record.
"""
id: types.ID
embedding: Optional[types.Embedding]
metadata: Optional[types.Metadata]
document: Optional[types.Document]
# TODO: support arbitrary text everywhere so we don't SQL-inject ourselves.
# TODO: support empty strings everywhere
sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"
safe_text = st.text(alphabet=sql_alphabet, min_size=1)
sql_alphabet_minus_underscore = (
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
)
safe_text_min_size_3 = st.text(alphabet=sql_alphabet_minus_underscore, min_size=3)
tenant_database_name = st.text(alphabet=sql_alphabet, min_size=3)
# Workaround for FastAPI json encoding peculiarities
# https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104
safe_text = safe_text.filter(lambda s: not s.startswith("_sa"))
safe_text_min_size_3 = safe_text_min_size_3.filter(lambda s: not s.startswith("_sa"))
tenant_database_name = tenant_database_name.filter(lambda s: not s.startswith("_sa"))
safe_integers = st.integers(
min_value=-(2**31), max_value=2**31 - 1
) # TODO: handle longs
# In distributed chroma, floats are 32 bit hence we need to
# restrict the generation to generate only 32 bit floats.
safe_floats = st.floats(
allow_infinity=False,
allow_nan=False,
allow_subnormal=False,
width=32,
min_value=-1e6,
max_value=1e6,
) # TODO: handle infinity and NAN
safe_values: List[SearchStrategy[Union[int, float, str, bool]]] = [
safe_text,
safe_integers,
safe_floats,
st.booleans(),
]
def one_or_both(
strategy_a: st.SearchStrategy[Any], strategy_b: st.SearchStrategy[Any]
) -> st.SearchStrategy[Any]:
return st.one_of(
st.tuples(strategy_a, strategy_b),
st.tuples(strategy_a, st.none()),
st.tuples(st.none(), strategy_b),
)
# Temporarily generate only these to avoid SQL formatting issues.
legal_id_characters = (
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./+"
)
float_types = [np.float16, np.float32, np.float64]
int_types = [np.int16, np.int32, np.int64] # TODO: handle int types
@st.composite
def collection_name(draw: st.DrawFn) -> str:
_collection_name_re = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,60}[a-zA-Z0-9]$")
_ipv4_address_re = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}$")
_two_periods_re = re.compile(r"\.\.")
name: str = draw(st.from_regex(_collection_name_re)).strip()
hypothesis.assume(not _ipv4_address_re.match(name))
hypothesis.assume(not _two_periods_re.search(name))
return name
collection_metadata = st.one_of(
st.none(), st.dictionaries(safe_text, st.one_of(*safe_values))
)
# TODO: Use a hypothesis strategy while maintaining embedding uniqueness
# Or handle duplicate embeddings within a known epsilon
def create_embeddings(
dim: int,
count: int,
dtype: npt.DTypeLike,
) -> types.Embeddings:
embeddings: types.Embeddings = cast(
types.Embeddings,
(
np.random.uniform(
low=-1.0,
high=1.0,
size=(count, dim),
)
.astype(dtype)
.tolist()
),
)
return embeddings
def create_embeddings_ndarray(
dim: int,
count: int,
dtype: npt.DTypeLike,
) -> np.typing.NDArray[Any]:
return np.random.uniform(
low=-1.0,
high=1.0,
size=(count, dim),
).astype(dtype)
class hashing_embedding_function(types.EmbeddingFunction[Documents]):
def __init__(self, dim: int, dtype: npt.DTypeLike) -> None:
self.dim = dim
self.dtype = dtype
def __call__(self, input: types.Documents) -> types.Embeddings:
# Hash the texts and convert to hex strings
hashed_texts = [
list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in input
]
# Pad with repetition, or truncate the hex strings to the desired dimension
padded_texts = [
text * (self.dim // len(text)) + text[: self.dim % len(text)]
for text in hashed_texts
]
# Convert the hex strings to dtype
embeddings: types.Embeddings = [
np.array([int(char, 16) / 15.0 for char in text], dtype=self.dtype)
for text in padded_texts
]
return embeddings
def __repr__(self) -> str:
return f"hashing_embedding_function(dim={self.dim}, dtype={self.dtype})"
class not_implemented_embedding_function(types.EmbeddingFunction[Documents]):
def __call__(self, input: Documents) -> Embeddings:
assert False, "This embedding function is not implemented"
def embedding_function_strategy(
dim: int, dtype: npt.DTypeLike
) -> st.SearchStrategy[types.EmbeddingFunction[Embeddable]]:
return st.just(
cast(EmbeddingFunction[Embeddable], hashing_embedding_function(dim, dtype))
)
@dataclass
class ExternalCollection:
"""
An external view of a collection.
This strategy only contains information about a collection that a client of Chroma
sees -- that is, it contains none of Chroma's internal bookkeeping. It should
be used to test the API and client code.
"""
name: str
metadata: Optional[types.Metadata]
embedding_function: Optional[types.EmbeddingFunction[Embeddable]]
@register_embedding_function
class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction):
"""Simple embedding function with ip space for persistence tests."""
def default_space(self) -> str: # type: ignore[override]
return "ip"
@st.composite
def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig:
"""Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann."""
space = None
embedding_function = None
source_key = None
hnsw = None
spann = None
if draw(st.booleans()):
space = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
embedding_function = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
if draw(st.booleans()):
source_key = draw(st.one_of(st.just("#document"), safe_text))
index_choice = draw(st.sampled_from(["hnsw", "spann", "none"]))
if index_choice == "hnsw":
hnsw = HnswIndexConfig(
ef_construction=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans())
else None,
sync_threshold=draw(st.integers(min_value=2, max_value=10000))
if draw(st.booleans())
else None,
resize_factor=draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans())
else None,
)
elif index_choice == "spann":
spann = SpannIndexConfig(
search_nprobe=draw(st.integers(min_value=1, max_value=128))
if draw(st.booleans())
else None,
write_nprobe=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
ef_construction=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
ef_search=draw(st.integers(min_value=1, max_value=200))
if draw(st.booleans())
else None,
max_neighbors=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64))
if draw(st.booleans())
else None,
split_threshold=draw(st.integers(min_value=50, max_value=200))
if draw(st.booleans())
else None,
merge_threshold=draw(st.integers(min_value=25, max_value=100))
if draw(st.booleans())
else None,
)
return VectorIndexConfig(
space=cast(Space, space),
embedding_function=embedding_function,
source_key=source_key,
hnsw=hnsw,
spann=spann,
)
@st.composite
def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig:
"""Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25."""
embedding_function = None
source_key = None
bm25 = None
if draw(st.booleans()):
embedding_function = DeterministicSparseEmbeddingFunction()
source_key = draw(st.one_of(st.just("#document"), safe_text))
if draw(st.booleans()):
bm25 = draw(st.booleans())
return SparseVectorIndexConfig(
embedding_function=embedding_function,
source_key=source_key,
bm25=bm25,
)
@st.composite
def schema_strategy(draw: st.DrawFn) -> Optional[Schema]:
"""Generate a Schema object with various create_index/delete_index operations."""
if draw(st.booleans()):
return None
schema = Schema()
# Decide how many operations to perform
num_operations = draw(st.integers(min_value=0, max_value=5))
sparse_index_created = False
for _ in range(num_operations):
operation = draw(st.sampled_from(["create_index", "delete_index"]))
config_type = draw(
st.sampled_from(
[
"string_inverted",
"int_inverted",
"float_inverted",
"bool_inverted",
"vector",
"sparse_vector",
]
)
)
# Decide if we're setting on a key or globally
use_key = draw(st.booleans())
key = None
if use_key and config_type != "vector":
# Vector indexes can't be set on specific keys, only globally
key = draw(safe_text)
if operation == "create_index":
if config_type == "string_inverted":
schema.create_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.create_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.create_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.create_index(config=BoolInvertedIndexConfig(), key=key)
elif config_type == "vector":
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
elif (
config_type == "sparse_vector"
and not is_spann_disabled_mode
and not sparse_index_created
):
sparse_config = draw(sparse_vector_index_config_strategy())
# Sparse vector MUST have a key
if key is None:
key = draw(safe_text)
schema.create_index(config=sparse_config, key=key)
sparse_index_created = True
elif operation == "delete_index":
if config_type == "string_inverted":
schema.delete_index(config=StringInvertedIndexConfig(), key=key)
elif config_type == "int_inverted":
schema.delete_index(config=IntInvertedIndexConfig(), key=key)
elif config_type == "float_inverted":
schema.delete_index(config=FloatInvertedIndexConfig(), key=key)
elif config_type == "bool_inverted":
schema.delete_index(config=BoolInvertedIndexConfig(), key=key)
# Vector, FTS, and sparse_vector deletion is not currently supported
return schema
@st.composite
def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]:
"""Generate metadata with hnsw parameters."""
metadata: CollectionMetadata = {}
if draw(st.booleans()):
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
if draw(st.booleans()):
metadata["hnsw:construction_ef"] = draw(
st.integers(min_value=1, max_value=1000)
)
if draw(st.booleans()):
metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000))
if draw(st.booleans()):
metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
if draw(st.booleans()):
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=2, max_value=10000)
)
return metadata if metadata else None
@st.composite
def create_configuration_strategy(
draw: st.DrawFn,
) -> Optional[CreateCollectionConfiguration]:
"""Generate CreateCollectionConfiguration with mutual exclusivity rules."""
configuration: CreateCollectionConfiguration = {}
# Optionally set embedding_function (independent)
if draw(st.booleans()):
configuration["embedding_function"] = SimpleIpEmbeddingFunction(
dim=draw(st.integers(min_value=1, max_value=1000))
)
# Decide: set space only, OR set hnsw config, OR set spann config
config_choice = draw(
st.sampled_from(
["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"]
)
)
if config_choice == "space_only_hnsw":
configuration["hnsw"] = CreateHNSWConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "space_only_spann":
configuration["spann"] = CreateSpannConfiguration(
space=draw(st.sampled_from(["cosine", "l2", "ip"]))
)
elif config_choice == "hnsw":
# Set hnsw config (optionally with space)
hnsw_config: CreateHNSWConfiguration = {}
if draw(st.booleans()):
hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000))
hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000))
hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0))
configuration["hnsw"] = hnsw_config
elif config_choice == "spann":
# Set spann config (optionally with space)
spann_config: CreateSpannConfiguration = {}
if draw(st.booleans()):
spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128))
spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64))
spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200))
spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200))
spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64))
spann_config["reassign_neighbor_count"] = draw(
st.integers(min_value=1, max_value=64)
)
spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200))
spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100))
configuration["spann"] = spann_config
return configuration if configuration else None
@dataclass
class CollectionInputCombination:
"""
Input tuple for collection creation tests.
"""
metadata: Optional[CollectionMetadata]
configuration: Optional[CreateCollectionConfiguration]
schema: Optional[Schema]
schema_vector_info: Optional[Dict[str, Any]]
kind: str
def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in items.items() if v is not None}
def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]:
embedding_default_space: Optional[str] = None
if config.embedding_function is not None and hasattr(
config.embedding_function, "default_space"
):
embedding_default_space = cast(str, config.embedding_function.default_space())
return {
"space": config.space,
"hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None,
"spann": config.spann.model_dump(exclude_none=True) if config.spann else None,
"embedding_function_default_space": embedding_default_space,
}
@st.composite
def _schema_input_strategy(
draw: st.DrawFn,
) -> Tuple[Schema, Dict[str, Any]]:
schema = Schema()
vector_config = draw(vector_index_config_strategy())
schema.create_index(config=vector_config, key=None)
return schema, vector_index_to_dict(vector_config)
@st.composite
def metadata_configuration_schema_strategy(
draw: st.DrawFn,
) -> CollectionInputCombination:
"""
Generate compatible combinations of metadata, configuration, and schema inputs.
"""
choice = draw(
st.sampled_from(
[
"none",
"metadata",
"configuration",
"metadata_configuration",
"schema",
]
)
)
metadata: Optional[CollectionMetadata] = None
configuration: Optional[CreateCollectionConfiguration] = None
schema: Optional[Schema] = None
schema_info: Optional[Dict[str, Any]] = None
if choice in ("metadata", "metadata_configuration"):
metadata = draw(
metadata_with_hnsw_strategy().filter(
lambda value: value is not None and len(value) > 0
)
)
if choice in ("configuration", "metadata_configuration"):
configuration = draw(
create_configuration_strategy().filter(
lambda value: value is not None
and (
(value.get("hnsw") is not None and len(value["hnsw"]) > 0)
or (value.get("spann") is not None and len(value["spann"]) > 0)
)
)
)
if choice == "schema":
schema, schema_info = draw(_schema_input_strategy())
return CollectionInputCombination(
metadata=metadata,
configuration=configuration,
schema=schema,
schema_vector_info=schema_info,
kind=choice,
)
@dataclass
class Collection(ExternalCollection):
"""
An internal view of a collection.
This strategy contains all the information Chroma uses internally to manage a
collection. It is a superset of ExternalCollection and should be used to test
internal Chroma logic.
"""
id: uuid.UUID
dimension: int
dtype: npt.DTypeLike
known_metadata_keys: types.Metadata
known_document_keywords: List[str]
has_documents: bool = False
has_embeddings: bool = False
collection_config: Optional[CreateCollectionConfiguration] = None
@st.composite
def collections(
draw: st.DrawFn,
add_filterable_data: bool = False,
with_hnsw_params: bool = False,
has_embeddings: Optional[bool] = None,
has_documents: Optional[bool] = None,
with_persistent_hnsw_params: st.SearchStrategy[bool] = st.just(False),
max_hnsw_batch_size: int = 2000,
max_hnsw_sync_threshold: int = 2000,
) -> Collection:
"""Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data."""
assert not ((has_embeddings is False) and (has_documents is False))
name = draw(collection_name())
metadata = draw(collection_metadata)
dimension = draw(st.integers(min_value=2, max_value=2048))
dtype = draw(st.sampled_from(float_types))
use_persistent_hnsw_params = draw(with_persistent_hnsw_params)
if use_persistent_hnsw_params and not with_hnsw_params:
raise ValueError(
"with_persistent_hnsw_params requires with_hnsw_params to be true"
)
if with_hnsw_params:
if metadata is None:
metadata = {}
metadata.update(test_hnsw_config)
if use_persistent_hnsw_params:
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=3, max_value=max_hnsw_sync_threshold)
)
metadata["hnsw:batch_size"] = draw(
st.integers(
min_value=3,
max_value=min(
[metadata["hnsw:sync_threshold"], max_hnsw_batch_size]
),
)
)
# Sometimes, select a space at random
if draw(st.booleans()):
# TODO: pull the distance functions from a source of truth that lives not
# in tests once https://github.com/chroma-core/issues/issues/61 lands
metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))
collection_config: Optional[CreateCollectionConfiguration] = None
# Generate a spann config if in spann mode
if not is_spann_disabled_mode:
# Use metadata["hnsw:space"] if it exists, otherwise default to "l2"
spann_space = metadata.get("hnsw:space", "l2") if metadata else "l2"
spann_config: CreateSpannConfiguration = {
"space": spann_space,
"write_nprobe": 4,
"reassign_neighbor_count": 4,
}
collection_config = {
"spann": spann_config,
}
known_metadata_keys: Dict[str, Union[int, str, float]] = {}
if add_filterable_data:
while len(known_metadata_keys) < 5:
key = draw(safe_text)
known_metadata_keys[key] = draw(st.one_of(*safe_values))
if has_documents is None:
has_documents = draw(st.booleans())
assert has_documents is not None
# For cluster tests, we want to avoid generating documents and where_document
# clauses of length < 3. We also don't want them to contain certan special
# characters like _ and % that implicitly involve searching for a regex in sqlite.
if not NOT_CLUSTER_ONLY:
if has_documents and add_filterable_data:
known_document_keywords = draw(
st.lists(safe_text_min_size_3, min_size=5, max_size=5)
)
else:
known_document_keywords = []
else:
if has_documents and add_filterable_data:
known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5))
else:
known_document_keywords = []
if not has_documents:
has_embeddings = True
else:
if has_embeddings is None:
has_embeddings = draw(st.booleans())
assert has_embeddings is not None
embedding_function = draw(embedding_function_strategy(dimension, dtype))
return Collection(
id=uuid.uuid4(),
name=name,
metadata=metadata,
dimension=dimension,
dtype=dtype,
known_metadata_keys=known_metadata_keys,
has_documents=has_documents,
known_document_keywords=known_document_keywords,
has_embeddings=has_embeddings,
embedding_function=embedding_function,
collection_config=collection_config,
)
@st.composite
def metadata(
draw: st.DrawFn,
collection: Collection,
min_size: int = 0,
max_size: Optional[int] = None,
) -> Optional[types.Metadata]:
"""Strategy for generating metadata that could be a part of the given collection"""
# First draw a random dictionary.
metadata: types.Metadata = draw(
st.dictionaries(
safe_text, st.one_of(*safe_values), min_size=min_size, max_size=max_size
)
)
# Then, remove keys that overlap with the known keys for the coll
# to avoid type errors when comparing.
if collection.known_metadata_keys:
for key in collection.known_metadata_keys.keys():
if key in metadata:
del metadata[key] # type: ignore
# Finally, add in some of the known keys for the collection
sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
k: st.just(v)
for k, v in collection.known_metadata_keys.items()
if isinstance(v, (str, int, float))
}
metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore
# We don't allow submitting empty metadata
if metadata == {}:
return None
return metadata
@st.composite
def document(draw: st.DrawFn, collection: Collection) -> types.Document:
"""Strategy for generating documents that could be a part of the given collection"""
# For cluster tests, we want to avoid generating documents of length < 3.
# We also don't want them to contain certan special
# characters like _ and % that implicitly involve searching for a regex in sqlite.
if not NOT_CLUSTER_ONLY:
# Blacklist certain unicode characters that affect sqlite processing.
# For example, the null (/x00) character makes sqlite stop processing a string.
# Also, blacklist _ and % for cluster tests.
blacklist_categories = ("Cc", "Cs", "Pc", "Po")
if collection.known_document_keywords:
known_words_st = st.sampled_from(collection.known_document_keywords)
else:
known_words_st = st.text(
min_size=3,
alphabet=st.characters(blacklist_categories=blacklist_categories), # type: ignore
)
random_words_st = st.text(
min_size=3, alphabet=st.characters(blacklist_categories=blacklist_categories) # type: ignore
)
words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1))
return " ".join(words)
# Blacklist certain unicode characters that affect sqlite processing.
# For example, the null (/x00) character makes sqlite stop processing a string.
blacklist_categories = ("Cc", "Cs") # type: ignore
if collection.known_document_keywords:
known_words_st = st.sampled_from(collection.known_document_keywords)
else:
known_words_st = st.text(
min_size=1,
alphabet=st.characters(blacklist_categories=blacklist_categories), # type: ignore
)
random_words_st = st.text(
min_size=1, alphabet=st.characters(blacklist_categories=blacklist_categories) # type: ignore
)
words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1))
return " ".join(words)
@st.composite
def recordsets(
draw: st.DrawFn,
collection_strategy: SearchStrategy[Collection] = collections(),
id_strategy: SearchStrategy[str] = safe_text,
min_size: int = 1,
max_size: int = 50,
# If num_unique_metadata is not None, then the number of metadata generations
# will be the size of the record set. If set, the number of metadata
# generations will be the value of num_unique_metadata.
num_unique_metadata: Optional[int] = None,
min_metadata_size: int = 0,
max_metadata_size: Optional[int] = None,
) -> RecordSet:
collection = draw(collection_strategy)
ids = list(
draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True))
)
embeddings: Optional[Embeddings] = None
if collection.has_embeddings:
embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype)
num_metadata = num_unique_metadata if num_unique_metadata is not None else len(ids)
generated_metadatas = draw(
st.lists(
metadata(
collection, min_size=min_metadata_size, max_size=max_metadata_size
),
min_size=num_metadata,
max_size=num_metadata,
)
)
metadatas = []
for i in range(len(ids)):
metadatas.append(generated_metadatas[i % len(generated_metadatas)])
documents: Optional[Documents] = None
if collection.has_documents:
documents = draw(
st.lists(document(collection), min_size=len(ids), max_size=len(ids))
)
# in the case where we have a single record, sometimes exercise
# the code that handles individual values rather than lists.
# In this case, any field may be a list or a single value.
if len(ids) == 1:
single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids
single_embedding = (
embeddings[0]
if embeddings is not None and draw(st.booleans())
else embeddings
)
single_metadata: Union[Optional[Metadata], List[Optional[Metadata]]] = (
metadatas[0] if draw(st.booleans()) else metadatas
)
single_document = (
documents[0] if documents is not None and draw(st.booleans()) else documents
)
return {
"ids": single_id,
"embeddings": single_embedding,
"metadatas": single_metadata,
"documents": single_document,
}
return {
"ids": ids,
"embeddings": embeddings,
"metadatas": metadatas,
"documents": documents,
}
def opposite_value(value: LiteralValue) -> SearchStrategy[Any]:
"""
Returns a strategy that will generate all valid values except the input value - testing of $nin
"""
if isinstance(value, float):
return safe_floats.filter(lambda x: x != value)
elif isinstance(value, str):
return safe_text.filter(lambda x: x != value)
elif isinstance(value, bool):
return st.booleans().filter(lambda x: x != value)
elif isinstance(value, int):
return st.integers(min_value=-(2**31), max_value=2**31 - 1).filter(
lambda x: x != value
)
else:
return st.from_type(type(value)).filter(lambda x: x != value)
@st.composite
def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
"""Generate a filter that could be used in a query against the given collection"""
known_keys = sorted(collection.known_metadata_keys.keys())
key = draw(st.sampled_from(known_keys))
value = collection.known_metadata_keys[key]
legal_ops: List[Optional[str]] = [None]
if isinstance(value, bool):
legal_ops.extend(["$eq", "$ne", "$in", "$nin"])
elif isinstance(value, float):
legal_ops.extend(["$gt", "$lt", "$lte", "$gte"])
elif isinstance(value, int):
legal_ops.extend(["$gt", "$lt", "$lte", "$gte", "$eq", "$ne", "$in", "$nin"])
elif isinstance(value, str):
legal_ops.extend(["$eq", "$ne", "$in", "$nin"])
else:
assert False, f"Unsupported type: {type(value)}"
if isinstance(value, float):
# Add or subtract a small number to avoid floating point rounding errors
value = value + draw(st.sampled_from([1e-6, -1e-6]))
# Truncate to 32 bit
value = float(np.float32(value))
op: WhereOperator = draw(st.sampled_from(legal_ops))
if op is None:
return {key: value}
elif op == "$in": # type: ignore
if isinstance(value, str) and not value:
return {}
return {key: {op: [value, *[draw(opposite_value(value)) for _ in range(3)]]}}
elif op == "$nin": # type: ignore
if isinstance(value, str) and not value:
return {}
return {key: {op: [draw(opposite_value(value)) for _ in range(3)]}}
else:
return {key: {op: value}} # type: ignore
@st.composite
def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocument:
"""Generate a where_document filter that could be used against the given collection"""
# For cluster tests, we want to avoid generating where_document
# clauses of length < 3. We also don't want them to contain certan special
# characters like _ and % that implicitly involve searching for a regex in sqlite.
if not NOT_CLUSTER_ONLY:
if collection.known_document_keywords:
word = draw(st.sampled_from(collection.known_document_keywords))
else:
word = draw(safe_text_min_size_3)
else:
if collection.known_document_keywords:
word = draw(st.sampled_from(collection.known_document_keywords))
else:
word = draw(safe_text)
# This is hacky, but the distributed system does not support $not_contains
# so we need to avoid generating these operators for now in that case.
# TODO: Remove this once the distributed system supports $not_contains
op = draw(st.sampled_from(["$contains", "$not_contains"]))
if op == "$contains":
return {"$contains": word}
else:
assert op == "$not_contains"
return {"$not_contains": word}
def binary_operator_clause(
base_st: SearchStrategy[types.Where],
) -> SearchStrategy[types.Where]:
op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"])
return st.dictionaries(
keys=op,
values=st.lists(base_st, max_size=2, min_size=2),
min_size=1,
max_size=1,
)
def binary_document_operator_clause(
base_st: SearchStrategy[types.WhereDocument],
) -> SearchStrategy[types.WhereDocument]:
op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"])
return st.dictionaries(
keys=op,
values=st.lists(base_st, max_size=2, min_size=2),
min_size=1,
max_size=1,
)
@st.composite
def recursive_where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
base_st = where_clause(collection)
where: types.Where = draw(st.recursive(base_st, binary_operator_clause))
return where
@st.composite
def recursive_where_doc_clause(
draw: st.DrawFn, collection: Collection
) -> types.WhereDocument:
base_st = where_doc_clause(collection)
where: types.WhereDocument = draw(
st.recursive(base_st, binary_document_operator_clause)
)
return where
class Filter(TypedDict):
where: Optional[types.Where]
ids: Optional[Union[str, List[str]]]
where_document: Optional[types.WhereDocument]
@st.composite
def filters(
draw: st.DrawFn,
collection_st: st.SearchStrategy[Collection],
recordset_st: st.SearchStrategy[RecordSet],
include_all_ids: bool = False,
) -> Filter:
collection = draw(collection_st)
recordset = draw(recordset_st)
where_clause = draw(st.one_of(st.none(), recursive_where_clause(collection)))
where_document_clause = draw(
st.one_of(st.none(), recursive_where_doc_clause(collection))
)
ids: Optional[Union[List[types.ID], types.ID]]
# Record sets can be a value instead of a list of values if there is only one record
if isinstance(recordset["ids"], str):
ids = [recordset["ids"]]
else:
ids = recordset["ids"]
if not include_all_ids:
ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids), min_size=1)))
if ids is not None:
# Remove duplicates since hypothesis samples with replacement
ids = list(set(ids))
# Test both the single value list and the unwrapped single value case
if ids is not None and len(ids) == 1 and draw(st.booleans()):
ids = ids[0]
return {"where": where_clause, "where_document": where_document_clause, "ids": ids}