Skip to content

Commit d524b7e

Browse files
fix(parsing): parse extra field types
1 parent f3a9a92 commit d524b7e

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

src/openai/_models.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,18 @@ def construct( # pyright: ignore[reportIncompatibleMethodOverride]
233233
else:
234234
fields_values[name] = field_get_default(field)
235235

236+
extra_field_type = _get_extra_fields_type(__cls)
237+
236238
_extra = {}
237239
for key, value in values.items():
238240
if key not in model_fields:
241+
parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
242+
239243
if PYDANTIC_V2:
240-
_extra[key] = value
244+
_extra[key] = parsed
241245
else:
242246
_fields_set.add(key)
243-
fields_values[key] = value
247+
fields_values[key] = parsed
244248

245249
object.__setattr__(m, "__dict__", fields_values)
246250

@@ -395,6 +399,23 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
395399
return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
396400

397401

402+
def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
403+
if not PYDANTIC_V2:
404+
# TODO
405+
return None
406+
407+
schema = cls.__pydantic_core_schema__
408+
if schema["type"] == "model":
409+
fields = schema["schema"]
410+
if fields["type"] == "model-fields":
411+
extras = fields.get("extras_schema")
412+
if extras and "cls" in extras:
413+
# mypy can't narrow the type
414+
return extras["cls"] # type: ignore[no-any-return]
415+
416+
return None
417+
418+
398419
def is_basemodel(type_: type) -> bool:
399420
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
400421
if is_union(type_):

tests/test_models.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, List, Union, Optional, cast
2+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast
33
from datetime import datetime, timezone
44
from typing_extensions import Literal, Annotated, TypeAliasType
55

@@ -934,3 +934,30 @@ class Type2(BaseModel):
934934
)
935935
assert isinstance(model, Type1)
936936
assert isinstance(model.value, InnerType2)
937+
938+
939+
@pytest.mark.skipif(not PYDANTIC_V2, reason="this is only supported in pydantic v2 for now")
940+
def test_extra_properties() -> None:
941+
class Item(BaseModel):
942+
prop: int
943+
944+
class Model(BaseModel):
945+
__pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride]
946+
947+
other: str
948+
949+
if TYPE_CHECKING:
950+
951+
def __getattr__(self, attr: str) -> Item: ...
952+
953+
model = construct_type(
954+
type_=Model,
955+
value={
956+
"a": {"prop": 1},
957+
"other": "foo",
958+
},
959+
)
960+
assert isinstance(model, Model)
961+
assert model.a.prop == 1
962+
assert isinstance(model.a, Item)
963+
assert model.other == "foo"

0 commit comments

Comments
 (0)