95 lines
2.8 KiB
Python
95 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Iterator, Sequence
|
|
from typing import Any, Generic
|
|
|
|
from typing_extensions import Self
|
|
|
|
from langgraph._internal._typing import MISSING
|
|
from langgraph.channels.base import BaseChannel, Value
|
|
from langgraph.errors import EmptyChannelError
|
|
|
|
__all__ = ("Topic",)
|
|
|
|
|
|
def _flatten(values: Sequence[Value | list[Value]]) -> Iterator[Value]:
|
|
for value in values:
|
|
if isinstance(value, list):
|
|
yield from value
|
|
else:
|
|
yield value
|
|
|
|
|
|
class Topic(
|
|
Generic[Value],
|
|
BaseChannel[Sequence[Value], Value | list[Value], list[Value]],
|
|
):
|
|
"""A configurable PubSub Topic.
|
|
|
|
Args:
|
|
typ: The type of the value stored in the channel.
|
|
accumulate: Whether to accumulate values across steps. If `False`, the channel will be emptied after each step.
|
|
"""
|
|
|
|
__slots__ = ("values", "accumulate")
|
|
|
|
def __init__(self, typ: type[Value], accumulate: bool = False) -> None:
|
|
super().__init__(typ)
|
|
# attrs
|
|
self.accumulate = accumulate
|
|
# state
|
|
self.values = list[Value]()
|
|
|
|
def __eq__(self, value: object) -> bool:
|
|
return isinstance(value, Topic) and value.accumulate == self.accumulate
|
|
|
|
@property
|
|
def ValueType(self) -> Any:
|
|
"""The type of the value stored in the channel."""
|
|
return Sequence[self.typ] # type: ignore[name-defined]
|
|
|
|
@property
|
|
def UpdateType(self) -> Any:
|
|
"""The type of the update received by the channel."""
|
|
return self.typ | list[self.typ] # type: ignore[name-defined]
|
|
|
|
def copy(self) -> Self:
|
|
"""Return a copy of the channel."""
|
|
empty = self.__class__(self.typ, self.accumulate)
|
|
empty.key = self.key
|
|
empty.values = self.values.copy()
|
|
return empty
|
|
|
|
def checkpoint(self) -> list[Value]:
|
|
return self.values
|
|
|
|
def from_checkpoint(self, checkpoint: list[Value]) -> Self:
|
|
empty = self.__class__(self.typ, self.accumulate)
|
|
empty.key = self.key
|
|
if checkpoint is not MISSING:
|
|
if isinstance(checkpoint, tuple):
|
|
# backwards compatibility
|
|
empty.values = checkpoint[1]
|
|
else:
|
|
empty.values = checkpoint
|
|
return empty
|
|
|
|
def update(self, values: Sequence[Value | list[Value]]) -> bool:
|
|
updated = False
|
|
if not self.accumulate:
|
|
updated = bool(self.values)
|
|
self.values = list[Value]()
|
|
if flat_values := tuple(_flatten(values)):
|
|
updated = True
|
|
self.values.extend(flat_values)
|
|
return updated
|
|
|
|
def get(self) -> Sequence[Value]:
|
|
if self.values:
|
|
return list(self.values)
|
|
else:
|
|
raise EmptyChannelError
|
|
|
|
def is_available(self) -> bool:
|
|
return bool(self.values)
|