diff --git a/src/common/mod.rs b/src/common/mod.rs index 11f2e1ece..47c0a0349 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1 +1,2 @@ +pub(crate) mod prebuilt; pub(crate) mod union; diff --git a/src/common/prebuilt.rs b/src/common/prebuilt.rs new file mode 100644 index 000000000..961123691 --- /dev/null +++ b/src/common/prebuilt.rs @@ -0,0 +1,43 @@ +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyAny, PyDict, PyType}; + +use crate::tools::SchemaDict; + +pub fn get_prebuilt( + type_: &str, + schema: &Bound<'_, PyDict>, + prebuilt_attr_name: &str, + extractor: impl FnOnce(Bound<'_, PyAny>) -> PyResult, +) -> PyResult> { + let py = schema.py(); + + // we can only use prebuilt validators / serializers from models, typed dicts, and dataclasses + // however, we don't want to use a prebuilt structure from dataclasses if we have a generic_origin + // because the validator / serializer is cached on the unparametrized dataclass + if !matches!(type_, "model" | "typed-dict") + || matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))? + { + return Ok(None); + } + + let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; + + // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) + // because we don't want to fetch prebuilt validators from parent classes. + // We don't downcast here because __dict__ on a class is a readonly mappingproxy, + // so we can just leave it as is and do get_item checks. + let class_dict = class.getattr(intern!(py, "__dict__"))?; + + let is_complete: bool = class_dict + .get_item(intern!(py, "__pydantic_complete__")) + .is_ok_and(|b| b.extract().unwrap_or(false)); + + if !is_complete { + return Ok(None); + } + + // Retrieve the prebuilt validator / serializer if available + let prebuilt: Bound<'_, PyAny> = class_dict.get_item(prebuilt_attr_name)?; + extractor(prebuilt).map(Some) +} diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 8fc1a4230..e652a75de 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -24,6 +24,7 @@ mod fields; mod filter; mod infer; mod ob_type; +mod prebuilt; pub mod ser; mod shared; mod type_serializers; diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs new file mode 100644 index 000000000..33d197d9b --- /dev/null +++ b/src/serializers/prebuilt.rs @@ -0,0 +1,68 @@ +use std::borrow::Cow; + +use pyo3::prelude::*; +use pyo3::types::PyDict; + +use crate::common::prebuilt::get_prebuilt; +use crate::SchemaSerializer; + +use super::extra::Extra; +use super::shared::{CombinedSerializer, TypeSerializer}; + +#[derive(Debug)] +pub struct PrebuiltSerializer { + schema_serializer: Py, +} + +impl PrebuiltSerializer { + pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult> { + get_prebuilt(type_, schema, "__pydantic_serializer__", |py_any| { + py_any + .extract::>() + .map(|schema_serializer| Self { schema_serializer }.into()) + }) + } +} + +impl_py_gc_traverse!(PrebuiltSerializer { schema_serializer }); + +impl TypeSerializer for PrebuiltSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + self.schema_serializer + .get() + .serializer + .to_python(value, include, exclude, extra) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + self.schema_serializer.get().serializer.json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + self.schema_serializer + .get() + .serializer + .serde_serialize(value, serializer, include, exclude, extra) + } + + fn get_name(&self) -> &str { + self.schema_serializer.get().serializer.get_name() + } + + fn retry_with_lax_check(&self) -> bool { + self.schema_serializer.get().serializer.retry_with_lax_check() + } +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index f37810657..f7a018749 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -84,6 +84,8 @@ combined_serializer! { Function: super::type_serializers::function::FunctionPlainSerializer; FunctionWrap: super::type_serializers::function::FunctionWrapSerializer; Fields: super::fields::GeneralFieldsSerializer; + // prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum + Prebuilt: super::prebuilt::PrebuiltSerializer; } // `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer` // but aren't actually used for serialization, e.g. their `build` method must return another serializer @@ -195,7 +197,14 @@ impl CombinedSerializer { } let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?; - Self::find_serializer(type_.to_str()?, schema, config, definitions) + let type_ = type_.to_str()?; + + // if we have a SchemaValidator on the type already, use it + if let Ok(Some(prebuilt_serializer)) = super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) { + return Ok(prebuilt_serializer); + } + + Self::find_serializer(type_, schema, config, definitions) } } @@ -219,6 +228,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit), CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit), CombinedSerializer::None(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit), diff --git a/src/validators/mod.rs b/src/validators/mod.rs index bc1851ac5..75f39df29 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -52,6 +52,7 @@ mod model; mod model_fields; mod none; mod nullable; +mod prebuilt; mod set; mod string; mod time; @@ -515,8 +516,15 @@ pub fn build_validator( definitions: &mut DefinitionsBuilder, ) -> PyResult { let dict = schema.downcast::()?; - let type_: Bound<'_, PyString> = dict.get_as_req(intern!(schema.py(), "type"))?; + let py = schema.py(); + let type_: Bound<'_, PyString> = dict.get_as_req(intern!(py, "type"))?; let type_ = type_.to_str()?; + + // if we have a SchemaValidator on the type already, use it + if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) { + return Ok(prebuilt_validator); + } + validator_match!( type_, dict, @@ -763,6 +771,8 @@ pub enum CombinedValidator { // input dependent JsonOrPython(json_or_python::JsonOrPython), Complex(complex::ComplexValidator), + // uses a reference to an existing SchemaValidator to reduce memory usage + Prebuilt(prebuilt::PrebuiltValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs new file mode 100644 index 000000000..c17acb9f9 --- /dev/null +++ b/src/validators/prebuilt.rs @@ -0,0 +1,41 @@ +use pyo3::prelude::*; +use pyo3::types::PyDict; + +use crate::common::prebuilt::get_prebuilt; +use crate::errors::ValResult; +use crate::input::Input; + +use super::ValidationState; +use super::{CombinedValidator, SchemaValidator, Validator}; + +#[derive(Debug)] +pub struct PrebuiltValidator { + schema_validator: Py, +} + +impl PrebuiltValidator { + pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult> { + get_prebuilt(type_, schema, "__pydantic_validator__", |py_any| { + py_any + .extract::>() + .map(|schema_validator| Self { schema_validator }.into()) + }) + } +} + +impl_py_gc_traverse!(PrebuiltValidator { schema_validator }); + +impl Validator for PrebuiltValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + self.schema_validator.get().validator.validate(py, input, state) + } + + fn get_name(&self) -> &str { + self.schema_validator.get().validator.get_name() + } +} diff --git a/tests/test_prebuilt.py b/tests/test_prebuilt.py new file mode 100644 index 000000000..9cd5aa325 --- /dev/null +++ b/tests/test_prebuilt.py @@ -0,0 +1,48 @@ +from pydantic_core import SchemaSerializer, SchemaValidator, core_schema + + +def test_prebuilt_val_and_ser_used() -> None: + class InnerModel: + x: int + + inner_schema = core_schema.model_schema( + InnerModel, + schema=core_schema.model_fields_schema( + {'x': core_schema.model_field(schema=core_schema.int_schema())}, + ), + ) + + inner_schema_validator = SchemaValidator(inner_schema) + inner_schema_serializer = SchemaSerializer(inner_schema) + InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue] + InnerModel.__pydantic_validator__ = inner_schema_validator # pyright: ignore[reportAttributeAccessIssue] + InnerModel.__pydantic_serializer__ = inner_schema_serializer # pyright: ignore[reportAttributeAccessIssue] + + class OuterModel: + inner: InnerModel + + outer_schema = core_schema.model_schema( + OuterModel, + schema=core_schema.model_fields_schema( + { + 'inner': core_schema.model_field( + schema=core_schema.model_schema( + InnerModel, + schema=core_schema.model_fields_schema( + # note, we use str schema here even though that's incorrect + # in order to verify that the prebuilt validator is used + # off of InnerModel with the correct int schema, not this str schema + {'x': core_schema.model_field(schema=core_schema.str_schema())}, + ), + ) + ) + } + ), + ) + + outer_validator = SchemaValidator(outer_schema) + outer_serializer = SchemaSerializer(outer_schema) + + result = outer_validator.validate_python({'inner': {'x': 1}}) + assert result.inner.x == 1 + assert outer_serializer.to_python(result) == {'inner': {'x': 1}}