group-wbl/.venv/lib/python3.13/site-packages/chromadb/utils/statistics.py
2026-01-09 09:48:03 +08:00

273 lines
9.2 KiB
Python

"""Utility functions for managing collection statistics.
This module provides standalone functions for enabling, disabling, and retrieving
statistics for ChromaDB collections. These functions work with the attached function
system to automatically compute metadata value frequencies.
Example:
>>> from chromadb.utils.statistics import attach_statistics_function, get_statistics
>>> import chromadb
>>>
>>> client = chromadb.Client()
>>> collection = client.get_or_create_collection("my_collection")
>>>
>>> # Attach statistics function with output collection name
>>> attach_statistics_function(collection, "my_collection_statistics")
>>>
>>> # Add some data
>>> collection.add(
... ids=["id1", "id2"],
... documents=["doc1", "doc2"],
... metadatas=[{"category": "A"}, {"category": "B"}]
... )
>>>
>>> # Get statistics from the named output collection
>>> stats = get_statistics(collection, "my_collection_statistics")
>>> print(stats)
"""
from typing import TYPE_CHECKING, Optional, Dict, Any, cast, Tuple
from collections import defaultdict
from chromadb.api.types import OneOrMany, Where, maybe_cast_one_to_many
from chromadb.api.functions import STATISTICS_FUNCTION
if TYPE_CHECKING:
from chromadb.api.models.Collection import Collection
from chromadb.api.models.AttachedFunction import AttachedFunction
def get_statistics_fn_name(collection: "Collection") -> str:
"""Generate the default name for the statistics attached function.
Args:
collection: The collection to generate the name for
Returns:
str: The statistics function name
"""
return f"{collection.name}_stats"
def attach_statistics_function(
collection: "Collection", stats_collection_name: str
) -> Tuple["AttachedFunction", bool]:
"""Attach statistics collection function to a collection.
This attaches the statistics function which will automatically compute
and update metadata value frequencies whenever records are added, updated,
or deleted.
Args:
collection: The collection to enable statistics for
stats_collection_name: Name of the collection where statistics will be stored.
Returns:
Tuple of (AttachedFunction, created) where created is True if newly created,
False if already existed (idempotent request)
Example:
>>> attached_fn, created = attach_statistics_function(collection, "my_collection_statistics")
>>> if created:
... print("Statistics function newly attached")
>>> collection.add(ids=["id1"], documents=["doc1"], metadatas=[{"key": "value"}])
>>> # Statistics are automatically computed
>>> stats = get_statistics(collection, "my_collection_statistics")
"""
return collection.attach_function(
function=STATISTICS_FUNCTION,
name=get_statistics_fn_name(collection),
output_collection=stats_collection_name,
params=None,
)
def get_statistics_fn(collection: "Collection") -> "AttachedFunction":
"""Get the statistics attached function for a collection.
Args:
collection: The collection to get the statistics function for
Returns:
AttachedFunction: The statistics function
Raises:
NotFoundError: If statistics are not enabled
AssertionError: If the attached function is not a statistics function
"""
af = collection.get_attached_function(get_statistics_fn_name(collection))
assert (
af.function_name == "statistics"
), "Attached function is not a statistics function"
return af
def detach_statistics_function(
collection: "Collection", delete_stats_collection: bool = False
) -> bool:
"""Detach statistics collection function from a collection.
Args:
collection: The collection to disable statistics for
delete_stats_collection: If True, also delete the statistics output collection.
Defaults to False.
Returns:
bool: True if successful
Example:
>>> detach_statistics_function(collection, delete_stats_collection=True)
"""
attached_fn = get_statistics_fn(collection)
return collection.detach_function(
attached_fn.name, delete_output_collection=delete_stats_collection
)
def get_statistics(
collection: "Collection",
stats_collection_name: str,
keys: Optional[OneOrMany[str]] = None,
) -> Dict[str, Any]:
"""Get the current statistics for a collection.
Statistics include frequency counts for all metadata key-value pairs,
as well as a summary with the total record count.
Args:
collection: The collection to get statistics for
stats_collection_name: Name of the statistics collection to read from.
keys: Optional metadata key(s) to filter statistics for. Can be a single key
string or a list of keys. If provided, only returns statistics for
those specific keys.
Returns:
Dict[str, Any]: A dictionary with the structure:
{
"statistics": {
"key1": {
"value1": {"count": count, ...},
"value2": {"count": count, ...}
},
"key2": {...},
...
},
"summary": {
"total_count": count
}
}
Example:
>>> attach_statistics_function(collection, "my_collection_statistics")
>>> collection.add(
... ids=["id1", "id2"],
... documents=["doc1", "doc2"],
... metadatas=[{"category": "A", "score": 10}, {"category": "B", "score": 10}]
... )
>>> # Wait for statistics to be computed
>>> stats = get_statistics(collection, "my_collection_statistics")
>>> print(stats)
{
"statistics": {
"category": {
"A": {"count": 1},
"B": {"count": 1}
},
"score": {
"10": {"count": 2}
}
},
"summary": {
"total_count": 2
}
}
Raises:
ValueError: If more than 30 keys are provided in the keys filter.
"""
# Normalize keys to list
keys_list = maybe_cast_one_to_many(keys)
# Validate keys count to avoid issues with large $in queries
MAX_KEYS = 30
if keys_list is not None and len(keys_list) > MAX_KEYS:
raise ValueError(
f"Too many keys provided: {len(keys_list)}. "
f"Maximum allowed is {MAX_KEYS} keys per request. "
"Consider calling get_statistics multiple times with smaller key batches."
)
# Import here to avoid circular dependency
from chromadb.api.models.Collection import Collection
# Get the statistics output collection model from the server
stats_collection_model = collection._client.get_collection(
name=stats_collection_name,
tenant=collection.tenant,
database=collection.database,
)
# Wrap it in a Collection object to access get/query methods
stats_collection = Collection(
client=collection._client,
model=stats_collection_model,
embedding_function=None, # Statistics collections don't need embedding functions
data_loader=None,
)
# Get all statistics records by paginating through the stats collection
stats: Dict[str, Dict[str, Dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
summary: Dict[str, Any] = {}
offset = 0
# When filtering by keys, also include "summary" entries to get total_count
where_filter: Optional[Where] = (
cast(Where, {"key": {"$in": keys_list + ["summary"]}})
if keys_list is not None
else None
)
while True:
page = stats_collection.get(
include=["metadatas"], offset=offset, where=where_filter
)
metadatas = page.get("metadatas") or []
if not metadatas:
break
for metadata in metadatas:
if metadata is None:
continue
meta_key = metadata.get("key")
value = metadata.get("value")
value_label = metadata.get("value_label")
value_type = metadata.get("type")
count = metadata.get("count")
if (
meta_key is not None
and value is not None
and value_type is not None
and count is not None
):
if meta_key == "summary":
if value == "total_count":
summary["total_count"] = count
else:
# Prioritize value_label if present, otherwise use value
stats_key = value_label if value_label is not None else value
assert isinstance(meta_key, str)
assert isinstance(stats_key, str)
assert isinstance(count, int)
stats[meta_key][stats_key]["count"] = count
# Advance to next page using the actual number of items returned
offset += len(metadatas)
result = {"statistics": dict(stats)}
if summary:
result["summary"] = summary
return result