227 lines
8.2 KiB
Python
227 lines
8.2 KiB
Python
"""Load summarizing chains."""
|
|
|
|
from collections.abc import Mapping
|
|
from typing import Any, Protocol
|
|
|
|
from langchain_core.callbacks import Callbacks
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
from langchain_core.prompts import BasePromptTemplate
|
|
|
|
from langchain_classic.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
from langchain_classic.chains.combine_documents.map_reduce import (
|
|
MapReduceDocumentsChain,
|
|
)
|
|
from langchain_classic.chains.combine_documents.reduce import ReduceDocumentsChain
|
|
from langchain_classic.chains.combine_documents.refine import RefineDocumentsChain
|
|
from langchain_classic.chains.combine_documents.stuff import StuffDocumentsChain
|
|
from langchain_classic.chains.llm import LLMChain
|
|
from langchain_classic.chains.summarize import (
|
|
map_reduce_prompt,
|
|
refine_prompts,
|
|
stuff_prompt,
|
|
)
|
|
|
|
|
|
class LoadingCallable(Protocol):
|
|
"""Interface for loading the combine documents chain."""
|
|
|
|
def __call__(
|
|
self,
|
|
llm: BaseLanguageModel,
|
|
**kwargs: Any,
|
|
) -> BaseCombineDocumentsChain:
|
|
"""Callable to load the combine documents chain."""
|
|
|
|
|
|
def _load_stuff_chain(
|
|
llm: BaseLanguageModel,
|
|
*,
|
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
|
document_variable_name: str = "text",
|
|
verbose: bool | None = None,
|
|
**kwargs: Any,
|
|
) -> StuffDocumentsChain:
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
|
"""Load a StuffDocumentsChain for summarization.
|
|
|
|
Args:
|
|
llm: Language Model to use in the chain.
|
|
prompt: Prompt template that controls how the documents are formatted and
|
|
passed into the LLM.
|
|
document_variable_name: Variable name in the prompt template where the
|
|
document text will be inserted.
|
|
verbose: Whether to log progress and intermediate steps.
|
|
**kwargs: Additional keyword arguments passed to the StuffDocumentsChain.
|
|
|
|
Returns:
|
|
A StuffDocumentsChain that takes in documents, formats them with the
|
|
given prompt, and runs the chain on the provided LLM.
|
|
"""
|
|
return StuffDocumentsChain(
|
|
llm_chain=llm_chain,
|
|
document_variable_name=document_variable_name,
|
|
verbose=verbose,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def _load_map_reduce_chain(
|
|
llm: BaseLanguageModel,
|
|
*,
|
|
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
|
combine_document_variable_name: str = "text",
|
|
map_reduce_document_variable_name: str = "text",
|
|
collapse_prompt: BasePromptTemplate | None = None,
|
|
reduce_llm: BaseLanguageModel | None = None,
|
|
collapse_llm: BaseLanguageModel | None = None,
|
|
verbose: bool | None = None,
|
|
token_max: int = 3000,
|
|
callbacks: Callbacks = None,
|
|
collapse_max_retries: int | None = None,
|
|
**kwargs: Any,
|
|
) -> MapReduceDocumentsChain:
|
|
map_chain = LLMChain(
|
|
llm=llm,
|
|
prompt=map_prompt,
|
|
verbose=verbose,
|
|
callbacks=callbacks,
|
|
)
|
|
_reduce_llm = reduce_llm or llm
|
|
reduce_chain = LLMChain(
|
|
llm=_reduce_llm,
|
|
prompt=combine_prompt,
|
|
verbose=verbose,
|
|
callbacks=callbacks,
|
|
)
|
|
"""Load a MapReduceDocumentsChain for summarization.
|
|
|
|
This chain first applies a "map" step to summarize each document,
|
|
then applies a "reduce" step to combine the summaries into a
|
|
final result. Optionally, a "collapse" step can be used to handle
|
|
long intermediate results.
|
|
|
|
Args:
|
|
llm: Language Model to use for map and reduce steps.
|
|
map_prompt: Prompt used to summarize each document in the map step.
|
|
combine_prompt: Prompt used to combine summaries in the reduce step.
|
|
combine_document_variable_name: Variable name in the `combine_prompt` where
|
|
the mapped summaries are inserted.
|
|
map_reduce_document_variable_name: Variable name in the `map_prompt`
|
|
where document text is inserted.
|
|
collapse_prompt: Optional prompt used to collapse intermediate summaries
|
|
if they exceed the token limit (`token_max`).
|
|
reduce_llm: Optional separate LLM for the reduce step.
|
|
which uses the same model as the map step.
|
|
collapse_llm: Optional separate LLM for the collapse step.
|
|
which uses the same model as the map step.
|
|
verbose: Whether to log progress and intermediate steps.
|
|
token_max: Token threshold that triggers the collapse step during reduction.
|
|
callbacks: Optional callbacks for logging and tracing.
|
|
collapse_max_retries: Maximum retries for the collapse step if it fails.
|
|
|
|
**kwargs: Additional keyword arguments passed to the MapReduceDocumentsChain.
|
|
|
|
Returns:
|
|
A MapReduceDocumentsChain that maps each document to a summary,
|
|
then reduces all summaries into a single cohesive result.
|
|
"""
|
|
combine_documents_chain = StuffDocumentsChain(
|
|
llm_chain=reduce_chain,
|
|
document_variable_name=combine_document_variable_name,
|
|
verbose=verbose,
|
|
callbacks=callbacks,
|
|
)
|
|
if collapse_prompt is None:
|
|
collapse_chain = None
|
|
if collapse_llm is not None:
|
|
msg = (
|
|
"collapse_llm provided, but collapse_prompt was not: please "
|
|
"provide one or stop providing collapse_llm."
|
|
)
|
|
raise ValueError(msg)
|
|
else:
|
|
_collapse_llm = collapse_llm or llm
|
|
collapse_chain = StuffDocumentsChain(
|
|
llm_chain=LLMChain(
|
|
llm=_collapse_llm,
|
|
prompt=collapse_prompt,
|
|
verbose=verbose,
|
|
callbacks=callbacks,
|
|
),
|
|
document_variable_name=combine_document_variable_name,
|
|
)
|
|
reduce_documents_chain = ReduceDocumentsChain(
|
|
combine_documents_chain=combine_documents_chain,
|
|
collapse_documents_chain=collapse_chain,
|
|
token_max=token_max,
|
|
verbose=verbose,
|
|
callbacks=callbacks,
|
|
collapse_max_retries=collapse_max_retries,
|
|
)
|
|
return MapReduceDocumentsChain(
|
|
llm_chain=map_chain,
|
|
reduce_documents_chain=reduce_documents_chain,
|
|
document_variable_name=map_reduce_document_variable_name,
|
|
verbose=verbose,
|
|
callbacks=callbacks,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def _load_refine_chain(
|
|
llm: BaseLanguageModel,
|
|
*,
|
|
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
|
|
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
|
|
document_variable_name: str = "text",
|
|
initial_response_name: str = "existing_answer",
|
|
refine_llm: BaseLanguageModel | None = None,
|
|
verbose: bool | None = None,
|
|
**kwargs: Any,
|
|
) -> RefineDocumentsChain:
|
|
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
|
_refine_llm = refine_llm or llm
|
|
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
|
return RefineDocumentsChain(
|
|
initial_llm_chain=initial_chain,
|
|
refine_llm_chain=refine_chain,
|
|
document_variable_name=document_variable_name,
|
|
initial_response_name=initial_response_name,
|
|
verbose=verbose,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def load_summarize_chain(
|
|
llm: BaseLanguageModel,
|
|
chain_type: str = "stuff",
|
|
verbose: bool | None = None, # noqa: FBT001
|
|
**kwargs: Any,
|
|
) -> BaseCombineDocumentsChain:
|
|
"""Load summarizing chain.
|
|
|
|
Args:
|
|
llm: Language Model to use in the chain.
|
|
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
|
"map_reduce", and "refine".
|
|
verbose: Whether chains should be run in verbose mode or not. Note that this
|
|
applies to all chains that make up the final chain.
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
Returns:
|
|
A chain to use for summarizing.
|
|
"""
|
|
loader_mapping: Mapping[str, LoadingCallable] = {
|
|
"stuff": _load_stuff_chain,
|
|
"map_reduce": _load_map_reduce_chain,
|
|
"refine": _load_refine_chain,
|
|
}
|
|
if chain_type not in loader_mapping:
|
|
msg = (
|
|
f"Got unsupported chain type: {chain_type}. "
|
|
f"Should be one of {loader_mapping.keys()}"
|
|
)
|
|
raise ValueError(msg)
|
|
return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
|