276 lines
8.6 KiB
Python
276 lines
8.6 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
import typing
|
|
import warnings
|
|
from contextlib import nullcontext
|
|
from dataclasses import is_dataclass
|
|
from functools import lru_cache
|
|
from typing import (
|
|
Any,
|
|
cast,
|
|
overload,
|
|
)
|
|
|
|
from pydantic import (
|
|
BaseModel,
|
|
ConfigDict,
|
|
Field,
|
|
RootModel,
|
|
)
|
|
from pydantic import (
|
|
create_model as _create_model_base,
|
|
)
|
|
from pydantic.fields import FieldInfo
|
|
from pydantic.json_schema import (
|
|
DEFAULT_REF_TEMPLATE,
|
|
GenerateJsonSchema,
|
|
JsonSchemaMode,
|
|
)
|
|
from typing_extensions import TypedDict
|
|
|
|
|
|
@overload
|
|
def get_fields(model: type[BaseModel]) -> dict[str, FieldInfo]: ...
|
|
|
|
|
|
@overload
|
|
def get_fields(model: BaseModel) -> dict[str, FieldInfo]: ...
|
|
|
|
|
|
def get_fields(
|
|
model: type[BaseModel] | BaseModel,
|
|
) -> dict[str, FieldInfo]:
|
|
"""Get the field names of a Pydantic model."""
|
|
if hasattr(model, "model_fields"):
|
|
return model.model_fields
|
|
|
|
if hasattr(model, "__fields__"):
|
|
return model.__fields__
|
|
msg = f"Expected a Pydantic model. Got {type(model)}"
|
|
raise TypeError(msg)
|
|
|
|
|
|
_SchemaConfig = ConfigDict(
|
|
arbitrary_types_allowed=True, frozen=True, protected_namespaces=()
|
|
)
|
|
|
|
NO_DEFAULT = object()
|
|
|
|
|
|
def _create_root_model(
|
|
name: str,
|
|
type_: Any,
|
|
module_name: str | None = None,
|
|
default_: object = NO_DEFAULT,
|
|
) -> type[BaseModel]:
|
|
"""Create a base class."""
|
|
|
|
def schema(
|
|
cls: type[BaseModel],
|
|
by_alias: bool = True, # noqa: FBT001,FBT002
|
|
ref_template: str = DEFAULT_REF_TEMPLATE,
|
|
) -> dict[str, Any]:
|
|
# Complains about schema not being defined in superclass
|
|
schema_ = super(cls, cls).schema( # type: ignore[misc]
|
|
by_alias=by_alias, ref_template=ref_template
|
|
)
|
|
schema_["title"] = name
|
|
return schema_
|
|
|
|
def model_json_schema(
|
|
cls: type[BaseModel],
|
|
by_alias: bool = True, # noqa: FBT001,FBT002
|
|
ref_template: str = DEFAULT_REF_TEMPLATE,
|
|
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
|
mode: JsonSchemaMode = "validation",
|
|
) -> dict[str, Any]:
|
|
# Complains about model_json_schema not being defined in superclass
|
|
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
|
|
by_alias=by_alias,
|
|
ref_template=ref_template,
|
|
schema_generator=schema_generator,
|
|
mode=mode,
|
|
)
|
|
schema_["title"] = name
|
|
return schema_
|
|
|
|
base_class_attributes = {
|
|
"__annotations__": {"root": type_},
|
|
"model_config": ConfigDict(arbitrary_types_allowed=True),
|
|
"schema": classmethod(schema),
|
|
"model_json_schema": classmethod(model_json_schema),
|
|
"__module__": module_name or "langchain_core.runnables.utils",
|
|
}
|
|
|
|
if default_ is not NO_DEFAULT:
|
|
base_class_attributes["root"] = default_
|
|
with warnings.catch_warnings():
|
|
custom_root_type = type(name, (RootModel,), base_class_attributes)
|
|
return cast("type[BaseModel]", custom_root_type)
|
|
|
|
|
|
@lru_cache(maxsize=256)
|
|
def _create_root_model_cached(
|
|
model_name: str,
|
|
type_: Any,
|
|
*,
|
|
module_name: str | None = None,
|
|
default_: object = NO_DEFAULT,
|
|
) -> type[BaseModel]:
|
|
return _create_root_model(
|
|
model_name, type_, default_=default_, module_name=module_name
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=256)
|
|
def _create_model_cached(
|
|
model_name: str,
|
|
/,
|
|
**field_definitions: Any,
|
|
) -> type[BaseModel]:
|
|
return _create_model_base(
|
|
model_name,
|
|
__config__=_SchemaConfig,
|
|
**_remap_field_definitions(field_definitions),
|
|
)
|
|
|
|
|
|
# Reserved names should capture all the `public` names / methods that are
|
|
# used by BaseModel internally. This will keep the reserved names up-to-date.
|
|
# For reference, the reserved names are:
|
|
# "construct", "copy", "dict", "from_orm", "json", "parse_file", "parse_obj",
|
|
# "parse_raw", "schema", "schema_json", "update_forward_refs", "validate",
|
|
# "model_computed_fields", "model_config", "model_construct", "model_copy",
|
|
# "model_dump", "model_dump_json", "model_extra", "model_fields",
|
|
# "model_fields_set", "model_json_schema", "model_parametrized_name",
|
|
# "model_post_init", "model_rebuild", "model_validate", "model_validate_json",
|
|
# "model_validate_strings"
|
|
_RESERVED_NAMES = {key for key in dir(BaseModel) if not key.startswith("_")}
|
|
|
|
|
|
def _remap_field_definitions(field_definitions: dict[str, Any]) -> dict[str, Any]:
|
|
"""This remaps fields to avoid colliding with internal pydantic fields."""
|
|
|
|
remapped = {}
|
|
for key, value in field_definitions.items():
|
|
if key.startswith("_") or key in _RESERVED_NAMES:
|
|
# Let's add a prefix to avoid colliding with internal pydantic fields
|
|
if isinstance(value, FieldInfo):
|
|
msg = (
|
|
f"Remapping for fields starting with '_' or fields with a name "
|
|
f"matching a reserved name {_RESERVED_NAMES} is not supported if "
|
|
f" the field is a pydantic Field instance. Got {key}."
|
|
)
|
|
raise NotImplementedError(msg)
|
|
type_, default_ = value
|
|
remapped[f"private_{key}"] = (
|
|
type_,
|
|
Field(
|
|
default=default_,
|
|
alias=key,
|
|
serialization_alias=key,
|
|
title=key.lstrip("_").replace("_", " ").title(),
|
|
),
|
|
)
|
|
else:
|
|
remapped[key] = value
|
|
return remapped
|
|
|
|
|
|
def create_model(
|
|
model_name: str,
|
|
*,
|
|
field_definitions: dict[str, Any] | None = None,
|
|
root: Any | None = None,
|
|
) -> type[BaseModel]:
|
|
"""Create a pydantic model with the given field definitions.
|
|
|
|
Attention:
|
|
Please do not use outside of langchain packages. This API
|
|
is subject to change at any time.
|
|
|
|
Args:
|
|
model_name: The name of the model.
|
|
module_name: The name of the module where the model is defined.
|
|
This is used by Pydantic to resolve any forward references.
|
|
field_definitions: The field definitions for the model.
|
|
root: Type for a root model (RootModel)
|
|
|
|
Returns:
|
|
Type[BaseModel]: The created model.
|
|
"""
|
|
field_definitions = field_definitions or {}
|
|
|
|
if root:
|
|
if field_definitions:
|
|
msg = (
|
|
"When specifying __root__ no other "
|
|
f"fields should be provided. Got {field_definitions}"
|
|
)
|
|
raise NotImplementedError(msg)
|
|
|
|
if isinstance(root, tuple):
|
|
kwargs = {"type_": root[0], "default_": root[1]}
|
|
else:
|
|
kwargs = {"type_": root}
|
|
|
|
try:
|
|
named_root_model = _create_root_model_cached(model_name, **kwargs)
|
|
except TypeError:
|
|
# something in the arguments into _create_root_model_cached is not hashable
|
|
named_root_model = _create_root_model(
|
|
model_name,
|
|
**kwargs,
|
|
)
|
|
return named_root_model
|
|
|
|
# No root, just field definitions
|
|
names = set(field_definitions.keys())
|
|
|
|
capture_warnings = False
|
|
|
|
for name in names:
|
|
# Also if any non-reserved name is used (e.g., model_id or model_name)
|
|
if name.startswith("model"):
|
|
capture_warnings = True
|
|
|
|
with warnings.catch_warnings() if capture_warnings else nullcontext():
|
|
if capture_warnings:
|
|
warnings.filterwarnings(action="ignore")
|
|
try:
|
|
return _create_model_cached(model_name, **field_definitions)
|
|
except TypeError:
|
|
# something in field definitions is not hashable
|
|
return _create_model_base(
|
|
model_name,
|
|
__config__=_SchemaConfig,
|
|
**_remap_field_definitions(field_definitions),
|
|
)
|
|
|
|
|
|
def is_supported_by_pydantic(type_: Any) -> bool:
|
|
"""Check if a given "complex" type is supported by pydantic.
|
|
|
|
This will return False for primitive types like int, str, etc.
|
|
|
|
The check is meant for container types like dataclasses, TypedDicts, etc.
|
|
"""
|
|
if is_dataclass(type_):
|
|
return True
|
|
|
|
if isinstance(type_, type) and issubclass(type_, BaseModel):
|
|
return True
|
|
|
|
if hasattr(type_, "__orig_bases__"):
|
|
for base in type_.__orig_bases__:
|
|
if base is TypedDict:
|
|
return True
|
|
elif base is typing.TypedDict: # noqa: TID251
|
|
# ignoring TID251 since it's OK to use typing.TypedDict in this case.
|
|
# Pydantic supports typing.TypedDict from Python 3.12
|
|
# For older versions, only typing_extensions.TypedDict is supported.
|
|
if sys.version_info >= (3, 12):
|
|
return True
|
|
return False
|