group-wbl/.venv/lib/python3.13/site-packages/langchain_classic/chains/conversation/base.py
2026-01-09 09:48:03 +08:00

150 lines
5.2 KiB
Python

"""Chain that carries on a conversation and calls an LLM."""
from langchain_core._api import deprecated
from langchain_core.prompts import BasePromptTemplate
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self, override
from langchain_classic.base_memory import BaseMemory
from langchain_classic.chains.conversation.prompt import PROMPT
from langchain_classic.chains.llm import LLMChain
from langchain_classic.memory.buffer import ConversationBufferMemory
@deprecated(
since="0.2.7",
alternative="langchain_core.runnables.history.RunnableWithMessageHistory",
removal="1.0",
)
class ConversationChain(LLMChain):
"""Chain to have a conversation and load context from memory.
This class is deprecated in favor of `RunnableWithMessageHistory`. Please refer
to this tutorial for more detail: https://python.langchain.com/docs/tutorials/chatbot/
`RunnableWithMessageHistory` offers several benefits, including:
- Stream, batch, and async support;
- More flexible memory handling, including the ability to manage memory
outside the chain;
- Support for multiple threads.
Below is a minimal implementation, analogous to using `ConversationChain` with
the default `ConversationBufferMemory`:
```python
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
store = {} # memory is maintained outside the chain
def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
if session_id not in store:
store[session_id] = InMemoryChatMessageHistory()
return store[session_id]
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
chain = RunnableWithMessageHistory(model, get_session_history)
chain.invoke(
"Hi I'm Bob.",
config={"configurable": {"session_id": "1"}},
) # session_id determines thread
```
Memory objects can also be incorporated into the `get_session_history` callable:
```python
from langchain_classic.memory import ConversationBufferWindowMemory
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
store = {} # memory is maintained outside the chain
def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
if session_id not in store:
store[session_id] = InMemoryChatMessageHistory()
return store[session_id]
memory = ConversationBufferWindowMemory(
chat_memory=store[session_id],
k=3,
return_messages=True,
)
assert len(memory.memory_variables) == 1
key = memory.memory_variables[0]
messages = memory.load_memory_variables({})[key]
store[session_id] = InMemoryChatMessageHistory(messages=messages)
return store[session_id]
model = ChatOpenAI(model="gpt-3.5-turbo-0125")
chain = RunnableWithMessageHistory(model, get_session_history)
chain.invoke(
"Hi I'm Bob.",
config={"configurable": {"session_id": "1"}},
) # session_id determines thread
```
Example:
```python
from langchain_classic.chains import ConversationChain
from langchain_openai import OpenAI
conversation = ConversationChain(llm=OpenAI())
```
"""
memory: BaseMemory = Field(default_factory=ConversationBufferMemory)
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
input_key: str = "input"
output_key: str = "response"
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@classmethod
@override
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> list[str]:
"""Use this since so some prompt vars come from history."""
return [self.input_key]
@model_validator(mode="after")
def validate_prompt_input_variables(self) -> Self:
"""Validate that prompt input variables are consistent."""
memory_keys = self.memory.memory_variables
input_key = self.input_key
if input_key in memory_keys:
msg = (
f"The input key {input_key} was also found in the memory keys "
f"({memory_keys}) - please provide keys that don't overlap."
)
raise ValueError(msg)
prompt_variables = self.prompt.input_variables
expected_keys = [*memory_keys, input_key]
if set(expected_keys) != set(prompt_variables):
msg = (
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but got {memory_keys} as inputs from "
f"memory, and {input_key} as the normal input key."
)
raise ValueError(msg)
return self