From 75bde54b871c8422234ef4f37561e25da114e471 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 31 Jan 2025 09:54:33 -0500 Subject: [PATCH 01/11] using prebuilt validators --- src/validators/mod.rs | 19 ++++++++++-- src/validators/prebuilt.rs | 63 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 src/validators/prebuilt.rs diff --git a/src/validators/mod.rs b/src/validators/mod.rs index bc1851ac5..8404d226c 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::sync::Arc; use enum_dispatch::enum_dispatch; use jiter::{PartialMode, StringCacheMode}; @@ -52,6 +53,7 @@ mod model; mod model_fields; mod none; mod nullable; +mod prebuilt; mod set; mod string; mod time; @@ -105,7 +107,7 @@ impl PySome { #[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaValidator { - validator: CombinedValidator, + validator: Arc, definitions: Definitions, // References to the Python schema and config objects are saved to enable // reconstructing the object for cloudpickle support (see `__reduce__`). @@ -146,7 +148,7 @@ impl SchemaValidator { .get_as(intern!(py, "cache_strings"))? .unwrap_or(StringCacheMode::All); Ok(Self { - validator, + validator: Arc::new(validator), definitions, py_schema, py_config, @@ -455,7 +457,7 @@ impl<'py> SelfValidator<'py> { }; let definitions = definitions_builder.finish()?; Ok(SchemaValidator { - validator, + validator: Arc::new(validator), definitions, py_schema: py.None(), py_config: None, @@ -517,6 +519,15 @@ pub fn build_validator( let dict = schema.downcast::()?; let type_: Bound<'_, PyString> = dict.get_as_req(intern!(schema.py(), "type"))?; let type_ = type_.to_str()?; + + // if we have a SchemaValidator on the type already, use it + if matches!(type_, "model" | "dataclass" | "typed-dict") { + match prebuilt::PrebuiltValidator::build(dict, config, definitions) { + Ok(prebuilt_validator) => return Ok(prebuilt_validator), + Err(_) => (), + } + } + validator_match!( type_, dict, @@ -763,6 +774,8 @@ pub enum CombinedValidator { // input dependent JsonOrPython(json_or_python::JsonOrPython), Complex(complex::ComplexValidator), + // uses a reference to an existing SchemaValidator to reduce memory usage + Prebuilt(prebuilt::PrebuiltValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs new file mode 100644 index 000000000..7c1559ccd --- /dev/null +++ b/src/validators/prebuilt.rs @@ -0,0 +1,63 @@ +use std::sync::Arc; + +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyType}; +use pyo3::exceptions::PyValueError; + +use crate::errors::ValResult; +use crate::input::Input; +use crate::tools::SchemaDict; + +use super::ValidationState; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator, SchemaValidator}; + +#[derive(Debug)] +pub struct PrebuiltValidator { + validator: Arc, + name: String, +} + +impl BuildValidator for PrebuiltValidator { + const EXPECTED_TYPE: &'static str = "prebuilt"; + + fn build( + schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; + + if class + .getattr(intern!(py, "__pydantic_complete__")) + .map_or(false, |pc| pc.extract::().unwrap_or(false)) + { + if let Ok(prebuilt_validator) = class.getattr(intern!(py, "__pydantic_validator__")) { + let schema_validator: PyRef = prebuilt_validator.extract()?; + let combined_validator: Arc = schema_validator.validator.clone(); + let name = class.getattr(intern!(py, "__name__"))?.extract()?; + + return Ok( Self { validator: combined_validator, name}.into()) + } + } + Err(PyValueError::new_err("Prebuilt validator not found.")) + } +} + +impl_py_gc_traverse!(PrebuiltValidator { validator }); + +impl Validator for PrebuiltValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + self.validator.validate(py, input, state) + } + + fn get_name(&self) -> &str { + &self.name + } +} From 5835c5972d09ce096a56ac6a44df5e2550523aa3 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 31 Jan 2025 09:59:19 -0500 Subject: [PATCH 02/11] linting --- src/validators/prebuilt.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index 7c1559ccd..9ca055197 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -1,16 +1,16 @@ use std::sync::Arc; +use pyo3::exceptions::PyValueError; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyType}; -use pyo3::exceptions::PyValueError; use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator, SchemaValidator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, SchemaValidator, Validator}; #[derive(Debug)] pub struct PrebuiltValidator { @@ -38,7 +38,11 @@ impl BuildValidator for PrebuiltValidator { let combined_validator: Arc = schema_validator.validator.clone(); let name = class.getattr(intern!(py, "__name__"))?.extract()?; - return Ok( Self { validator: combined_validator, name}.into()) + return Ok(Self { + validator: combined_validator, + name, + } + .into()); } } Err(PyValueError::new_err("Prebuilt validator not found.")) From e5245b5961a73d5e7d1bbf006c7f1fdc605ba36f Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 31 Jan 2025 10:30:24 -0500 Subject: [PATCH 03/11] linting and inheritance fix --- src/validators/mod.rs | 5 ++--- src/validators/prebuilt.rs | 39 ++++++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 8404d226c..2c1562229 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -522,9 +522,8 @@ pub fn build_validator( // if we have a SchemaValidator on the type already, use it if matches!(type_, "model" | "dataclass" | "typed-dict") { - match prebuilt::PrebuiltValidator::build(dict, config, definitions) { - Ok(prebuilt_validator) => return Ok(prebuilt_validator), - Err(_) => (), + if let Ok(prebuilt_validator) = prebuilt::PrebuiltValidator::build(dict, config, definitions) { + return Ok(prebuilt_validator); } } diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index 9ca055197..d4e9289dd 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use pyo3::exceptions::PyValueError; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyType}; +use pyo3::types::{PyBool, PyDict, PyType}; use crate::errors::ValResult; use crate::input::Input; @@ -29,23 +29,30 @@ impl BuildValidator for PrebuiltValidator { let py = schema.py(); let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; - if class - .getattr(intern!(py, "__pydantic_complete__")) - .map_or(false, |pc| pc.extract::().unwrap_or(false)) - { - if let Ok(prebuilt_validator) = class.getattr(intern!(py, "__pydantic_validator__")) { - let schema_validator: PyRef = prebuilt_validator.extract()?; - let combined_validator: Arc = schema_validator.validator.clone(); - let name = class.getattr(intern!(py, "__name__"))?.extract()?; + // note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) + // because we don't want to fetch prebuilt validators from parent classes + let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?; - return Ok(Self { - validator: combined_validator, - name, - } - .into()); - } + // Ensure the class has completed its Pydantic validation setup + let is_complete: bool = class_dict + .get_as_req::>(intern!(py, "__pydantic_complete__")) + .is_ok_and(|b| b.extract().unwrap_or(false)); + + if !is_complete { + return Err(PyValueError::new_err("Prebuilt validator not found.")); + } + + // Retrieve the prebuilt validator if available + let prebuilt_validator: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_validator__"))?; + let schema_validator: PyRef = prebuilt_validator.extract()?; + let combined_validator: Arc = schema_validator.validator.clone(); + let name: String = class.getattr(intern!(py, "__name__"))?.extract()?; + + Ok(Self { + validator: combined_validator, + name, } - Err(PyValueError::new_err("Prebuilt validator not found.")) + .into()) } } From 980fa20ca3b5ff831e64867b7d67f99e736dec45 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 31 Jan 2025 10:58:38 -0500 Subject: [PATCH 04/11] serializer reuse logic as well --- src/serializers/mod.rs | 6 ++- src/serializers/prebuilt.rs | 93 +++++++++++++++++++++++++++++++++++++ src/serializers/shared.rs | 14 +++++- 3 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 src/serializers/prebuilt.rs diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 8fc1a4230..ed524e09b 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -1,5 +1,6 @@ use std::fmt::Debug; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyTuple, PyType}; @@ -24,6 +25,7 @@ mod fields; mod filter; mod infer; mod ob_type; +mod prebuilt; pub mod ser; mod shared; mod type_serializers; @@ -37,7 +39,7 @@ pub enum WarningsArg { #[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaSerializer { - serializer: CombinedSerializer, + serializer: Arc, definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, @@ -92,7 +94,7 @@ impl SchemaSerializer { let mut definitions_builder = DefinitionsBuilder::new(); let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { - serializer, + serializer: Arc::new(serializer), definitions: definitions_builder.finish()?, expected_json_size: AtomicUsize::new(1024), config: SerializationConfig::from_config(config)?, diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs new file mode 100644 index 000000000..5b81985e0 --- /dev/null +++ b/src/serializers/prebuilt.rs @@ -0,0 +1,93 @@ +use std::borrow::Cow; +use std::sync::Arc; + +use pyo3::exceptions::PyValueError; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyBool, PyDict, PyType}; + +use crate::definitions::DefinitionsBuilder; +use crate::tools::SchemaDict; +use crate::SchemaSerializer; + +use super::extra::Extra; +use super::shared::{BuildSerializer, CombinedSerializer, TypeSerializer}; + +#[derive(Debug)] +pub struct PrebuiltSerializer { + serializer: Arc, +} + +impl BuildSerializer for PrebuiltSerializer { + const EXPECTED_TYPE: &'static str = "prebuilt"; + + fn build( + schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; + + // note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) + // because we don't want to fetch prebuilt serializers from parent classes + let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?; + + // Ensure the class has completed its Pydantic setup + let is_complete: bool = class_dict + .get_as_req::>(intern!(py, "__pydantic_complete__")) + .is_ok_and(|b| b.extract().unwrap_or(false)); + + if !is_complete { + return Err(PyValueError::new_err("Prebuilt serializer not found.")); + } + + // Retrieve the prebuilt validator if available + let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_serializer__"))?; + let schema_serializer: PyRef = prebuilt_serializer.extract()?; + let combined_serializer: Arc = schema_serializer.serializer.clone(); + + Ok(Self { + serializer: combined_serializer, + } + .into()) + } +} + +impl_py_gc_traverse!(PrebuiltSerializer { serializer }); + +impl TypeSerializer for PrebuiltSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + self.serializer.to_python(value, include, exclude, extra) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + self.serializer.json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + self.serializer + .serde_serialize(value, serializer, include, exclude, extra) + } + + fn get_name(&self) -> &str { + self.serializer.get_name() + } + + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() + } +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index f37810657..741ce19c2 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -84,6 +84,8 @@ combined_serializer! { Function: super::type_serializers::function::FunctionPlainSerializer; FunctionWrap: super::type_serializers::function::FunctionWrapSerializer; Fields: super::fields::GeneralFieldsSerializer; + // prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum + Prebuilt: super::prebuilt::PrebuiltSerializer; } // `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer` // but aren't actually used for serialization, e.g. their `build` method must return another serializer @@ -195,7 +197,16 @@ impl CombinedSerializer { } let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?; - Self::find_serializer(type_.to_str()?, schema, config, definitions) + let type_ = type_.to_str()?; + + // if we have a SchemaValidator on the type already, use it + if matches!(type_, "model" | "dataclass" | "typed-dict") { + if let Ok(prebuilt_serializer) = super::prebuilt::PrebuiltSerializer::build(schema, config, definitions) { + return Ok(prebuilt_serializer); + } + } + + Self::find_serializer(type_, schema, config, definitions) } } @@ -219,6 +230,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit), CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit), CombinedSerializer::None(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit), From 8bd3ef27cefc5e49e412e739d650714dc5f65fd7 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Fri, 31 Jan 2025 11:57:36 -0500 Subject: [PATCH 05/11] handling for mappingproxy rather than dict on classes --- src/serializers/prebuilt.rs | 15 ++++++++------- src/validators/prebuilt.rs | 15 ++++++++------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs index 5b81985e0..fd0974c09 100644 --- a/src/serializers/prebuilt.rs +++ b/src/serializers/prebuilt.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use pyo3::exceptions::PyValueError; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyType}; +use pyo3::types::{PyDict, PyType}; use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; @@ -29,13 +29,14 @@ impl BuildSerializer for PrebuiltSerializer { let py = schema.py(); let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; - // note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) - // because we don't want to fetch prebuilt serializers from parent classes - let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?; + // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) + // because we don't want to fetch prebuilt validators from parent classes. + // We don't downcast here because __dict__ on a class is a readonly mappingproxy, + // so we can just leave it as is and do get_item checks. + let class_dict = class.getattr(intern!(py, "__dict__"))?; - // Ensure the class has completed its Pydantic setup let is_complete: bool = class_dict - .get_as_req::>(intern!(py, "__pydantic_complete__")) + .get_item(intern!(py, "__pydantic_complete__")) .is_ok_and(|b| b.extract().unwrap_or(false)); if !is_complete { @@ -43,7 +44,7 @@ impl BuildSerializer for PrebuiltSerializer { } // Retrieve the prebuilt validator if available - let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_serializer__"))?; + let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_item(intern!(py, "__pydantic_serializer__"))?; let schema_serializer: PyRef = prebuilt_serializer.extract()?; let combined_serializer: Arc = schema_serializer.serializer.clone(); diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index d4e9289dd..eb74c3574 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use pyo3::exceptions::PyValueError; use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyType}; +use pyo3::types::{PyDict, PyType}; use crate::errors::ValResult; use crate::input::Input; @@ -29,13 +29,14 @@ impl BuildValidator for PrebuiltValidator { let py = schema.py(); let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; - // note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) - // because we don't want to fetch prebuilt validators from parent classes - let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?; + // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) + // because we don't want to fetch prebuilt validators from parent classes. + // We don't downcast here because __dict__ on a class is a readonly mappingproxy, + // so we can just leave it as is and do get_item checks. + let class_dict = class.getattr(intern!(py, "__dict__"))?; - // Ensure the class has completed its Pydantic validation setup let is_complete: bool = class_dict - .get_as_req::>(intern!(py, "__pydantic_complete__")) + .get_item(intern!(py, "__pydantic_complete__")) .is_ok_and(|b| b.extract().unwrap_or(false)); if !is_complete { @@ -43,7 +44,7 @@ impl BuildValidator for PrebuiltValidator { } // Retrieve the prebuilt validator if available - let prebuilt_validator: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_validator__"))?; + let prebuilt_validator = class_dict.get_item(intern!(py, "__pydantic_validator__"))?; let schema_validator: PyRef = prebuilt_validator.extract()?; let combined_validator: Arc = schema_validator.validator.clone(); let name: String = class.getattr(intern!(py, "__name__"))?.extract()?; From 8e754c4e4b8bb6cb526626488e51a256620170da Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 3 Feb 2025 15:41:07 -0500 Subject: [PATCH 06/11] using more simple py approach --- src/serializers/mod.rs | 5 ++--- src/serializers/prebuilt.rs | 24 ++++++++++++------------ src/validators/mod.rs | 7 +++---- src/validators/prebuilt.rs | 17 +++++------------ 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index ed524e09b..e652a75de 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -1,6 +1,5 @@ use std::fmt::Debug; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyTuple, PyType}; @@ -39,7 +38,7 @@ pub enum WarningsArg { #[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaSerializer { - serializer: Arc, + serializer: CombinedSerializer, definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, @@ -94,7 +93,7 @@ impl SchemaSerializer { let mut definitions_builder = DefinitionsBuilder::new(); let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { - serializer: Arc::new(serializer), + serializer, definitions: definitions_builder.finish()?, expected_json_size: AtomicUsize::new(1024), config: SerializationConfig::from_config(config)?, diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs index fd0974c09..238e262a4 100644 --- a/src/serializers/prebuilt.rs +++ b/src/serializers/prebuilt.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::sync::Arc; use pyo3::exceptions::PyValueError; use pyo3::intern; @@ -15,7 +14,7 @@ use super::shared::{BuildSerializer, CombinedSerializer, TypeSerializer}; #[derive(Debug)] pub struct PrebuiltSerializer { - serializer: Arc, + serializer: Py, } impl BuildSerializer for PrebuiltSerializer { @@ -45,13 +44,9 @@ impl BuildSerializer for PrebuiltSerializer { // Retrieve the prebuilt validator if available let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_item(intern!(py, "__pydantic_serializer__"))?; - let schema_serializer: PyRef = prebuilt_serializer.extract()?; - let combined_serializer: Arc = schema_serializer.serializer.clone(); + let serializer: Py = prebuilt_serializer.extract()?; - Ok(Self { - serializer: combined_serializer, - } - .into()) + Ok(Self { serializer }.into()) } } @@ -65,11 +60,14 @@ impl TypeSerializer for PrebuiltSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - self.serializer.to_python(value, include, exclude, extra) + self.serializer + .get() + .serializer + .to_python(value, include, exclude, extra) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - self.serializer.json_key(key, extra) + self.serializer.get().serializer.json_key(key, extra) } fn serde_serialize( @@ -81,14 +79,16 @@ impl TypeSerializer for PrebuiltSerializer { extra: &Extra, ) -> Result { self.serializer + .get() + .serializer .serde_serialize(value, serializer, include, exclude, extra) } fn get_name(&self) -> &str { - self.serializer.get_name() + self.serializer.get().serializer.get_name() } fn retry_with_lax_check(&self) -> bool { - self.serializer.retry_with_lax_check() + self.serializer.get().serializer.retry_with_lax_check() } } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 2c1562229..5f2fa8d70 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -1,5 +1,4 @@ use std::fmt::Debug; -use std::sync::Arc; use enum_dispatch::enum_dispatch; use jiter::{PartialMode, StringCacheMode}; @@ -107,7 +106,7 @@ impl PySome { #[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaValidator { - validator: Arc, + validator: CombinedValidator, definitions: Definitions, // References to the Python schema and config objects are saved to enable // reconstructing the object for cloudpickle support (see `__reduce__`). @@ -148,7 +147,7 @@ impl SchemaValidator { .get_as(intern!(py, "cache_strings"))? .unwrap_or(StringCacheMode::All); Ok(Self { - validator: Arc::new(validator), + validator, definitions, py_schema, py_config, @@ -457,7 +456,7 @@ impl<'py> SelfValidator<'py> { }; let definitions = definitions_builder.finish()?; Ok(SchemaValidator { - validator: Arc::new(validator), + validator, definitions, py_schema: py.None(), py_config: None, diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index eb74c3574..8900e5093 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use pyo3::exceptions::PyValueError; use pyo3::intern; use pyo3::prelude::*; @@ -14,7 +12,7 @@ use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, SchemaValidat #[derive(Debug)] pub struct PrebuiltValidator { - validator: Arc, + schema_validator: Py, name: String, } @@ -45,19 +43,14 @@ impl BuildValidator for PrebuiltValidator { // Retrieve the prebuilt validator if available let prebuilt_validator = class_dict.get_item(intern!(py, "__pydantic_validator__"))?; - let schema_validator: PyRef = prebuilt_validator.extract()?; - let combined_validator: Arc = schema_validator.validator.clone(); + let schema_validator = prebuilt_validator.extract::>()?; let name: String = class.getattr(intern!(py, "__name__"))?.extract()?; - Ok(Self { - validator: combined_validator, - name, - } - .into()) + Ok(Self { schema_validator, name }.into()) } } -impl_py_gc_traverse!(PrebuiltValidator { validator }); +impl_py_gc_traverse!(PrebuiltValidator { schema_validator }); impl Validator for PrebuiltValidator { fn validate<'py>( @@ -66,7 +59,7 @@ impl Validator for PrebuiltValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - self.validator.validate(py, input, state) + self.schema_validator.get().validator.validate(py, input, state) } fn get_name(&self) -> &str { From 57671e8d2009041b4e6ae27a2a095260a30f09aa Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 4 Feb 2025 10:37:23 -0500 Subject: [PATCH 07/11] edge case for generic dataclasses --- src/serializers/shared.rs | 6 +++++- src/validators/mod.rs | 9 +++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 741ce19c2..a989ae6da 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -200,7 +200,11 @@ impl CombinedSerializer { let type_ = type_.to_str()?; // if we have a SchemaValidator on the type already, use it - if matches!(type_, "model" | "dataclass" | "typed-dict") { + // however, we don't want to use a prebuilt validator for dataclasses if we have a generic_origin + // because __pydantic_serializer__ is cached on the unparametrized dataclass + if matches!(type_, "model" | "typed-dict") + || matches!(type_, "dataclass") && !schema.contains(intern!(py, "generic_origin"))? + { if let Ok(prebuilt_serializer) = super::prebuilt::PrebuiltSerializer::build(schema, config, definitions) { return Ok(prebuilt_serializer); } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 5f2fa8d70..db3bdf506 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -516,11 +516,16 @@ pub fn build_validator( definitions: &mut DefinitionsBuilder, ) -> PyResult { let dict = schema.downcast::()?; - let type_: Bound<'_, PyString> = dict.get_as_req(intern!(schema.py(), "type"))?; + let py = schema.py(); + let type_: Bound<'_, PyString> = dict.get_as_req(intern!(py, "type"))?; let type_ = type_.to_str()?; // if we have a SchemaValidator on the type already, use it - if matches!(type_, "model" | "dataclass" | "typed-dict") { + // however, we don't want to use a prebuilt validator for dataclasses if we have a generic_origin + // because __pydantic_validator__ is cached on the unparametrized dataclass + if matches!(type_, "model" | "typed-dict") + || matches!(type_, "dataclass") && !dict.contains(intern!(py, "generic_origin"))? + { if let Ok(prebuilt_validator) = prebuilt::PrebuiltValidator::build(dict, config, definitions) { return Ok(prebuilt_validator); } From a8a7c5ee2a009ce004e07e06b8813c7b0c449d9e Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 5 Feb 2025 10:22:26 -0500 Subject: [PATCH 08/11] fix name function for validator --- src/validators/prebuilt.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index 8900e5093..da3aea161 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -13,7 +13,6 @@ use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, SchemaValidat #[derive(Debug)] pub struct PrebuiltValidator { schema_validator: Py, - name: String, } impl BuildValidator for PrebuiltValidator { @@ -44,9 +43,8 @@ impl BuildValidator for PrebuiltValidator { // Retrieve the prebuilt validator if available let prebuilt_validator = class_dict.get_item(intern!(py, "__pydantic_validator__"))?; let schema_validator = prebuilt_validator.extract::>()?; - let name: String = class.getattr(intern!(py, "__name__"))?.extract()?; - Ok(Self { schema_validator, name }.into()) + Ok(Self { schema_validator }.into()) } } @@ -63,6 +61,6 @@ impl Validator for PrebuiltValidator { } fn get_name(&self) -> &str { - &self.name + self.schema_validator.get().validator.get_name() } } From 78f18b3bed539ee52682c0ebca346eacab92765a Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 5 Feb 2025 10:35:51 -0500 Subject: [PATCH 09/11] restructuring recs from david --- src/serializers/prebuilt.rs | 28 +++++++++++++++------------- src/serializers/shared.rs | 10 ++-------- src/validators/mod.rs | 10 ++-------- src/validators/prebuilt.rs | 27 +++++++++++++++------------ 4 files changed, 34 insertions(+), 41 deletions(-) diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs index 238e262a4..2debf9432 100644 --- a/src/serializers/prebuilt.rs +++ b/src/serializers/prebuilt.rs @@ -1,31 +1,33 @@ use std::borrow::Cow; -use pyo3::exceptions::PyValueError; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyType}; -use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use crate::SchemaSerializer; use super::extra::Extra; -use super::shared::{BuildSerializer, CombinedSerializer, TypeSerializer}; +use super::shared::{CombinedSerializer, TypeSerializer}; #[derive(Debug)] pub struct PrebuiltSerializer { serializer: Py, } -impl BuildSerializer for PrebuiltSerializer { - const EXPECTED_TYPE: &'static str = "prebuilt"; - - fn build( - schema: &Bound<'_, PyDict>, - _config: Option<&Bound<'_, PyDict>>, - _definitions: &mut DefinitionsBuilder, - ) -> PyResult { +impl PrebuiltSerializer { + pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult> { let py = schema.py(); + + // we can only use prebuilt serializeres from models, typed dicts, and dataclasses + // however, we don't want to use a prebuilt serializer for dataclasses if we have a generic_origin + // because __pydantic_serializer__ is cached on the unparametrized dataclass + if !matches!(type_, "model" | "typed-dict") + || matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))? + { + return Ok(None); + } + let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) @@ -39,14 +41,14 @@ impl BuildSerializer for PrebuiltSerializer { .is_ok_and(|b| b.extract().unwrap_or(false)); if !is_complete { - return Err(PyValueError::new_err("Prebuilt serializer not found.")); + return Ok(None); } // Retrieve the prebuilt validator if available let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_item(intern!(py, "__pydantic_serializer__"))?; let serializer: Py = prebuilt_serializer.extract()?; - Ok(Self { serializer }.into()) + Ok(Some(Self { serializer }.into())) } } diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index a989ae6da..f7a018749 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -200,14 +200,8 @@ impl CombinedSerializer { let type_ = type_.to_str()?; // if we have a SchemaValidator on the type already, use it - // however, we don't want to use a prebuilt validator for dataclasses if we have a generic_origin - // because __pydantic_serializer__ is cached on the unparametrized dataclass - if matches!(type_, "model" | "typed-dict") - || matches!(type_, "dataclass") && !schema.contains(intern!(py, "generic_origin"))? - { - if let Ok(prebuilt_serializer) = super::prebuilt::PrebuiltSerializer::build(schema, config, definitions) { - return Ok(prebuilt_serializer); - } + if let Ok(Some(prebuilt_serializer)) = super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) { + return Ok(prebuilt_serializer); } Self::find_serializer(type_, schema, config, definitions) diff --git a/src/validators/mod.rs b/src/validators/mod.rs index db3bdf506..75f39df29 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -521,14 +521,8 @@ pub fn build_validator( let type_ = type_.to_str()?; // if we have a SchemaValidator on the type already, use it - // however, we don't want to use a prebuilt validator for dataclasses if we have a generic_origin - // because __pydantic_validator__ is cached on the unparametrized dataclass - if matches!(type_, "model" | "typed-dict") - || matches!(type_, "dataclass") && !dict.contains(intern!(py, "generic_origin"))? - { - if let Ok(prebuilt_validator) = prebuilt::PrebuiltValidator::build(dict, config, definitions) { - return Ok(prebuilt_validator); - } + if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) { + return Ok(prebuilt_validator); } validator_match!( diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index da3aea161..c3fffe334 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -1,4 +1,3 @@ -use pyo3::exceptions::PyValueError; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyType}; @@ -8,22 +7,26 @@ use crate::input::Input; use crate::tools::SchemaDict; use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, SchemaValidator, Validator}; +use super::{CombinedValidator, SchemaValidator, Validator}; #[derive(Debug)] pub struct PrebuiltValidator { schema_validator: Py, } -impl BuildValidator for PrebuiltValidator { - const EXPECTED_TYPE: &'static str = "prebuilt"; - - fn build( - schema: &Bound<'_, PyDict>, - _config: Option<&Bound<'_, PyDict>>, - _definitions: &mut DefinitionsBuilder, - ) -> PyResult { +impl PrebuiltValidator { + pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult> { let py = schema.py(); + + // we can only use prebuilt validators from models, typed dicts, and dataclasses + // however, we don't want to use a prebuilt validator for dataclasses if we have a generic_origin + // because __pydantic_validator__ is cached on the unparametrized dataclass + if !matches!(type_, "model" | "typed-dict") + || matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))? + { + return Ok(None); + } + let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) @@ -37,14 +40,14 @@ impl BuildValidator for PrebuiltValidator { .is_ok_and(|b| b.extract().unwrap_or(false)); if !is_complete { - return Err(PyValueError::new_err("Prebuilt validator not found.")); + return Ok(None); } // Retrieve the prebuilt validator if available let prebuilt_validator = class_dict.get_item(intern!(py, "__pydantic_validator__"))?; let schema_validator = prebuilt_validator.extract::>()?; - Ok(Self { schema_validator }.into()) + Ok(Some(Self { schema_validator }.into())) } } From bbaef68a4187091bfdb357fa61b435b8003fb964 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 5 Feb 2025 10:52:31 -0500 Subject: [PATCH 10/11] confirming prebuilt usage via a test --- tests/validators/test_prebuilt.py | 48 +++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/validators/test_prebuilt.py diff --git a/tests/validators/test_prebuilt.py b/tests/validators/test_prebuilt.py new file mode 100644 index 000000000..9cd5aa325 --- /dev/null +++ b/tests/validators/test_prebuilt.py @@ -0,0 +1,48 @@ +from pydantic_core import SchemaSerializer, SchemaValidator, core_schema + + +def test_prebuilt_val_and_ser_used() -> None: + class InnerModel: + x: int + + inner_schema = core_schema.model_schema( + InnerModel, + schema=core_schema.model_fields_schema( + {'x': core_schema.model_field(schema=core_schema.int_schema())}, + ), + ) + + inner_schema_validator = SchemaValidator(inner_schema) + inner_schema_serializer = SchemaSerializer(inner_schema) + InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue] + InnerModel.__pydantic_validator__ = inner_schema_validator # pyright: ignore[reportAttributeAccessIssue] + InnerModel.__pydantic_serializer__ = inner_schema_serializer # pyright: ignore[reportAttributeAccessIssue] + + class OuterModel: + inner: InnerModel + + outer_schema = core_schema.model_schema( + OuterModel, + schema=core_schema.model_fields_schema( + { + 'inner': core_schema.model_field( + schema=core_schema.model_schema( + InnerModel, + schema=core_schema.model_fields_schema( + # note, we use str schema here even though that's incorrect + # in order to verify that the prebuilt validator is used + # off of InnerModel with the correct int schema, not this str schema + {'x': core_schema.model_field(schema=core_schema.str_schema())}, + ), + ) + ) + } + ), + ) + + outer_validator = SchemaValidator(outer_schema) + outer_serializer = SchemaSerializer(outer_schema) + + result = outer_validator.validate_python({'inner': {'x': 1}}) + assert result.inner.x == 1 + assert outer_serializer.to_python(result) == {'inner': {'x': 1}} From 99136a5fc3d819c1cdb808e7fb3d638bd08246e0 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 5 Feb 2025 11:27:20 -0500 Subject: [PATCH 11/11] refactor common extraction logic --- src/common/mod.rs | 1 + src/common/prebuilt.rs | 43 +++++++++++++++++++ src/serializers/prebuilt.rs | 56 +++++++------------------ src/validators/prebuilt.rs | 42 ++++--------------- tests/{validators => }/test_prebuilt.py | 0 5 files changed, 65 insertions(+), 77 deletions(-) create mode 100644 src/common/prebuilt.rs rename tests/{validators => }/test_prebuilt.py (100%) diff --git a/src/common/mod.rs b/src/common/mod.rs index 11f2e1ece..47c0a0349 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1 +1,2 @@ +pub(crate) mod prebuilt; pub(crate) mod union; diff --git a/src/common/prebuilt.rs b/src/common/prebuilt.rs new file mode 100644 index 000000000..961123691 --- /dev/null +++ b/src/common/prebuilt.rs @@ -0,0 +1,43 @@ +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyAny, PyDict, PyType}; + +use crate::tools::SchemaDict; + +pub fn get_prebuilt( + type_: &str, + schema: &Bound<'_, PyDict>, + prebuilt_attr_name: &str, + extractor: impl FnOnce(Bound<'_, PyAny>) -> PyResult, +) -> PyResult> { + let py = schema.py(); + + // we can only use prebuilt validators / serializers from models, typed dicts, and dataclasses + // however, we don't want to use a prebuilt structure from dataclasses if we have a generic_origin + // because the validator / serializer is cached on the unparametrized dataclass + if !matches!(type_, "model" | "typed-dict") + || matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))? + { + return Ok(None); + } + + let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; + + // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) + // because we don't want to fetch prebuilt validators from parent classes. + // We don't downcast here because __dict__ on a class is a readonly mappingproxy, + // so we can just leave it as is and do get_item checks. + let class_dict = class.getattr(intern!(py, "__dict__"))?; + + let is_complete: bool = class_dict + .get_item(intern!(py, "__pydantic_complete__")) + .is_ok_and(|b| b.extract().unwrap_or(false)); + + if !is_complete { + return Ok(None); + } + + // Retrieve the prebuilt validator / serializer if available + let prebuilt: Bound<'_, PyAny> = class_dict.get_item(prebuilt_attr_name)?; + extractor(prebuilt).map(Some) +} diff --git a/src/serializers/prebuilt.rs b/src/serializers/prebuilt.rs index 2debf9432..33d197d9b 100644 --- a/src/serializers/prebuilt.rs +++ b/src/serializers/prebuilt.rs @@ -1,10 +1,9 @@ use std::borrow::Cow; -use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyType}; +use pyo3::types::PyDict; -use crate::tools::SchemaDict; +use crate::common::prebuilt::get_prebuilt; use crate::SchemaSerializer; use super::extra::Extra; @@ -12,47 +11,20 @@ use super::shared::{CombinedSerializer, TypeSerializer}; #[derive(Debug)] pub struct PrebuiltSerializer { - serializer: Py, + schema_serializer: Py, } impl PrebuiltSerializer { pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult> { - let py = schema.py(); - - // we can only use prebuilt serializeres from models, typed dicts, and dataclasses - // however, we don't want to use a prebuilt serializer for dataclasses if we have a generic_origin - // because __pydantic_serializer__ is cached on the unparametrized dataclass - if !matches!(type_, "model" | "typed-dict") - || matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))? - { - return Ok(None); - } - - let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; - - // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) - // because we don't want to fetch prebuilt validators from parent classes. - // We don't downcast here because __dict__ on a class is a readonly mappingproxy, - // so we can just leave it as is and do get_item checks. - let class_dict = class.getattr(intern!(py, "__dict__"))?; - - let is_complete: bool = class_dict - .get_item(intern!(py, "__pydantic_complete__")) - .is_ok_and(|b| b.extract().unwrap_or(false)); - - if !is_complete { - return Ok(None); - } - - // Retrieve the prebuilt validator if available - let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_item(intern!(py, "__pydantic_serializer__"))?; - let serializer: Py = prebuilt_serializer.extract()?; - - Ok(Some(Self { serializer }.into())) + get_prebuilt(type_, schema, "__pydantic_serializer__", |py_any| { + py_any + .extract::>() + .map(|schema_serializer| Self { schema_serializer }.into()) + }) } } -impl_py_gc_traverse!(PrebuiltSerializer { serializer }); +impl_py_gc_traverse!(PrebuiltSerializer { schema_serializer }); impl TypeSerializer for PrebuiltSerializer { fn to_python( @@ -62,14 +34,14 @@ impl TypeSerializer for PrebuiltSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - self.serializer + self.schema_serializer .get() .serializer .to_python(value, include, exclude, extra) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - self.serializer.get().serializer.json_key(key, extra) + self.schema_serializer.get().serializer.json_key(key, extra) } fn serde_serialize( @@ -80,17 +52,17 @@ impl TypeSerializer for PrebuiltSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - self.serializer + self.schema_serializer .get() .serializer .serde_serialize(value, serializer, include, exclude, extra) } fn get_name(&self) -> &str { - self.serializer.get().serializer.get_name() + self.schema_serializer.get().serializer.get_name() } fn retry_with_lax_check(&self) -> bool { - self.serializer.get().serializer.retry_with_lax_check() + self.schema_serializer.get().serializer.retry_with_lax_check() } } diff --git a/src/validators/prebuilt.rs b/src/validators/prebuilt.rs index c3fffe334..c17acb9f9 100644 --- a/src/validators/prebuilt.rs +++ b/src/validators/prebuilt.rs @@ -1,10 +1,9 @@ -use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyType}; +use pyo3::types::PyDict; +use crate::common::prebuilt::get_prebuilt; use crate::errors::ValResult; use crate::input::Input; -use crate::tools::SchemaDict; use super::ValidationState; use super::{CombinedValidator, SchemaValidator, Validator}; @@ -16,38 +15,11 @@ pub struct PrebuiltValidator { impl PrebuiltValidator { pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult> { - let py = schema.py(); - - // we can only use prebuilt validators from models, typed dicts, and dataclasses - // however, we don't want to use a prebuilt validator for dataclasses if we have a generic_origin - // because __pydantic_validator__ is cached on the unparametrized dataclass - if !matches!(type_, "model" | "typed-dict") - || matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))? - { - return Ok(None); - } - - let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; - - // Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr) - // because we don't want to fetch prebuilt validators from parent classes. - // We don't downcast here because __dict__ on a class is a readonly mappingproxy, - // so we can just leave it as is and do get_item checks. - let class_dict = class.getattr(intern!(py, "__dict__"))?; - - let is_complete: bool = class_dict - .get_item(intern!(py, "__pydantic_complete__")) - .is_ok_and(|b| b.extract().unwrap_or(false)); - - if !is_complete { - return Ok(None); - } - - // Retrieve the prebuilt validator if available - let prebuilt_validator = class_dict.get_item(intern!(py, "__pydantic_validator__"))?; - let schema_validator = prebuilt_validator.extract::>()?; - - Ok(Some(Self { schema_validator }.into())) + get_prebuilt(type_, schema, "__pydantic_validator__", |py_any| { + py_any + .extract::>() + .map(|schema_validator| Self { schema_validator }.into()) + }) } } diff --git a/tests/validators/test_prebuilt.py b/tests/test_prebuilt.py similarity index 100% rename from tests/validators/test_prebuilt.py rename to tests/test_prebuilt.py