group-wbl/.venv/lib/python3.13/site-packages/langgraph/channels/binop.py

135 lines
4.4 KiB
Python
Raw Normal View History

2026-01-09 09:12:25 +08:00
import collections.abc
from collections.abc import Callable, Sequence
from typing import Any, Generic
from typing_extensions import NotRequired, Required, Self
from langgraph._internal._constants import OVERWRITE
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import (
EmptyChannelError,
ErrorCode,
InvalidUpdateError,
create_error_message,
)
from langgraph.types import Overwrite
__all__ = ("BinaryOperatorAggregate",)
# Adapted from typing_extensions
def _strip_extras(t): # type: ignore[no-untyped-def]
"""Strips Annotated, Required and NotRequired from a given type."""
if hasattr(t, "__origin__"):
return _strip_extras(t.__origin__)
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
return _strip_extras(t.__args__[0])
return t
def _get_overwrite(value: Any) -> tuple[bool, Any]:
"""Inspects the given value and returns (is_overwrite, overwrite_value)."""
if isinstance(value, Overwrite):
return True, value.value
if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
return True, value[OVERWRITE]
return False, None
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
"""Stores the result of applying a binary operator to the current value and each new value.
```python
import operator
total = Channels.BinaryOperatorAggregate(int, operator.add)
```
"""
__slots__ = ("value", "operator")
def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
super().__init__(typ)
self.operator = operator
# special forms from typing or collections.abc are not instantiable
# so we need to replace them with their concrete counterparts
typ = _strip_extras(typ)
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
typ = list
if typ in (collections.abc.Set, collections.abc.MutableSet):
typ = set
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
typ = dict
try:
self.value = typ()
except Exception:
self.value = MISSING
def __eq__(self, value: object) -> bool:
return isinstance(value, BinaryOperatorAggregate) and (
value.operator is self.operator
if value.operator.__name__ != "<lambda>"
and self.operator.__name__ != "<lambda>"
else True
)
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.operator)
empty.key = self.key
empty.value = self.value
return empty
def from_checkpoint(self, checkpoint: Value) -> Self:
empty = self.__class__(self.typ, self.operator)
empty.key = self.key
if checkpoint is not MISSING:
empty.value = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
if not values:
return False
if self.value is MISSING:
self.value = values[0]
values = values[1:]
seen_overwrite: bool = False
for value in values:
is_overwrite, overwrite_value = _get_overwrite(value)
if is_overwrite:
if seen_overwrite:
msg = create_error_message(
message="Can receive only one Overwrite value per super-step.",
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
)
raise InvalidUpdateError(msg)
self.value = overwrite_value
seen_overwrite = True
continue
if not seen_overwrite:
self.value = self.operator(self.value, value)
return True
def get(self) -> Value:
if self.value is MISSING:
raise EmptyChannelError()
return self.value
def is_available(self) -> bool:
return self.value is not MISSING
def checkpoint(self) -> Value:
return self.value