80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from collections.abc import Sequence
|
||
|
|
from typing import Any, Generic
|
||
|
|
|
||
|
|
from typing_extensions import Self
|
||
|
|
|
||
|
|
from langgraph._internal._typing import MISSING
|
||
|
|
from langgraph.channels.base import BaseChannel, Value
|
||
|
|
from langgraph.errors import EmptyChannelError, InvalidUpdateError
|
||
|
|
|
||
|
|
__all__ = ("EphemeralValue",)
|
||
|
|
|
||
|
|
|
||
|
|
class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):
|
||
|
|
"""Stores the value received in the step immediately preceding, clears after."""
|
||
|
|
|
||
|
|
__slots__ = ("value", "guard")
|
||
|
|
|
||
|
|
value: Value | Any
|
||
|
|
guard: bool
|
||
|
|
|
||
|
|
def __init__(self, typ: Any, guard: bool = True) -> None:
|
||
|
|
super().__init__(typ)
|
||
|
|
self.guard = guard
|
||
|
|
self.value = MISSING
|
||
|
|
|
||
|
|
def __eq__(self, value: object) -> bool:
|
||
|
|
return isinstance(value, EphemeralValue) and value.guard == self.guard
|
||
|
|
|
||
|
|
@property
|
||
|
|
def ValueType(self) -> type[Value]:
|
||
|
|
"""The type of the value stored in the channel."""
|
||
|
|
return self.typ
|
||
|
|
|
||
|
|
@property
|
||
|
|
def UpdateType(self) -> type[Value]:
|
||
|
|
"""The type of the update received by the channel."""
|
||
|
|
return self.typ
|
||
|
|
|
||
|
|
def copy(self) -> Self:
|
||
|
|
"""Return a copy of the channel."""
|
||
|
|
empty = self.__class__(self.typ, self.guard)
|
||
|
|
empty.key = self.key
|
||
|
|
empty.value = self.value
|
||
|
|
return empty
|
||
|
|
|
||
|
|
def from_checkpoint(self, checkpoint: Value) -> Self:
|
||
|
|
empty = self.__class__(self.typ, self.guard)
|
||
|
|
empty.key = self.key
|
||
|
|
if checkpoint is not MISSING:
|
||
|
|
empty.value = checkpoint
|
||
|
|
return empty
|
||
|
|
|
||
|
|
def update(self, values: Sequence[Value]) -> bool:
|
||
|
|
if len(values) == 0:
|
||
|
|
if self.value is not MISSING:
|
||
|
|
self.value = MISSING
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
return False
|
||
|
|
if len(values) != 1 and self.guard:
|
||
|
|
raise InvalidUpdateError(
|
||
|
|
f"At key '{self.key}': EphemeralValue(guard=True) can receive only one value per step. Use guard=False if you want to store any one of multiple values."
|
||
|
|
)
|
||
|
|
|
||
|
|
self.value = values[-1]
|
||
|
|
return True
|
||
|
|
|
||
|
|
def get(self) -> Value:
|
||
|
|
if self.value is MISSING:
|
||
|
|
raise EmptyChannelError()
|
||
|
|
return self.value
|
||
|
|
|
||
|
|
def is_available(self) -> bool:
|
||
|
|
return self.value is not MISSING
|
||
|
|
|
||
|
|
def checkpoint(self) -> Value:
|
||
|
|
return self.value
|