Skip to content

Commit 0961627

Browse files
committed
WIP
1 parent bb67044 commit 0961627

File tree

7 files changed

+212
-1
lines changed

7 files changed

+212
-1
lines changed

python/pydantic_core/core_schema.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,24 @@ def uuid_schema(
13841384
)
13851385

13861386

1387+
class NestedModelSchema(TypedDict, total=False):
1388+
type: Required[Literal['nested-model']]
1389+
model: Required[Type[Any]]
1390+
metadata: Any
1391+
1392+
1393+
def nested_model_schema(
1394+
*,
1395+
model: Type[Any],
1396+
metadata: Any = None,
1397+
) -> NestedModelSchema:
1398+
return _dict_not_none(
1399+
type='nested-model',
1400+
model=model,
1401+
metadata=metadata,
1402+
)
1403+
1404+
13871405
class IncExSeqSerSchema(TypedDict, total=False):
13881406
type: Required[Literal['include-exclude-sequence']]
13891407
include: Set[int]
@@ -3796,6 +3814,7 @@ def definition_reference_schema(
37963814
DefinitionsSchema,
37973815
DefinitionReferenceSchema,
37983816
UuidSchema,
3817+
NestedModelSchema,
37993818
]
38003819
elif False:
38013820
CoreSchema: TypeAlias = Mapping[str, Any]
@@ -3851,6 +3870,7 @@ def definition_reference_schema(
38513870
'definitions',
38523871
'definition-ref',
38533872
'uuid',
3873+
'nested-model',
38543874
]
38553875

38563876
CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']

src/serializers/shared.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ combined_serializer! {
142142
Enum: super::type_serializers::enum_::EnumSerializer;
143143
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
144144
Tuple: super::type_serializers::tuple::TupleSerializer;
145+
NestedModel: super::type_serializers::nested_model::NestedModelSerializer;
145146
}
146147
}
147148

@@ -251,6 +252,7 @@ impl PyGcTraverse for CombinedSerializer {
251252
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
252253
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
253254
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
255+
CombinedSerializer::NestedModel(inner) => inner.py_gc_traverse(visit),
254256
}
255257
}
256258
}

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub mod json_or_python;
1515
pub mod list;
1616
pub mod literal;
1717
pub mod model;
18+
pub mod nested_model;
1819
pub mod nullable;
1920
pub mod other;
2021
pub mod set_frozenset;
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use std::borrow::Cow;
2+
3+
use pyo3::{
4+
intern,
5+
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
6+
Bound, Py, PyAny, PyObject, PyResult, Python,
7+
};
8+
9+
use crate::{
10+
definitions::DefinitionsBuilder,
11+
serializers::{
12+
shared::{BuildSerializer, TypeSerializer},
13+
CombinedSerializer, Extra,
14+
},
15+
SchemaSerializer,
16+
};
17+
18+
#[derive(Debug, Clone)]
19+
pub struct NestedModelSerializer {
20+
model: Py<PyType>,
21+
name: String,
22+
}
23+
24+
impl_py_gc_traverse!(NestedModelSerializer { model });
25+
26+
impl BuildSerializer for NestedModelSerializer {
27+
const EXPECTED_TYPE: &'static str = "nested-model";
28+
29+
fn build(
30+
schema: &Bound<'_, PyDict>,
31+
_config: Option<&Bound<'_, PyDict>>,
32+
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
33+
) -> PyResult<CombinedSerializer> {
34+
let py = schema.py();
35+
let model = schema
36+
.get_item(intern!(py, "model"))?
37+
.expect("Invalid core schema for `nested-model` type")
38+
.downcast::<PyType>()
39+
.expect("Invalid core schema for `nested-model` type")
40+
.clone();
41+
42+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
43+
44+
Ok(CombinedSerializer::NestedModel(NestedModelSerializer {
45+
model: model.clone().unbind(),
46+
name,
47+
}))
48+
}
49+
}
50+
51+
impl NestedModelSerializer {
52+
fn nested_serializer<'py>(&self, py: Python<'py>) -> Bound<'py, SchemaSerializer> {
53+
self.model
54+
.bind(py)
55+
.call_method(intern!(py, "model_rebuild"), (), None)
56+
.unwrap();
57+
58+
self.model
59+
.getattr(py, intern!(py, "__pydantic_serializer__"))
60+
.unwrap()
61+
.downcast_bound::<SchemaSerializer>(py)
62+
.unwrap()
63+
.clone()
64+
65+
// crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
66+
// .downcast_bound::<SchemaSerializer>(py)
67+
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
68+
// .expect("Cached validator was not a `SchemaSerializer`")
69+
// .clone()
70+
}
71+
}
72+
73+
impl TypeSerializer for NestedModelSerializer {
74+
fn to_python(
75+
&self,
76+
value: &Bound<'_, PyAny>,
77+
include: Option<&Bound<'_, PyAny>>,
78+
exclude: Option<&Bound<'_, PyAny>>,
79+
extra: &Extra,
80+
) -> PyResult<PyObject> {
81+
self.nested_serializer(value.py())
82+
.get()
83+
.serializer
84+
.to_python(value, include, exclude, extra)
85+
}
86+
87+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
88+
self.nested_serializer(key.py()).get().serializer.json_key(key, extra)
89+
}
90+
91+
fn serde_serialize<S: serde::ser::Serializer>(
92+
&self,
93+
value: &Bound<'_, PyAny>,
94+
serializer: S,
95+
include: Option<&Bound<'_, PyAny>>,
96+
exclude: Option<&Bound<'_, PyAny>>,
97+
extra: &Extra,
98+
) -> Result<S::Ok, S::Error> {
99+
self.nested_serializer(value.py())
100+
.get()
101+
.serializer
102+
.serde_serialize(value, serializer, include, exclude, extra)
103+
}
104+
105+
fn get_name(&self) -> &str {
106+
&self.name
107+
}
108+
}

src/validators/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ mod list;
4848
mod literal;
4949
mod model;
5050
mod model_fields;
51+
mod nested_model;
5152
mod none;
5253
mod nullable;
5354
mod set;
@@ -582,6 +583,7 @@ pub fn build_validator(
582583
// recursive (self-referencing) models
583584
definitions::DefinitionRefValidator,
584585
definitions::DefinitionsValidatorBuilder,
586+
nested_model::NestedModelValidator,
585587
)
586588
}
587589

@@ -735,6 +737,8 @@ pub enum CombinedValidator {
735737
DefinitionRef(definitions::DefinitionRefValidator),
736738
// input dependent
737739
JsonOrPython(json_or_python::JsonOrPython),
740+
// Schema for a model inside of another schema
741+
NestedModel(nested_model::NestedModelValidator),
738742
}
739743

740744
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,

src/validators/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator {
7777

7878
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
7979
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
80-
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
80+
let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?;
8181
let name = class.getattr(intern!(py, "__name__"))?.extract()?;
8282

8383
Ok(Self {

src/validators/nested_model.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use pyo3::{
2+
intern,
3+
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
4+
Bound, Py, PyObject, PyResult, Python,
5+
};
6+
7+
use crate::{definitions::DefinitionsBuilder, errors::ValResult, input::Input};
8+
9+
use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};
10+
11+
#[derive(Debug, Clone)]
12+
pub struct NestedModelValidator {
13+
model: Py<PyType>,
14+
name: String,
15+
}
16+
17+
impl_py_gc_traverse!(NestedModelValidator { model });
18+
19+
impl BuildValidator for NestedModelValidator {
20+
const EXPECTED_TYPE: &'static str = "nested-model";
21+
22+
fn build(
23+
schema: &Bound<'_, PyDict>,
24+
_config: Option<&Bound<'_, PyDict>>,
25+
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
26+
) -> PyResult<super::CombinedValidator> {
27+
let py = schema.py();
28+
let model = schema
29+
.get_item(intern!(py, "model"))?
30+
.expect("Invalid core schema for `nested-model` type")
31+
.downcast::<PyType>()
32+
.expect("Invalid core schema for `nested-model` type")
33+
.clone();
34+
35+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
36+
37+
Ok(CombinedValidator::NestedModel(NestedModelValidator {
38+
model: model.clone().unbind(),
39+
name,
40+
}))
41+
}
42+
}
43+
44+
impl Validator for NestedModelValidator {
45+
fn validate<'py>(
46+
&self,
47+
py: Python<'py>,
48+
input: &(impl Input<'py> + ?Sized),
49+
state: &mut ValidationState<'_, 'py>,
50+
) -> ValResult<PyObject> {
51+
self.model
52+
.bind(py)
53+
.call_method(intern!(py, "model_rebuild"), (), None)
54+
.unwrap();
55+
56+
let validator = self
57+
.model
58+
.getattr(py, intern!(py, "__pydantic_validator__"))
59+
.unwrap()
60+
.downcast_bound::<SchemaValidator>(py)
61+
.unwrap()
62+
.clone();
63+
64+
// let validator = crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
65+
// .downcast_bound::<SchemaValidator>(py)
66+
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
67+
// .expect("Cached validator was not a `SchemaValidator`")
68+
// .clone();
69+
70+
validator.get().validator.validate(py, input, state)
71+
}
72+
73+
fn get_name(&self) -> &str {
74+
&self.name
75+
}
76+
}

0 commit comments

Comments
 (0)