1028 lines
35 KiB
Python
1028 lines
35 KiB
Python
import functools
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Callable,
|
|
Dict,
|
|
Generic,
|
|
Optional,
|
|
Any,
|
|
Set,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
List,
|
|
)
|
|
from chromadb.types import Metadata
|
|
import numpy as np
|
|
from uuid import UUID
|
|
|
|
from chromadb.api.types import (
|
|
URI,
|
|
Schema,
|
|
SparseVectorIndexConfig,
|
|
URIs,
|
|
AddRequest,
|
|
BaseRecordSet,
|
|
CollectionMetadata,
|
|
DataLoader,
|
|
DeleteRequest,
|
|
Embedding,
|
|
Embeddings,
|
|
FilterSet,
|
|
GetRequest,
|
|
PyEmbedding,
|
|
Embeddable,
|
|
GetResult,
|
|
Include,
|
|
Loadable,
|
|
Document,
|
|
Image,
|
|
QueryRequest,
|
|
QueryResult,
|
|
IDs,
|
|
EmbeddingFunction,
|
|
SparseEmbeddingFunction,
|
|
ID,
|
|
OneOrMany,
|
|
UpdateRequest,
|
|
UpsertRequest,
|
|
get_default_embeddable_record_set_fields,
|
|
maybe_cast_one_to_many,
|
|
normalize_base_record_set,
|
|
normalize_insert_record_set,
|
|
validate_base_record_set,
|
|
validate_ids,
|
|
validate_include,
|
|
validate_insert_record_set,
|
|
validate_metadata,
|
|
validate_metadatas,
|
|
validate_embedding_function,
|
|
validate_sparse_embedding_function,
|
|
validate_n_results,
|
|
validate_record_set_contains_any,
|
|
validate_record_set_for_embedding,
|
|
validate_filter_set,
|
|
DefaultEmbeddingFunction,
|
|
EMBEDDING_KEY,
|
|
DOCUMENT_KEY,
|
|
)
|
|
from chromadb.api.collection_configuration import (
|
|
UpdateCollectionConfiguration,
|
|
overwrite_collection_configuration,
|
|
load_collection_configuration_from_json,
|
|
CollectionConfiguration,
|
|
)
|
|
|
|
# TODO: We should rename the types in chromadb.types to be Models where
|
|
# appropriate. This will help to distinguish between manipulation objects
|
|
# which are essentially API views. And the actual data models which are
|
|
# stored / retrieved / transmitted.
|
|
from chromadb.types import Collection as CollectionModel, Where, WhereDocument
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from chromadb.api import ServerAPI, AsyncServerAPI
|
|
|
|
ClientT = TypeVar("ClientT", "ServerAPI", "AsyncServerAPI")
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def validation_context(name: str) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
"""A decorator that wraps a method with a try-except block that catches
|
|
exceptions and adds the method name to the error message. This allows us to
|
|
provide more context when an error occurs, without rewriting validators.
|
|
"""
|
|
|
|
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
@functools.wraps(func)
|
|
def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
|
|
try:
|
|
return func(self, *args, **kwargs)
|
|
except Exception as e:
|
|
msg = f"{str(e)} in {name}."
|
|
# add the rest of the args to the error message if they exist
|
|
e.args = (msg,) + e.args[1:] if e.args else ()
|
|
# raise the same error that was caught with the modified message
|
|
raise
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
class CollectionCommon(Generic[ClientT]):
|
|
_model: CollectionModel
|
|
_client: ClientT
|
|
_embedding_function: Optional[EmbeddingFunction[Embeddable]]
|
|
_data_loader: Optional[DataLoader[Loadable]]
|
|
|
|
def __init__(
|
|
self,
|
|
client: ClientT,
|
|
model: CollectionModel,
|
|
embedding_function: Optional[
|
|
EmbeddingFunction[Embeddable]
|
|
] = DefaultEmbeddingFunction(), # type: ignore
|
|
data_loader: Optional[DataLoader[Loadable]] = None,
|
|
):
|
|
"""Initializes a new instance of the Collection class."""
|
|
|
|
self._client = client
|
|
self._model = model
|
|
|
|
# Check to make sure the embedding function has the right signature, as defined by the EmbeddingFunction protocol
|
|
if embedding_function is not None:
|
|
validate_embedding_function(embedding_function)
|
|
|
|
self._embedding_function = embedding_function
|
|
self._data_loader = data_loader
|
|
|
|
# Expose the model properties as read-only properties on the Collection class
|
|
|
|
@property
|
|
def id(self) -> UUID:
|
|
return self._model.id
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._model.name
|
|
|
|
@property
|
|
def configuration(self) -> CollectionConfiguration:
|
|
return load_collection_configuration_from_json(self._model.configuration_json)
|
|
|
|
@property
|
|
def configuration_json(self) -> Dict[str, Any]:
|
|
return self._model.configuration_json
|
|
|
|
@property
|
|
def schema(self) -> Optional[Schema]:
|
|
return Schema.deserialize_from_json(
|
|
self._model.serialized_schema if self._model.serialized_schema else {}
|
|
)
|
|
|
|
@property
|
|
def metadata(self) -> CollectionMetadata:
|
|
return cast(CollectionMetadata, self._model.metadata)
|
|
|
|
@property
|
|
def tenant(self) -> str:
|
|
return self._model.tenant
|
|
|
|
@property
|
|
def database(self) -> str:
|
|
return self._model.database
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, CollectionCommon):
|
|
return False
|
|
id_match = self.id == other.id
|
|
name_match = self.name == other.name
|
|
configuration_match = self.configuration_json == other.configuration_json
|
|
schema_match = self.schema == other.schema
|
|
metadata_match = self.metadata == other.metadata
|
|
tenant_match = self.tenant == other.tenant
|
|
database_match = self.database == other.database
|
|
embedding_function_match = self._embedding_function == other._embedding_function
|
|
data_loader_match = self._data_loader == other._data_loader
|
|
return (
|
|
id_match
|
|
and name_match
|
|
and configuration_match
|
|
and schema_match
|
|
and metadata_match
|
|
and tenant_match
|
|
and database_match
|
|
and embedding_function_match
|
|
and data_loader_match
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Collection(name={self.name})"
|
|
|
|
def get_model(self) -> CollectionModel:
|
|
return self._model
|
|
|
|
@validation_context("add")
|
|
def _validate_and_prepare_add_request(
|
|
self,
|
|
ids: OneOrMany[ID],
|
|
embeddings: Optional[
|
|
Union[
|
|
OneOrMany[Embedding],
|
|
OneOrMany[PyEmbedding],
|
|
]
|
|
],
|
|
metadatas: Optional[OneOrMany[Metadata]],
|
|
documents: Optional[OneOrMany[Document]],
|
|
images: Optional[OneOrMany[Image]],
|
|
uris: Optional[OneOrMany[URI]],
|
|
) -> AddRequest:
|
|
# Unpack
|
|
add_records = normalize_insert_record_set(
|
|
ids=ids,
|
|
embeddings=embeddings,
|
|
metadatas=metadatas,
|
|
documents=documents,
|
|
images=images,
|
|
uris=uris,
|
|
)
|
|
|
|
# Validate
|
|
validate_insert_record_set(record_set=add_records)
|
|
validate_record_set_contains_any(record_set=add_records, contains_any={"ids"})
|
|
|
|
# Prepare
|
|
if add_records["embeddings"] is None:
|
|
validate_record_set_for_embedding(record_set=add_records)
|
|
add_embeddings = self._embed_record_set(record_set=add_records)
|
|
else:
|
|
add_embeddings = add_records["embeddings"]
|
|
|
|
add_metadatas = self._apply_sparse_embeddings_to_metadatas(
|
|
add_records["metadatas"], add_records["documents"]
|
|
)
|
|
|
|
return AddRequest(
|
|
ids=add_records["ids"],
|
|
embeddings=add_embeddings,
|
|
metadatas=add_metadatas,
|
|
documents=add_records["documents"],
|
|
uris=add_records["uris"],
|
|
)
|
|
|
|
@validation_context("get")
|
|
def _validate_and_prepare_get_request(
|
|
self,
|
|
ids: Optional[OneOrMany[ID]],
|
|
where: Optional[Where],
|
|
where_document: Optional[WhereDocument],
|
|
include: Include,
|
|
) -> GetRequest:
|
|
# Unpack
|
|
unpacked_ids: Optional[IDs] = maybe_cast_one_to_many(target=ids)
|
|
filters = FilterSet(where=where, where_document=where_document)
|
|
|
|
# Validate
|
|
if unpacked_ids is not None:
|
|
validate_ids(ids=unpacked_ids)
|
|
|
|
validate_filter_set(filter_set=filters)
|
|
validate_include(include=include, dissalowed=["distances"])
|
|
|
|
if "data" in include and self._data_loader is None:
|
|
raise ValueError(
|
|
"You must set a data loader on the collection if loading from URIs."
|
|
)
|
|
|
|
# Prepare
|
|
request_include = include
|
|
# We need to include uris in the result from the API to load datas
|
|
if "data" in include and "uris" not in include:
|
|
request_include.append("uris")
|
|
|
|
return GetRequest(
|
|
ids=unpacked_ids,
|
|
where=filters["where"],
|
|
where_document=filters["where_document"],
|
|
include=request_include,
|
|
)
|
|
|
|
@validation_context("query")
|
|
def _validate_and_prepare_query_request(
|
|
self,
|
|
query_embeddings: Optional[
|
|
Union[
|
|
OneOrMany[Embedding],
|
|
OneOrMany[PyEmbedding],
|
|
]
|
|
],
|
|
query_texts: Optional[OneOrMany[Document]],
|
|
query_images: Optional[OneOrMany[Image]],
|
|
query_uris: Optional[OneOrMany[URI]],
|
|
ids: Optional[OneOrMany[ID]],
|
|
n_results: int,
|
|
where: Optional[Where],
|
|
where_document: Optional[WhereDocument],
|
|
include: Include,
|
|
) -> QueryRequest:
|
|
# Unpack
|
|
query_records = normalize_base_record_set(
|
|
embeddings=query_embeddings,
|
|
documents=query_texts,
|
|
images=query_images,
|
|
uris=query_uris,
|
|
)
|
|
|
|
filter_ids = maybe_cast_one_to_many(ids)
|
|
|
|
filters = FilterSet(
|
|
where=where,
|
|
where_document=where_document,
|
|
)
|
|
|
|
# Validate
|
|
validate_base_record_set(record_set=query_records)
|
|
validate_filter_set(filter_set=filters)
|
|
validate_include(include=include)
|
|
validate_n_results(n_results=n_results)
|
|
|
|
# Prepare
|
|
if query_records["embeddings"] is None:
|
|
validate_record_set_for_embedding(record_set=query_records)
|
|
request_embeddings = self._embed_record_set(
|
|
record_set=query_records, is_query=True
|
|
)
|
|
else:
|
|
request_embeddings = query_records["embeddings"]
|
|
|
|
request_where = filters["where"]
|
|
request_where_document = filters["where_document"]
|
|
|
|
# We need to manually include uris in the result from the API to load datas
|
|
request_include = include
|
|
if "data" in request_include and "uris" not in request_include:
|
|
request_include.append("uris")
|
|
|
|
return QueryRequest(
|
|
embeddings=request_embeddings,
|
|
ids=filter_ids,
|
|
where=request_where,
|
|
where_document=request_where_document,
|
|
include=request_include,
|
|
n_results=n_results,
|
|
)
|
|
|
|
@validation_context("update")
|
|
def _validate_and_prepare_update_request(
|
|
self,
|
|
ids: OneOrMany[ID],
|
|
embeddings: Optional[
|
|
Union[
|
|
OneOrMany[Embedding],
|
|
OneOrMany[PyEmbedding],
|
|
]
|
|
],
|
|
metadatas: Optional[OneOrMany[Metadata]],
|
|
documents: Optional[OneOrMany[Document]],
|
|
images: Optional[OneOrMany[Image]],
|
|
uris: Optional[OneOrMany[URI]],
|
|
) -> UpdateRequest:
|
|
# Unpack
|
|
update_records = normalize_insert_record_set(
|
|
ids=ids,
|
|
embeddings=embeddings,
|
|
metadatas=metadatas,
|
|
documents=documents,
|
|
images=images,
|
|
uris=uris,
|
|
)
|
|
|
|
# Validate
|
|
validate_insert_record_set(record_set=update_records)
|
|
|
|
# Prepare
|
|
if update_records["embeddings"] is None:
|
|
# TODO: Handle URI updates.
|
|
if (
|
|
update_records["documents"] is not None
|
|
or update_records["images"] is not None
|
|
):
|
|
validate_record_set_for_embedding(
|
|
update_records, embeddable_fields={"documents", "images"}
|
|
)
|
|
update_embeddings = self._embed_record_set(record_set=update_records)
|
|
else:
|
|
update_embeddings = None
|
|
else:
|
|
update_embeddings = update_records["embeddings"]
|
|
|
|
update_metadatas = self._apply_sparse_embeddings_to_metadatas(
|
|
update_records["metadatas"], update_records["documents"]
|
|
)
|
|
|
|
return UpdateRequest(
|
|
ids=update_records["ids"],
|
|
embeddings=update_embeddings,
|
|
metadatas=update_metadatas,
|
|
documents=update_records["documents"],
|
|
uris=update_records["uris"],
|
|
)
|
|
|
|
@validation_context("upsert")
|
|
def _validate_and_prepare_upsert_request(
|
|
self,
|
|
ids: OneOrMany[ID],
|
|
embeddings: Optional[
|
|
Union[
|
|
OneOrMany[Embedding],
|
|
OneOrMany[PyEmbedding],
|
|
]
|
|
] = None,
|
|
metadatas: Optional[OneOrMany[Metadata]] = None,
|
|
documents: Optional[OneOrMany[Document]] = None,
|
|
images: Optional[OneOrMany[Image]] = None,
|
|
uris: Optional[OneOrMany[URI]] = None,
|
|
) -> UpsertRequest:
|
|
# Unpack
|
|
upsert_records = normalize_insert_record_set(
|
|
ids=ids,
|
|
embeddings=embeddings,
|
|
metadatas=metadatas,
|
|
documents=documents,
|
|
images=images,
|
|
uris=uris,
|
|
)
|
|
|
|
# Validate
|
|
validate_insert_record_set(record_set=upsert_records)
|
|
|
|
# Prepare
|
|
if upsert_records["embeddings"] is None:
|
|
validate_record_set_for_embedding(
|
|
record_set=upsert_records, embeddable_fields={"documents", "images"}
|
|
)
|
|
upsert_embeddings = self._embed_record_set(record_set=upsert_records)
|
|
else:
|
|
upsert_embeddings = upsert_records["embeddings"]
|
|
|
|
upsert_metadatas = self._apply_sparse_embeddings_to_metadatas(
|
|
upsert_records["metadatas"], upsert_records["documents"]
|
|
)
|
|
|
|
return UpsertRequest(
|
|
ids=upsert_records["ids"],
|
|
metadatas=upsert_metadatas,
|
|
embeddings=upsert_embeddings,
|
|
documents=upsert_records["documents"],
|
|
uris=upsert_records["uris"],
|
|
)
|
|
|
|
@validation_context("delete")
|
|
def _validate_and_prepare_delete_request(
|
|
self,
|
|
ids: Optional[IDs],
|
|
where: Optional[Where],
|
|
where_document: Optional[WhereDocument],
|
|
) -> DeleteRequest:
|
|
if ids is None and where is None and where_document is None:
|
|
raise ValueError(
|
|
"At least one of ids, where, or where_document must be provided"
|
|
)
|
|
|
|
# Unpack
|
|
if ids is not None:
|
|
request_ids = cast(IDs, maybe_cast_one_to_many(ids))
|
|
else:
|
|
request_ids = None
|
|
filters = FilterSet(where=where, where_document=where_document)
|
|
|
|
# Validate
|
|
if request_ids is not None:
|
|
validate_ids(ids=request_ids)
|
|
validate_filter_set(filter_set=filters)
|
|
|
|
return DeleteRequest(
|
|
ids=request_ids, where=where, where_document=where_document
|
|
)
|
|
|
|
def _transform_peek_response(self, response: GetResult) -> GetResult:
|
|
if response["embeddings"] is not None:
|
|
response["embeddings"] = np.array(response["embeddings"])
|
|
|
|
return response
|
|
|
|
def _transform_get_response(
|
|
self, response: GetResult, include: Include
|
|
) -> GetResult:
|
|
if (
|
|
"data" in include
|
|
and self._data_loader is not None
|
|
and response["uris"] is not None
|
|
):
|
|
response["data"] = self._data_loader(response["uris"])
|
|
|
|
if "embeddings" in include:
|
|
response["embeddings"] = np.array(response["embeddings"])
|
|
|
|
# Remove URIs from the result if they weren't requested
|
|
if "uris" not in include:
|
|
response["uris"] = None
|
|
|
|
return response
|
|
|
|
def _transform_query_response(
|
|
self, response: QueryResult, include: Include
|
|
) -> QueryResult:
|
|
if (
|
|
"data" in include
|
|
and self._data_loader is not None
|
|
and response["uris"] is not None
|
|
):
|
|
response["data"] = [self._data_loader(uris) for uris in response["uris"]]
|
|
|
|
if "embeddings" in include and response["embeddings"] is not None:
|
|
response["embeddings"] = [
|
|
np.array(embedding) for embedding in response["embeddings"]
|
|
]
|
|
|
|
# Remove URIs from the result if they weren't requested
|
|
if "uris" not in include:
|
|
response["uris"] = None
|
|
|
|
return response
|
|
|
|
def _validate_modify_request(self, metadata: Optional[CollectionMetadata]) -> None:
|
|
if metadata is not None:
|
|
validate_metadata(metadata)
|
|
if "hnsw:space" in metadata:
|
|
raise ValueError(
|
|
"Changing the distance function of a collection once it is created is not supported currently."
|
|
)
|
|
|
|
def _update_model_after_modify_success(
|
|
self,
|
|
name: Optional[str],
|
|
metadata: Optional[CollectionMetadata],
|
|
configuration: Optional[UpdateCollectionConfiguration],
|
|
) -> None:
|
|
if name:
|
|
self._model["name"] = name
|
|
if metadata:
|
|
self._model["metadata"] = metadata
|
|
if configuration:
|
|
self._model.set_configuration(
|
|
overwrite_collection_configuration(
|
|
self._model.get_configuration(), configuration
|
|
)
|
|
)
|
|
|
|
# If schema exists, also update it with the configuration changes
|
|
if self.schema:
|
|
from chromadb.api.collection_configuration import (
|
|
update_schema_from_collection_configuration,
|
|
)
|
|
|
|
updated_schema = update_schema_from_collection_configuration(
|
|
self.schema, configuration
|
|
)
|
|
self._model["serialized_schema"] = updated_schema.serialize_to_json()
|
|
|
|
def _get_sparse_embedding_targets(self) -> Dict[str, "SparseVectorIndexConfig"]:
|
|
schema = self.schema
|
|
if schema is None:
|
|
return {}
|
|
|
|
targets: Dict[str, "SparseVectorIndexConfig"] = {}
|
|
for key, value_types in schema.keys.items():
|
|
if value_types.sparse_vector is None:
|
|
continue
|
|
sparse_index = value_types.sparse_vector.sparse_vector_index
|
|
if sparse_index is None or not sparse_index.enabled:
|
|
continue
|
|
config = sparse_index.config
|
|
if config.embedding_function is None or config.source_key is None:
|
|
continue
|
|
targets[key] = config
|
|
|
|
return targets
|
|
|
|
def _apply_sparse_embeddings_to_metadatas(
|
|
self,
|
|
metadatas: Optional[List[Metadata]],
|
|
documents: Optional[List[Document]] = None,
|
|
) -> Optional[List[Metadata]]:
|
|
sparse_targets = self._get_sparse_embedding_targets()
|
|
if not sparse_targets:
|
|
return metadatas
|
|
|
|
# If no metadatas provided, create empty dicts based on documents length
|
|
if metadatas is None:
|
|
if documents is None:
|
|
return None
|
|
metadatas = [{} for _ in range(len(documents))]
|
|
|
|
# Create copies, converting None to empty dict
|
|
updated_metadatas: List[Dict[str, Any]] = [
|
|
dict(metadata) if metadata is not None else {} for metadata in metadatas
|
|
]
|
|
|
|
documents_list = list(documents) if documents is not None else None
|
|
|
|
for target_key, config in sparse_targets.items():
|
|
source_key = config.source_key
|
|
embedding_func = config.embedding_function
|
|
if source_key is None or embedding_func is None:
|
|
continue
|
|
|
|
if not isinstance(embedding_func, SparseEmbeddingFunction):
|
|
embedding_func = cast(SparseEmbeddingFunction[Any], embedding_func)
|
|
validate_sparse_embedding_function(embedding_func)
|
|
|
|
# Initialize collection lists for batch processing
|
|
inputs: List[str] = []
|
|
positions: List[int] = []
|
|
|
|
# Handle special case: source_key is "#document"
|
|
if source_key == DOCUMENT_KEY:
|
|
if documents_list is None:
|
|
continue
|
|
|
|
# Collect documents that need embedding
|
|
for idx, metadata in enumerate(updated_metadatas):
|
|
# Skip if target already exists in metadata
|
|
if target_key in metadata:
|
|
continue
|
|
|
|
# Get document at this position
|
|
if idx < len(documents_list):
|
|
doc = documents_list[idx]
|
|
if isinstance(doc, str):
|
|
inputs.append(doc)
|
|
positions.append(idx)
|
|
|
|
# Generate embeddings for all collected documents
|
|
if len(inputs) == 0:
|
|
continue
|
|
|
|
sparse_embeddings = self._sparse_embed(
|
|
input=inputs,
|
|
sparse_embedding_function=embedding_func,
|
|
)
|
|
|
|
if len(sparse_embeddings) != len(positions):
|
|
raise ValueError(
|
|
"Sparse embedding function returned unexpected number of embeddings."
|
|
)
|
|
|
|
for position, embedding in zip(positions, sparse_embeddings):
|
|
updated_metadatas[position][target_key] = embedding
|
|
|
|
continue # Skip the metadata-based logic below
|
|
|
|
# Handle normal case: source_key is a metadata field
|
|
for idx, metadata in enumerate(updated_metadatas):
|
|
if target_key in metadata:
|
|
continue
|
|
|
|
source_value = metadata.get(source_key)
|
|
if not isinstance(source_value, str):
|
|
continue
|
|
|
|
inputs.append(source_value)
|
|
positions.append(idx)
|
|
|
|
if len(inputs) == 0:
|
|
continue
|
|
|
|
sparse_embeddings = self._sparse_embed(
|
|
input=inputs,
|
|
sparse_embedding_function=embedding_func,
|
|
)
|
|
|
|
if len(sparse_embeddings) != len(positions):
|
|
raise ValueError(
|
|
"Sparse embedding function returned unexpected number of embeddings."
|
|
)
|
|
|
|
for position, embedding in zip(positions, sparse_embeddings):
|
|
updated_metadatas[position][target_key] = embedding
|
|
|
|
# Convert empty dicts back to None, validation requires non-empty dicts or None
|
|
result_metadatas: List[Optional[Metadata]] = [
|
|
metadata if metadata else None for metadata in updated_metadatas
|
|
]
|
|
|
|
validate_metadatas(cast(List[Metadata], result_metadatas))
|
|
return cast(List[Metadata], result_metadatas)
|
|
|
|
def _embed_record_set(
|
|
self,
|
|
record_set: BaseRecordSet,
|
|
embeddable_fields: Optional[Set[str]] = None,
|
|
is_query: bool = False,
|
|
) -> Embeddings:
|
|
if embeddable_fields is None:
|
|
embeddable_fields = get_default_embeddable_record_set_fields()
|
|
|
|
for field in embeddable_fields:
|
|
if record_set[field] is not None: # type: ignore[literal-required]
|
|
# uris require special handling
|
|
if field == "uris":
|
|
if self._data_loader is None:
|
|
raise ValueError(
|
|
"You must set a data loader on the collection if loading from URIs."
|
|
)
|
|
return self._embed(
|
|
input=self._data_loader(uris=cast(URIs, record_set[field])), # type: ignore[literal-required]
|
|
is_query=is_query,
|
|
)
|
|
else:
|
|
return self._embed(
|
|
input=record_set[field], # type: ignore[literal-required]
|
|
is_query=is_query,
|
|
)
|
|
raise ValueError(
|
|
"Record does not contain any non-None fields that can be embedded."
|
|
f"Embeddable Fields: {embeddable_fields}"
|
|
f"Record Fields: {record_set}"
|
|
)
|
|
|
|
def _embed(self, input: Any, is_query: bool = False) -> Embeddings:
|
|
if self._embedding_function is not None and not isinstance(
|
|
self._embedding_function, DefaultEmbeddingFunction
|
|
):
|
|
if is_query:
|
|
return self._embedding_function.embed_query(input=input)
|
|
else:
|
|
return self._embedding_function(input=input)
|
|
|
|
config_ef = self.configuration.get("embedding_function")
|
|
if config_ef is not None:
|
|
if is_query:
|
|
return config_ef.embed_query(input=input)
|
|
else:
|
|
return config_ef(input=input)
|
|
schema = self.schema
|
|
schema_embedding_function: Optional[EmbeddingFunction[Embeddable]] = None
|
|
if schema is not None:
|
|
override = schema.keys.get(EMBEDDING_KEY)
|
|
if (
|
|
override is not None
|
|
and override.float_list is not None
|
|
and override.float_list.vector_index is not None
|
|
and override.float_list.vector_index.config.embedding_function
|
|
is not None
|
|
):
|
|
schema_embedding_function = cast(
|
|
EmbeddingFunction[Embeddable],
|
|
override.float_list.vector_index.config.embedding_function,
|
|
)
|
|
elif (
|
|
schema.defaults.float_list is not None
|
|
and schema.defaults.float_list.vector_index is not None
|
|
and schema.defaults.float_list.vector_index.config.embedding_function
|
|
is not None
|
|
):
|
|
schema_embedding_function = cast(
|
|
EmbeddingFunction[Embeddable],
|
|
schema.defaults.float_list.vector_index.config.embedding_function,
|
|
)
|
|
|
|
if schema_embedding_function is not None:
|
|
if is_query and hasattr(schema_embedding_function, "embed_query"):
|
|
return schema_embedding_function.embed_query(input=input)
|
|
return schema_embedding_function(input=input)
|
|
if self._embedding_function is None:
|
|
raise ValueError(
|
|
"You must provide an embedding function to compute embeddings."
|
|
"https://docs.trychroma.com/guides/embeddings"
|
|
)
|
|
if is_query:
|
|
return self._embedding_function.embed_query(input=input)
|
|
else:
|
|
return self._embedding_function(input=input)
|
|
|
|
def _sparse_embed(
|
|
self,
|
|
input: Any,
|
|
sparse_embedding_function: SparseEmbeddingFunction[Any],
|
|
is_query: bool = False,
|
|
) -> Any:
|
|
if is_query:
|
|
return sparse_embedding_function.embed_query(input=input)
|
|
return sparse_embedding_function(input=input)
|
|
|
|
def _embed_knn_string_queries(self, knn: Any) -> Any:
|
|
"""Embed string queries in Knn objects using the appropriate embedding function.
|
|
|
|
Args:
|
|
knn: A Knn object that may have a string query
|
|
|
|
Returns:
|
|
A Knn object with the string query replaced by an embedding
|
|
|
|
Raises:
|
|
ValueError: If the query is a string but no embedding function is available
|
|
"""
|
|
from chromadb.execution.expression.operator import Knn
|
|
|
|
if not isinstance(knn, Knn):
|
|
return knn
|
|
|
|
# If query is not a string, nothing to do
|
|
if not isinstance(knn.query, str):
|
|
return knn
|
|
|
|
query_text = knn.query
|
|
key = knn.key
|
|
|
|
# Handle main embedding field
|
|
if key == EMBEDDING_KEY:
|
|
# Use the collection's main embedding function
|
|
embedding = self._embed(input=[query_text], is_query=True)
|
|
if not embedding or len(embedding) != 1:
|
|
raise ValueError(
|
|
"Embedding function returned unexpected number of embeddings"
|
|
)
|
|
# Return a new Knn with the embedded query
|
|
return Knn(
|
|
query=embedding[0],
|
|
key=knn.key,
|
|
limit=knn.limit,
|
|
default=knn.default,
|
|
return_rank=knn.return_rank,
|
|
)
|
|
|
|
# Handle metadata field with potential sparse embedding
|
|
schema = self.schema
|
|
if schema is None or key not in schema.keys:
|
|
raise ValueError(
|
|
f"Cannot embed string query for key '{key}': "
|
|
f"key not found in schema. Please provide an embedded vector or "
|
|
f"configure an embedding function for this key in the schema."
|
|
)
|
|
|
|
value_type = schema.keys[key]
|
|
|
|
# Check for sparse vector with embedding function
|
|
if value_type.sparse_vector is not None:
|
|
sparse_index = value_type.sparse_vector.sparse_vector_index
|
|
if sparse_index is not None and sparse_index.enabled:
|
|
sparse_config = sparse_index.config
|
|
if sparse_config.embedding_function is not None:
|
|
embedding_func = sparse_config.embedding_function
|
|
if not isinstance(embedding_func, SparseEmbeddingFunction):
|
|
embedding_func = cast(
|
|
SparseEmbeddingFunction[Any], embedding_func
|
|
)
|
|
validate_sparse_embedding_function(embedding_func)
|
|
|
|
# Embed the query
|
|
sparse_embedding = self._sparse_embed(
|
|
input=[query_text],
|
|
sparse_embedding_function=embedding_func,
|
|
is_query=True,
|
|
)
|
|
|
|
if not sparse_embedding or len(sparse_embedding) != 1:
|
|
raise ValueError(
|
|
"Sparse embedding function returned unexpected number of embeddings"
|
|
)
|
|
|
|
# Return a new Knn with the sparse embedding
|
|
return Knn(
|
|
query=sparse_embedding[0],
|
|
key=knn.key,
|
|
limit=knn.limit,
|
|
default=knn.default,
|
|
return_rank=knn.return_rank,
|
|
)
|
|
|
|
# Check for dense vector with embedding function (float_list)
|
|
if value_type.float_list is not None:
|
|
vector_index = value_type.float_list.vector_index
|
|
if vector_index is not None and vector_index.enabled:
|
|
dense_config = vector_index.config
|
|
if dense_config.embedding_function is not None:
|
|
embedding_func = dense_config.embedding_function
|
|
validate_embedding_function(embedding_func)
|
|
|
|
# Embed the query using the schema's embedding function
|
|
try:
|
|
embeddings = embedding_func.embed_query(input=[query_text])
|
|
except AttributeError:
|
|
# Fallback if embed_query doesn't exist
|
|
embeddings = embedding_func([query_text])
|
|
|
|
if not embeddings or len(embeddings) != 1:
|
|
raise ValueError(
|
|
"Embedding function returned unexpected number of embeddings"
|
|
)
|
|
|
|
# Return a new Knn with the dense embedding
|
|
return Knn(
|
|
query=embeddings[0],
|
|
key=knn.key,
|
|
limit=knn.limit,
|
|
default=knn.default,
|
|
return_rank=knn.return_rank,
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Cannot embed string query for key '{key}': "
|
|
f"no embedding function configured for this key in the schema. "
|
|
f"Please provide an embedded vector or configure an embedding function."
|
|
)
|
|
|
|
def _embed_rank_string_queries(self, rank: Any) -> Any:
|
|
"""Recursively embed string queries in Rank expressions.
|
|
|
|
Args:
|
|
rank: A Rank expression that may contain Knn objects with string queries
|
|
|
|
Returns:
|
|
A Rank expression with all string queries embedded
|
|
"""
|
|
# Import here to avoid circular dependency
|
|
from chromadb.execution.expression.operator import (
|
|
Knn,
|
|
Abs,
|
|
Div,
|
|
Exp,
|
|
Log,
|
|
Max,
|
|
Min,
|
|
Mul,
|
|
Sub,
|
|
Sum,
|
|
Val,
|
|
Rrf,
|
|
)
|
|
|
|
if rank is None:
|
|
return None
|
|
|
|
# Base case: Knn - embed if it has a string query
|
|
if isinstance(rank, Knn):
|
|
return self._embed_knn_string_queries(rank)
|
|
|
|
# Base case: Val - no embedding needed
|
|
if isinstance(rank, Val):
|
|
return rank
|
|
|
|
# Recursive cases: walk through child ranks
|
|
if isinstance(rank, Abs):
|
|
return Abs(self._embed_rank_string_queries(rank.rank))
|
|
|
|
if isinstance(rank, Div):
|
|
return Div(
|
|
self._embed_rank_string_queries(rank.left),
|
|
self._embed_rank_string_queries(rank.right),
|
|
)
|
|
|
|
if isinstance(rank, Exp):
|
|
return Exp(self._embed_rank_string_queries(rank.rank))
|
|
|
|
if isinstance(rank, Log):
|
|
return Log(self._embed_rank_string_queries(rank.rank))
|
|
|
|
if isinstance(rank, Max):
|
|
return Max([self._embed_rank_string_queries(r) for r in rank.ranks])
|
|
|
|
if isinstance(rank, Min):
|
|
return Min([self._embed_rank_string_queries(r) for r in rank.ranks])
|
|
|
|
if isinstance(rank, Mul):
|
|
return Mul([self._embed_rank_string_queries(r) for r in rank.ranks])
|
|
|
|
if isinstance(rank, Sub):
|
|
return Sub(
|
|
self._embed_rank_string_queries(rank.left),
|
|
self._embed_rank_string_queries(rank.right),
|
|
)
|
|
|
|
if isinstance(rank, Sum):
|
|
return Sum([self._embed_rank_string_queries(r) for r in rank.ranks])
|
|
|
|
if isinstance(rank, Rrf):
|
|
return Rrf(
|
|
ranks=[self._embed_rank_string_queries(r) for r in rank.ranks],
|
|
k=rank.k,
|
|
weights=rank.weights,
|
|
normalize=rank.normalize,
|
|
)
|
|
|
|
# Unknown rank type - return as is
|
|
return rank
|
|
|
|
def _embed_search_string_queries(self, search: Any) -> Any:
|
|
"""Embed string queries in a Search object.
|
|
|
|
Args:
|
|
search: A Search object that may contain Knn objects with string queries
|
|
|
|
Returns:
|
|
A Search object with all string queries embedded
|
|
"""
|
|
# Import here to avoid circular dependency
|
|
from chromadb.execution.expression.plan import Search
|
|
|
|
if not isinstance(search, Search):
|
|
return search
|
|
|
|
# Embed the rank expression if it exists
|
|
embedded_rank = self._embed_rank_string_queries(search._rank)
|
|
|
|
# Create a new Search with the embedded rank
|
|
return Search(
|
|
where=search._where,
|
|
rank=embedded_rank,
|
|
group_by=search._group_by,
|
|
limit=search._limit,
|
|
select=search._select,
|
|
)
|