Skip to content

core: Cleanup Pydantic models and handle deprecation warnings #30799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions libs/core/langchain_core/_api/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
cast,
)

from pydantic.fields import FieldInfo
from pydantic.v1.fields import FieldInfo as FieldInfoV1
from typing_extensions import ParamSpec

from langchain_core._api.internal import is_caller_internal
Expand Down Expand Up @@ -152,10 +154,6 @@ def deprecate(
_package: str = package,
) -> T:
"""Implementation of the decorator returned by `deprecated`."""
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
FieldInfoV1,
FieldInfoV2,
)

def emit_warning() -> None:
"""Emit the warning."""
Expand Down Expand Up @@ -249,7 +247,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
),
)

elif isinstance(obj, FieldInfoV2):
elif isinstance(obj, FieldInfo):
wrapped = None
if not _obj_type:
_obj_type = "attribute"
Expand All @@ -261,7 +259,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
return cast(
"T",
FieldInfoV2(
FieldInfo(
default=obj.default,
default_factory=obj.default_factory,
description=new_doc,
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def _get_invocation_params(
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> dict:
params = self.dict()
params = self.model_dump()
params["stop"] = stop
return {**params, **kwargs}

Expand Down Expand Up @@ -1288,7 +1288,7 @@ def _llm_type(self) -> str:
"""Return type of chat model."""

@override
def dict(self, **kwargs: Any) -> dict:
def model_dump(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
Expand Down
12 changes: 6 additions & 6 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def stream(
else:
prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params = self.model_dump()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
Expand Down Expand Up @@ -598,7 +598,7 @@ async def astream(

prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params = self.model_dump()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
Expand Down Expand Up @@ -941,7 +941,7 @@ def generate(
] * len(prompts)
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params = self.model_dump()
params["stop"] = stop
options = {"stop": stop}
(
Expand Down Expand Up @@ -1193,7 +1193,7 @@ async def agenerate(
] * len(prompts)
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params = self.model_dump()
params["stop"] = stop
options = {"stop": stop}
(
Expand Down Expand Up @@ -1400,7 +1400,7 @@ def _llm_type(self) -> str:
"""Return type of llm."""

@override
def dict(self, **kwargs: Any) -> dict:
def model_dump(self, **kwargs: Any) -> dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
Expand All @@ -1427,7 +1427,7 @@ def save(self, file_path: Union[Path, str]) -> None:
directory_path.mkdir(parents=True, exist_ok=True)

# Fetch dictionary to save
prompt_dict = self.dict()
prompt_dict = self.model_dump()

if save_path.suffix == ".json":
with save_path.open("w") as f:
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/output_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,9 @@ def _type(self) -> str:
)
raise NotImplementedError(msg)

def dict(self, **kwargs: Any) -> dict:
def model_dump(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
output_parser_dict = super().model_dump(**kwargs)
with contextlib.suppress(NotImplementedError):
output_parser_dict["_type"] = self._type
return output_parser_dict
12 changes: 3 additions & 9 deletions libs/core/langchain_core/output_parsers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jsonpatch # type: ignore[import-untyped]
import pydantic
from pydantic import SkipValidation
from pydantic.v1 import BaseModel

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
Expand All @@ -19,16 +20,9 @@
parse_json_markdown,
parse_partial_json,
)
from langchain_core.utils.pydantic import IS_PYDANTIC_V1

if IS_PYDANTIC_V1:
PydanticBaseModel = pydantic.BaseModel

else:
from pydantic.v1 import BaseModel

# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel]

TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)

Expand Down
8 changes: 6 additions & 2 deletions libs/core/langchain_core/output_parsers/openai_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import jsonpatch # type: ignore[import-untyped]
from pydantic import BaseModel, model_validator
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import override

from langchain_core.exceptions import OutputParserException
Expand Down Expand Up @@ -274,10 +275,13 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> An
pydantic_schema = self.pydantic_schema[fn_name]
else:
pydantic_schema = self.pydantic_schema
if hasattr(pydantic_schema, "model_validate_json"):
if issubclass(pydantic_schema, BaseModel):
pydantic_args = pydantic_schema.model_validate_json(_args)
else:
elif issubclass(pydantic_schema, BaseModelV1):
pydantic_args = pydantic_schema.parse_raw(_args)
else:
msg = f"Unsupported pydantic schema: {pydantic_schema}"
raise ValueError(msg)
return pydantic_args


Expand Down
25 changes: 9 additions & 16 deletions libs/core/langchain_core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import (
IS_PYDANTIC_V2,
PydanticBaseModel,
TBaseModel,
)
Expand All @@ -24,22 +23,16 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
"""The pydantic model to parse."""

def _parse_obj(self, obj: dict) -> TBaseModel:
if IS_PYDANTIC_V2:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
msg = f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e
else: # pydantic v1
try:
try:
if issubclass(self.pydantic_object, pydantic.BaseModel):
return self.pydantic_object.model_validate(obj)
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
return self.pydantic_object.parse_obj(obj)
except pydantic.ValidationError as e:
raise self._parser_exception(e, obj) from e
msg = f"Unsupported model version for PydanticOutputParser: \
{self.pydantic_object.__class__}"
raise OutputParserException(msg)
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
raise self._parser_exception(e, obj) from e

def _parser_exception(
self, e: Exception, json_object: dict
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/output_parsers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
message=BaseMessageChunk(**chunk.model_dump())
)
else:
chunk_gen = GenerationChunk(text=chunk)
Expand All @@ -151,7 +151,7 @@ async def _atransform(
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
message=BaseMessageChunk(**chunk.model_dump())
)
else:
chunk_gen = GenerationChunk(text=chunk)
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError

def dict(self, **kwargs: Any) -> dict:
def model_dump(self, **kwargs: Any) -> dict:
"""Return dictionary representation of prompt.

Args:
Expand Down Expand Up @@ -369,7 +369,7 @@ def save(self, file_path: Union[Path, str]) -> None:
raise ValueError(msg)

# Fetch dictionary to save
prompt_dict = self.dict()
prompt_dict = self.model_dump()
if "_type" not in prompt_dict:
msg = f"Prompt {self} does not support saving."
raise NotImplementedError(msg)
Expand Down
19 changes: 2 additions & 17 deletions libs/core/langchain_core/pydantic_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,9 @@

from importlib import metadata

from langchain_core._api.deprecation import warn_deprecated

# Create namespaces for pydantic v1 and v2.
# This code must stay at the top of the file before other modules may
# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules.
#
# This hack is done for the following reasons:
# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since
# both dependencies and dependents may be stuck on either version of v1 or v2.
# * Creating namespaces for pydantic v1 and v2 should allow us to write code that
# unambiguously uses either v1 or v2 API.
# * This change is easier to roll out and roll back.

try:
from pydantic.v1 import * # noqa: F403
except ImportError:
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
from pydantic.v1 import * # noqa: F403

from langchain_core._api.deprecation import warn_deprecated

try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
Expand Down
7 changes: 2 additions & 5 deletions libs/core/langchain_core/pydantic_v1/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Pydantic v1 compatibility shim."""

from langchain_core._api import warn_deprecated
from pydantic.v1.dataclasses import * # noqa: F403

try:
from pydantic.v1.dataclasses import * # noqa: F403
except ImportError:
from pydantic.dataclasses import * # type: ignore[no-redef] # noqa: F403
from langchain_core._api import warn_deprecated

warn_deprecated(
"0.3.0",
Expand Down
7 changes: 2 additions & 5 deletions libs/core/langchain_core/pydantic_v1/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Pydantic v1 compatibility shim."""

from langchain_core._api import warn_deprecated
from pydantic.v1.main import * # noqa: F403

try:
from pydantic.v1.main import * # noqa: F403
except ImportError:
from pydantic.main import * # type: ignore[assignment,no-redef] # noqa: F403
from langchain_core._api import warn_deprecated

warn_deprecated(
"0.3.0",
Expand Down
7 changes: 5 additions & 2 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,13 @@ def _parse_input(
)
raise ValueError(msg)
key_ = next(iter(get_fields(input_args).keys()))
if hasattr(input_args, "model_validate"):
if issubclass(input_args, BaseModel):
input_args.model_validate({key_: tool_input})
else:
elif issubclass(input_args, BaseModelV1):
input_args.parse_obj({key_: tool_input})
else:
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
raise TypeError(msg)
return tool_input
if input_args is not None:
if isinstance(input_args, dict):
Expand Down
8 changes: 4 additions & 4 deletions libs/core/langchain_core/tracers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

import datetime
import warnings
from datetime import datetime, timezone
from typing import Any, Optional
from uuid import UUID

Expand Down Expand Up @@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
class TracerSessionV1Base(BaseModelV1):
"""Base class for TracerSessionV1."""

start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
name: Optional[str] = None
extra: Optional[dict[str, Any]] = None

Expand Down Expand Up @@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):

uuid: str
parent_uuid: Optional[str] = None
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
extra: Optional[dict[str, Any]] = None
execution_order: int
child_execution_order: int
Expand Down
Loading
Loading