"""Character text splitters.""" from __future__ import annotations import re from typing import Any, Literal from langchain_text_splitters.base import Language, TextSplitter class CharacterTextSplitter(TextSplitter): """Splitting text that looks at characters.""" def __init__( self, separator: str = "\n\n", is_separator_regex: bool = False, # noqa: FBT001,FBT002 **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) self._separator = separator self._is_separator_regex = is_separator_regex def split_text(self, text: str) -> list[str]: """Split into chunks without re-inserting lookaround separators.""" # 1. Determine split pattern: raw regex or escaped literal sep_pattern = ( self._separator if self._is_separator_regex else re.escape(self._separator) ) # 2. Initial split (keep separator if requested) splits = _split_text_with_regex( text, sep_pattern, keep_separator=self._keep_separator ) # 3. Detect zero-width lookaround so we never re-insert it lookaround_prefixes = ("(?=", "(? don't re-insert # - else -> re-insert literal separator merge_sep = "" if not (self._keep_separator or is_lookaround): merge_sep = self._separator # 5. Merge adjacent splits and return return self._merge_splits(splits, merge_sep) def _split_text_with_regex( text: str, separator: str, *, keep_separator: bool | Literal["start", "end"] ) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. splits_ = re.split(f"({separator})", text) splits = ( ([splits_[i] + splits_[i + 1] for i in range(0, len(splits_) - 1, 2)]) if keep_separator == "end" else ([splits_[i] + splits_[i + 1] for i in range(1, len(splits_), 2)]) ) if len(splits_) % 2 == 0: splits += splits_[-1:] splits = ( ([*splits, splits_[-1]]) if keep_separator == "end" else ([splits_[0], *splits]) ) else: splits = re.split(separator, text) else: splits = list(text) return [s for s in splits if s] class RecursiveCharacterTextSplitter(TextSplitter): """Splitting text by recursively look at characters. Recursively tries to split by different characters to find one that works. """ def __init__( self, separators: list[str] | None = None, keep_separator: bool | Literal["start", "end"] = True, # noqa: FBT001,FBT002 is_separator_regex: bool = False, # noqa: FBT001,FBT002 **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] self._is_separator_regex = is_separator_regex def _split_text(self, text: str, separators: list[str]) -> list[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = separators[-1] new_separators = [] for i, _s in enumerate(separators): separator_ = _s if self._is_separator_regex else re.escape(_s) if not _s: separator = _s break if re.search(separator_, text): separator = _s new_separators = separators[i + 1 :] break separator_ = separator if self._is_separator_regex else re.escape(separator) splits = _split_text_with_regex( text, separator_, keep_separator=self._keep_separator ) # Now go merging things, recursively splitting longer texts. good_splits = [] separator_ = "" if self._keep_separator else separator for s in splits: if self._length_function(s) < self._chunk_size: good_splits.append(s) else: if good_splits: merged_text = self._merge_splits(good_splits, separator_) final_chunks.extend(merged_text) good_splits = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if good_splits: merged_text = self._merge_splits(good_splits, separator_) final_chunks.extend(merged_text) return final_chunks def split_text(self, text: str) -> list[str]: """Split the input text into smaller chunks based on predefined separators. Args: text: The input text to be split. Returns: A list of text chunks obtained after splitting. """ return self._split_text(text, self._separators) @classmethod def from_language( cls, language: Language, **kwargs: Any ) -> RecursiveCharacterTextSplitter: """Return an instance of this class based on a specific language. This method initializes the text splitter with language-specific separators. Args: language: The language to configure the text splitter for. **kwargs: Additional keyword arguments to customize the splitter. Returns: An instance of the text splitter configured for the specified language. """ separators = cls.get_separators_for_language(language) return cls(separators=separators, is_separator_regex=True, **kwargs) @staticmethod def get_separators_for_language(language: Language) -> list[str]: """Retrieve a list of separators specific to the given language. Args: language: The language for which to get the separators. Returns: A list of separators appropriate for the specified language. """ if language in {Language.C, Language.CPP}: return [ # Split along class definitions "\nclass ", # Split along function definitions "\nvoid ", "\nint ", "\nfloat ", "\ndouble ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.GO: return [ # Split along function definitions "\nfunc ", "\nvar ", "\nconst ", "\ntype ", # Split along control flow statements "\nif ", "\nfor ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.JAVA: return [ # Split along class definitions "\nclass ", # Split along method definitions "\npublic ", "\nprotected ", "\nprivate ", "\nstatic ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.KOTLIN: return [ # Split along class definitions "\nclass ", # Split along method definitions "\npublic ", "\nprotected ", "\nprivate ", "\ninternal ", "\ncompanion ", "\nfun ", "\nval ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nwhen ", "\ncase ", "\nelse ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.JS: return [ # Split along function definitions "\nfunction ", "\nconst ", "\nlet ", "\nvar ", "\nclass ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\ndefault ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.TS: return [ "\nenum ", "\ninterface ", "\nnamespace ", "\ntype ", # Split along class definitions "\nclass ", # Split along function definitions "\nfunction ", "\nconst ", "\nlet ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nswitch ", "\ncase ", "\ndefault ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.PHP: return [ # Split along function definitions "\nfunction ", # Split along class definitions "\nclass ", # Split along control flow statements "\nif ", "\nforeach ", "\nwhile ", "\ndo ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.PROTO: return [ # Split along message definitions "\nmessage ", # Split along service definitions "\nservice ", # Split along enum definitions "\nenum ", # Split along option definitions "\noption ", # Split along import statements "\nimport ", # Split along syntax declarations "\nsyntax ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.PYTHON: return [ # First, try to split along class definitions "\nclass ", "\ndef ", "\n\tdef ", # Now split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.R: return [ # Split along function definitions "\nfunction ", # Split along S4 class and method definitions "\nsetClass\\(", "\nsetMethod\\(", "\nsetGeneric\\(", # Split along control flow statements "\nif ", "\nelse ", "\nfor ", "\nwhile ", "\nrepeat ", # Split along package loading "\nlibrary\\(", "\nrequire\\(", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.RST: return [ # Split along section titles "\n=+\n", "\n-+\n", "\n\\*+\n", # Split along directive markers "\n\n.. *\n\n", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.RUBY: return [ # Split along method definitions "\ndef ", "\nclass ", # Split along control flow statements "\nif ", "\nunless ", "\nwhile ", "\nfor ", "\ndo ", "\nbegin ", "\nrescue ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.ELIXIR: return [ # Split along method function and module definition "\ndef ", "\ndefp ", "\ndefmodule ", "\ndefprotocol ", "\ndefmacro ", "\ndefmacrop ", # Split along control flow statements "\nif ", "\nunless ", "\nwhile ", "\ncase ", "\ncond ", "\nwith ", "\nfor ", "\ndo ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.RUST: return [ # Split along function definitions "\nfn ", "\nconst ", "\nlet ", # Split along control flow statements "\nif ", "\nwhile ", "\nfor ", "\nloop ", "\nmatch ", "\nconst ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.SCALA: return [ # Split along class definitions "\nclass ", "\nobject ", # Split along method definitions "\ndef ", "\nval ", "\nvar ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\nmatch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.SWIFT: return [ # Split along function definitions "\nfunc ", # Split along class definitions "\nclass ", "\nstruct ", "\nenum ", # Split along control flow statements "\nif ", "\nfor ", "\nwhile ", "\ndo ", "\nswitch ", "\ncase ", # Split by the normal type of lines "\n\n", "\n", " ", "", ] if language == Language.MARKDOWN: return [ # First, try to split along Markdown headings (starting with level 2) "\n#{1,6} ", # Note the alternative syntax for headings (below) is not handled here # Heading level 2 # --------------- # End of code block "```\n", # Horizontal lines "\n\\*\\*\\*+\n", "\n---+\n", "\n___+\n", # Note that this splitter doesn't handle horizontal lines defined # by *three or more* of ***, ---, or ___, but this is not handled "\n\n", "\n", " ", "", ] if language == Language.LATEX: return [ # First, try to split along Latex sections "\n\\\\chapter{", "\n\\\\section{", "\n\\\\subsection{", "\n\\\\subsubsection{", # Now split by environments "\n\\\\begin{enumerate}", "\n\\\\begin{itemize}", "\n\\\\begin{description}", "\n\\\\begin{list}", "\n\\\\begin{quote}", "\n\\\\begin{quotation}", "\n\\\\begin{verse}", "\n\\\\begin{verbatim}", # Now split by math environments "\n\\\\begin{align}", "$$", "$", # Now split by the normal type of lines " ", "", ] if language == Language.HTML: return [ # First, try to split along HTML tags "