883 lines
31 KiB
Python
883 lines
31 KiB
Python
from typing import TypedDict, Dict, Any, Optional, cast, get_args
|
|
import json
|
|
from chromadb.api.types import (
|
|
Space,
|
|
CollectionMetadata,
|
|
UpdateMetadata,
|
|
EmbeddingFunction,
|
|
)
|
|
from chromadb.utils.embedding_functions import (
|
|
known_embedding_functions,
|
|
register_embedding_function,
|
|
)
|
|
from multiprocessing import cpu_count
|
|
import warnings
|
|
|
|
from chromadb.api.types import Schema
|
|
|
|
|
|
class HNSWConfiguration(TypedDict, total=False):
|
|
space: Space
|
|
ef_construction: int
|
|
max_neighbors: int
|
|
ef_search: int
|
|
num_threads: int
|
|
batch_size: int
|
|
sync_threshold: int
|
|
resize_factor: float
|
|
|
|
|
|
class SpannConfiguration(TypedDict, total=False):
|
|
search_nprobe: int
|
|
write_nprobe: int
|
|
space: Space
|
|
ef_construction: int
|
|
ef_search: int
|
|
max_neighbors: int
|
|
reassign_neighbor_count: int
|
|
split_threshold: int
|
|
merge_threshold: int
|
|
|
|
|
|
class CollectionConfiguration(TypedDict, total=True):
|
|
hnsw: Optional[HNSWConfiguration]
|
|
spann: Optional[SpannConfiguration]
|
|
embedding_function: Optional[EmbeddingFunction] # type: ignore
|
|
|
|
|
|
def load_collection_configuration_from_json_str(
|
|
config_json_str: str,
|
|
) -> CollectionConfiguration:
|
|
config_json_map = json.loads(config_json_str)
|
|
return load_collection_configuration_from_json(config_json_map)
|
|
|
|
|
|
# TODO: make warnings prettier and add link to migration docs
|
|
def load_collection_configuration_from_json(
|
|
config_json_map: Dict[str, Any]
|
|
) -> CollectionConfiguration:
|
|
if (
|
|
config_json_map.get("spann") is not None
|
|
and config_json_map.get("hnsw") is not None
|
|
):
|
|
raise ValueError("hnsw and spann cannot both be provided")
|
|
|
|
hnsw_config = None
|
|
spann_config = None
|
|
ef_config = None
|
|
|
|
# Process vector index configuration (HNSW or SPANN)
|
|
if config_json_map.get("hnsw") is not None:
|
|
hnsw_config = cast(HNSWConfiguration, config_json_map["hnsw"])
|
|
if config_json_map.get("spann") is not None:
|
|
spann_config = cast(SpannConfiguration, config_json_map["spann"])
|
|
|
|
# Process embedding function configuration
|
|
if config_json_map.get("embedding_function") is not None:
|
|
ef_config = config_json_map["embedding_function"]
|
|
if ef_config["type"] == "legacy":
|
|
warnings.warn(
|
|
"legacy embedding function config",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
ef = None
|
|
else:
|
|
try:
|
|
ef_name = ef_config["name"]
|
|
except KeyError:
|
|
raise ValueError(
|
|
f"Embedding function name not found in config: {ef_config}"
|
|
)
|
|
try:
|
|
ef = known_embedding_functions[ef_name]
|
|
except KeyError:
|
|
raise ValueError(
|
|
f"Embedding function {ef_name} not found. Add @register_embedding_function decorator to the class definition."
|
|
)
|
|
try:
|
|
ef = ef.build_from_config(ef_config["config"]) # type: ignore
|
|
except Exception as e:
|
|
raise ValueError(
|
|
f"Could not build embedding function {ef_config['name']} from config {ef_config['config']}: {e}"
|
|
)
|
|
else:
|
|
ef = None
|
|
|
|
return CollectionConfiguration(
|
|
hnsw=hnsw_config,
|
|
spann=spann_config,
|
|
embedding_function=ef, # type: ignore
|
|
)
|
|
|
|
|
|
def collection_configuration_to_json_str(config: CollectionConfiguration) -> str:
|
|
return json.dumps(collection_configuration_to_json(config))
|
|
|
|
|
|
def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[str, Any]:
|
|
if isinstance(config, dict):
|
|
hnsw_config = config.get("hnsw")
|
|
spann_config = config.get("spann")
|
|
ef = config.get("embedding_function")
|
|
else:
|
|
try:
|
|
hnsw_config = config.get_parameter("hnsw").value
|
|
except ValueError:
|
|
hnsw_config = None
|
|
try:
|
|
spann_config = config.get_parameter("spann").value
|
|
except ValueError:
|
|
spann_config = None
|
|
try:
|
|
ef = config.get_parameter("embedding_function").value
|
|
except ValueError:
|
|
ef = None
|
|
|
|
ef_config: Dict[str, Any] | None = None
|
|
if hnsw_config is not None:
|
|
try:
|
|
hnsw_config = cast(HNSWConfiguration, hnsw_config)
|
|
except Exception as e:
|
|
raise ValueError(f"not a valid hnsw config: {e}")
|
|
if spann_config is not None:
|
|
try:
|
|
spann_config = cast(SpannConfiguration, spann_config)
|
|
except Exception as e:
|
|
raise ValueError(f"not a valid spann config: {e}")
|
|
|
|
if ef is None:
|
|
ef = None
|
|
ef_config = {"type": "legacy"}
|
|
|
|
if ef is not None:
|
|
try:
|
|
if ef.is_legacy():
|
|
ef_config = {"type": "legacy"}
|
|
else:
|
|
ef_config = {
|
|
"name": ef.name(),
|
|
"type": "known",
|
|
"config": ef.get_config(),
|
|
}
|
|
register_embedding_function(type(ef)) # type: ignore
|
|
except Exception as e:
|
|
warnings.warn(
|
|
f"legacy embedding function config: {e}",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
ef = None
|
|
ef_config = {"type": "legacy"}
|
|
|
|
return {
|
|
"hnsw": hnsw_config,
|
|
"spann": spann_config,
|
|
"embedding_function": ef_config,
|
|
}
|
|
|
|
|
|
class CreateHNSWConfiguration(TypedDict, total=False):
|
|
space: Space
|
|
ef_construction: int
|
|
max_neighbors: int
|
|
ef_search: int
|
|
num_threads: int
|
|
batch_size: int
|
|
sync_threshold: int
|
|
resize_factor: float
|
|
|
|
|
|
def json_to_create_hnsw_configuration(
|
|
json_map: Dict[str, Any]
|
|
) -> CreateHNSWConfiguration:
|
|
config: CreateHNSWConfiguration = {}
|
|
if "space" in json_map:
|
|
space_value = json_map["space"]
|
|
if space_value in get_args(Space):
|
|
config["space"] = space_value
|
|
else:
|
|
raise ValueError(f"not a valid space: {space_value}")
|
|
if "ef_construction" in json_map:
|
|
config["ef_construction"] = json_map["ef_construction"]
|
|
if "max_neighbors" in json_map:
|
|
config["max_neighbors"] = json_map["max_neighbors"]
|
|
if "ef_search" in json_map:
|
|
config["ef_search"] = json_map["ef_search"]
|
|
if "num_threads" in json_map:
|
|
config["num_threads"] = json_map["num_threads"]
|
|
if "batch_size" in json_map:
|
|
config["batch_size"] = json_map["batch_size"]
|
|
if "sync_threshold" in json_map:
|
|
config["sync_threshold"] = json_map["sync_threshold"]
|
|
if "resize_factor" in json_map:
|
|
config["resize_factor"] = json_map["resize_factor"]
|
|
return config
|
|
|
|
|
|
class CreateSpannConfiguration(TypedDict, total=False):
|
|
search_nprobe: int
|
|
write_nprobe: int
|
|
space: Space
|
|
ef_construction: int
|
|
ef_search: int
|
|
max_neighbors: int
|
|
reassign_neighbor_count: int
|
|
split_threshold: int
|
|
merge_threshold: int
|
|
|
|
|
|
def json_to_create_spann_configuration(
|
|
json_map: Dict[str, Any]
|
|
) -> CreateSpannConfiguration:
|
|
config: CreateSpannConfiguration = {}
|
|
if "search_nprobe" in json_map:
|
|
config["search_nprobe"] = json_map["search_nprobe"]
|
|
if "write_nprobe" in json_map:
|
|
config["write_nprobe"] = json_map["write_nprobe"]
|
|
if "space" in json_map:
|
|
space_value = json_map["space"]
|
|
if space_value in get_args(Space):
|
|
config["space"] = space_value
|
|
else:
|
|
raise ValueError(f"not a valid space: {space_value}")
|
|
if "ef_construction" in json_map:
|
|
config["ef_construction"] = json_map["ef_construction"]
|
|
if "ef_search" in json_map:
|
|
config["ef_search"] = json_map["ef_search"]
|
|
if "max_neighbors" in json_map:
|
|
config["max_neighbors"] = json_map["max_neighbors"]
|
|
return config
|
|
|
|
|
|
class CreateCollectionConfiguration(TypedDict, total=False):
|
|
hnsw: Optional[CreateHNSWConfiguration]
|
|
spann: Optional[CreateSpannConfiguration]
|
|
embedding_function: Optional[EmbeddingFunction] # type: ignore
|
|
|
|
|
|
def create_collection_configuration_from_legacy_collection_metadata(
|
|
metadata: CollectionMetadata,
|
|
) -> CreateCollectionConfiguration:
|
|
"""Create a CreateCollectionConfiguration from legacy collection metadata"""
|
|
return create_collection_configuration_from_legacy_metadata_dict(metadata)
|
|
|
|
|
|
def create_collection_configuration_from_legacy_metadata_dict(
|
|
metadata: Dict[str, Any],
|
|
) -> CreateCollectionConfiguration:
|
|
"""Create a CreateCollectionConfiguration from legacy collection metadata"""
|
|
old_to_new = {
|
|
"hnsw:space": "space",
|
|
"hnsw:construction_ef": "ef_construction",
|
|
"hnsw:M": "max_neighbors",
|
|
"hnsw:search_ef": "ef_search",
|
|
"hnsw:num_threads": "num_threads",
|
|
"hnsw:batch_size": "batch_size",
|
|
"hnsw:sync_threshold": "sync_threshold",
|
|
"hnsw:resize_factor": "resize_factor",
|
|
}
|
|
json_map = {}
|
|
for name, value in metadata.items():
|
|
if name in old_to_new:
|
|
json_map[old_to_new[name]] = value
|
|
hnsw_config = json_to_create_hnsw_configuration(json_map)
|
|
hnsw_config = populate_create_hnsw_defaults(hnsw_config)
|
|
|
|
return CreateCollectionConfiguration(hnsw=hnsw_config)
|
|
|
|
|
|
# TODO: make warnings prettier and add link to migration docs
|
|
def load_create_collection_configuration_from_json(
|
|
json_map: Dict[str, Any]
|
|
) -> CreateCollectionConfiguration:
|
|
if json_map.get("hnsw") is not None and json_map.get("spann") is not None:
|
|
raise ValueError("hnsw and spann cannot both be provided")
|
|
|
|
result = CreateCollectionConfiguration()
|
|
|
|
# Handle vector index configuration
|
|
if json_map.get("hnsw") is not None:
|
|
result["hnsw"] = json_to_create_hnsw_configuration(json_map["hnsw"])
|
|
|
|
if json_map.get("spann") is not None:
|
|
result["spann"] = json_to_create_spann_configuration(json_map["spann"])
|
|
|
|
# Handle embedding function configuration
|
|
if json_map.get("embedding_function") is not None:
|
|
ef_config = json_map["embedding_function"]
|
|
if ef_config["type"] == "legacy":
|
|
warnings.warn(
|
|
"legacy embedding function config",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
else:
|
|
ef = known_embedding_functions[ef_config["name"]]
|
|
result["embedding_function"] = ef.build_from_config(ef_config["config"])
|
|
|
|
return result
|
|
|
|
|
|
def create_collection_configuration_to_json_str(
|
|
config: CreateCollectionConfiguration,
|
|
metadata: Optional[CollectionMetadata] = None,
|
|
) -> str:
|
|
"""Convert a CreateCollection configuration to a JSON-serializable string"""
|
|
return json.dumps(create_collection_configuration_to_json(config, metadata))
|
|
|
|
|
|
# TODO: make warnings prettier and add link to migration docs
|
|
def create_collection_configuration_to_json(
|
|
config: CreateCollectionConfiguration,
|
|
metadata: Optional[CollectionMetadata] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Convert a CreateCollection configuration to a JSON-serializable dict"""
|
|
ef_config: Dict[str, Any] | None = None
|
|
hnsw_config = config.get("hnsw")
|
|
spann_config = config.get("spann")
|
|
if hnsw_config is not None:
|
|
try:
|
|
hnsw_config = cast(CreateHNSWConfiguration, hnsw_config)
|
|
except Exception as e:
|
|
raise ValueError(f"not a valid hnsw config: {e}")
|
|
if spann_config is not None:
|
|
try:
|
|
spann_config = cast(CreateSpannConfiguration, spann_config)
|
|
except Exception as e:
|
|
raise ValueError(f"not a valid spann config: {e}")
|
|
|
|
if hnsw_config is not None and spann_config is not None:
|
|
raise ValueError("hnsw and spann cannot both be provided")
|
|
|
|
if config.get("embedding_function") is None:
|
|
ef = None
|
|
ef_config = {"type": "legacy"}
|
|
return {
|
|
"hnsw": hnsw_config,
|
|
"spann": spann_config,
|
|
"embedding_function": ef_config,
|
|
}
|
|
|
|
try:
|
|
ef = cast(EmbeddingFunction, config.get("embedding_function")) # type: ignore
|
|
if ef.is_legacy():
|
|
ef_config = {"type": "legacy"}
|
|
else:
|
|
# default space logic: if neither hnsw nor spann config is provided and metadata doesn't have space,
|
|
# then populate space from ef
|
|
# otherwise dont use default space from ef
|
|
|
|
# then validate the space afterwards based on the supported spaces of the embedding function,
|
|
# warn if space is not supported
|
|
|
|
if hnsw_config is None and spann_config is None:
|
|
if metadata is None or metadata.get("hnsw:space") is None:
|
|
# this populates space from ef if not provided in either config
|
|
hnsw_config = CreateHNSWConfiguration(space=ef.default_space())
|
|
|
|
# if hnsw config or spann config exists but space is not provided, populate it from ef
|
|
if hnsw_config is not None and hnsw_config.get("space") is None:
|
|
hnsw_config["space"] = ef.default_space()
|
|
if spann_config is not None and spann_config.get("space") is None:
|
|
spann_config["space"] = ef.default_space()
|
|
|
|
# Validate space compatibility with embedding function
|
|
if hnsw_config is not None:
|
|
if hnsw_config.get("space") not in ef.supported_spaces():
|
|
warnings.warn(
|
|
f"space {hnsw_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
if spann_config is not None:
|
|
if spann_config.get("space") not in ef.supported_spaces():
|
|
warnings.warn(
|
|
f"space {spann_config.get('space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
# only validate space from metadata if config is not provided
|
|
if (
|
|
hnsw_config is None
|
|
and spann_config is None
|
|
and metadata is not None
|
|
and metadata.get("hnsw:space") is not None
|
|
):
|
|
if metadata.get("hnsw:space") not in ef.supported_spaces():
|
|
warnings.warn(
|
|
f"space {metadata.get('hnsw:space')} is not supported by {ef.name()}. Supported spaces: {ef.supported_spaces()}",
|
|
UserWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
ef_config = {
|
|
"name": ef.name(),
|
|
"type": "known",
|
|
"config": ef.get_config(),
|
|
}
|
|
register_embedding_function(type(ef)) # type: ignore
|
|
except Exception as e:
|
|
warnings.warn(
|
|
f"legacy embedding function config: {e}",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
ef = None
|
|
ef_config = {"type": "legacy"}
|
|
|
|
return {
|
|
"hnsw": hnsw_config,
|
|
"spann": spann_config,
|
|
"embedding_function": ef_config,
|
|
}
|
|
|
|
|
|
def populate_create_hnsw_defaults(
|
|
config: CreateHNSWConfiguration, ef: Optional[EmbeddingFunction] = None # type: ignore
|
|
) -> CreateHNSWConfiguration:
|
|
"""Populate a CreateHNSW configuration with default values"""
|
|
if config.get("space") is None:
|
|
config["space"] = ef.default_space() if ef else "l2"
|
|
if config.get("ef_construction") is None:
|
|
config["ef_construction"] = 100
|
|
if config.get("max_neighbors") is None:
|
|
config["max_neighbors"] = 16
|
|
if config.get("ef_search") is None:
|
|
config["ef_search"] = 100
|
|
if config.get("num_threads") is None:
|
|
config["num_threads"] = cpu_count()
|
|
if config.get("batch_size") is None:
|
|
config["batch_size"] = 100
|
|
if config.get("sync_threshold") is None:
|
|
config["sync_threshold"] = 1000
|
|
if config.get("resize_factor") is None:
|
|
config["resize_factor"] = 1.2
|
|
return config
|
|
|
|
|
|
class UpdateHNSWConfiguration(TypedDict, total=False):
|
|
ef_search: int
|
|
num_threads: int
|
|
batch_size: int
|
|
sync_threshold: int
|
|
resize_factor: float
|
|
|
|
|
|
def json_to_update_hnsw_configuration(
|
|
json_map: Dict[str, Any]
|
|
) -> UpdateHNSWConfiguration:
|
|
config: UpdateHNSWConfiguration = {}
|
|
if "ef_search" in json_map:
|
|
config["ef_search"] = json_map["ef_search"]
|
|
if "num_threads" in json_map:
|
|
config["num_threads"] = json_map["num_threads"]
|
|
if "batch_size" in json_map:
|
|
config["batch_size"] = json_map["batch_size"]
|
|
if "sync_threshold" in json_map:
|
|
config["sync_threshold"] = json_map["sync_threshold"]
|
|
if "resize_factor" in json_map:
|
|
config["resize_factor"] = json_map["resize_factor"]
|
|
return config
|
|
|
|
|
|
class UpdateSpannConfiguration(TypedDict, total=False):
|
|
search_nprobe: int
|
|
ef_search: int
|
|
|
|
|
|
def json_to_update_spann_configuration(
|
|
json_map: Dict[str, Any]
|
|
) -> UpdateSpannConfiguration:
|
|
config: UpdateSpannConfiguration = {}
|
|
if "search_nprobe" in json_map:
|
|
config["search_nprobe"] = json_map["search_nprobe"]
|
|
if "ef_search" in json_map:
|
|
config["ef_search"] = json_map["ef_search"]
|
|
return config
|
|
|
|
|
|
class UpdateCollectionConfiguration(TypedDict, total=False):
|
|
hnsw: Optional[UpdateHNSWConfiguration]
|
|
spann: Optional[UpdateSpannConfiguration]
|
|
embedding_function: Optional[EmbeddingFunction] # type: ignore
|
|
|
|
|
|
def update_collection_configuration_from_legacy_collection_metadata(
|
|
metadata: CollectionMetadata,
|
|
) -> UpdateCollectionConfiguration:
|
|
"""Create an UpdateCollectionConfiguration from legacy collection metadata"""
|
|
old_to_new = {
|
|
"hnsw:search_ef": "ef_search",
|
|
"hnsw:num_threads": "num_threads",
|
|
"hnsw:batch_size": "batch_size",
|
|
"hnsw:sync_threshold": "sync_threshold",
|
|
"hnsw:resize_factor": "resize_factor",
|
|
}
|
|
json_map = {}
|
|
for name, value in metadata.items():
|
|
if name in old_to_new:
|
|
json_map[old_to_new[name]] = value
|
|
hnsw_config = json_to_update_hnsw_configuration(json_map)
|
|
return UpdateCollectionConfiguration(hnsw=hnsw_config)
|
|
|
|
|
|
def update_collection_configuration_from_legacy_update_metadata(
|
|
metadata: UpdateMetadata,
|
|
) -> UpdateCollectionConfiguration:
|
|
"""Create an UpdateCollectionConfiguration from legacy update metadata"""
|
|
old_to_new = {
|
|
"hnsw:search_ef": "ef_search",
|
|
"hnsw:num_threads": "num_threads",
|
|
"hnsw:batch_size": "batch_size",
|
|
"hnsw:sync_threshold": "sync_threshold",
|
|
"hnsw:resize_factor": "resize_factor",
|
|
}
|
|
json_map = {}
|
|
for name, value in metadata.items():
|
|
if name in old_to_new:
|
|
json_map[old_to_new[name]] = value
|
|
hnsw_config = json_to_update_hnsw_configuration(json_map)
|
|
return UpdateCollectionConfiguration(hnsw=hnsw_config)
|
|
|
|
|
|
def update_collection_configuration_to_json_str(
|
|
config: UpdateCollectionConfiguration,
|
|
) -> str:
|
|
"""Convert an UpdateCollectionConfiguration to a JSON-serializable string"""
|
|
json_dict = update_collection_configuration_to_json(config)
|
|
return json.dumps(json_dict)
|
|
|
|
|
|
def update_collection_configuration_to_json(
|
|
config: UpdateCollectionConfiguration,
|
|
) -> Dict[str, Any]:
|
|
"""Convert an UpdateCollectionConfiguration to a JSON-serializable dict"""
|
|
hnsw_config = config.get("hnsw")
|
|
spann_config = config.get("spann")
|
|
ef = config.get("embedding_function")
|
|
if hnsw_config is None and spann_config is None and ef is None:
|
|
return {}
|
|
|
|
if hnsw_config is not None:
|
|
try:
|
|
hnsw_config = cast(UpdateHNSWConfiguration, hnsw_config)
|
|
except Exception as e:
|
|
raise ValueError(f"not a valid hnsw config: {e}")
|
|
|
|
if spann_config is not None:
|
|
try:
|
|
spann_config = cast(UpdateSpannConfiguration, spann_config)
|
|
except Exception as e:
|
|
raise ValueError(f"not a valid spann config: {e}")
|
|
|
|
ef_config: Dict[str, Any] | None = None
|
|
if ef is not None:
|
|
if ef.is_legacy():
|
|
ef_config = {"type": "legacy"}
|
|
else:
|
|
ef.validate_config(ef.get_config())
|
|
ef_config = {
|
|
"name": ef.name(),
|
|
"type": "known",
|
|
"config": ef.get_config(),
|
|
}
|
|
register_embedding_function(type(ef)) # type: ignore
|
|
else:
|
|
ef_config = None
|
|
|
|
return {
|
|
"hnsw": hnsw_config,
|
|
"spann": spann_config,
|
|
"embedding_function": ef_config,
|
|
}
|
|
|
|
|
|
def load_update_collection_configuration_from_json_str(
|
|
json_str: str,
|
|
) -> UpdateCollectionConfiguration:
|
|
json_map = json.loads(json_str)
|
|
return load_update_collection_configuration_from_json(json_map)
|
|
|
|
|
|
# TODO: make warnings prettier and add link to migration docs
|
|
def load_update_collection_configuration_from_json(
|
|
json_map: Dict[str, Any]
|
|
) -> UpdateCollectionConfiguration:
|
|
"""Convert a JSON dict to an UpdateCollectionConfiguration"""
|
|
if json_map.get("hnsw") is not None and json_map.get("spann") is not None:
|
|
raise ValueError("hnsw and spann cannot both be provided")
|
|
|
|
result = UpdateCollectionConfiguration()
|
|
|
|
# Handle vector index configurations
|
|
if json_map.get("hnsw") is not None:
|
|
result["hnsw"] = json_to_update_hnsw_configuration(json_map["hnsw"])
|
|
|
|
if json_map.get("spann") is not None:
|
|
result["spann"] = json_to_update_spann_configuration(json_map["spann"])
|
|
|
|
# Handle embedding function
|
|
if json_map.get("embedding_function") is not None:
|
|
if json_map["embedding_function"]["type"] == "legacy":
|
|
warnings.warn(
|
|
"legacy embedding function config",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
else:
|
|
ef = known_embedding_functions[json_map["embedding_function"]["name"]]
|
|
result["embedding_function"] = ef.build_from_config(
|
|
json_map["embedding_function"]["config"]
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
def overwrite_hnsw_configuration(
|
|
existing_hnsw_config: HNSWConfiguration, update_hnsw_config: UpdateHNSWConfiguration
|
|
) -> HNSWConfiguration:
|
|
"""Overwrite a HNSWConfiguration with a new configuration"""
|
|
# Create a copy of the existing config and update with new values
|
|
result = dict(existing_hnsw_config)
|
|
update_fields = [
|
|
"ef_search",
|
|
"num_threads",
|
|
"batch_size",
|
|
"sync_threshold",
|
|
"resize_factor",
|
|
]
|
|
|
|
for field in update_fields:
|
|
if field in update_hnsw_config:
|
|
result[field] = update_hnsw_config[field] # type: ignore
|
|
|
|
return cast(HNSWConfiguration, result)
|
|
|
|
|
|
def overwrite_spann_configuration(
|
|
existing_spann_config: SpannConfiguration,
|
|
update_spann_config: UpdateSpannConfiguration,
|
|
) -> SpannConfiguration:
|
|
"""Overwrite a SpannConfiguration with a new configuration"""
|
|
result = dict(existing_spann_config)
|
|
update_fields = [
|
|
"search_nprobe",
|
|
"ef_search",
|
|
]
|
|
|
|
for field in update_fields:
|
|
if field in update_spann_config:
|
|
result[field] = update_spann_config[field] # type: ignore
|
|
|
|
return cast(SpannConfiguration, result)
|
|
|
|
|
|
# TODO: make warnings prettier and add link to migration docs
|
|
def overwrite_embedding_function(
|
|
existing_embedding_function: EmbeddingFunction, # type: ignore
|
|
update_embedding_function: EmbeddingFunction, # type: ignore
|
|
) -> EmbeddingFunction: # type: ignore
|
|
"""Overwrite an EmbeddingFunction with a new configuration"""
|
|
# Check for legacy embedding functions
|
|
if existing_embedding_function.is_legacy() or update_embedding_function.is_legacy():
|
|
warnings.warn(
|
|
"cannot update legacy embedding function config",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
return existing_embedding_function
|
|
|
|
# Validate function compatibility
|
|
if existing_embedding_function.name() != update_embedding_function.name():
|
|
raise ValueError(
|
|
f"Cannot update embedding function: incompatible types "
|
|
f"({existing_embedding_function.name()} vs {update_embedding_function.name()})"
|
|
)
|
|
|
|
# Validate and apply the configuration update
|
|
update_embedding_function.validate_config_update(
|
|
existing_embedding_function.get_config(), update_embedding_function.get_config()
|
|
)
|
|
return update_embedding_function
|
|
|
|
|
|
def overwrite_collection_configuration(
|
|
existing_config: CollectionConfiguration,
|
|
update_config: UpdateCollectionConfiguration,
|
|
) -> CollectionConfiguration:
|
|
"""Overwrite a CollectionConfiguration with a new configuration"""
|
|
update_spann = update_config.get("spann")
|
|
update_hnsw = update_config.get("hnsw")
|
|
if update_spann is not None and update_hnsw is not None:
|
|
raise ValueError("hnsw and spann cannot both be provided")
|
|
|
|
# Handle HNSW configuration update
|
|
|
|
updated_hnsw_config = existing_config.get("hnsw")
|
|
if updated_hnsw_config is not None and update_hnsw is not None:
|
|
updated_hnsw_config = overwrite_hnsw_configuration(
|
|
updated_hnsw_config, update_hnsw
|
|
)
|
|
|
|
# Handle SPANN configuration update
|
|
updated_spann_config = existing_config.get("spann")
|
|
if updated_spann_config is not None and update_spann is not None:
|
|
updated_spann_config = overwrite_spann_configuration(
|
|
updated_spann_config, update_spann
|
|
)
|
|
|
|
# Handle embedding function update
|
|
updated_embedding_function = existing_config.get("embedding_function")
|
|
update_ef = update_config.get("embedding_function")
|
|
if update_ef is not None:
|
|
if updated_embedding_function is not None:
|
|
updated_embedding_function = overwrite_embedding_function(
|
|
updated_embedding_function, update_ef
|
|
)
|
|
else:
|
|
updated_embedding_function = update_ef
|
|
|
|
return CollectionConfiguration(
|
|
hnsw=updated_hnsw_config,
|
|
spann=updated_spann_config,
|
|
embedding_function=updated_embedding_function,
|
|
)
|
|
|
|
|
|
def validate_embedding_function_conflict_on_create(
|
|
embedding_function: Optional[EmbeddingFunction], # type: ignore
|
|
configuration_ef: Optional[EmbeddingFunction], # type: ignore
|
|
) -> None:
|
|
"""
|
|
Validates that there are no conflicting embedding functions between function parameter
|
|
and collection configuration.
|
|
|
|
Args:
|
|
embedding_function: The embedding function provided as a parameter
|
|
configuration_ef: The embedding function from collection configuration
|
|
|
|
Returns:
|
|
bool: True if there is a conflict, False otherwise
|
|
"""
|
|
# If ef provided in function params and collection config, check if they are the same
|
|
# If not, there's a conflict
|
|
# ef is by default "default" if not provided, so ignore that case.
|
|
if embedding_function is not None and configuration_ef is not None:
|
|
if (
|
|
embedding_function.name() != "default"
|
|
and embedding_function.name() != configuration_ef.name()
|
|
):
|
|
raise ValueError(
|
|
f"Multiple embedding functions provided. Please provide only one. Embedding function conflict: {embedding_function.name()} vs {configuration_ef.name()}"
|
|
)
|
|
return None
|
|
|
|
|
|
# The reason to use the config on get, rather than build the ef is because
|
|
# if there is an issue with deserializing the config, an error shouldn't be raised
|
|
# at get time. CollectionCommon.py will raise an error at _embed time if there is an issue deserializing.
|
|
def validate_embedding_function_conflict_on_get(
|
|
embedding_function: Optional[EmbeddingFunction], # type: ignore
|
|
persisted_ef_config: Optional[Dict[str, Any]],
|
|
) -> None:
|
|
"""
|
|
Validates that there are no conflicting embedding functions between function parameter
|
|
and collection configuration.
|
|
"""
|
|
if persisted_ef_config is not None and embedding_function is not None:
|
|
if (
|
|
embedding_function.name() != "default"
|
|
and persisted_ef_config.get("name") is not None
|
|
and persisted_ef_config.get("name") != embedding_function.name()
|
|
):
|
|
raise ValueError(
|
|
f"An embedding function already exists in the collection configuration, and a new one is provided. If this is intentional, please embed documents separately. Embedding function conflict: new: {embedding_function.name()} vs persisted: {persisted_ef_config.get('name')}"
|
|
)
|
|
return None
|
|
|
|
|
|
def update_schema_from_collection_configuration(
|
|
schema: "Schema", configuration: "UpdateCollectionConfiguration"
|
|
) -> "Schema":
|
|
"""
|
|
Updates a schema with configuration changes.
|
|
Only updates fields that are present in the configuration update.
|
|
|
|
Args:
|
|
schema: The existing Schema object
|
|
configuration: The configuration updates to apply
|
|
|
|
Returns:
|
|
Updated Schema object
|
|
"""
|
|
|
|
# Get the vector index from defaults and #embedding key
|
|
if (
|
|
schema.defaults.float_list is None
|
|
or schema.defaults.float_list.vector_index is None
|
|
):
|
|
raise ValueError("Schema is missing defaults.float_list.vector_index")
|
|
|
|
embedding_key = "#embedding"
|
|
if embedding_key not in schema.keys:
|
|
raise ValueError(f"Schema is missing keys[{embedding_key}]")
|
|
|
|
embedding_value_types = schema.keys[embedding_key]
|
|
if (
|
|
embedding_value_types.float_list is None
|
|
or embedding_value_types.float_list.vector_index is None
|
|
):
|
|
raise ValueError(
|
|
f"Schema is missing keys[{embedding_key}].float_list.vector_index"
|
|
)
|
|
|
|
# Update vector index config in both locations
|
|
for vector_index in [
|
|
schema.defaults.float_list.vector_index,
|
|
embedding_value_types.float_list.vector_index,
|
|
]:
|
|
if "hnsw" in configuration and configuration["hnsw"] is not None:
|
|
# Update HNSW config
|
|
if vector_index.config.hnsw is None:
|
|
raise ValueError("Trying to update HNSW config but schema has SPANN")
|
|
|
|
hnsw_config = vector_index.config.hnsw
|
|
update_hnsw = configuration["hnsw"]
|
|
|
|
# Only update fields that are present in the update
|
|
if "ef_search" in update_hnsw:
|
|
hnsw_config.ef_search = update_hnsw["ef_search"]
|
|
if "num_threads" in update_hnsw:
|
|
hnsw_config.num_threads = update_hnsw["num_threads"]
|
|
if "batch_size" in update_hnsw:
|
|
hnsw_config.batch_size = update_hnsw["batch_size"]
|
|
if "sync_threshold" in update_hnsw:
|
|
hnsw_config.sync_threshold = update_hnsw["sync_threshold"]
|
|
if "resize_factor" in update_hnsw:
|
|
hnsw_config.resize_factor = update_hnsw["resize_factor"]
|
|
|
|
elif "spann" in configuration and configuration["spann"] is not None:
|
|
# Update SPANN config
|
|
if vector_index.config.spann is None:
|
|
raise ValueError("Trying to update SPANN config but schema has HNSW")
|
|
|
|
spann_config = vector_index.config.spann
|
|
update_spann = configuration["spann"]
|
|
|
|
# Only update fields that are present in the update
|
|
if "search_nprobe" in update_spann:
|
|
spann_config.search_nprobe = update_spann["search_nprobe"]
|
|
if "ef_search" in update_spann:
|
|
spann_config.ef_search = update_spann["ef_search"]
|
|
|
|
# Update embedding function if present
|
|
if (
|
|
"embedding_function" in configuration
|
|
and configuration["embedding_function"] is not None
|
|
):
|
|
vector_index.config.embedding_function = configuration["embedding_function"]
|
|
|
|
return schema
|