168 lines
5.0 KiB
Python
168 lines
5.0 KiB
Python
from collections.abc import Sequence
|
|
from typing import Generic
|
|
|
|
from typing_extensions import Self
|
|
|
|
from langgraph._internal._typing import MISSING
|
|
from langgraph.channels.base import BaseChannel, Value
|
|
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
|
|
|
__all__ = ("NamedBarrierValue", "NamedBarrierValueAfterFinish")
|
|
|
|
|
|
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
|
|
"""A channel that waits until all named values are received before making the value available."""
|
|
|
|
__slots__ = ("names", "seen")
|
|
|
|
names: set[Value]
|
|
seen: set[Value]
|
|
|
|
def __init__(self, typ: type[Value], names: set[Value]) -> None:
|
|
super().__init__(typ)
|
|
self.names = names
|
|
self.seen: set[str] = set()
|
|
|
|
def __eq__(self, value: object) -> bool:
|
|
return isinstance(value, NamedBarrierValue) and value.names == self.names
|
|
|
|
@property
|
|
def ValueType(self) -> type[Value]:
|
|
"""The type of the value stored in the channel."""
|
|
return self.typ
|
|
|
|
@property
|
|
def UpdateType(self) -> type[Value]:
|
|
"""The type of the update received by the channel."""
|
|
return self.typ
|
|
|
|
def copy(self) -> Self:
|
|
"""Return a copy of the channel."""
|
|
empty = self.__class__(self.typ, self.names)
|
|
empty.key = self.key
|
|
empty.seen = self.seen.copy()
|
|
return empty
|
|
|
|
def checkpoint(self) -> set[Value]:
|
|
return self.seen
|
|
|
|
def from_checkpoint(self, checkpoint: set[Value]) -> Self:
|
|
empty = self.__class__(self.typ, self.names)
|
|
empty.key = self.key
|
|
if checkpoint is not MISSING:
|
|
empty.seen = checkpoint
|
|
return empty
|
|
|
|
def update(self, values: Sequence[Value]) -> bool:
|
|
updated = False
|
|
for value in values:
|
|
if value in self.names:
|
|
if value not in self.seen:
|
|
self.seen.add(value)
|
|
updated = True
|
|
else:
|
|
raise InvalidUpdateError(
|
|
f"At key '{self.key}': Value {value} not in {self.names}"
|
|
)
|
|
return updated
|
|
|
|
def get(self) -> Value:
|
|
if self.seen != self.names:
|
|
raise EmptyChannelError()
|
|
return None
|
|
|
|
def is_available(self) -> bool:
|
|
return self.seen == self.names
|
|
|
|
def consume(self) -> bool:
|
|
if self.seen == self.names:
|
|
self.seen = set()
|
|
return True
|
|
return False
|
|
|
|
|
|
class NamedBarrierValueAfterFinish(
|
|
Generic[Value], BaseChannel[Value, Value, set[Value]]
|
|
):
|
|
"""A channel that waits until all named values are received before making the value ready to be made available. It is only made available after finish() is called."""
|
|
|
|
__slots__ = ("names", "seen", "finished")
|
|
|
|
names: set[Value]
|
|
seen: set[Value]
|
|
|
|
def __init__(self, typ: type[Value], names: set[Value]) -> None:
|
|
super().__init__(typ)
|
|
self.names = names
|
|
self.seen: set[str] = set()
|
|
self.finished = False
|
|
|
|
def __eq__(self, value: object) -> bool:
|
|
return (
|
|
isinstance(value, NamedBarrierValueAfterFinish)
|
|
and value.names == self.names
|
|
)
|
|
|
|
@property
|
|
def ValueType(self) -> type[Value]:
|
|
"""The type of the value stored in the channel."""
|
|
return self.typ
|
|
|
|
@property
|
|
def UpdateType(self) -> type[Value]:
|
|
"""The type of the update received by the channel."""
|
|
return self.typ
|
|
|
|
def copy(self) -> Self:
|
|
"""Return a copy of the channel."""
|
|
empty = self.__class__(self.typ, self.names)
|
|
empty.key = self.key
|
|
empty.seen = self.seen.copy()
|
|
empty.finished = self.finished
|
|
return empty
|
|
|
|
def checkpoint(self) -> tuple[set[Value], bool]:
|
|
return (self.seen, self.finished)
|
|
|
|
def from_checkpoint(self, checkpoint: tuple[set[Value], bool]) -> Self:
|
|
empty = self.__class__(self.typ, self.names)
|
|
empty.key = self.key
|
|
if checkpoint is not MISSING:
|
|
empty.seen, empty.finished = checkpoint
|
|
return empty
|
|
|
|
def update(self, values: Sequence[Value]) -> bool:
|
|
updated = False
|
|
for value in values:
|
|
if value in self.names:
|
|
if value not in self.seen:
|
|
self.seen.add(value)
|
|
updated = True
|
|
else:
|
|
raise InvalidUpdateError(
|
|
f"At key '{self.key}': Value {value} not in {self.names}"
|
|
)
|
|
return updated
|
|
|
|
def get(self) -> Value:
|
|
if not self.finished or self.seen != self.names:
|
|
raise EmptyChannelError()
|
|
return None
|
|
|
|
def is_available(self) -> bool:
|
|
return self.finished and self.seen == self.names
|
|
|
|
def consume(self) -> bool:
|
|
if self.finished and self.seen == self.names:
|
|
self.finished = False
|
|
self.seen = set()
|
|
return True
|
|
return False
|
|
|
|
def finish(self) -> bool:
|
|
if not self.finished and self.seen == self.names:
|
|
self.finished = True
|
|
return True
|
|
else:
|
|
return False
|