group-wbl/.venv/lib/python3.13/site-packages/chromadb/execution/expression/operator.py
2026-01-09 09:48:03 +08:00

1513 lines
50 KiB
Python

from dataclasses import dataclass, field
from typing import Optional, List, Dict, Set, Any, Union, cast
import numpy as np
from numpy.typing import NDArray
from chromadb.api.types import (
Embeddings,
IDs,
Include,
OneOrMany,
SparseVector,
TYPE_KEY,
SPARSE_VECTOR_TYPE_VALUE,
maybe_cast_one_to_many,
normalize_embeddings,
validate_embeddings,
)
from chromadb.types import (
Collection,
RequestVersionContext,
Segment,
)
@dataclass
class Scan:
collection: Collection
knn: Segment
metadata: Segment
record: Segment
@property
def version(self) -> RequestVersionContext:
return RequestVersionContext(
collection_version=self.collection.version,
log_position=self.collection.log_position,
)
# Where expression types for filtering
@dataclass
class Where:
"""Base class for Where expressions (algebraic data type).
Supports logical operators for combining conditions:
- AND: where1 & where2
- OR: where1 | where2
Examples:
# Simple conditions
where1 = Key("status") == "active"
where2 = Key("score") > 0.5
# Combining with AND
combined_and = where1 & where2
# Combining with OR
combined_or = where1 | where2
# Complex expressions
complex_where = (Key("status") == "active") & ((Key("score") > 0.5) | (Key("priority") == "high"))
"""
def to_dict(self) -> Dict[str, Any]:
"""Convert the Where expression to a dictionary for JSON serialization"""
raise NotImplementedError("Subclasses must implement to_dict()")
@staticmethod
def from_dict(data: Dict[str, Any]) -> "Where":
"""Create Where expression from dictionary.
Supports MongoDB-style query operators:
- {"field": "value"} -> Key("field") == "value" (shorthand for equality)
- {"field": {"$eq": value}} -> Key("field") == value
- {"field": {"$ne": value}} -> Key("field") != value
- {"field": {"$gt": value}} -> Key("field") > value
- {"field": {"$gte": value}} -> Key("field") >= value
- {"field": {"$lt": value}} -> Key("field") < value
- {"field": {"$lte": value}} -> Key("field") <= value
- {"field": {"$in": [values]}} -> Key("field").is_in([values])
- {"field": {"$nin": [values]}} -> Key("field").not_in([values])
- {"field": {"$contains": "text"}} -> Key("field").contains("text")
- {"field": {"$not_contains": "text"}} -> Key("field").not_contains("text")
- {"field": {"$regex": "pattern"}} -> Key("field").regex("pattern")
- {"field": {"$not_regex": "pattern"}} -> Key("field").not_regex("pattern")
- {"$and": [conditions]} -> condition1 & condition2 & ...
- {"$or": [conditions]} -> condition1 | condition2 | ...
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for Where, got {type(data).__name__}")
if not data:
raise ValueError("Where dict cannot be empty")
# Handle logical operators
if "$and" in data:
if not isinstance(data["$and"], list):
raise TypeError(
f"$and must be a list, got {type(data['$and']).__name__}"
)
if len(data["$and"]) == 0:
raise ValueError("$and requires at least one condition")
if len(data) > 1:
raise ValueError(
"$and cannot be combined with other fields in the same dict"
)
conditions = [Where.from_dict(c) for c in data["$and"]]
if len(conditions) == 1:
return conditions[0]
result = conditions[0]
for c in conditions[1:]:
result = result & c
return result
elif "$or" in data:
if not isinstance(data["$or"], list):
raise TypeError(f"$or must be a list, got {type(data['$or']).__name__}")
if len(data["$or"]) == 0:
raise ValueError("$or requires at least one condition")
if len(data) > 1:
raise ValueError(
"$or cannot be combined with other fields in the same dict"
)
conditions = [Where.from_dict(c) for c in data["$or"]]
if len(conditions) == 1:
return conditions[0]
result = conditions[0]
for c in conditions[1:]:
result = result | c
return result
else:
# Single field condition
if len(data) != 1:
raise ValueError(
f"Where dict must contain exactly one field, got {len(data)}"
)
field, condition = next(iter(data.items()))
if not isinstance(field, str):
raise TypeError(
f"Field name must be a string, got {type(field).__name__}"
)
if isinstance(condition, dict):
# Operator-based condition
if not condition:
raise ValueError(
f"Operator dict for field '{field}' cannot be empty"
)
if len(condition) != 1:
raise ValueError(
f"Operator dict for field '{field}' must contain exactly one operator"
)
op, value = next(iter(condition.items()))
if op == "$eq":
return Key(field) == value
elif op == "$ne":
return Key(field) != value
elif op == "$gt":
return Key(field) > value
elif op == "$gte":
return Key(field) >= value
elif op == "$lt":
return Key(field) < value
elif op == "$lte":
return Key(field) <= value
elif op == "$in":
if not isinstance(value, list):
raise TypeError(
f"$in requires a list, got {type(value).__name__}"
)
return Key(field).is_in(value)
elif op == "$nin":
if not isinstance(value, list):
raise TypeError(
f"$nin requires a list, got {type(value).__name__}"
)
return Key(field).not_in(value)
elif op == "$contains":
if not isinstance(value, str):
raise TypeError(
f"$contains requires a string, got {type(value).__name__}"
)
return Key(field).contains(value)
elif op == "$not_contains":
if not isinstance(value, str):
raise TypeError(
f"$not_contains requires a string, got {type(value).__name__}"
)
return Key(field).not_contains(value)
elif op == "$regex":
if not isinstance(value, str):
raise TypeError(
f"$regex requires a string pattern, got {type(value).__name__}"
)
return Key(field).regex(value)
elif op == "$not_regex":
if not isinstance(value, str):
raise TypeError(
f"$not_regex requires a string pattern, got {type(value).__name__}"
)
return Key(field).not_regex(value)
else:
raise ValueError(f"Unknown operator: {op}")
else:
# Direct value is shorthand for equality
return Key(field) == condition
def __and__(self, other: "Where") -> "And":
"""Overload & operator for AND"""
# If self is already an And, extend it
if isinstance(self, And):
# If other is also And, combine all conditions
if isinstance(other, And):
return And(self.conditions + other.conditions)
return And(self.conditions + [other])
# If other is And, prepend self to it
elif isinstance(other, And):
return And([self] + other.conditions)
# Create new And with both conditions
return And([self, other])
def __or__(self, other: "Where") -> "Or":
"""Overload | operator for OR"""
# If self is already an Or, extend it
if isinstance(self, Or):
# If other is also Or, combine all conditions
if isinstance(other, Or):
return Or(self.conditions + other.conditions)
return Or(self.conditions + [other])
# If other is Or, prepend self to it
elif isinstance(other, Or):
return Or([self] + other.conditions)
# Create new Or with both conditions
return Or([self, other])
@dataclass
class And(Where):
"""Logical AND of multiple where conditions"""
conditions: List[Where]
def to_dict(self) -> Dict[str, Any]:
return {"$and": [c.to_dict() for c in self.conditions]}
@dataclass
class Or(Where):
"""Logical OR of multiple where conditions"""
conditions: List[Where]
def to_dict(self) -> Dict[str, Any]:
return {"$or": [c.to_dict() for c in self.conditions]}
@dataclass
class Eq(Where):
"""Equality comparison"""
key: str
value: Any
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$eq": self.value}}
@dataclass
class Ne(Where):
"""Not equal comparison"""
key: str
value: Any
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$ne": self.value}}
@dataclass
class Gt(Where):
"""Greater than comparison"""
key: str
value: Any
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$gt": self.value}}
@dataclass
class Gte(Where):
"""Greater than or equal comparison"""
key: str
value: Any
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$gte": self.value}}
@dataclass
class Lt(Where):
"""Less than comparison"""
key: str
value: Any
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$lt": self.value}}
@dataclass
class Lte(Where):
"""Less than or equal comparison"""
key: str
value: Any
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$lte": self.value}}
@dataclass
class In(Where):
"""In comparison - value is in a list"""
key: str
values: List[Any]
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$in": self.values}}
@dataclass
class Nin(Where):
"""Not in comparison - value is not in a list"""
key: str
values: List[Any]
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$nin": self.values}}
@dataclass
class Contains(Where):
"""Contains comparison for document content"""
key: str
content: str
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$contains": self.content}}
@dataclass
class NotContains(Where):
"""Not contains comparison for document content"""
key: str
content: str
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$not_contains": self.content}}
@dataclass
class Regex(Where):
"""Regular expression matching"""
key: str
pattern: str
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$regex": self.pattern}}
@dataclass
class NotRegex(Where):
"""Negative regular expression matching"""
key: str
pattern: str
def to_dict(self) -> Dict[str, Any]:
return {self.key: {"$not_regex": self.pattern}}
# Field proxy for building Where conditions
class Key:
"""Field proxy for building Where conditions with operator overloading.
The Key class allows for readable field references using either:
1. Predefined constants for special fields: K.EMBEDDING, K.DOCUMENT, K.SCORE, etc.
2. String literals with # prefix for special fields: Key("#embedding")
3. Metadata field names without # prefix: Key("my_metadata_field")
Predefined field constants (special fields with # prefix):
Key.ID - ID field (equivalent to Key("#id"))
Key.DOCUMENT - Document field (equivalent to Key("#document"))
Key.EMBEDDING - Embedding field (equivalent to Key("#embedding"))
Key.METADATA - Metadata field (equivalent to Key("#metadata"))
Key.SCORE - Score field (equivalent to Key("#score"))
Note: K is an alias for Key, so you can use K.DOCUMENT or Key.DOCUMENT interchangeably.
Examples:
# Using predefined keys with K alias for special fields
from chromadb.execution.expression import K
K.DOCUMENT.contains("search text") # Searches document field
# Custom metadata field names (without # prefix)
K("status") == "active" # Metadata field named "status"
K("category").is_in(["science", "tech"]) # Metadata field named "category"
K("sparse_embedding") # Example: metadata field (could store anything)
# Using with Knn for different fields
Knn(query=[0.1, 0.2]) # Default: searches "#embedding"
Knn(query=[0.1, 0.2], key=K.EMBEDDING) # Explicit: searches "#embedding"
Knn(query=sparse, key="sparse_embedding") # Example: searches a metadata field
# Combining conditions
(K("status") == "active") & (K.SCORE > 0.5)
"""
# Predefined key constants (initialized after class definition)
ID: "Key"
DOCUMENT: "Key"
EMBEDDING: "Key"
METADATA: "Key"
SCORE: "Key"
def __init__(self, name: str):
self.name = name
def __hash__(self) -> int:
"""Make Key hashable for use in sets"""
return hash(self.name)
# Comparison operators
def __eq__(self, value: Any) -> Eq: # type: ignore[override]
"""Equality: Key('field') == value"""
return Eq(self.name, value)
def __ne__(self, value: Any) -> Ne: # type: ignore[override]
"""Not equal: Key('field') != value"""
return Ne(self.name, value)
def __gt__(self, value: Any) -> Gt:
"""Greater than: Key('field') > value"""
return Gt(self.name, value)
def __ge__(self, value: Any) -> Gte:
"""Greater than or equal: Key('field') >= value"""
return Gte(self.name, value)
def __lt__(self, value: Any) -> Lt:
"""Less than: Key('field') < value"""
return Lt(self.name, value)
def __le__(self, value: Any) -> Lte:
"""Less than or equal: Key('field') <= value"""
return Lte(self.name, value)
# Builder methods for operations without operators
def is_in(self, values: List[Any]) -> In:
"""Check if field value is in list: Key('field').is_in(['a', 'b'])"""
return In(self.name, values)
def not_in(self, values: List[Any]) -> Nin:
"""Check if field value is not in list: Key('field').not_in(['a', 'b'])"""
return Nin(self.name, values)
def regex(self, pattern: str) -> Regex:
"""Match field against regex: Key('field').regex('^pattern')"""
return Regex(self.name, pattern)
def not_regex(self, pattern: str) -> NotRegex:
"""Field should not match regex: Key('field').not_regex('^pattern')"""
return NotRegex(self.name, pattern)
def contains(self, content: str) -> Contains:
"""Check if field contains text: Key('field').contains('text')"""
return Contains(self.name, content)
def not_contains(self, content: str) -> NotContains:
"""Check if field doesn't contain text: Key('field').not_contains('text')"""
return NotContains(self.name, content)
# Initialize predefined key constants
Key.ID = Key("#id")
Key.DOCUMENT = Key("#document")
Key.EMBEDDING = Key("#embedding")
Key.METADATA = Key("#metadata")
Key.SCORE = Key("#score")
# Alias for Key
K = Key
@dataclass
class Filter:
user_ids: Optional[IDs] = None
where: Optional[Any] = None # Old Where type from chromadb.types
where_document: Optional[Any] = None # Old WhereDocument type
@dataclass
class KNN:
embeddings: Embeddings
fetch: int
@dataclass
class Limit:
offset: int = 0
limit: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert the Limit to a dictionary for JSON serialization"""
result = {"offset": self.offset}
if self.limit is not None:
result["limit"] = self.limit
return result
@staticmethod
def from_dict(data: Dict[str, Any]) -> "Limit":
"""Create Limit from dictionary.
Examples:
- {"offset": 10} -> Limit(offset=10)
- {"offset": 10, "limit": 20} -> Limit(offset=10, limit=20)
- {"limit": 20} -> Limit(offset=0, limit=20)
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for Limit, got {type(data).__name__}")
offset = data.get("offset", 0)
if not isinstance(offset, int):
raise TypeError(
f"Limit offset must be an integer, got {type(offset).__name__}"
)
if offset < 0:
raise ValueError(f"Limit offset must be non-negative, got {offset}")
limit = data.get("limit")
if limit is not None:
if not isinstance(limit, int):
raise TypeError(
f"Limit limit must be an integer, got {type(limit).__name__}"
)
if limit <= 0:
raise ValueError(f"Limit limit must be positive, got {limit}")
# Check for unexpected keys
allowed_keys = {"offset", "limit"}
unexpected_keys = set(data.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(f"Unexpected keys in Limit dict: {unexpected_keys}")
return Limit(offset=offset, limit=limit)
@dataclass
class Projection:
document: bool = False
embedding: bool = False
metadata: bool = False
rank: bool = False
uri: bool = False
@property
def included(self) -> Include:
includes = list()
if self.document:
includes.append("documents")
if self.embedding:
includes.append("embeddings")
if self.metadata:
includes.append("metadatas")
if self.rank:
includes.append("distances")
if self.uri:
includes.append("uris")
return includes # type: ignore[return-value]
# Rank expression types for hybrid search
@dataclass
class Rank:
"""Base class for Rank expressions (algebraic data type).
Supports arithmetic operations for combining rank expressions:
- Addition: rank1 + rank2, rank + 0.5
- Subtraction: rank1 - rank2, rank - 0.5
- Multiplication: rank1 * rank2, rank * 0.8
- Division: rank1 / rank2, rank / 2.0
- Negation: -rank
- Absolute value: abs(rank)
Supports mathematical functions:
- Exponential: rank.exp()
- Logarithm: rank.log()
- Maximum: rank.max(other)
- Minimum: rank.min(other)
Examples:
# Weighted combination
Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2
# Normalization
Knn(query=[0.1, 0.2]) / Val(10.0)
# Clamping
Knn(query=[0.1, 0.2]).min(1.0).max(0.0)
"""
def to_dict(self) -> Dict[str, Any]:
"""Convert the Score expression to a dictionary for JSON serialization"""
raise NotImplementedError("Subclasses must implement to_dict()")
@staticmethod
def from_dict(data: Dict[str, Any]) -> "Rank":
"""Create Rank expression from dictionary.
Supports operators:
- {"$val": number} -> Val(number)
- {"$knn": {...}} -> Knn(...)
- {"$sum": [ranks]} -> rank1 + rank2 + ...
- {"$sub": {"left": ..., "right": ...}} -> left - right
- {"$mul": [ranks]} -> rank1 * rank2 * ...
- {"$div": {"left": ..., "right": ...}} -> left / right
- {"$abs": rank} -> abs(rank)
- {"$exp": rank} -> rank.exp()
- {"$log": rank} -> rank.log()
- {"$max": [ranks]} -> rank1.max(rank2).max(rank3)...
- {"$min": [ranks]} -> rank1.min(rank2).min(rank3)...
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for Rank, got {type(data).__name__}")
if not data:
raise ValueError("Rank dict cannot be empty")
if len(data) != 1:
raise ValueError(
f"Rank dict must contain exactly one operator, got {len(data)}"
)
op = next(iter(data.keys()))
if op == "$val":
value = data["$val"]
if not isinstance(value, (int, float)):
raise TypeError(f"$val requires a number, got {type(value).__name__}")
return Val(value)
elif op == "$knn":
knn_data = data["$knn"]
if not isinstance(knn_data, dict):
raise TypeError(f"$knn requires a dict, got {type(knn_data).__name__}")
if "query" not in knn_data:
raise ValueError("$knn requires 'query' field")
query = knn_data["query"]
if isinstance(query, dict):
# SparseVector case - deserialize from transport format
if query.get(TYPE_KEY) == SPARSE_VECTOR_TYPE_VALUE:
query = SparseVector.from_dict(query)
else:
# Old format or invalid - try to construct directly
raise ValueError(
f"Expected dict with {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {query}"
)
elif isinstance(query, (list, tuple, np.ndarray)):
# Dense vector case - normalize then validate
normalized = normalize_embeddings(query)
if not normalized or len(normalized) > 1:
raise ValueError("$knn requires exactly one query embedding")
# Validate the normalized version
validate_embeddings(normalized)
query = normalized[0]
else:
raise TypeError(
f"$knn query must be a list, numpy array, or SparseVector dict, got {type(query).__name__}"
)
key = knn_data.get("key", "#embedding")
if not isinstance(key, str):
raise TypeError(f"$knn key must be a string, got {type(key).__name__}")
limit = knn_data.get("limit", 16)
if not isinstance(limit, int):
raise TypeError(
f"$knn limit must be an integer, got {type(limit).__name__}"
)
if limit <= 0:
raise ValueError(f"$knn limit must be positive, got {limit}")
return_rank = knn_data.get("return_rank", False)
if not isinstance(return_rank, bool):
raise TypeError(
f"$knn return_rank must be a boolean, got {type(return_rank).__name__}"
)
return Knn(
query=query,
key=key,
limit=limit,
default=knn_data.get("default"),
return_rank=return_rank,
)
elif op == "$sum":
ranks_data = data["$sum"]
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$sum requires a list, got {type(ranks_data).__name__}"
)
if len(ranks_data) < 2:
raise ValueError(
f"$sum requires at least 2 ranks, got {len(ranks_data)}"
)
ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result + r
return result
elif op == "$sub":
sub_data = data["$sub"]
if not isinstance(sub_data, dict):
raise TypeError(
f"$sub requires a dict with 'left' and 'right', got {type(sub_data).__name__}"
)
if "left" not in sub_data or "right" not in sub_data:
raise ValueError("$sub requires 'left' and 'right' fields")
left = Rank.from_dict(sub_data["left"])
right = Rank.from_dict(sub_data["right"])
return left - right
elif op == "$mul":
ranks_data = data["$mul"]
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$mul requires a list, got {type(ranks_data).__name__}"
)
if len(ranks_data) < 2:
raise ValueError(
f"$mul requires at least 2 ranks, got {len(ranks_data)}"
)
ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result * r
return result
elif op == "$div":
div_data = data["$div"]
if not isinstance(div_data, dict):
raise TypeError(
f"$div requires a dict with 'left' and 'right', got {type(div_data).__name__}"
)
if "left" not in div_data or "right" not in div_data:
raise ValueError("$div requires 'left' and 'right' fields")
left = Rank.from_dict(div_data["left"])
right = Rank.from_dict(div_data["right"])
return left / right
elif op == "$abs":
child_data = data["$abs"]
if not isinstance(child_data, dict):
raise TypeError(
f"$abs requires a rank dict, got {type(child_data).__name__}"
)
return abs(Rank.from_dict(child_data))
elif op == "$exp":
child_data = data["$exp"]
if not isinstance(child_data, dict):
raise TypeError(
f"$exp requires a rank dict, got {type(child_data).__name__}"
)
return Rank.from_dict(child_data).exp()
elif op == "$log":
child_data = data["$log"]
if not isinstance(child_data, dict):
raise TypeError(
f"$log requires a rank dict, got {type(child_data).__name__}"
)
return Rank.from_dict(child_data).log()
elif op == "$max":
ranks_data = data["$max"]
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$max requires a list, got {type(ranks_data).__name__}"
)
if len(ranks_data) < 2:
raise ValueError(
f"$max requires at least 2 ranks, got {len(ranks_data)}"
)
ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result.max(r)
return result
elif op == "$min":
ranks_data = data["$min"]
if not isinstance(ranks_data, (list, tuple)):
raise TypeError(
f"$min requires a list, got {type(ranks_data).__name__}"
)
if len(ranks_data) < 2:
raise ValueError(
f"$min requires at least 2 ranks, got {len(ranks_data)}"
)
ranks = [Rank.from_dict(r) for r in ranks_data]
result = ranks[0]
for r in ranks[1:]:
result = result.min(r)
return result
else:
raise ValueError(f"Unknown rank operator: {op}")
# Arithmetic operators
def __add__(self, other: Union["Rank", float, int]) -> "Sum":
"""Addition: rank1 + rank2 or rank + value"""
other_rank = Val(other) if isinstance(other, (int, float)) else other
# Flatten if already Sum
if isinstance(self, Sum):
if isinstance(other_rank, Sum):
return Sum(self.ranks + other_rank.ranks)
return Sum(self.ranks + [other_rank])
elif isinstance(other_rank, Sum):
return Sum([self] + other_rank.ranks)
return Sum([self, other_rank])
def __radd__(self, other: Union[float, int]) -> "Sum":
"""Right addition: value + rank"""
return Val(other) + self
def __sub__(self, other: Union["Rank", float, int]) -> "Sub":
"""Subtraction: rank1 - rank2 or rank - value"""
other_rank = Val(other) if isinstance(other, (int, float)) else other
return Sub(self, other_rank)
def __rsub__(self, other: Union[float, int]) -> "Sub":
"""Right subtraction: value - rank"""
return Sub(Val(other), self)
def __mul__(self, other: Union["Rank", float, int]) -> "Mul":
"""Multiplication: rank1 * rank2 or rank * value"""
other_rank = Val(other) if isinstance(other, (int, float)) else other
# Flatten if already Mul
if isinstance(self, Mul):
if isinstance(other_rank, Mul):
return Mul(self.ranks + other_rank.ranks)
return Mul(self.ranks + [other_rank])
elif isinstance(other_rank, Mul):
return Mul([self] + other_rank.ranks)
return Mul([self, other_rank])
def __rmul__(self, other: Union[float, int]) -> "Mul":
"""Right multiplication: value * rank"""
return Val(other) * self
def __truediv__(self, other: Union["Rank", float, int]) -> "Div":
"""Division: rank1 / rank2 or rank / value"""
other_rank = Val(other) if isinstance(other, (int, float)) else other
return Div(self, other_rank)
def __rtruediv__(self, other: Union[float, int]) -> "Div":
"""Right division: value / rank"""
return Div(Val(other), self)
def __neg__(self) -> "Mul":
"""Negation: -rank (equivalent to -1 * rank)"""
return Mul([Val(-1), self])
def __abs__(self) -> "Abs":
"""Absolute value: abs(rank)"""
return Abs(self)
def abs(self) -> "Abs":
"""Absolute value builder: rank.abs()"""
return Abs(self)
# Builder methods for functions
def exp(self) -> "Exp":
"""Exponential: e^rank"""
return Exp(self)
def log(self) -> "Log":
"""Natural logarithm: ln(rank)"""
return Log(self)
def max(self, other: Union["Rank", float, int]) -> "Max":
"""Maximum of this rank and another: rank.max(rank2)"""
other_rank = Val(other) if isinstance(other, (int, float)) else other
# Flatten if already Max
if isinstance(self, Max):
if isinstance(other_rank, Max):
return Max(self.ranks + other_rank.ranks)
return Max(self.ranks + [other_rank])
elif isinstance(other_rank, Max):
return Max([self] + other_rank.ranks)
return Max([self, other_rank])
def min(self, other: Union["Rank", float, int]) -> "Min":
"""Minimum of this rank and another: rank.min(rank2)"""
other_rank = Val(other) if isinstance(other, (int, float)) else other
# Flatten if already Min
if isinstance(self, Min):
if isinstance(other_rank, Min):
return Min(self.ranks + other_rank.ranks)
return Min(self.ranks + [other_rank])
elif isinstance(other_rank, Min):
return Min([self] + other_rank.ranks)
return Min([self, other_rank])
@dataclass
class Abs(Rank):
"""Absolute value of a rank"""
rank: Rank
def to_dict(self) -> Dict[str, Any]:
return {"$abs": self.rank.to_dict()}
@dataclass
class Div(Rank):
"""Division of two ranks"""
left: Rank
right: Rank
def to_dict(self) -> Dict[str, Any]:
return {"$div": {"left": self.left.to_dict(), "right": self.right.to_dict()}}
@dataclass
class Exp(Rank):
"""Exponentiation of a rank"""
rank: Rank
def to_dict(self) -> Dict[str, Any]:
return {"$exp": self.rank.to_dict()}
@dataclass
class Log(Rank):
"""Logarithm of a rank"""
rank: Rank
def to_dict(self) -> Dict[str, Any]:
return {"$log": self.rank.to_dict()}
@dataclass
class Max(Rank):
"""Maximum of multiple ranks"""
ranks: List[Rank]
def to_dict(self) -> Dict[str, Any]:
return {"$max": [r.to_dict() for r in self.ranks]}
@dataclass
class Min(Rank):
"""Minimum of multiple ranks"""
ranks: List[Rank]
def to_dict(self) -> Dict[str, Any]:
return {"$min": [r.to_dict() for r in self.ranks]}
@dataclass
class Mul(Rank):
"""Multiplication of multiple ranks"""
ranks: List[Rank]
def to_dict(self) -> Dict[str, Any]:
return {"$mul": [r.to_dict() for r in self.ranks]}
@dataclass
class Knn(Rank):
"""KNN-based ranking
Args:
query: The query for KNN search. Can be:
- A string (will be automatically embedded using the collection's embedding function)
- A dense vector (list or numpy array)
- A sparse vector (SparseVector dict)
key: The embedding key to search against. Can be:
- Key.EMBEDDING (default) - searches the main embedding field
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
limit: Maximum number of results to consider (default: 16)
default: Default score for records not in KNN results (default: None)
return_rank: If True, return the rank position (0, 1, 2, ...) instead of distance (default: False)
Examples:
# Search with string query (automatically embedded)
Knn(query="hello world") # Will use collection's embedding function
# Search main embeddings with vectors (equivalent forms)
Knn(query=[0.1, 0.2]) # Uses default key="#embedding"
Knn(query=[0.1, 0.2], key=K.EMBEDDING)
Knn(query=[0.1, 0.2], key="#embedding")
# Search sparse embeddings stored in metadata with string
Knn(query="hello world", key="custom_embedding") # Will use schema's embedding function
# Search sparse embeddings stored in metadata with vector
Knn(query=my_vector, key="custom_embedding") # Example: searches a metadata field
"""
query: Union[
str,
List[float],
SparseVector,
"NDArray[np.float32]",
"NDArray[np.float64]",
"NDArray[np.int32]",
]
key: Union[Key, str] = K.EMBEDDING
limit: int = 16
default: Optional[float] = None
return_rank: bool = False
def to_dict(self) -> Dict[str, Any]:
# Convert to transport format
query_value = self.query
if isinstance(query_value, SparseVector):
# Convert SparseVector dataclass to transport dict
query_value = query_value.to_dict()
elif isinstance(query_value, np.ndarray):
# Convert numpy array to list
query_value = query_value.tolist()
key_value = self.key
if isinstance(key_value, Key):
key_value = key_value.name
# Build result dict - only include non-default values to keep JSON clean
result = {"query": query_value, "key": key_value, "limit": self.limit}
# Only include optional fields if they're set to non-default values
if self.default is not None:
result["default"] = self.default # type: ignore[assignment]
if self.return_rank: # Only include if True (non-default)
result["return_rank"] = self.return_rank
return {"$knn": result}
@dataclass
class Sub(Rank):
"""Subtraction of two ranks"""
left: Rank
right: Rank
def to_dict(self) -> Dict[str, Any]:
return {"$sub": {"left": self.left.to_dict(), "right": self.right.to_dict()}}
@dataclass
class Sum(Rank):
"""Summation of multiple ranks"""
ranks: List[Rank]
def to_dict(self) -> Dict[str, Any]:
return {"$sum": [r.to_dict() for r in self.ranks]}
@dataclass
class Val(Rank):
"""Constant rank value"""
value: float
def to_dict(self) -> Dict[str, Any]:
return {"$val": self.value}
@dataclass
class Rrf(Rank):
"""Reciprocal Rank Fusion for combining multiple ranking strategies.
RRF formula: score = -sum(weight_i / (k + rank_i)) for each ranking strategy
The negative is used because RRF produces higher scores for better results,
but Chroma uses ascending order (lower scores = better results).
Args:
ranks: List of Rank expressions to fuse (must have at least one)
k: Smoothing constant (default: 60, standard in literature)
weights: Optional weights for each ranking strategy. If not provided,
all ranks are weighted equally (weight=1.0 each).
normalize: If True, normalize weights to sum to 1.0 (default: False).
When False, weights are used as-is for relative importance.
When True, weights are scaled so they sum to 1.0.
Examples:
# Note: metadata fields (like "sparse_embedding" below) are user-defined and can store any data.
# The field name is just an example - use whatever name matches your metadata structure.
# Basic RRF combining KNN rankings (equal weight)
Rrf([
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=another_vector, key="custom_embedding", return_rank=True) # Example metadata field
])
# Weighted RRF with relative weights (not normalized)
Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=another_vector, key="custom_embedding", return_rank=True) # Example metadata field
weights=[2.0, 1.0], # First ranking is 2x more important
k=100
)
# Weighted RRF with normalized weights
Rrf(
ranks=[
Knn(query=[0.1, 0.2], return_rank=True),
Knn(query=another_vector, key="custom_embedding", return_rank=True) # Example metadata field
],
weights=[3.0, 1.0], # Will be normalized to [0.75, 0.25]
normalize=True,
k=100
)
"""
ranks: List[Rank]
k: int = 60
weights: Optional[List[float]] = None
normalize: bool = False
def to_dict(self) -> Dict[str, Any]:
"""Convert RRF to a composition of existing expression operators.
Builds: -sum(weight_i / (k + rank_i)) for each rank
Using Python's overloaded operators for cleaner code.
"""
# Validate RRF parameters
if not self.ranks:
raise ValueError("RRF requires at least one rank")
if self.k <= 0:
raise ValueError(f"k must be positive, got {self.k}")
# Validate weights if provided
if self.weights is not None:
if len(self.weights) != len(self.ranks):
raise ValueError(
f"Number of weights ({len(self.weights)}) must match number of ranks ({len(self.ranks)})"
)
if any(w < 0.0 for w in self.weights):
raise ValueError("All weights must be non-negative")
# Populate weights with 1.0 if not provided
weights = self.weights if self.weights else [1.0] * len(self.ranks)
# Normalize weights if requested
if self.normalize:
weight_sum = sum(weights)
if weight_sum == 0:
raise ValueError("Sum of weights must be positive when normalize=True")
weights = [w / weight_sum for w in weights]
# Zip weights with ranks and build terms: weight / (k + rank)
terms = [w / (self.k + rank) for w, rank in zip(weights, self.ranks)]
# Sum all terms - guaranteed to have at least one
rrf_sum: Rank = terms[0]
for term in terms[1:]:
rrf_sum = rrf_sum + term
# Negate (RRF gives higher scores for better, Chroma needs lower for better)
return (-rrf_sum).to_dict()
@dataclass
class Select:
"""Selection configuration for search results.
Fields can be:
- Key.DOCUMENT - Select document key (equivalent to Key("#document"))
- Key.EMBEDDING - Select embedding key (equivalent to Key("#embedding"))
- Key.SCORE - Select score key (equivalent to Key("#score"))
- Any other string - Select specific metadata property
Note: You can use K as an alias for Key for more concise code.
Examples:
# Select predefined keys using K alias (K is shorthand for Key)
from chromadb.execution.expression import K
Select(keys={K.DOCUMENT, K.SCORE})
# Select specific metadata properties
Select(keys={"title", "author", "date"})
# Mixed selection
Select(keys={K.DOCUMENT, "title", "author"})
"""
keys: Set[Union[Key, str]] = field(default_factory=set)
def to_dict(self) -> Dict[str, Any]:
"""Convert the Select to a dictionary for JSON serialization"""
# Convert Key objects to their string values
key_strings = []
for k in self.keys:
if isinstance(k, Key):
key_strings.append(k.name)
else:
key_strings.append(k)
# Remove duplicates while preserving order
return {"keys": list(dict.fromkeys(key_strings))}
@staticmethod
def from_dict(data: Dict[str, Any]) -> "Select":
"""Create Select from dictionary.
Examples:
- {"keys": ["#document", "#score"]} -> Select(keys={Key.DOCUMENT, Key.SCORE})
- {"keys": ["title", "author"]} -> Select(keys={"title", "author"})
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for Select, got {type(data).__name__}")
keys = data.get("keys", [])
if not isinstance(keys, (list, tuple, set)):
raise TypeError(
f"Select keys must be a list/tuple/set, got {type(keys).__name__}"
)
# Validate and convert each key
key_list = []
for k in keys:
if not isinstance(k, str):
raise TypeError(f"Select key must be a string, got {type(k).__name__}")
# Map special keys to Key instances
if k == "#id":
key_list.append(Key.ID)
elif k == "#document":
key_list.append(Key.DOCUMENT)
elif k == "#embedding":
key_list.append(Key.EMBEDDING)
elif k == "#metadata":
key_list.append(Key.METADATA)
elif k == "#score":
key_list.append(Key.SCORE)
else:
# Regular metadata field
key_list.append(Key(k))
# Check for unexpected keys in dict
allowed_keys = {"keys"}
unexpected_keys = set(data.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(f"Unexpected keys in Select dict: {unexpected_keys}")
# Convert to set while preserving the Key instances
return Select(keys=set(key_list))
# GroupBy and Aggregate types for grouping search results
def _keys_to_strings(keys: OneOrMany[Union[Key, str]]) -> List[str]:
"""Convert OneOrMany[Key|str] to List[str] for serialization."""
keys_list = cast(List[Union[Key, str]], maybe_cast_one_to_many(keys))
return [k.name if isinstance(k, Key) else k for k in keys_list]
def _strings_to_keys(keys: Union[List[Any], tuple[Any, ...]]) -> List[Union[Key, str]]:
"""Convert List[str] to List[Key] for deserialization."""
return [Key(k) if isinstance(k, str) else k for k in keys]
def _parse_k_aggregate(
op: str, data: Dict[str, Any]
) -> tuple[List[Union[Key, str]], int]:
"""Parse common fields for MinK/MaxK from dict.
Args:
op: The operator name (e.g., "$min_k" or "$max_k")
data: The dict containing the operator
Returns:
Tuple of (keys, k) where keys is List[Union[Key, str]] and k is int
Raises:
TypeError: If data types are invalid
ValueError: If required fields are missing or invalid
"""
agg_data = data[op]
if not isinstance(agg_data, dict):
raise TypeError(f"{op} requires a dict, got {type(agg_data).__name__}")
if "keys" not in agg_data:
raise ValueError(f"{op} requires 'keys' field")
if "k" not in agg_data:
raise ValueError(f"{op} requires 'k' field")
keys = agg_data["keys"]
if not isinstance(keys, (list, tuple)):
raise TypeError(f"{op} keys must be a list, got {type(keys).__name__}")
if not keys:
raise ValueError(f"{op} keys cannot be empty")
k = agg_data["k"]
if not isinstance(k, int):
raise TypeError(f"{op} k must be an integer, got {type(k).__name__}")
if k <= 0:
raise ValueError(f"{op} k must be positive, got {k}")
return _strings_to_keys(keys), k
@dataclass
class Aggregate:
"""Base class for aggregation expressions within groups.
Aggregations determine which records to keep from each group:
- MinK: Keep k records with minimum values (ascending order)
- MaxK: Keep k records with maximum values (descending order)
Examples:
# Keep top 3 by score per group (single key)
MinK(keys=Key.SCORE, k=3)
# Keep top 5 by priority, then score as tiebreaker (multiple keys)
MinK(keys=[Key("priority"), Key.SCORE], k=5)
# Keep bottom 2 by score per group
MaxK(keys=Key.SCORE, k=2)
"""
def to_dict(self) -> Dict[str, Any]:
"""Convert the Aggregate expression to a dictionary for JSON serialization"""
raise NotImplementedError("Subclasses must implement to_dict()")
@staticmethod
def from_dict(data: Dict[str, Any]) -> "Aggregate":
"""Create Aggregate expression from dictionary.
Supports:
- {"$min_k": {"keys": [...], "k": n}} -> MinK(keys=[...], k=n)
- {"$max_k": {"keys": [...], "k": n}} -> MaxK(keys=[...], k=n)
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for Aggregate, got {type(data).__name__}")
if not data:
raise ValueError("Aggregate dict cannot be empty")
if len(data) != 1:
raise ValueError(
f"Aggregate dict must contain exactly one operator, got {len(data)}"
)
op = next(iter(data.keys()))
if op == "$min_k":
keys, k = _parse_k_aggregate(op, data)
return MinK(keys=keys, k=k)
elif op == "$max_k":
keys, k = _parse_k_aggregate(op, data)
return MaxK(keys=keys, k=k)
else:
raise ValueError(f"Unknown aggregate operator: {op}")
@dataclass
class MinK(Aggregate):
"""Keep k records with minimum aggregate key values per group"""
keys: OneOrMany[Union[Key, str]]
k: int
def to_dict(self) -> Dict[str, Any]:
return {"$min_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
@dataclass
class MaxK(Aggregate):
"""Keep k records with maximum aggregate key values per group"""
keys: OneOrMany[Union[Key, str]]
k: int
def to_dict(self) -> Dict[str, Any]:
return {"$max_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
@dataclass
class GroupBy:
"""Group results by metadata keys and aggregate within each group.
Groups search results by one or more metadata fields, then applies an
aggregation (MinK or MaxK) to select records within each group.
The final output is flattened and sorted by score.
Args:
keys: Metadata key(s) to group by. Can be a single key or a list of keys.
E.g., Key("category") or [Key("category"), Key("author")]
aggregate: Aggregation to apply within each group (MinK or MaxK)
Note: Both keys and aggregate must be specified together.
Examples:
# Top 3 documents per category (single key)
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=Key.SCORE, k=3)
)
# Top 2 per (year, category) combination (multiple keys)
GroupBy(
keys=[Key("year"), Key("category")],
aggregate=MinK(keys=Key.SCORE, k=2)
)
# Top 1 per category by priority, score as tiebreaker
GroupBy(
keys=Key("category"),
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=1)
)
"""
keys: OneOrMany[Union[Key, str]] = field(default_factory=list)
aggregate: Optional[Aggregate] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert the GroupBy to a dictionary for JSON serialization"""
# Default GroupBy (no keys, no aggregate) serializes to {}
if not self.keys or self.aggregate is None:
return {}
result: Dict[str, Any] = {"keys": _keys_to_strings(self.keys)}
result["aggregate"] = self.aggregate.to_dict()
return result
@staticmethod
def from_dict(data: Dict[str, Any]) -> "GroupBy":
"""Create GroupBy from dictionary.
Examples:
- {} -> GroupBy() (default, no grouping)
- {"keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
"""
if not isinstance(data, dict):
raise TypeError(f"Expected dict for GroupBy, got {type(data).__name__}")
# Empty dict returns default GroupBy (no grouping)
if not data:
return GroupBy()
# Non-empty dict requires keys and aggregate
if "keys" not in data:
raise ValueError("GroupBy requires 'keys' field")
if "aggregate" not in data:
raise ValueError("GroupBy requires 'aggregate' field")
keys = data["keys"]
if not isinstance(keys, (list, tuple)):
raise TypeError(f"GroupBy keys must be a list, got {type(keys).__name__}")
if not keys:
raise ValueError("GroupBy keys cannot be empty")
aggregate_data = data["aggregate"]
if not isinstance(aggregate_data, dict):
raise TypeError(
f"GroupBy aggregate must be a dict, got {type(aggregate_data).__name__}"
)
aggregate = Aggregate.from_dict(aggregate_data)
return GroupBy(keys=_strings_to_keys(keys), aggregate=aggregate)