group-wbl/.venv/lib/python3.13/site-packages/langgraph/channels/named_barrier_value.py
2026-01-09 09:12:25 +08:00

168 lines
5.0 KiB
Python

from collections.abc import Sequence
from typing import Generic
from typing_extensions import Self
from langgraph._internal._typing import MISSING
from langgraph.channels.base import BaseChannel, Value
from langgraph.errors import EmptyChannelError, InvalidUpdateError
__all__ = ("NamedBarrierValue", "NamedBarrierValueAfterFinish")
class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):
"""A channel that waits until all named values are received before making the value available."""
__slots__ = ("names", "seen")
names: set[Value]
seen: set[Value]
def __init__(self, typ: type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen: set[str] = set()
def __eq__(self, value: object) -> bool:
return isinstance(value, NamedBarrierValue) and value.names == self.names
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.names)
empty.key = self.key
empty.seen = self.seen.copy()
return empty
def checkpoint(self) -> set[Value]:
return self.seen
def from_checkpoint(self, checkpoint: set[Value]) -> Self:
empty = self.__class__(self.typ, self.names)
empty.key = self.key
if checkpoint is not MISSING:
empty.seen = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
updated = False
for value in values:
if value in self.names:
if value not in self.seen:
self.seen.add(value)
updated = True
else:
raise InvalidUpdateError(
f"At key '{self.key}': Value {value} not in {self.names}"
)
return updated
def get(self) -> Value:
if self.seen != self.names:
raise EmptyChannelError()
return None
def is_available(self) -> bool:
return self.seen == self.names
def consume(self) -> bool:
if self.seen == self.names:
self.seen = set()
return True
return False
class NamedBarrierValueAfterFinish(
Generic[Value], BaseChannel[Value, Value, set[Value]]
):
"""A channel that waits until all named values are received before making the value ready to be made available. It is only made available after finish() is called."""
__slots__ = ("names", "seen", "finished")
names: set[Value]
seen: set[Value]
def __init__(self, typ: type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen: set[str] = set()
self.finished = False
def __eq__(self, value: object) -> bool:
return (
isinstance(value, NamedBarrierValueAfterFinish)
and value.names == self.names
)
@property
def ValueType(self) -> type[Value]:
"""The type of the value stored in the channel."""
return self.typ
@property
def UpdateType(self) -> type[Value]:
"""The type of the update received by the channel."""
return self.typ
def copy(self) -> Self:
"""Return a copy of the channel."""
empty = self.__class__(self.typ, self.names)
empty.key = self.key
empty.seen = self.seen.copy()
empty.finished = self.finished
return empty
def checkpoint(self) -> tuple[set[Value], bool]:
return (self.seen, self.finished)
def from_checkpoint(self, checkpoint: tuple[set[Value], bool]) -> Self:
empty = self.__class__(self.typ, self.names)
empty.key = self.key
if checkpoint is not MISSING:
empty.seen, empty.finished = checkpoint
return empty
def update(self, values: Sequence[Value]) -> bool:
updated = False
for value in values:
if value in self.names:
if value not in self.seen:
self.seen.add(value)
updated = True
else:
raise InvalidUpdateError(
f"At key '{self.key}': Value {value} not in {self.names}"
)
return updated
def get(self) -> Value:
if not self.finished or self.seen != self.names:
raise EmptyChannelError()
return None
def is_available(self) -> bool:
return self.finished and self.seen == self.names
def consume(self) -> bool:
if self.finished and self.seen == self.names:
self.finished = False
self.seen = set()
return True
return False
def finish(self) -> bool:
if not self.finished and self.seen == self.names:
self.finished = True
return True
else:
return False