135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
|
|
import collections.abc
|
||
|
|
from collections.abc import Callable, Sequence
|
||
|
|
from typing import Any, Generic
|
||
|
|
|
||
|
|
from typing_extensions import NotRequired, Required, Self
|
||
|
|
|
||
|
|
from langgraph._internal._constants import OVERWRITE
|
||
|
|
from langgraph._internal._typing import MISSING
|
||
|
|
from langgraph.channels.base import BaseChannel, Value
|
||
|
|
from langgraph.errors import (
|
||
|
|
EmptyChannelError,
|
||
|
|
ErrorCode,
|
||
|
|
InvalidUpdateError,
|
||
|
|
create_error_message,
|
||
|
|
)
|
||
|
|
from langgraph.types import Overwrite
|
||
|
|
|
||
|
|
__all__ = ("BinaryOperatorAggregate",)
|
||
|
|
|
||
|
|
|
||
|
|
# Adapted from typing_extensions
|
||
|
|
def _strip_extras(t): # type: ignore[no-untyped-def]
|
||
|
|
"""Strips Annotated, Required and NotRequired from a given type."""
|
||
|
|
if hasattr(t, "__origin__"):
|
||
|
|
return _strip_extras(t.__origin__)
|
||
|
|
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
|
||
|
|
return _strip_extras(t.__args__[0])
|
||
|
|
|
||
|
|
return t
|
||
|
|
|
||
|
|
|
||
|
|
def _get_overwrite(value: Any) -> tuple[bool, Any]:
|
||
|
|
"""Inspects the given value and returns (is_overwrite, overwrite_value)."""
|
||
|
|
if isinstance(value, Overwrite):
|
||
|
|
return True, value.value
|
||
|
|
if isinstance(value, dict) and set(value.keys()) == {OVERWRITE}:
|
||
|
|
return True, value[OVERWRITE]
|
||
|
|
return False, None
|
||
|
|
|
||
|
|
|
||
|
|
class BinaryOperatorAggregate(Generic[Value], BaseChannel[Value, Value, Value]):
|
||
|
|
"""Stores the result of applying a binary operator to the current value and each new value.
|
||
|
|
|
||
|
|
```python
|
||
|
|
import operator
|
||
|
|
|
||
|
|
total = Channels.BinaryOperatorAggregate(int, operator.add)
|
||
|
|
```
|
||
|
|
"""
|
||
|
|
|
||
|
|
__slots__ = ("value", "operator")
|
||
|
|
|
||
|
|
def __init__(self, typ: type[Value], operator: Callable[[Value, Value], Value]):
|
||
|
|
super().__init__(typ)
|
||
|
|
self.operator = operator
|
||
|
|
# special forms from typing or collections.abc are not instantiable
|
||
|
|
# so we need to replace them with their concrete counterparts
|
||
|
|
typ = _strip_extras(typ)
|
||
|
|
if typ in (collections.abc.Sequence, collections.abc.MutableSequence):
|
||
|
|
typ = list
|
||
|
|
if typ in (collections.abc.Set, collections.abc.MutableSet):
|
||
|
|
typ = set
|
||
|
|
if typ in (collections.abc.Mapping, collections.abc.MutableMapping):
|
||
|
|
typ = dict
|
||
|
|
try:
|
||
|
|
self.value = typ()
|
||
|
|
except Exception:
|
||
|
|
self.value = MISSING
|
||
|
|
|
||
|
|
def __eq__(self, value: object) -> bool:
|
||
|
|
return isinstance(value, BinaryOperatorAggregate) and (
|
||
|
|
value.operator is self.operator
|
||
|
|
if value.operator.__name__ != "<lambda>"
|
||
|
|
and self.operator.__name__ != "<lambda>"
|
||
|
|
else True
|
||
|
|
)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def ValueType(self) -> type[Value]:
|
||
|
|
"""The type of the value stored in the channel."""
|
||
|
|
return self.typ
|
||
|
|
|
||
|
|
@property
|
||
|
|
def UpdateType(self) -> type[Value]:
|
||
|
|
"""The type of the update received by the channel."""
|
||
|
|
return self.typ
|
||
|
|
|
||
|
|
def copy(self) -> Self:
|
||
|
|
"""Return a copy of the channel."""
|
||
|
|
empty = self.__class__(self.typ, self.operator)
|
||
|
|
empty.key = self.key
|
||
|
|
empty.value = self.value
|
||
|
|
return empty
|
||
|
|
|
||
|
|
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||
|
|
empty = self.__class__(self.typ, self.operator)
|
||
|
|
empty.key = self.key
|
||
|
|
if checkpoint is not MISSING:
|
||
|
|
empty.value = checkpoint
|
||
|
|
return empty
|
||
|
|
|
||
|
|
def update(self, values: Sequence[Value]) -> bool:
|
||
|
|
if not values:
|
||
|
|
return False
|
||
|
|
if self.value is MISSING:
|
||
|
|
self.value = values[0]
|
||
|
|
values = values[1:]
|
||
|
|
seen_overwrite: bool = False
|
||
|
|
for value in values:
|
||
|
|
is_overwrite, overwrite_value = _get_overwrite(value)
|
||
|
|
if is_overwrite:
|
||
|
|
if seen_overwrite:
|
||
|
|
msg = create_error_message(
|
||
|
|
message="Can receive only one Overwrite value per super-step.",
|
||
|
|
error_code=ErrorCode.INVALID_CONCURRENT_GRAPH_UPDATE,
|
||
|
|
)
|
||
|
|
raise InvalidUpdateError(msg)
|
||
|
|
self.value = overwrite_value
|
||
|
|
seen_overwrite = True
|
||
|
|
continue
|
||
|
|
if not seen_overwrite:
|
||
|
|
self.value = self.operator(self.value, value)
|
||
|
|
return True
|
||
|
|
|
||
|
|
def get(self) -> Value:
|
||
|
|
if self.value is MISSING:
|
||
|
|
raise EmptyChannelError()
|
||
|
|
return self.value
|
||
|
|
|
||
|
|
def is_available(self) -> bool:
|
||
|
|
return self.value is not MISSING
|
||
|
|
|
||
|
|
def checkpoint(self) -> Value:
|
||
|
|
return self.value
|