1513 lines
50 KiB
Python
1513 lines
50 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Optional, List, Dict, Set, Any, Union, cast
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
from chromadb.api.types import (
|
|
Embeddings,
|
|
IDs,
|
|
Include,
|
|
OneOrMany,
|
|
SparseVector,
|
|
TYPE_KEY,
|
|
SPARSE_VECTOR_TYPE_VALUE,
|
|
maybe_cast_one_to_many,
|
|
normalize_embeddings,
|
|
validate_embeddings,
|
|
)
|
|
from chromadb.types import (
|
|
Collection,
|
|
RequestVersionContext,
|
|
Segment,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Scan:
|
|
collection: Collection
|
|
knn: Segment
|
|
metadata: Segment
|
|
record: Segment
|
|
|
|
@property
|
|
def version(self) -> RequestVersionContext:
|
|
return RequestVersionContext(
|
|
collection_version=self.collection.version,
|
|
log_position=self.collection.log_position,
|
|
)
|
|
|
|
|
|
# Where expression types for filtering
|
|
@dataclass
|
|
class Where:
|
|
"""Base class for Where expressions (algebraic data type).
|
|
|
|
Supports logical operators for combining conditions:
|
|
- AND: where1 & where2
|
|
- OR: where1 | where2
|
|
|
|
Examples:
|
|
# Simple conditions
|
|
where1 = Key("status") == "active"
|
|
where2 = Key("score") > 0.5
|
|
|
|
# Combining with AND
|
|
combined_and = where1 & where2
|
|
|
|
# Combining with OR
|
|
combined_or = where1 | where2
|
|
|
|
# Complex expressions
|
|
complex_where = (Key("status") == "active") & ((Key("score") > 0.5) | (Key("priority") == "high"))
|
|
"""
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert the Where expression to a dictionary for JSON serialization"""
|
|
raise NotImplementedError("Subclasses must implement to_dict()")
|
|
|
|
@staticmethod
|
|
def from_dict(data: Dict[str, Any]) -> "Where":
|
|
"""Create Where expression from dictionary.
|
|
|
|
Supports MongoDB-style query operators:
|
|
- {"field": "value"} -> Key("field") == "value" (shorthand for equality)
|
|
- {"field": {"$eq": value}} -> Key("field") == value
|
|
- {"field": {"$ne": value}} -> Key("field") != value
|
|
- {"field": {"$gt": value}} -> Key("field") > value
|
|
- {"field": {"$gte": value}} -> Key("field") >= value
|
|
- {"field": {"$lt": value}} -> Key("field") < value
|
|
- {"field": {"$lte": value}} -> Key("field") <= value
|
|
- {"field": {"$in": [values]}} -> Key("field").is_in([values])
|
|
- {"field": {"$nin": [values]}} -> Key("field").not_in([values])
|
|
- {"field": {"$contains": "text"}} -> Key("field").contains("text")
|
|
- {"field": {"$not_contains": "text"}} -> Key("field").not_contains("text")
|
|
- {"field": {"$regex": "pattern"}} -> Key("field").regex("pattern")
|
|
- {"field": {"$not_regex": "pattern"}} -> Key("field").not_regex("pattern")
|
|
- {"$and": [conditions]} -> condition1 & condition2 & ...
|
|
- {"$or": [conditions]} -> condition1 | condition2 | ...
|
|
"""
|
|
if not isinstance(data, dict):
|
|
raise TypeError(f"Expected dict for Where, got {type(data).__name__}")
|
|
|
|
if not data:
|
|
raise ValueError("Where dict cannot be empty")
|
|
|
|
# Handle logical operators
|
|
if "$and" in data:
|
|
if not isinstance(data["$and"], list):
|
|
raise TypeError(
|
|
f"$and must be a list, got {type(data['$and']).__name__}"
|
|
)
|
|
if len(data["$and"]) == 0:
|
|
raise ValueError("$and requires at least one condition")
|
|
if len(data) > 1:
|
|
raise ValueError(
|
|
"$and cannot be combined with other fields in the same dict"
|
|
)
|
|
|
|
conditions = [Where.from_dict(c) for c in data["$and"]]
|
|
if len(conditions) == 1:
|
|
return conditions[0]
|
|
result = conditions[0]
|
|
for c in conditions[1:]:
|
|
result = result & c
|
|
return result
|
|
|
|
elif "$or" in data:
|
|
if not isinstance(data["$or"], list):
|
|
raise TypeError(f"$or must be a list, got {type(data['$or']).__name__}")
|
|
if len(data["$or"]) == 0:
|
|
raise ValueError("$or requires at least one condition")
|
|
if len(data) > 1:
|
|
raise ValueError(
|
|
"$or cannot be combined with other fields in the same dict"
|
|
)
|
|
|
|
conditions = [Where.from_dict(c) for c in data["$or"]]
|
|
if len(conditions) == 1:
|
|
return conditions[0]
|
|
result = conditions[0]
|
|
for c in conditions[1:]:
|
|
result = result | c
|
|
return result
|
|
|
|
else:
|
|
# Single field condition
|
|
if len(data) != 1:
|
|
raise ValueError(
|
|
f"Where dict must contain exactly one field, got {len(data)}"
|
|
)
|
|
|
|
field, condition = next(iter(data.items()))
|
|
|
|
if not isinstance(field, str):
|
|
raise TypeError(
|
|
f"Field name must be a string, got {type(field).__name__}"
|
|
)
|
|
|
|
if isinstance(condition, dict):
|
|
# Operator-based condition
|
|
if not condition:
|
|
raise ValueError(
|
|
f"Operator dict for field '{field}' cannot be empty"
|
|
)
|
|
if len(condition) != 1:
|
|
raise ValueError(
|
|
f"Operator dict for field '{field}' must contain exactly one operator"
|
|
)
|
|
|
|
op, value = next(iter(condition.items()))
|
|
|
|
if op == "$eq":
|
|
return Key(field) == value
|
|
elif op == "$ne":
|
|
return Key(field) != value
|
|
elif op == "$gt":
|
|
return Key(field) > value
|
|
elif op == "$gte":
|
|
return Key(field) >= value
|
|
elif op == "$lt":
|
|
return Key(field) < value
|
|
elif op == "$lte":
|
|
return Key(field) <= value
|
|
elif op == "$in":
|
|
if not isinstance(value, list):
|
|
raise TypeError(
|
|
f"$in requires a list, got {type(value).__name__}"
|
|
)
|
|
return Key(field).is_in(value)
|
|
elif op == "$nin":
|
|
if not isinstance(value, list):
|
|
raise TypeError(
|
|
f"$nin requires a list, got {type(value).__name__}"
|
|
)
|
|
return Key(field).not_in(value)
|
|
elif op == "$contains":
|
|
if not isinstance(value, str):
|
|
raise TypeError(
|
|
f"$contains requires a string, got {type(value).__name__}"
|
|
)
|
|
return Key(field).contains(value)
|
|
elif op == "$not_contains":
|
|
if not isinstance(value, str):
|
|
raise TypeError(
|
|
f"$not_contains requires a string, got {type(value).__name__}"
|
|
)
|
|
return Key(field).not_contains(value)
|
|
elif op == "$regex":
|
|
if not isinstance(value, str):
|
|
raise TypeError(
|
|
f"$regex requires a string pattern, got {type(value).__name__}"
|
|
)
|
|
return Key(field).regex(value)
|
|
elif op == "$not_regex":
|
|
if not isinstance(value, str):
|
|
raise TypeError(
|
|
f"$not_regex requires a string pattern, got {type(value).__name__}"
|
|
)
|
|
return Key(field).not_regex(value)
|
|
else:
|
|
raise ValueError(f"Unknown operator: {op}")
|
|
else:
|
|
# Direct value is shorthand for equality
|
|
return Key(field) == condition
|
|
|
|
def __and__(self, other: "Where") -> "And":
|
|
"""Overload & operator for AND"""
|
|
# If self is already an And, extend it
|
|
if isinstance(self, And):
|
|
# If other is also And, combine all conditions
|
|
if isinstance(other, And):
|
|
return And(self.conditions + other.conditions)
|
|
return And(self.conditions + [other])
|
|
# If other is And, prepend self to it
|
|
elif isinstance(other, And):
|
|
return And([self] + other.conditions)
|
|
# Create new And with both conditions
|
|
return And([self, other])
|
|
|
|
def __or__(self, other: "Where") -> "Or":
|
|
"""Overload | operator for OR"""
|
|
# If self is already an Or, extend it
|
|
if isinstance(self, Or):
|
|
# If other is also Or, combine all conditions
|
|
if isinstance(other, Or):
|
|
return Or(self.conditions + other.conditions)
|
|
return Or(self.conditions + [other])
|
|
# If other is Or, prepend self to it
|
|
elif isinstance(other, Or):
|
|
return Or([self] + other.conditions)
|
|
# Create new Or with both conditions
|
|
return Or([self, other])
|
|
|
|
|
|
@dataclass
|
|
class And(Where):
|
|
"""Logical AND of multiple where conditions"""
|
|
|
|
conditions: List[Where]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$and": [c.to_dict() for c in self.conditions]}
|
|
|
|
|
|
@dataclass
|
|
class Or(Where):
|
|
"""Logical OR of multiple where conditions"""
|
|
|
|
conditions: List[Where]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$or": [c.to_dict() for c in self.conditions]}
|
|
|
|
|
|
@dataclass
|
|
class Eq(Where):
|
|
"""Equality comparison"""
|
|
|
|
key: str
|
|
value: Any
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$eq": self.value}}
|
|
|
|
|
|
@dataclass
|
|
class Ne(Where):
|
|
"""Not equal comparison"""
|
|
|
|
key: str
|
|
value: Any
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$ne": self.value}}
|
|
|
|
|
|
@dataclass
|
|
class Gt(Where):
|
|
"""Greater than comparison"""
|
|
|
|
key: str
|
|
value: Any
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$gt": self.value}}
|
|
|
|
|
|
@dataclass
|
|
class Gte(Where):
|
|
"""Greater than or equal comparison"""
|
|
|
|
key: str
|
|
value: Any
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$gte": self.value}}
|
|
|
|
|
|
@dataclass
|
|
class Lt(Where):
|
|
"""Less than comparison"""
|
|
|
|
key: str
|
|
value: Any
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$lt": self.value}}
|
|
|
|
|
|
@dataclass
|
|
class Lte(Where):
|
|
"""Less than or equal comparison"""
|
|
|
|
key: str
|
|
value: Any
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$lte": self.value}}
|
|
|
|
|
|
@dataclass
|
|
class In(Where):
|
|
"""In comparison - value is in a list"""
|
|
|
|
key: str
|
|
values: List[Any]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$in": self.values}}
|
|
|
|
|
|
@dataclass
|
|
class Nin(Where):
|
|
"""Not in comparison - value is not in a list"""
|
|
|
|
key: str
|
|
values: List[Any]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$nin": self.values}}
|
|
|
|
|
|
@dataclass
|
|
class Contains(Where):
|
|
"""Contains comparison for document content"""
|
|
|
|
key: str
|
|
content: str
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$contains": self.content}}
|
|
|
|
|
|
@dataclass
|
|
class NotContains(Where):
|
|
"""Not contains comparison for document content"""
|
|
|
|
key: str
|
|
content: str
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$not_contains": self.content}}
|
|
|
|
|
|
@dataclass
|
|
class Regex(Where):
|
|
"""Regular expression matching"""
|
|
|
|
key: str
|
|
pattern: str
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$regex": self.pattern}}
|
|
|
|
|
|
@dataclass
|
|
class NotRegex(Where):
|
|
"""Negative regular expression matching"""
|
|
|
|
key: str
|
|
pattern: str
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {self.key: {"$not_regex": self.pattern}}
|
|
|
|
|
|
# Field proxy for building Where conditions
|
|
class Key:
|
|
"""Field proxy for building Where conditions with operator overloading.
|
|
|
|
The Key class allows for readable field references using either:
|
|
1. Predefined constants for special fields: K.EMBEDDING, K.DOCUMENT, K.SCORE, etc.
|
|
2. String literals with # prefix for special fields: Key("#embedding")
|
|
3. Metadata field names without # prefix: Key("my_metadata_field")
|
|
|
|
Predefined field constants (special fields with # prefix):
|
|
Key.ID - ID field (equivalent to Key("#id"))
|
|
Key.DOCUMENT - Document field (equivalent to Key("#document"))
|
|
Key.EMBEDDING - Embedding field (equivalent to Key("#embedding"))
|
|
Key.METADATA - Metadata field (equivalent to Key("#metadata"))
|
|
Key.SCORE - Score field (equivalent to Key("#score"))
|
|
|
|
Note: K is an alias for Key, so you can use K.DOCUMENT or Key.DOCUMENT interchangeably.
|
|
|
|
Examples:
|
|
# Using predefined keys with K alias for special fields
|
|
from chromadb.execution.expression import K
|
|
K.DOCUMENT.contains("search text") # Searches document field
|
|
|
|
# Custom metadata field names (without # prefix)
|
|
K("status") == "active" # Metadata field named "status"
|
|
K("category").is_in(["science", "tech"]) # Metadata field named "category"
|
|
K("sparse_embedding") # Example: metadata field (could store anything)
|
|
|
|
# Using with Knn for different fields
|
|
Knn(query=[0.1, 0.2]) # Default: searches "#embedding"
|
|
Knn(query=[0.1, 0.2], key=K.EMBEDDING) # Explicit: searches "#embedding"
|
|
Knn(query=sparse, key="sparse_embedding") # Example: searches a metadata field
|
|
|
|
# Combining conditions
|
|
(K("status") == "active") & (K.SCORE > 0.5)
|
|
"""
|
|
|
|
# Predefined key constants (initialized after class definition)
|
|
ID: "Key"
|
|
DOCUMENT: "Key"
|
|
EMBEDDING: "Key"
|
|
METADATA: "Key"
|
|
SCORE: "Key"
|
|
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
|
|
def __hash__(self) -> int:
|
|
"""Make Key hashable for use in sets"""
|
|
return hash(self.name)
|
|
|
|
# Comparison operators
|
|
def __eq__(self, value: Any) -> Eq: # type: ignore[override]
|
|
"""Equality: Key('field') == value"""
|
|
return Eq(self.name, value)
|
|
|
|
def __ne__(self, value: Any) -> Ne: # type: ignore[override]
|
|
"""Not equal: Key('field') != value"""
|
|
return Ne(self.name, value)
|
|
|
|
def __gt__(self, value: Any) -> Gt:
|
|
"""Greater than: Key('field') > value"""
|
|
return Gt(self.name, value)
|
|
|
|
def __ge__(self, value: Any) -> Gte:
|
|
"""Greater than or equal: Key('field') >= value"""
|
|
return Gte(self.name, value)
|
|
|
|
def __lt__(self, value: Any) -> Lt:
|
|
"""Less than: Key('field') < value"""
|
|
return Lt(self.name, value)
|
|
|
|
def __le__(self, value: Any) -> Lte:
|
|
"""Less than or equal: Key('field') <= value"""
|
|
return Lte(self.name, value)
|
|
|
|
# Builder methods for operations without operators
|
|
def is_in(self, values: List[Any]) -> In:
|
|
"""Check if field value is in list: Key('field').is_in(['a', 'b'])"""
|
|
return In(self.name, values)
|
|
|
|
def not_in(self, values: List[Any]) -> Nin:
|
|
"""Check if field value is not in list: Key('field').not_in(['a', 'b'])"""
|
|
return Nin(self.name, values)
|
|
|
|
def regex(self, pattern: str) -> Regex:
|
|
"""Match field against regex: Key('field').regex('^pattern')"""
|
|
return Regex(self.name, pattern)
|
|
|
|
def not_regex(self, pattern: str) -> NotRegex:
|
|
"""Field should not match regex: Key('field').not_regex('^pattern')"""
|
|
return NotRegex(self.name, pattern)
|
|
|
|
def contains(self, content: str) -> Contains:
|
|
"""Check if field contains text: Key('field').contains('text')"""
|
|
return Contains(self.name, content)
|
|
|
|
def not_contains(self, content: str) -> NotContains:
|
|
"""Check if field doesn't contain text: Key('field').not_contains('text')"""
|
|
return NotContains(self.name, content)
|
|
|
|
|
|
# Initialize predefined key constants
|
|
Key.ID = Key("#id")
|
|
Key.DOCUMENT = Key("#document")
|
|
Key.EMBEDDING = Key("#embedding")
|
|
Key.METADATA = Key("#metadata")
|
|
Key.SCORE = Key("#score")
|
|
|
|
# Alias for Key
|
|
K = Key
|
|
|
|
|
|
@dataclass
|
|
class Filter:
|
|
user_ids: Optional[IDs] = None
|
|
where: Optional[Any] = None # Old Where type from chromadb.types
|
|
where_document: Optional[Any] = None # Old WhereDocument type
|
|
|
|
|
|
@dataclass
|
|
class KNN:
|
|
embeddings: Embeddings
|
|
fetch: int
|
|
|
|
|
|
@dataclass
|
|
class Limit:
|
|
offset: int = 0
|
|
limit: Optional[int] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert the Limit to a dictionary for JSON serialization"""
|
|
result = {"offset": self.offset}
|
|
if self.limit is not None:
|
|
result["limit"] = self.limit
|
|
return result
|
|
|
|
@staticmethod
|
|
def from_dict(data: Dict[str, Any]) -> "Limit":
|
|
"""Create Limit from dictionary.
|
|
|
|
Examples:
|
|
- {"offset": 10} -> Limit(offset=10)
|
|
- {"offset": 10, "limit": 20} -> Limit(offset=10, limit=20)
|
|
- {"limit": 20} -> Limit(offset=0, limit=20)
|
|
"""
|
|
if not isinstance(data, dict):
|
|
raise TypeError(f"Expected dict for Limit, got {type(data).__name__}")
|
|
|
|
offset = data.get("offset", 0)
|
|
if not isinstance(offset, int):
|
|
raise TypeError(
|
|
f"Limit offset must be an integer, got {type(offset).__name__}"
|
|
)
|
|
if offset < 0:
|
|
raise ValueError(f"Limit offset must be non-negative, got {offset}")
|
|
|
|
limit = data.get("limit")
|
|
if limit is not None:
|
|
if not isinstance(limit, int):
|
|
raise TypeError(
|
|
f"Limit limit must be an integer, got {type(limit).__name__}"
|
|
)
|
|
if limit <= 0:
|
|
raise ValueError(f"Limit limit must be positive, got {limit}")
|
|
|
|
# Check for unexpected keys
|
|
allowed_keys = {"offset", "limit"}
|
|
unexpected_keys = set(data.keys()) - allowed_keys
|
|
if unexpected_keys:
|
|
raise ValueError(f"Unexpected keys in Limit dict: {unexpected_keys}")
|
|
|
|
return Limit(offset=offset, limit=limit)
|
|
|
|
|
|
@dataclass
|
|
class Projection:
|
|
document: bool = False
|
|
embedding: bool = False
|
|
metadata: bool = False
|
|
rank: bool = False
|
|
uri: bool = False
|
|
|
|
@property
|
|
def included(self) -> Include:
|
|
includes = list()
|
|
if self.document:
|
|
includes.append("documents")
|
|
if self.embedding:
|
|
includes.append("embeddings")
|
|
if self.metadata:
|
|
includes.append("metadatas")
|
|
if self.rank:
|
|
includes.append("distances")
|
|
if self.uri:
|
|
includes.append("uris")
|
|
return includes # type: ignore[return-value]
|
|
|
|
|
|
# Rank expression types for hybrid search
|
|
@dataclass
|
|
class Rank:
|
|
"""Base class for Rank expressions (algebraic data type).
|
|
|
|
Supports arithmetic operations for combining rank expressions:
|
|
- Addition: rank1 + rank2, rank + 0.5
|
|
- Subtraction: rank1 - rank2, rank - 0.5
|
|
- Multiplication: rank1 * rank2, rank * 0.8
|
|
- Division: rank1 / rank2, rank / 2.0
|
|
- Negation: -rank
|
|
- Absolute value: abs(rank)
|
|
|
|
Supports mathematical functions:
|
|
- Exponential: rank.exp()
|
|
- Logarithm: rank.log()
|
|
- Maximum: rank.max(other)
|
|
- Minimum: rank.min(other)
|
|
|
|
Examples:
|
|
# Weighted combination
|
|
Knn(query=[0.1, 0.2]) * 0.8 + Val(0.5) * 0.2
|
|
|
|
# Normalization
|
|
Knn(query=[0.1, 0.2]) / Val(10.0)
|
|
|
|
# Clamping
|
|
Knn(query=[0.1, 0.2]).min(1.0).max(0.0)
|
|
"""
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert the Score expression to a dictionary for JSON serialization"""
|
|
raise NotImplementedError("Subclasses must implement to_dict()")
|
|
|
|
@staticmethod
|
|
def from_dict(data: Dict[str, Any]) -> "Rank":
|
|
"""Create Rank expression from dictionary.
|
|
|
|
Supports operators:
|
|
- {"$val": number} -> Val(number)
|
|
- {"$knn": {...}} -> Knn(...)
|
|
- {"$sum": [ranks]} -> rank1 + rank2 + ...
|
|
- {"$sub": {"left": ..., "right": ...}} -> left - right
|
|
- {"$mul": [ranks]} -> rank1 * rank2 * ...
|
|
- {"$div": {"left": ..., "right": ...}} -> left / right
|
|
- {"$abs": rank} -> abs(rank)
|
|
- {"$exp": rank} -> rank.exp()
|
|
- {"$log": rank} -> rank.log()
|
|
- {"$max": [ranks]} -> rank1.max(rank2).max(rank3)...
|
|
- {"$min": [ranks]} -> rank1.min(rank2).min(rank3)...
|
|
"""
|
|
if not isinstance(data, dict):
|
|
raise TypeError(f"Expected dict for Rank, got {type(data).__name__}")
|
|
|
|
if not data:
|
|
raise ValueError("Rank dict cannot be empty")
|
|
|
|
if len(data) != 1:
|
|
raise ValueError(
|
|
f"Rank dict must contain exactly one operator, got {len(data)}"
|
|
)
|
|
|
|
op = next(iter(data.keys()))
|
|
|
|
if op == "$val":
|
|
value = data["$val"]
|
|
if not isinstance(value, (int, float)):
|
|
raise TypeError(f"$val requires a number, got {type(value).__name__}")
|
|
return Val(value)
|
|
|
|
elif op == "$knn":
|
|
knn_data = data["$knn"]
|
|
if not isinstance(knn_data, dict):
|
|
raise TypeError(f"$knn requires a dict, got {type(knn_data).__name__}")
|
|
|
|
if "query" not in knn_data:
|
|
raise ValueError("$knn requires 'query' field")
|
|
|
|
query = knn_data["query"]
|
|
|
|
if isinstance(query, dict):
|
|
# SparseVector case - deserialize from transport format
|
|
if query.get(TYPE_KEY) == SPARSE_VECTOR_TYPE_VALUE:
|
|
query = SparseVector.from_dict(query)
|
|
else:
|
|
# Old format or invalid - try to construct directly
|
|
raise ValueError(
|
|
f"Expected dict with {TYPE_KEY}='{SPARSE_VECTOR_TYPE_VALUE}', got {query}"
|
|
)
|
|
|
|
elif isinstance(query, (list, tuple, np.ndarray)):
|
|
# Dense vector case - normalize then validate
|
|
normalized = normalize_embeddings(query)
|
|
if not normalized or len(normalized) > 1:
|
|
raise ValueError("$knn requires exactly one query embedding")
|
|
|
|
# Validate the normalized version
|
|
validate_embeddings(normalized)
|
|
|
|
query = normalized[0]
|
|
|
|
else:
|
|
raise TypeError(
|
|
f"$knn query must be a list, numpy array, or SparseVector dict, got {type(query).__name__}"
|
|
)
|
|
|
|
key = knn_data.get("key", "#embedding")
|
|
if not isinstance(key, str):
|
|
raise TypeError(f"$knn key must be a string, got {type(key).__name__}")
|
|
|
|
limit = knn_data.get("limit", 16)
|
|
if not isinstance(limit, int):
|
|
raise TypeError(
|
|
f"$knn limit must be an integer, got {type(limit).__name__}"
|
|
)
|
|
if limit <= 0:
|
|
raise ValueError(f"$knn limit must be positive, got {limit}")
|
|
|
|
return_rank = knn_data.get("return_rank", False)
|
|
if not isinstance(return_rank, bool):
|
|
raise TypeError(
|
|
f"$knn return_rank must be a boolean, got {type(return_rank).__name__}"
|
|
)
|
|
|
|
return Knn(
|
|
query=query,
|
|
key=key,
|
|
limit=limit,
|
|
default=knn_data.get("default"),
|
|
return_rank=return_rank,
|
|
)
|
|
|
|
elif op == "$sum":
|
|
ranks_data = data["$sum"]
|
|
if not isinstance(ranks_data, (list, tuple)):
|
|
raise TypeError(
|
|
f"$sum requires a list, got {type(ranks_data).__name__}"
|
|
)
|
|
if len(ranks_data) < 2:
|
|
raise ValueError(
|
|
f"$sum requires at least 2 ranks, got {len(ranks_data)}"
|
|
)
|
|
|
|
ranks = [Rank.from_dict(r) for r in ranks_data]
|
|
result = ranks[0]
|
|
for r in ranks[1:]:
|
|
result = result + r
|
|
return result
|
|
|
|
elif op == "$sub":
|
|
sub_data = data["$sub"]
|
|
if not isinstance(sub_data, dict):
|
|
raise TypeError(
|
|
f"$sub requires a dict with 'left' and 'right', got {type(sub_data).__name__}"
|
|
)
|
|
if "left" not in sub_data or "right" not in sub_data:
|
|
raise ValueError("$sub requires 'left' and 'right' fields")
|
|
|
|
left = Rank.from_dict(sub_data["left"])
|
|
right = Rank.from_dict(sub_data["right"])
|
|
return left - right
|
|
|
|
elif op == "$mul":
|
|
ranks_data = data["$mul"]
|
|
if not isinstance(ranks_data, (list, tuple)):
|
|
raise TypeError(
|
|
f"$mul requires a list, got {type(ranks_data).__name__}"
|
|
)
|
|
if len(ranks_data) < 2:
|
|
raise ValueError(
|
|
f"$mul requires at least 2 ranks, got {len(ranks_data)}"
|
|
)
|
|
|
|
ranks = [Rank.from_dict(r) for r in ranks_data]
|
|
result = ranks[0]
|
|
for r in ranks[1:]:
|
|
result = result * r
|
|
return result
|
|
|
|
elif op == "$div":
|
|
div_data = data["$div"]
|
|
if not isinstance(div_data, dict):
|
|
raise TypeError(
|
|
f"$div requires a dict with 'left' and 'right', got {type(div_data).__name__}"
|
|
)
|
|
if "left" not in div_data or "right" not in div_data:
|
|
raise ValueError("$div requires 'left' and 'right' fields")
|
|
|
|
left = Rank.from_dict(div_data["left"])
|
|
right = Rank.from_dict(div_data["right"])
|
|
return left / right
|
|
|
|
elif op == "$abs":
|
|
child_data = data["$abs"]
|
|
if not isinstance(child_data, dict):
|
|
raise TypeError(
|
|
f"$abs requires a rank dict, got {type(child_data).__name__}"
|
|
)
|
|
return abs(Rank.from_dict(child_data))
|
|
|
|
elif op == "$exp":
|
|
child_data = data["$exp"]
|
|
if not isinstance(child_data, dict):
|
|
raise TypeError(
|
|
f"$exp requires a rank dict, got {type(child_data).__name__}"
|
|
)
|
|
return Rank.from_dict(child_data).exp()
|
|
|
|
elif op == "$log":
|
|
child_data = data["$log"]
|
|
if not isinstance(child_data, dict):
|
|
raise TypeError(
|
|
f"$log requires a rank dict, got {type(child_data).__name__}"
|
|
)
|
|
return Rank.from_dict(child_data).log()
|
|
|
|
elif op == "$max":
|
|
ranks_data = data["$max"]
|
|
if not isinstance(ranks_data, (list, tuple)):
|
|
raise TypeError(
|
|
f"$max requires a list, got {type(ranks_data).__name__}"
|
|
)
|
|
if len(ranks_data) < 2:
|
|
raise ValueError(
|
|
f"$max requires at least 2 ranks, got {len(ranks_data)}"
|
|
)
|
|
|
|
ranks = [Rank.from_dict(r) for r in ranks_data]
|
|
result = ranks[0]
|
|
for r in ranks[1:]:
|
|
result = result.max(r)
|
|
return result
|
|
|
|
elif op == "$min":
|
|
ranks_data = data["$min"]
|
|
if not isinstance(ranks_data, (list, tuple)):
|
|
raise TypeError(
|
|
f"$min requires a list, got {type(ranks_data).__name__}"
|
|
)
|
|
if len(ranks_data) < 2:
|
|
raise ValueError(
|
|
f"$min requires at least 2 ranks, got {len(ranks_data)}"
|
|
)
|
|
|
|
ranks = [Rank.from_dict(r) for r in ranks_data]
|
|
result = ranks[0]
|
|
for r in ranks[1:]:
|
|
result = result.min(r)
|
|
return result
|
|
|
|
else:
|
|
raise ValueError(f"Unknown rank operator: {op}")
|
|
|
|
# Arithmetic operators
|
|
def __add__(self, other: Union["Rank", float, int]) -> "Sum":
|
|
"""Addition: rank1 + rank2 or rank + value"""
|
|
other_rank = Val(other) if isinstance(other, (int, float)) else other
|
|
# Flatten if already Sum
|
|
if isinstance(self, Sum):
|
|
if isinstance(other_rank, Sum):
|
|
return Sum(self.ranks + other_rank.ranks)
|
|
return Sum(self.ranks + [other_rank])
|
|
elif isinstance(other_rank, Sum):
|
|
return Sum([self] + other_rank.ranks)
|
|
return Sum([self, other_rank])
|
|
|
|
def __radd__(self, other: Union[float, int]) -> "Sum":
|
|
"""Right addition: value + rank"""
|
|
return Val(other) + self
|
|
|
|
def __sub__(self, other: Union["Rank", float, int]) -> "Sub":
|
|
"""Subtraction: rank1 - rank2 or rank - value"""
|
|
other_rank = Val(other) if isinstance(other, (int, float)) else other
|
|
return Sub(self, other_rank)
|
|
|
|
def __rsub__(self, other: Union[float, int]) -> "Sub":
|
|
"""Right subtraction: value - rank"""
|
|
return Sub(Val(other), self)
|
|
|
|
def __mul__(self, other: Union["Rank", float, int]) -> "Mul":
|
|
"""Multiplication: rank1 * rank2 or rank * value"""
|
|
other_rank = Val(other) if isinstance(other, (int, float)) else other
|
|
# Flatten if already Mul
|
|
if isinstance(self, Mul):
|
|
if isinstance(other_rank, Mul):
|
|
return Mul(self.ranks + other_rank.ranks)
|
|
return Mul(self.ranks + [other_rank])
|
|
elif isinstance(other_rank, Mul):
|
|
return Mul([self] + other_rank.ranks)
|
|
return Mul([self, other_rank])
|
|
|
|
def __rmul__(self, other: Union[float, int]) -> "Mul":
|
|
"""Right multiplication: value * rank"""
|
|
return Val(other) * self
|
|
|
|
def __truediv__(self, other: Union["Rank", float, int]) -> "Div":
|
|
"""Division: rank1 / rank2 or rank / value"""
|
|
other_rank = Val(other) if isinstance(other, (int, float)) else other
|
|
return Div(self, other_rank)
|
|
|
|
def __rtruediv__(self, other: Union[float, int]) -> "Div":
|
|
"""Right division: value / rank"""
|
|
return Div(Val(other), self)
|
|
|
|
def __neg__(self) -> "Mul":
|
|
"""Negation: -rank (equivalent to -1 * rank)"""
|
|
return Mul([Val(-1), self])
|
|
|
|
def __abs__(self) -> "Abs":
|
|
"""Absolute value: abs(rank)"""
|
|
return Abs(self)
|
|
|
|
def abs(self) -> "Abs":
|
|
"""Absolute value builder: rank.abs()"""
|
|
return Abs(self)
|
|
|
|
# Builder methods for functions
|
|
def exp(self) -> "Exp":
|
|
"""Exponential: e^rank"""
|
|
return Exp(self)
|
|
|
|
def log(self) -> "Log":
|
|
"""Natural logarithm: ln(rank)"""
|
|
return Log(self)
|
|
|
|
def max(self, other: Union["Rank", float, int]) -> "Max":
|
|
"""Maximum of this rank and another: rank.max(rank2)"""
|
|
other_rank = Val(other) if isinstance(other, (int, float)) else other
|
|
|
|
# Flatten if already Max
|
|
if isinstance(self, Max):
|
|
if isinstance(other_rank, Max):
|
|
return Max(self.ranks + other_rank.ranks)
|
|
return Max(self.ranks + [other_rank])
|
|
elif isinstance(other_rank, Max):
|
|
return Max([self] + other_rank.ranks)
|
|
return Max([self, other_rank])
|
|
|
|
def min(self, other: Union["Rank", float, int]) -> "Min":
|
|
"""Minimum of this rank and another: rank.min(rank2)"""
|
|
other_rank = Val(other) if isinstance(other, (int, float)) else other
|
|
|
|
# Flatten if already Min
|
|
if isinstance(self, Min):
|
|
if isinstance(other_rank, Min):
|
|
return Min(self.ranks + other_rank.ranks)
|
|
return Min(self.ranks + [other_rank])
|
|
elif isinstance(other_rank, Min):
|
|
return Min([self] + other_rank.ranks)
|
|
return Min([self, other_rank])
|
|
|
|
|
|
@dataclass
|
|
class Abs(Rank):
|
|
"""Absolute value of a rank"""
|
|
|
|
rank: Rank
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$abs": self.rank.to_dict()}
|
|
|
|
|
|
@dataclass
|
|
class Div(Rank):
|
|
"""Division of two ranks"""
|
|
|
|
left: Rank
|
|
right: Rank
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$div": {"left": self.left.to_dict(), "right": self.right.to_dict()}}
|
|
|
|
|
|
@dataclass
|
|
class Exp(Rank):
|
|
"""Exponentiation of a rank"""
|
|
|
|
rank: Rank
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$exp": self.rank.to_dict()}
|
|
|
|
|
|
@dataclass
|
|
class Log(Rank):
|
|
"""Logarithm of a rank"""
|
|
|
|
rank: Rank
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$log": self.rank.to_dict()}
|
|
|
|
|
|
@dataclass
|
|
class Max(Rank):
|
|
"""Maximum of multiple ranks"""
|
|
|
|
ranks: List[Rank]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$max": [r.to_dict() for r in self.ranks]}
|
|
|
|
|
|
@dataclass
|
|
class Min(Rank):
|
|
"""Minimum of multiple ranks"""
|
|
|
|
ranks: List[Rank]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$min": [r.to_dict() for r in self.ranks]}
|
|
|
|
|
|
@dataclass
|
|
class Mul(Rank):
|
|
"""Multiplication of multiple ranks"""
|
|
|
|
ranks: List[Rank]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$mul": [r.to_dict() for r in self.ranks]}
|
|
|
|
|
|
@dataclass
|
|
class Knn(Rank):
|
|
"""KNN-based ranking
|
|
|
|
Args:
|
|
query: The query for KNN search. Can be:
|
|
- A string (will be automatically embedded using the collection's embedding function)
|
|
- A dense vector (list or numpy array)
|
|
- A sparse vector (SparseVector dict)
|
|
key: The embedding key to search against. Can be:
|
|
- Key.EMBEDDING (default) - searches the main embedding field
|
|
- A metadata field name (e.g., "my_custom_field") - searches that metadata field
|
|
limit: Maximum number of results to consider (default: 16)
|
|
default: Default score for records not in KNN results (default: None)
|
|
return_rank: If True, return the rank position (0, 1, 2, ...) instead of distance (default: False)
|
|
|
|
Examples:
|
|
# Search with string query (automatically embedded)
|
|
Knn(query="hello world") # Will use collection's embedding function
|
|
|
|
# Search main embeddings with vectors (equivalent forms)
|
|
Knn(query=[0.1, 0.2]) # Uses default key="#embedding"
|
|
Knn(query=[0.1, 0.2], key=K.EMBEDDING)
|
|
Knn(query=[0.1, 0.2], key="#embedding")
|
|
|
|
# Search sparse embeddings stored in metadata with string
|
|
Knn(query="hello world", key="custom_embedding") # Will use schema's embedding function
|
|
|
|
# Search sparse embeddings stored in metadata with vector
|
|
Knn(query=my_vector, key="custom_embedding") # Example: searches a metadata field
|
|
"""
|
|
|
|
query: Union[
|
|
str,
|
|
List[float],
|
|
SparseVector,
|
|
"NDArray[np.float32]",
|
|
"NDArray[np.float64]",
|
|
"NDArray[np.int32]",
|
|
]
|
|
key: Union[Key, str] = K.EMBEDDING
|
|
limit: int = 16
|
|
default: Optional[float] = None
|
|
return_rank: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
# Convert to transport format
|
|
query_value = self.query
|
|
if isinstance(query_value, SparseVector):
|
|
# Convert SparseVector dataclass to transport dict
|
|
query_value = query_value.to_dict()
|
|
elif isinstance(query_value, np.ndarray):
|
|
# Convert numpy array to list
|
|
query_value = query_value.tolist()
|
|
|
|
key_value = self.key
|
|
if isinstance(key_value, Key):
|
|
key_value = key_value.name
|
|
|
|
# Build result dict - only include non-default values to keep JSON clean
|
|
result = {"query": query_value, "key": key_value, "limit": self.limit}
|
|
|
|
# Only include optional fields if they're set to non-default values
|
|
if self.default is not None:
|
|
result["default"] = self.default # type: ignore[assignment]
|
|
if self.return_rank: # Only include if True (non-default)
|
|
result["return_rank"] = self.return_rank
|
|
|
|
return {"$knn": result}
|
|
|
|
|
|
@dataclass
|
|
class Sub(Rank):
|
|
"""Subtraction of two ranks"""
|
|
|
|
left: Rank
|
|
right: Rank
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$sub": {"left": self.left.to_dict(), "right": self.right.to_dict()}}
|
|
|
|
|
|
@dataclass
|
|
class Sum(Rank):
|
|
"""Summation of multiple ranks"""
|
|
|
|
ranks: List[Rank]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$sum": [r.to_dict() for r in self.ranks]}
|
|
|
|
|
|
@dataclass
|
|
class Val(Rank):
|
|
"""Constant rank value"""
|
|
|
|
value: float
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$val": self.value}
|
|
|
|
|
|
@dataclass
|
|
class Rrf(Rank):
|
|
"""Reciprocal Rank Fusion for combining multiple ranking strategies.
|
|
|
|
RRF formula: score = -sum(weight_i / (k + rank_i)) for each ranking strategy
|
|
The negative is used because RRF produces higher scores for better results,
|
|
but Chroma uses ascending order (lower scores = better results).
|
|
|
|
Args:
|
|
ranks: List of Rank expressions to fuse (must have at least one)
|
|
k: Smoothing constant (default: 60, standard in literature)
|
|
weights: Optional weights for each ranking strategy. If not provided,
|
|
all ranks are weighted equally (weight=1.0 each).
|
|
normalize: If True, normalize weights to sum to 1.0 (default: False).
|
|
When False, weights are used as-is for relative importance.
|
|
When True, weights are scaled so they sum to 1.0.
|
|
|
|
Examples:
|
|
# Note: metadata fields (like "sparse_embedding" below) are user-defined and can store any data.
|
|
# The field name is just an example - use whatever name matches your metadata structure.
|
|
# Basic RRF combining KNN rankings (equal weight)
|
|
Rrf([
|
|
Knn(query=[0.1, 0.2], return_rank=True),
|
|
Knn(query=another_vector, key="custom_embedding", return_rank=True) # Example metadata field
|
|
])
|
|
|
|
# Weighted RRF with relative weights (not normalized)
|
|
Rrf(
|
|
ranks=[
|
|
Knn(query=[0.1, 0.2], return_rank=True),
|
|
Knn(query=another_vector, key="custom_embedding", return_rank=True) # Example metadata field
|
|
weights=[2.0, 1.0], # First ranking is 2x more important
|
|
k=100
|
|
)
|
|
|
|
# Weighted RRF with normalized weights
|
|
Rrf(
|
|
ranks=[
|
|
Knn(query=[0.1, 0.2], return_rank=True),
|
|
Knn(query=another_vector, key="custom_embedding", return_rank=True) # Example metadata field
|
|
],
|
|
weights=[3.0, 1.0], # Will be normalized to [0.75, 0.25]
|
|
normalize=True,
|
|
k=100
|
|
)
|
|
"""
|
|
|
|
ranks: List[Rank]
|
|
k: int = 60
|
|
weights: Optional[List[float]] = None
|
|
normalize: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert RRF to a composition of existing expression operators.
|
|
|
|
Builds: -sum(weight_i / (k + rank_i)) for each rank
|
|
Using Python's overloaded operators for cleaner code.
|
|
"""
|
|
# Validate RRF parameters
|
|
if not self.ranks:
|
|
raise ValueError("RRF requires at least one rank")
|
|
if self.k <= 0:
|
|
raise ValueError(f"k must be positive, got {self.k}")
|
|
|
|
# Validate weights if provided
|
|
if self.weights is not None:
|
|
if len(self.weights) != len(self.ranks):
|
|
raise ValueError(
|
|
f"Number of weights ({len(self.weights)}) must match number of ranks ({len(self.ranks)})"
|
|
)
|
|
if any(w < 0.0 for w in self.weights):
|
|
raise ValueError("All weights must be non-negative")
|
|
|
|
# Populate weights with 1.0 if not provided
|
|
weights = self.weights if self.weights else [1.0] * len(self.ranks)
|
|
|
|
# Normalize weights if requested
|
|
if self.normalize:
|
|
weight_sum = sum(weights)
|
|
if weight_sum == 0:
|
|
raise ValueError("Sum of weights must be positive when normalize=True")
|
|
weights = [w / weight_sum for w in weights]
|
|
|
|
# Zip weights with ranks and build terms: weight / (k + rank)
|
|
terms = [w / (self.k + rank) for w, rank in zip(weights, self.ranks)]
|
|
|
|
# Sum all terms - guaranteed to have at least one
|
|
rrf_sum: Rank = terms[0]
|
|
for term in terms[1:]:
|
|
rrf_sum = rrf_sum + term
|
|
|
|
# Negate (RRF gives higher scores for better, Chroma needs lower for better)
|
|
return (-rrf_sum).to_dict()
|
|
|
|
|
|
@dataclass
|
|
class Select:
|
|
"""Selection configuration for search results.
|
|
|
|
Fields can be:
|
|
- Key.DOCUMENT - Select document key (equivalent to Key("#document"))
|
|
- Key.EMBEDDING - Select embedding key (equivalent to Key("#embedding"))
|
|
- Key.SCORE - Select score key (equivalent to Key("#score"))
|
|
- Any other string - Select specific metadata property
|
|
|
|
Note: You can use K as an alias for Key for more concise code.
|
|
|
|
Examples:
|
|
# Select predefined keys using K alias (K is shorthand for Key)
|
|
from chromadb.execution.expression import K
|
|
Select(keys={K.DOCUMENT, K.SCORE})
|
|
|
|
# Select specific metadata properties
|
|
Select(keys={"title", "author", "date"})
|
|
|
|
# Mixed selection
|
|
Select(keys={K.DOCUMENT, "title", "author"})
|
|
"""
|
|
|
|
keys: Set[Union[Key, str]] = field(default_factory=set)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert the Select to a dictionary for JSON serialization"""
|
|
# Convert Key objects to their string values
|
|
key_strings = []
|
|
for k in self.keys:
|
|
if isinstance(k, Key):
|
|
key_strings.append(k.name)
|
|
else:
|
|
key_strings.append(k)
|
|
# Remove duplicates while preserving order
|
|
return {"keys": list(dict.fromkeys(key_strings))}
|
|
|
|
@staticmethod
|
|
def from_dict(data: Dict[str, Any]) -> "Select":
|
|
"""Create Select from dictionary.
|
|
|
|
Examples:
|
|
- {"keys": ["#document", "#score"]} -> Select(keys={Key.DOCUMENT, Key.SCORE})
|
|
- {"keys": ["title", "author"]} -> Select(keys={"title", "author"})
|
|
"""
|
|
if not isinstance(data, dict):
|
|
raise TypeError(f"Expected dict for Select, got {type(data).__name__}")
|
|
|
|
keys = data.get("keys", [])
|
|
if not isinstance(keys, (list, tuple, set)):
|
|
raise TypeError(
|
|
f"Select keys must be a list/tuple/set, got {type(keys).__name__}"
|
|
)
|
|
|
|
# Validate and convert each key
|
|
key_list = []
|
|
for k in keys:
|
|
if not isinstance(k, str):
|
|
raise TypeError(f"Select key must be a string, got {type(k).__name__}")
|
|
|
|
# Map special keys to Key instances
|
|
if k == "#id":
|
|
key_list.append(Key.ID)
|
|
elif k == "#document":
|
|
key_list.append(Key.DOCUMENT)
|
|
elif k == "#embedding":
|
|
key_list.append(Key.EMBEDDING)
|
|
elif k == "#metadata":
|
|
key_list.append(Key.METADATA)
|
|
elif k == "#score":
|
|
key_list.append(Key.SCORE)
|
|
else:
|
|
# Regular metadata field
|
|
key_list.append(Key(k))
|
|
|
|
# Check for unexpected keys in dict
|
|
allowed_keys = {"keys"}
|
|
unexpected_keys = set(data.keys()) - allowed_keys
|
|
if unexpected_keys:
|
|
raise ValueError(f"Unexpected keys in Select dict: {unexpected_keys}")
|
|
|
|
# Convert to set while preserving the Key instances
|
|
return Select(keys=set(key_list))
|
|
|
|
|
|
# GroupBy and Aggregate types for grouping search results
|
|
|
|
|
|
def _keys_to_strings(keys: OneOrMany[Union[Key, str]]) -> List[str]:
|
|
"""Convert OneOrMany[Key|str] to List[str] for serialization."""
|
|
keys_list = cast(List[Union[Key, str]], maybe_cast_one_to_many(keys))
|
|
return [k.name if isinstance(k, Key) else k for k in keys_list]
|
|
|
|
|
|
def _strings_to_keys(keys: Union[List[Any], tuple[Any, ...]]) -> List[Union[Key, str]]:
|
|
"""Convert List[str] to List[Key] for deserialization."""
|
|
return [Key(k) if isinstance(k, str) else k for k in keys]
|
|
|
|
|
|
def _parse_k_aggregate(
|
|
op: str, data: Dict[str, Any]
|
|
) -> tuple[List[Union[Key, str]], int]:
|
|
"""Parse common fields for MinK/MaxK from dict.
|
|
|
|
Args:
|
|
op: The operator name (e.g., "$min_k" or "$max_k")
|
|
data: The dict containing the operator
|
|
|
|
Returns:
|
|
Tuple of (keys, k) where keys is List[Union[Key, str]] and k is int
|
|
|
|
Raises:
|
|
TypeError: If data types are invalid
|
|
ValueError: If required fields are missing or invalid
|
|
"""
|
|
agg_data = data[op]
|
|
if not isinstance(agg_data, dict):
|
|
raise TypeError(f"{op} requires a dict, got {type(agg_data).__name__}")
|
|
if "keys" not in agg_data:
|
|
raise ValueError(f"{op} requires 'keys' field")
|
|
if "k" not in agg_data:
|
|
raise ValueError(f"{op} requires 'k' field")
|
|
|
|
keys = agg_data["keys"]
|
|
if not isinstance(keys, (list, tuple)):
|
|
raise TypeError(f"{op} keys must be a list, got {type(keys).__name__}")
|
|
if not keys:
|
|
raise ValueError(f"{op} keys cannot be empty")
|
|
|
|
k = agg_data["k"]
|
|
if not isinstance(k, int):
|
|
raise TypeError(f"{op} k must be an integer, got {type(k).__name__}")
|
|
if k <= 0:
|
|
raise ValueError(f"{op} k must be positive, got {k}")
|
|
|
|
return _strings_to_keys(keys), k
|
|
|
|
|
|
@dataclass
|
|
class Aggregate:
|
|
"""Base class for aggregation expressions within groups.
|
|
|
|
Aggregations determine which records to keep from each group:
|
|
- MinK: Keep k records with minimum values (ascending order)
|
|
- MaxK: Keep k records with maximum values (descending order)
|
|
|
|
Examples:
|
|
# Keep top 3 by score per group (single key)
|
|
MinK(keys=Key.SCORE, k=3)
|
|
|
|
# Keep top 5 by priority, then score as tiebreaker (multiple keys)
|
|
MinK(keys=[Key("priority"), Key.SCORE], k=5)
|
|
|
|
# Keep bottom 2 by score per group
|
|
MaxK(keys=Key.SCORE, k=2)
|
|
"""
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert the Aggregate expression to a dictionary for JSON serialization"""
|
|
raise NotImplementedError("Subclasses must implement to_dict()")
|
|
|
|
@staticmethod
|
|
def from_dict(data: Dict[str, Any]) -> "Aggregate":
|
|
"""Create Aggregate expression from dictionary.
|
|
|
|
Supports:
|
|
- {"$min_k": {"keys": [...], "k": n}} -> MinK(keys=[...], k=n)
|
|
- {"$max_k": {"keys": [...], "k": n}} -> MaxK(keys=[...], k=n)
|
|
"""
|
|
if not isinstance(data, dict):
|
|
raise TypeError(f"Expected dict for Aggregate, got {type(data).__name__}")
|
|
|
|
if not data:
|
|
raise ValueError("Aggregate dict cannot be empty")
|
|
|
|
if len(data) != 1:
|
|
raise ValueError(
|
|
f"Aggregate dict must contain exactly one operator, got {len(data)}"
|
|
)
|
|
|
|
op = next(iter(data.keys()))
|
|
|
|
if op == "$min_k":
|
|
keys, k = _parse_k_aggregate(op, data)
|
|
return MinK(keys=keys, k=k)
|
|
elif op == "$max_k":
|
|
keys, k = _parse_k_aggregate(op, data)
|
|
return MaxK(keys=keys, k=k)
|
|
else:
|
|
raise ValueError(f"Unknown aggregate operator: {op}")
|
|
|
|
|
|
@dataclass
|
|
class MinK(Aggregate):
|
|
"""Keep k records with minimum aggregate key values per group"""
|
|
|
|
keys: OneOrMany[Union[Key, str]]
|
|
k: int
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$min_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
|
|
|
|
|
|
@dataclass
|
|
class MaxK(Aggregate):
|
|
"""Keep k records with maximum aggregate key values per group"""
|
|
|
|
keys: OneOrMany[Union[Key, str]]
|
|
k: int
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {"$max_k": {"keys": _keys_to_strings(self.keys), "k": self.k}}
|
|
|
|
|
|
@dataclass
|
|
class GroupBy:
|
|
"""Group results by metadata keys and aggregate within each group.
|
|
|
|
Groups search results by one or more metadata fields, then applies an
|
|
aggregation (MinK or MaxK) to select records within each group.
|
|
The final output is flattened and sorted by score.
|
|
|
|
Args:
|
|
keys: Metadata key(s) to group by. Can be a single key or a list of keys.
|
|
E.g., Key("category") or [Key("category"), Key("author")]
|
|
aggregate: Aggregation to apply within each group (MinK or MaxK)
|
|
|
|
Note: Both keys and aggregate must be specified together.
|
|
|
|
Examples:
|
|
# Top 3 documents per category (single key)
|
|
GroupBy(
|
|
keys=Key("category"),
|
|
aggregate=MinK(keys=Key.SCORE, k=3)
|
|
)
|
|
|
|
# Top 2 per (year, category) combination (multiple keys)
|
|
GroupBy(
|
|
keys=[Key("year"), Key("category")],
|
|
aggregate=MinK(keys=Key.SCORE, k=2)
|
|
)
|
|
|
|
# Top 1 per category by priority, score as tiebreaker
|
|
GroupBy(
|
|
keys=Key("category"),
|
|
aggregate=MinK(keys=[Key("priority"), Key.SCORE], k=1)
|
|
)
|
|
"""
|
|
|
|
keys: OneOrMany[Union[Key, str]] = field(default_factory=list)
|
|
aggregate: Optional[Aggregate] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert the GroupBy to a dictionary for JSON serialization"""
|
|
# Default GroupBy (no keys, no aggregate) serializes to {}
|
|
if not self.keys or self.aggregate is None:
|
|
return {}
|
|
result: Dict[str, Any] = {"keys": _keys_to_strings(self.keys)}
|
|
result["aggregate"] = self.aggregate.to_dict()
|
|
return result
|
|
|
|
@staticmethod
|
|
def from_dict(data: Dict[str, Any]) -> "GroupBy":
|
|
"""Create GroupBy from dictionary.
|
|
|
|
Examples:
|
|
- {} -> GroupBy() (default, no grouping)
|
|
- {"keys": ["category"], "aggregate": {"$min_k": {"keys": ["#score"], "k": 3}}}
|
|
"""
|
|
if not isinstance(data, dict):
|
|
raise TypeError(f"Expected dict for GroupBy, got {type(data).__name__}")
|
|
|
|
# Empty dict returns default GroupBy (no grouping)
|
|
if not data:
|
|
return GroupBy()
|
|
|
|
# Non-empty dict requires keys and aggregate
|
|
if "keys" not in data:
|
|
raise ValueError("GroupBy requires 'keys' field")
|
|
if "aggregate" not in data:
|
|
raise ValueError("GroupBy requires 'aggregate' field")
|
|
|
|
keys = data["keys"]
|
|
if not isinstance(keys, (list, tuple)):
|
|
raise TypeError(f"GroupBy keys must be a list, got {type(keys).__name__}")
|
|
if not keys:
|
|
raise ValueError("GroupBy keys cannot be empty")
|
|
|
|
aggregate_data = data["aggregate"]
|
|
if not isinstance(aggregate_data, dict):
|
|
raise TypeError(
|
|
f"GroupBy aggregate must be a dict, got {type(aggregate_data).__name__}"
|
|
)
|
|
aggregate = Aggregate.from_dict(aggregate_data)
|
|
|
|
return GroupBy(keys=_strings_to_keys(keys), aggregate=aggregate)
|