Skip to content

Commit 7529366

Browse files
committed
Cleanup Pydantic models and handle deprecation warnings
1 parent 42944f3 commit 7529366

File tree

17 files changed

+125
-358
lines changed

17 files changed

+125
-358
lines changed

libs/core/langchain_core/_api/deprecation.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
cast,
2424
)
2525

26+
from pydantic.fields import FieldInfo
27+
from pydantic.v1.fields import FieldInfo as FieldInfoV1
2628
from typing_extensions import ParamSpec
2729

2830
from langchain_core._api.internal import is_caller_internal
@@ -152,10 +154,6 @@ def deprecate(
152154
_package: str = package,
153155
) -> T:
154156
"""Implementation of the decorator returned by `deprecated`."""
155-
from langchain_core.utils.pydantic import ( # type: ignore[attr-defined]
156-
FieldInfoV1,
157-
FieldInfoV2,
158-
)
159157

160158
def emit_warning() -> None:
161159
"""Emit the warning."""
@@ -249,7 +247,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
249247
),
250248
)
251249

252-
elif isinstance(obj, FieldInfoV2):
250+
elif isinstance(obj, FieldInfo):
253251
wrapped = None
254252
if not _obj_type:
255253
_obj_type = "attribute"
@@ -261,7 +259,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
261259
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: # noqa: ARG001
262260
return cast(
263261
"T",
264-
FieldInfoV2(
262+
FieldInfo(
265263
default=obj.default,
266264
default_factory=obj.default_factory,
267265
description=new_doc,

libs/core/langchain_core/language_models/chat_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def _get_invocation_params(
584584
stop: Optional[list[str]] = None,
585585
**kwargs: Any,
586586
) -> dict:
587-
params = self.dict()
587+
params = self.model_dump()
588588
params["stop"] = stop
589589
return {**params, **kwargs}
590590

@@ -1245,7 +1245,7 @@ def _llm_type(self) -> str:
12451245
"""Return type of chat model."""
12461246

12471247
@override
1248-
def dict(self, **kwargs: Any) -> dict:
1248+
def model_dump(self, **kwargs: Any) -> dict:
12491249
"""Return a dictionary of the LLM."""
12501250
starter_dict = dict(self._identifying_params)
12511251
starter_dict["_type"] = self._llm_type

libs/core/langchain_core/language_models/llms.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def stream(
528528
else:
529529
prompt = self._convert_input(input).to_string()
530530
config = ensure_config(config)
531-
params = self.dict()
531+
params = self.model_dump()
532532
params["stop"] = stop
533533
params = {**params, **kwargs}
534534
options = {"stop": stop}
@@ -598,7 +598,7 @@ async def astream(
598598

599599
prompt = self._convert_input(input).to_string()
600600
config = ensure_config(config)
601-
params = self.dict()
601+
params = self.model_dump()
602602
params["stop"] = stop
603603
params = {**params, **kwargs}
604604
options = {"stop": stop}
@@ -941,7 +941,7 @@ def generate(
941941
] * len(prompts)
942942
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
943943
run_ids_list = self._get_run_ids_list(run_id, prompts)
944-
params = self.dict()
944+
params = self.model_dump()
945945
params["stop"] = stop
946946
options = {"stop": stop}
947947
(
@@ -1193,7 +1193,7 @@ async def agenerate(
11931193
] * len(prompts)
11941194
run_name_list = [cast("Optional[str]", run_name)] * len(prompts)
11951195
run_ids_list = self._get_run_ids_list(run_id, prompts)
1196-
params = self.dict()
1196+
params = self.model_dump()
11971197
params["stop"] = stop
11981198
options = {"stop": stop}
11991199
(
@@ -1400,7 +1400,7 @@ def _llm_type(self) -> str:
14001400
"""Return type of llm."""
14011401

14021402
@override
1403-
def dict(self, **kwargs: Any) -> dict:
1403+
def model_dump(self, **kwargs: Any) -> dict:
14041404
"""Return a dictionary of the LLM."""
14051405
starter_dict = dict(self._identifying_params)
14061406
starter_dict["_type"] = self._llm_type
@@ -1427,7 +1427,7 @@ def save(self, file_path: Union[Path, str]) -> None:
14271427
directory_path.mkdir(parents=True, exist_ok=True)
14281428

14291429
# Fetch dictionary to save
1430-
prompt_dict = self.dict()
1430+
prompt_dict = self.model_dump()
14311431

14321432
if save_path.suffix == ".json":
14331433
with save_path.open("w") as f:

libs/core/langchain_core/output_parsers/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,9 @@ def _type(self) -> str:
324324
)
325325
raise NotImplementedError(msg)
326326

327-
def dict(self, **kwargs: Any) -> dict:
327+
def model_dump(self, **kwargs: Any) -> dict:
328328
"""Return dictionary representation of output parser."""
329-
output_parser_dict = super().dict(**kwargs)
329+
output_parser_dict = super().model_dump(**kwargs)
330330
with contextlib.suppress(NotImplementedError):
331331
output_parser_dict["_type"] = self._type
332332
return output_parser_dict

libs/core/langchain_core/output_parsers/json.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jsonpatch # type: ignore[import-untyped]
1010
import pydantic
1111
from pydantic import SkipValidation
12+
from pydantic.v1 import BaseModel
1213

1314
from langchain_core.exceptions import OutputParserException
1415
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
@@ -19,16 +20,9 @@
1920
parse_json_markdown,
2021
parse_partial_json,
2122
)
22-
from langchain_core.utils.pydantic import IS_PYDANTIC_V1
2323

24-
if IS_PYDANTIC_V1:
25-
PydanticBaseModel = pydantic.BaseModel
26-
27-
else:
28-
from pydantic.v1 import BaseModel
29-
30-
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
31-
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
24+
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
25+
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel]
3226

3327
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
3428

libs/core/langchain_core/output_parsers/openai_functions.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import jsonpatch # type: ignore[import-untyped]
99
from pydantic import BaseModel, model_validator
10+
from pydantic.v1 import BaseModel as BaseModelV1
1011
from typing_extensions import override
1112

1213
from langchain_core.exceptions import OutputParserException
@@ -274,10 +275,13 @@ def parse_result(self, result: list[Generation], *, partial: bool = False) -> An
274275
pydantic_schema = self.pydantic_schema[fn_name]
275276
else:
276277
pydantic_schema = self.pydantic_schema
277-
if hasattr(pydantic_schema, "model_validate_json"):
278+
if isinstance(pydantic_schema, BaseModel):
278279
pydantic_args = pydantic_schema.model_validate_json(_args)
279-
else:
280+
elif isinstance(pydantic_schema, BaseModelV1):
280281
pydantic_args = pydantic_schema.parse_raw(_args)
282+
else:
283+
msg = f"Unsupported pydantic schema: {pydantic_schema}"
284+
raise ValueError(msg)
281285
return pydantic_args
282286

283287

libs/core/langchain_core/output_parsers/pydantic.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from langchain_core.output_parsers import JsonOutputParser
1212
from langchain_core.outputs import Generation
1313
from langchain_core.utils.pydantic import (
14-
IS_PYDANTIC_V2,
1514
PydanticBaseModel,
1615
TBaseModel,
1716
)
@@ -24,22 +23,16 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
2423
"""The pydantic model to parse."""
2524

2625
def _parse_obj(self, obj: dict) -> TBaseModel:
27-
if IS_PYDANTIC_V2:
28-
try:
29-
if issubclass(self.pydantic_object, pydantic.BaseModel):
30-
return self.pydantic_object.model_validate(obj)
31-
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
32-
return self.pydantic_object.parse_obj(obj)
33-
msg = f"Unsupported model version for PydanticOutputParser: \
34-
{self.pydantic_object.__class__}"
35-
raise OutputParserException(msg)
36-
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
37-
raise self._parser_exception(e, obj) from e
38-
else: # pydantic v1
39-
try:
26+
try:
27+
if issubclass(self.pydantic_object, pydantic.BaseModel):
28+
return self.pydantic_object.model_validate(obj)
29+
if issubclass(self.pydantic_object, pydantic.v1.BaseModel):
4030
return self.pydantic_object.parse_obj(obj)
41-
except pydantic.ValidationError as e:
42-
raise self._parser_exception(e, obj) from e
31+
msg = f"Unsupported model version for PydanticOutputParser: \
32+
{self.pydantic_object.__class__}"
33+
raise OutputParserException(msg)
34+
except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
35+
raise self._parser_exception(e, obj) from e
4336

4437
def _parser_exception(
4538
self, e: Exception, json_object: dict

libs/core/langchain_core/output_parsers/transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
125125
chunk_gen = ChatGenerationChunk(message=chunk)
126126
elif isinstance(chunk, BaseMessage):
127127
chunk_gen = ChatGenerationChunk(
128-
message=BaseMessageChunk(**chunk.dict())
128+
message=BaseMessageChunk(**chunk.model_dump())
129129
)
130130
else:
131131
chunk_gen = GenerationChunk(text=chunk)
@@ -151,7 +151,7 @@ async def _atransform(
151151
chunk_gen = ChatGenerationChunk(message=chunk)
152152
elif isinstance(chunk, BaseMessage):
153153
chunk_gen = ChatGenerationChunk(
154-
message=BaseMessageChunk(**chunk.dict())
154+
message=BaseMessageChunk(**chunk.model_dump())
155155
)
156156
else:
157157
chunk_gen = GenerationChunk(text=chunk)

libs/core/langchain_core/prompts/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _prompt_type(self) -> str:
331331
"""Return the prompt type key."""
332332
raise NotImplementedError
333333

334-
def dict(self, **kwargs: Any) -> dict:
334+
def model_dump(self, **kwargs: Any) -> dict:
335335
"""Return dictionary representation of prompt.
336336
337337
Args:
@@ -369,7 +369,7 @@ def save(self, file_path: Union[Path, str]) -> None:
369369
raise ValueError(msg)
370370

371371
# Fetch dictionary to save
372-
prompt_dict = self.dict()
372+
prompt_dict = self.model_dump()
373373
if "_type" not in prompt_dict:
374374
msg = f"Prompt {self} does not support saving."
375375
raise NotImplementedError(msg)

libs/core/langchain_core/pydantic_v1/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
try:
1919
from pydantic.v1 import * # noqa: F403
2020
except ImportError:
21-
from pydantic import * # type: ignore[assignment,no-redef] # noqa: F403
21+
from pydantic import * # type: ignore[assignment,deprecated,no-redef] # noqa: F403
2222

2323

2424
try:

libs/core/langchain_core/tools/base.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,13 @@ def _parse_input(
543543
)
544544
raise ValueError(msg)
545545
key_ = next(iter(get_fields(input_args).keys()))
546-
if hasattr(input_args, "model_validate"):
546+
if issubclass(input_args, BaseModel):
547547
input_args.model_validate({key_: tool_input})
548-
else:
548+
elif issubclass(input_args, BaseModelV1):
549549
input_args.parse_obj({key_: tool_input})
550+
else:
551+
msg = f"args_schema must be a Pydantic BaseModel, got {input_args}"
552+
raise TypeError(msg)
550553
return tool_input
551554
if input_args is not None:
552555
if isinstance(input_args, dict):

libs/core/langchain_core/tracers/schemas.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from __future__ import annotations
44

5-
import datetime
65
import warnings
6+
from datetime import datetime, timezone
77
from typing import Any, Optional
88
from uuid import UUID
99

@@ -32,7 +32,7 @@ def RunTypeEnum() -> type[RunTypeEnumDep]: # noqa: N802
3232
class TracerSessionV1Base(BaseModelV1):
3333
"""Base class for TracerSessionV1."""
3434

35-
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
35+
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
3636
name: Optional[str] = None
3737
extra: Optional[dict[str, Any]] = None
3838

@@ -69,8 +69,8 @@ class BaseRun(BaseModelV1):
6969

7070
uuid: str
7171
parent_uuid: Optional[str] = None
72-
start_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
73-
end_time: datetime.datetime = FieldV1(default_factory=datetime.datetime.utcnow)
72+
start_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
73+
end_time: datetime = FieldV1(default_factory=lambda: datetime.now(timezone.utc))
7474
extra: Optional[dict[str, Any]] = None
7575
execution_order: int
7676
child_execution_order: int

0 commit comments

Comments
 (0)