group-wbl/.venv/lib/python3.13/site-packages/chromadb/api/models/CollectionCommon.py
2026-01-09 09:48:03 +08:00

1028 lines
35 KiB
Python

import functools
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generic,
Optional,
Any,
Set,
TypeVar,
Union,
cast,
List,
)
from chromadb.types import Metadata
import numpy as np
from uuid import UUID
from chromadb.api.types import (
URI,
Schema,
SparseVectorIndexConfig,
URIs,
AddRequest,
BaseRecordSet,
CollectionMetadata,
DataLoader,
DeleteRequest,
Embedding,
Embeddings,
FilterSet,
GetRequest,
PyEmbedding,
Embeddable,
GetResult,
Include,
Loadable,
Document,
Image,
QueryRequest,
QueryResult,
IDs,
EmbeddingFunction,
SparseEmbeddingFunction,
ID,
OneOrMany,
UpdateRequest,
UpsertRequest,
get_default_embeddable_record_set_fields,
maybe_cast_one_to_many,
normalize_base_record_set,
normalize_insert_record_set,
validate_base_record_set,
validate_ids,
validate_include,
validate_insert_record_set,
validate_metadata,
validate_metadatas,
validate_embedding_function,
validate_sparse_embedding_function,
validate_n_results,
validate_record_set_contains_any,
validate_record_set_for_embedding,
validate_filter_set,
DefaultEmbeddingFunction,
EMBEDDING_KEY,
DOCUMENT_KEY,
)
from chromadb.api.collection_configuration import (
UpdateCollectionConfiguration,
overwrite_collection_configuration,
load_collection_configuration_from_json,
CollectionConfiguration,
)
# TODO: We should rename the types in chromadb.types to be Models where
# appropriate. This will help to distinguish between manipulation objects
# which are essentially API views. And the actual data models which are
# stored / retrieved / transmitted.
from chromadb.types import Collection as CollectionModel, Where, WhereDocument
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from chromadb.api import ServerAPI, AsyncServerAPI
ClientT = TypeVar("ClientT", "ServerAPI", "AsyncServerAPI")
T = TypeVar("T")
def validation_context(name: str) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""A decorator that wraps a method with a try-except block that catches
exceptions and adds the method name to the error message. This allows us to
provide more context when an error occurs, without rewriting validators.
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
try:
return func(self, *args, **kwargs)
except Exception as e:
msg = f"{str(e)} in {name}."
# add the rest of the args to the error message if they exist
e.args = (msg,) + e.args[1:] if e.args else ()
# raise the same error that was caught with the modified message
raise
return wrapper
return decorator
class CollectionCommon(Generic[ClientT]):
_model: CollectionModel
_client: ClientT
_embedding_function: Optional[EmbeddingFunction[Embeddable]]
_data_loader: Optional[DataLoader[Loadable]]
def __init__(
self,
client: ClientT,
model: CollectionModel,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
):
"""Initializes a new instance of the Collection class."""
self._client = client
self._model = model
# Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol
if embedding_function is not None:
validate_embedding_function(embedding_function)
self._embedding_function = embedding_function
self._data_loader = data_loader
# Expose the model properties as read-only properties on the Collection class
@property
def id(self) -> UUID:
return self._model.id
@property
def name(self) -> str:
return self._model.name
@property
def configuration(self) -> CollectionConfiguration:
return load_collection_configuration_from_json(self._model.configuration_json)
@property
def configuration_json(self) -> Dict[str, Any]:
return self._model.configuration_json
@property
def schema(self) -> Optional[Schema]:
return Schema.deserialize_from_json(
self._model.serialized_schema if self._model.serialized_schema else {}
)
@property
def metadata(self) -> CollectionMetadata:
return cast(CollectionMetadata, self._model.metadata)
@property
def tenant(self) -> str:
return self._model.tenant
@property
def database(self) -> str:
return self._model.database
def __eq__(self, other: object) -> bool:
if not isinstance(other, CollectionCommon):
return False
id_match = self.id == other.id
name_match = self.name == other.name
configuration_match = self.configuration_json == other.configuration_json
schema_match = self.schema == other.schema
metadata_match = self.metadata == other.metadata
tenant_match = self.tenant == other.tenant
database_match = self.database == other.database
embedding_function_match = self._embedding_function == other._embedding_function
data_loader_match = self._data_loader == other._data_loader
return (
id_match
and name_match
and configuration_match
and schema_match
and metadata_match
and tenant_match
and database_match
and embedding_function_match
and data_loader_match
)
def __repr__(self) -> str:
return f"Collection(name={self.name})"
def get_model(self) -> CollectionModel:
return self._model
@validation_context("add")
def _validate_and_prepare_add_request(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]],
uris: Optional[OneOrMany[URI]],
) -> AddRequest:
# Unpack
add_records = normalize_insert_record_set(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
# Validate
validate_insert_record_set(record_set=add_records)
validate_record_set_contains_any(record_set=add_records, contains_any={"ids"})
# Prepare
if add_records["embeddings"] is None:
validate_record_set_for_embedding(record_set=add_records)
add_embeddings = self._embed_record_set(record_set=add_records)
else:
add_embeddings = add_records["embeddings"]
add_metadatas = self._apply_sparse_embeddings_to_metadatas(
add_records["metadatas"], add_records["documents"]
)
return AddRequest(
ids=add_records["ids"],
embeddings=add_embeddings,
metadatas=add_metadatas,
documents=add_records["documents"],
uris=add_records["uris"],
)
@validation_context("get")
def _validate_and_prepare_get_request(
self,
ids: Optional[OneOrMany[ID]],
where: Optional[Where],
where_document: Optional[WhereDocument],
include: Include,
) -> GetRequest:
# Unpack
unpacked_ids: Optional[IDs] = maybe_cast_one_to_many(target=ids)
filters = FilterSet(where=where, where_document=where_document)
# Validate
if unpacked_ids is not None:
validate_ids(ids=unpacked_ids)
validate_filter_set(filter_set=filters)
validate_include(include=include, dissalowed=["distances"])
if "data" in include and self._data_loader is None:
raise ValueError(
"You must set a data loader on the collection if loading from URIs."
)
# Prepare
request_include = include
# We need to include uris in the result from the API to load datas
if "data" in include and "uris" not in include:
request_include.append("uris")
return GetRequest(
ids=unpacked_ids,
where=filters["where"],
where_document=filters["where_document"],
include=request_include,
)
@validation_context("query")
def _validate_and_prepare_query_request(
self,
query_embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
],
query_texts: Optional[OneOrMany[Document]],
query_images: Optional[OneOrMany[Image]],
query_uris: Optional[OneOrMany[URI]],
ids: Optional[OneOrMany[ID]],
n_results: int,
where: Optional[Where],
where_document: Optional[WhereDocument],
include: Include,
) -> QueryRequest:
# Unpack
query_records = normalize_base_record_set(
embeddings=query_embeddings,
documents=query_texts,
images=query_images,
uris=query_uris,
)
filter_ids = maybe_cast_one_to_many(ids)
filters = FilterSet(
where=where,
where_document=where_document,
)
# Validate
validate_base_record_set(record_set=query_records)
validate_filter_set(filter_set=filters)
validate_include(include=include)
validate_n_results(n_results=n_results)
# Prepare
if query_records["embeddings"] is None:
validate_record_set_for_embedding(record_set=query_records)
request_embeddings = self._embed_record_set(
record_set=query_records, is_query=True
)
else:
request_embeddings = query_records["embeddings"]
request_where = filters["where"]
request_where_document = filters["where_document"]
# We need to manually include uris in the result from the API to load datas
request_include = include
if "data" in request_include and "uris" not in request_include:
request_include.append("uris")
return QueryRequest(
embeddings=request_embeddings,
ids=filter_ids,
where=request_where,
where_document=request_where_document,
include=request_include,
n_results=n_results,
)
@validation_context("update")
def _validate_and_prepare_update_request(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
],
metadatas: Optional[OneOrMany[Metadata]],
documents: Optional[OneOrMany[Document]],
images: Optional[OneOrMany[Image]],
uris: Optional[OneOrMany[URI]],
) -> UpdateRequest:
# Unpack
update_records = normalize_insert_record_set(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
# Validate
validate_insert_record_set(record_set=update_records)
# Prepare
if update_records["embeddings"] is None:
# TODO: Handle URI updates.
if (
update_records["documents"] is not None
or update_records["images"] is not None
):
validate_record_set_for_embedding(
update_records, embeddable_fields={"documents", "images"}
)
update_embeddings = self._embed_record_set(record_set=update_records)
else:
update_embeddings = None
else:
update_embeddings = update_records["embeddings"]
update_metadatas = self._apply_sparse_embeddings_to_metadatas(
update_records["metadatas"], update_records["documents"]
)
return UpdateRequest(
ids=update_records["ids"],
embeddings=update_embeddings,
metadatas=update_metadatas,
documents=update_records["documents"],
uris=update_records["uris"],
)
@validation_context("upsert")
def _validate_and_prepare_upsert_request(
self,
ids: OneOrMany[ID],
embeddings: Optional[
Union[
OneOrMany[Embedding],
OneOrMany[PyEmbedding],
]
] = None,
metadatas: Optional[OneOrMany[Metadata]] = None,
documents: Optional[OneOrMany[Document]] = None,
images: Optional[OneOrMany[Image]] = None,
uris: Optional[OneOrMany[URI]] = None,
) -> UpsertRequest:
# Unpack
upsert_records = normalize_insert_record_set(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
images=images,
uris=uris,
)
# Validate
validate_insert_record_set(record_set=upsert_records)
# Prepare
if upsert_records["embeddings"] is None:
validate_record_set_for_embedding(
record_set=upsert_records, embeddable_fields={"documents", "images"}
)
upsert_embeddings = self._embed_record_set(record_set=upsert_records)
else:
upsert_embeddings = upsert_records["embeddings"]
upsert_metadatas = self._apply_sparse_embeddings_to_metadatas(
upsert_records["metadatas"], upsert_records["documents"]
)
return UpsertRequest(
ids=upsert_records["ids"],
metadatas=upsert_metadatas,
embeddings=upsert_embeddings,
documents=upsert_records["documents"],
uris=upsert_records["uris"],
)
@validation_context("delete")
def _validate_and_prepare_delete_request(
self,
ids: Optional[IDs],
where: Optional[Where],
where_document: Optional[WhereDocument],
) -> DeleteRequest:
if ids is None and where is None and where_document is None:
raise ValueError(
"At least one of ids, where, or where_document must be provided"
)
# Unpack
if ids is not None:
request_ids = cast(IDs, maybe_cast_one_to_many(ids))
else:
request_ids = None
filters = FilterSet(where=where, where_document=where_document)
# Validate
if request_ids is not None:
validate_ids(ids=request_ids)
validate_filter_set(filter_set=filters)
return DeleteRequest(
ids=request_ids, where=where, where_document=where_document
)
def _transform_peek_response(self, response: GetResult) -> GetResult:
if response["embeddings"] is not None:
response["embeddings"] = np.array(response["embeddings"])
return response
def _transform_get_response(
self, response: GetResult, include: Include
) -> GetResult:
if (
"data" in include
and self._data_loader is not None
and response["uris"] is not None
):
response["data"] = self._data_loader(response["uris"])
if "embeddings" in include:
response["embeddings"] = np.array(response["embeddings"])
# Remove URIs from the result if they weren't requested
if "uris" not in include:
response["uris"] = None
return response
def _transform_query_response(
self, response: QueryResult, include: Include
) -> QueryResult:
if (
"data" in include
and self._data_loader is not None
and response["uris"] is not None
):
response["data"] = [self._data_loader(uris) for uris in response["uris"]]
if "embeddings" in include and response["embeddings"] is not None:
response["embeddings"] = [
np.array(embedding) for embedding in response["embeddings"]
]
# Remove URIs from the result if they weren't requested
if "uris" not in include:
response["uris"] = None
return response
def _validate_modify_request(self, metadata: Optional[CollectionMetadata]) -> None:
if metadata is not None:
validate_metadata(metadata)
if "hnsw:space" in metadata:
raise ValueError(
"Changing the distance function of a collection once it is created is not supported currently."
)
def _update_model_after_modify_success(
self,
name: Optional[str],
metadata: Optional[CollectionMetadata],
configuration: Optional[UpdateCollectionConfiguration],
) -> None:
if name:
self._model["name"] = name
if metadata:
self._model["metadata"] = metadata
if configuration:
self._model.set_configuration(
overwrite_collection_configuration(
self._model.get_configuration(), configuration
)
)
# If schema exists, also update it with the configuration changes
if self.schema:
from chromadb.api.collection_configuration import (
update_schema_from_collection_configuration,
)
updated_schema = update_schema_from_collection_configuration(
self.schema, configuration
)
self._model["serialized_schema"] = updated_schema.serialize_to_json()
def _get_sparse_embedding_targets(self) -> Dict[str, "SparseVectorIndexConfig"]:
schema = self.schema
if schema is None:
return {}
targets: Dict[str, "SparseVectorIndexConfig"] = {}
for key, value_types in schema.keys.items():
if value_types.sparse_vector is None:
continue
sparse_index = value_types.sparse_vector.sparse_vector_index
if sparse_index is None or not sparse_index.enabled:
continue
config = sparse_index.config
if config.embedding_function is None or config.source_key is None:
continue
targets[key] = config
return targets
def _apply_sparse_embeddings_to_metadatas(
self,
metadatas: Optional[List[Metadata]],
documents: Optional[List[Document]] = None,
) -> Optional[List[Metadata]]:
sparse_targets = self._get_sparse_embedding_targets()
if not sparse_targets:
return metadatas
# If no metadatas provided, create empty dicts based on documents length
if metadatas is None:
if documents is None:
return None
metadatas = [{} for _ in range(len(documents))]
# Create copies, converting None to empty dict
updated_metadatas: List[Dict[str, Any]] = [
dict(metadata) if metadata is not None else {} for metadata in metadatas
]
documents_list = list(documents) if documents is not None else None
for target_key, config in sparse_targets.items():
source_key = config.source_key
embedding_func = config.embedding_function
if source_key is None or embedding_func is None:
continue
if not isinstance(embedding_func, SparseEmbeddingFunction):
embedding_func = cast(SparseEmbeddingFunction[Any], embedding_func)
validate_sparse_embedding_function(embedding_func)
# Initialize collection lists for batch processing
inputs: List[str] = []
positions: List[int] = []
# Handle special case: source_key is "#document"
if source_key == DOCUMENT_KEY:
if documents_list is None:
continue
# Collect documents that need embedding
for idx, metadata in enumerate(updated_metadatas):
# Skip if target already exists in metadata
if target_key in metadata:
continue
# Get document at this position
if idx < len(documents_list):
doc = documents_list[idx]
if isinstance(doc, str):
inputs.append(doc)
positions.append(idx)
# Generate embeddings for all collected documents
if len(inputs) == 0:
continue
sparse_embeddings = self._sparse_embed(
input=inputs,
sparse_embedding_function=embedding_func,
)
if len(sparse_embeddings) != len(positions):
raise ValueError(
"Sparse embedding function returned unexpected number of embeddings."
)
for position, embedding in zip(positions, sparse_embeddings):
updated_metadatas[position][target_key] = embedding
continue # Skip the metadata-based logic below
# Handle normal case: source_key is a metadata field
for idx, metadata in enumerate(updated_metadatas):
if target_key in metadata:
continue
source_value = metadata.get(source_key)
if not isinstance(source_value, str):
continue
inputs.append(source_value)
positions.append(idx)
if len(inputs) == 0:
continue
sparse_embeddings = self._sparse_embed(
input=inputs,
sparse_embedding_function=embedding_func,
)
if len(sparse_embeddings) != len(positions):
raise ValueError(
"Sparse embedding function returned unexpected number of embeddings."
)
for position, embedding in zip(positions, sparse_embeddings):
updated_metadatas[position][target_key] = embedding
# Convert empty dicts back to None, validation requires non-empty dicts or None
result_metadatas: List[Optional[Metadata]] = [
metadata if metadata else None for metadata in updated_metadatas
]
validate_metadatas(cast(List[Metadata], result_metadatas))
return cast(List[Metadata], result_metadatas)
def _embed_record_set(
self,
record_set: BaseRecordSet,
embeddable_fields: Optional[Set[str]] = None,
is_query: bool = False,
) -> Embeddings:
if embeddable_fields is None:
embeddable_fields = get_default_embeddable_record_set_fields()
for field in embeddable_fields:
if record_set[field] is not None: # type: ignore[literal-required]
# uris require special handling
if field == "uris":
if self._data_loader is None:
raise ValueError(
"You must set a data loader on the collection if loading from URIs."
)
return self._embed(
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
is_query=is_query,
)
else:
return self._embed(
input=record_set[field], # type: ignore[literal-required]
is_query=is_query,
)
raise ValueError(
"Record does not contain any non-None fields that can be embedded."
f"Embeddable Fields: {embeddable_fields}"
f"Record Fields: {record_set}"
)
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
if self._embedding_function is not None and not isinstance(
self._embedding_function, DefaultEmbeddingFunction
):
if is_query:
return self._embedding_function.embed_query(input=input)
else:
return self._embedding_function(input=input)
config_ef = self.configuration.get("embedding_function")
if config_ef is not None:
if is_query:
return config_ef.embed_query(input=input)
else:
return config_ef(input=input)
schema = self.schema
schema_embedding_function: Optional[EmbeddingFunction[Embeddable]] = None
if schema is not None:
override = schema.keys.get(EMBEDDING_KEY)
if (
override is not None
and override.float_list is not None
and override.float_list.vector_index is not None
and override.float_list.vector_index.config.embedding_function
is not None
):
schema_embedding_function = cast(
EmbeddingFunction[Embeddable],
override.float_list.vector_index.config.embedding_function,
)
elif (
schema.defaults.float_list is not None
and schema.defaults.float_list.vector_index is not None
and schema.defaults.float_list.vector_index.config.embedding_function
is not None
):
schema_embedding_function = cast(
EmbeddingFunction[Embeddable],
schema.defaults.float_list.vector_index.config.embedding_function,
)
if schema_embedding_function is not None:
if is_query and hasattr(schema_embedding_function, "embed_query"):
return schema_embedding_function.embed_query(input=input)
return schema_embedding_function(input=input)
if self._embedding_function is None:
raise ValueError(
"You must provide an embedding function to compute embeddings."
"https://docs.trychroma.com/guides/embeddings"
)
if is_query:
return self._embedding_function.embed_query(input=input)
else:
return self._embedding_function(input=input)
def _sparse_embed(
self,
input: Any,
sparse_embedding_function: SparseEmbeddingFunction[Any],
is_query: bool = False,
) -> Any:
if is_query:
return sparse_embedding_function.embed_query(input=input)
return sparse_embedding_function(input=input)
def _embed_knn_string_queries(self, knn: Any) -> Any:
"""Embed string queries in Knn objects using the appropriate embedding function.
Args:
knn: A Knn object that may have a string query
Returns:
A Knn object with the string query replaced by an embedding
Raises:
ValueError: If the query is a string but no embedding function is available
"""
from chromadb.execution.expression.operator import Knn
if not isinstance(knn, Knn):
return knn
# If query is not a string, nothing to do
if not isinstance(knn.query, str):
return knn
query_text = knn.query
key = knn.key
# Handle main embedding field
if key == EMBEDDING_KEY:
# Use the collection's main embedding function
embedding = self._embed(input=[query_text], is_query=True)
if not embedding or len(embedding) != 1:
raise ValueError(
"Embedding function returned unexpected number of embeddings"
)
# Return a new Knn with the embedded query
return Knn(
query=embedding[0],
key=knn.key,
limit=knn.limit,
default=knn.default,
return_rank=knn.return_rank,
)
# Handle metadata field with potential sparse embedding
schema = self.schema
if schema is None or key not in schema.keys:
raise ValueError(
f"Cannot embed string query for key '{key}': "
f"key not found in schema. Please provide an embedded vector or "
f"configure an embedding function for this key in the schema."
)
value_type = schema.keys[key]
# Check for sparse vector with embedding function
if value_type.sparse_vector is not None:
sparse_index = value_type.sparse_vector.sparse_vector_index
if sparse_index is not None and sparse_index.enabled:
sparse_config = sparse_index.config
if sparse_config.embedding_function is not None:
embedding_func = sparse_config.embedding_function
if not isinstance(embedding_func, SparseEmbeddingFunction):
embedding_func = cast(
SparseEmbeddingFunction[Any], embedding_func
)
validate_sparse_embedding_function(embedding_func)
# Embed the query
sparse_embedding = self._sparse_embed(
input=[query_text],
sparse_embedding_function=embedding_func,
is_query=True,
)
if not sparse_embedding or len(sparse_embedding) != 1:
raise ValueError(
"Sparse embedding function returned unexpected number of embeddings"
)
# Return a new Knn with the sparse embedding
return Knn(
query=sparse_embedding[0],
key=knn.key,
limit=knn.limit,
default=knn.default,
return_rank=knn.return_rank,
)
# Check for dense vector with embedding function (float_list)
if value_type.float_list is not None:
vector_index = value_type.float_list.vector_index
if vector_index is not None and vector_index.enabled:
dense_config = vector_index.config
if dense_config.embedding_function is not None:
embedding_func = dense_config.embedding_function
validate_embedding_function(embedding_func)
# Embed the query using the schema's embedding function
try:
embeddings = embedding_func.embed_query(input=[query_text])
except AttributeError:
# Fallback if embed_query doesn't exist
embeddings = embedding_func([query_text])
if not embeddings or len(embeddings) != 1:
raise ValueError(
"Embedding function returned unexpected number of embeddings"
)
# Return a new Knn with the dense embedding
return Knn(
query=embeddings[0],
key=knn.key,
limit=knn.limit,
default=knn.default,
return_rank=knn.return_rank,
)
raise ValueError(
f"Cannot embed string query for key '{key}': "
f"no embedding function configured for this key in the schema. "
f"Please provide an embedded vector or configure an embedding function."
)
def _embed_rank_string_queries(self, rank: Any) -> Any:
"""Recursively embed string queries in Rank expressions.
Args:
rank: A Rank expression that may contain Knn objects with string queries
Returns:
A Rank expression with all string queries embedded
"""
# Import here to avoid circular dependency
from chromadb.execution.expression.operator import (
Knn,
Abs,
Div,
Exp,
Log,
Max,
Min,
Mul,
Sub,
Sum,
Val,
Rrf,
)
if rank is None:
return None
# Base case: Knn - embed if it has a string query
if isinstance(rank, Knn):
return self._embed_knn_string_queries(rank)
# Base case: Val - no embedding needed
if isinstance(rank, Val):
return rank
# Recursive cases: walk through child ranks
if isinstance(rank, Abs):
return Abs(self._embed_rank_string_queries(rank.rank))
if isinstance(rank, Div):
return Div(
self._embed_rank_string_queries(rank.left),
self._embed_rank_string_queries(rank.right),
)
if isinstance(rank, Exp):
return Exp(self._embed_rank_string_queries(rank.rank))
if isinstance(rank, Log):
return Log(self._embed_rank_string_queries(rank.rank))
if isinstance(rank, Max):
return Max([self._embed_rank_string_queries(r) for r in rank.ranks])
if isinstance(rank, Min):
return Min([self._embed_rank_string_queries(r) for r in rank.ranks])
if isinstance(rank, Mul):
return Mul([self._embed_rank_string_queries(r) for r in rank.ranks])
if isinstance(rank, Sub):
return Sub(
self._embed_rank_string_queries(rank.left),
self._embed_rank_string_queries(rank.right),
)
if isinstance(rank, Sum):
return Sum([self._embed_rank_string_queries(r) for r in rank.ranks])
if isinstance(rank, Rrf):
return Rrf(
ranks=[self._embed_rank_string_queries(r) for r in rank.ranks],
k=rank.k,
weights=rank.weights,
normalize=rank.normalize,
)
# Unknown rank type - return as is
return rank
def _embed_search_string_queries(self, search: Any) -> Any:
"""Embed string queries in a Search object.
Args:
search: A Search object that may contain Knn objects with string queries
Returns:
A Search object with all string queries embedded
"""
# Import here to avoid circular dependency
from chromadb.execution.expression.plan import Search
if not isinstance(search, Search):
return search
# Embed the rank expression if it exists
embedded_rank = self._embed_rank_string_queries(search._rank)
# Create a new Search with the embedded rank
return Search(
where=search._where,
rank=embedded_rank,
group_by=search._group_by,
limit=search._limit,
select=search._select,
)