Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyTuple, PyType};
Expand All @@ -24,6 +25,7 @@ mod fields;
mod filter;
mod infer;
mod ob_type;
mod prebuilt;
pub mod ser;
mod shared;
mod type_serializers;
Expand All @@ -37,7 +39,7 @@ pub enum WarningsArg {
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
serializer: Arc<CombinedSerializer>,
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
Expand Down Expand Up @@ -92,7 +94,7 @@ impl SchemaSerializer {
let mut definitions_builder = DefinitionsBuilder::new();
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
serializer: Arc::new(serializer),
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
Expand Down
93 changes: 93 additions & 0 deletions src/serializers/prebuilt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::borrow::Cow;
use std::sync::Arc;

use pyo3::exceptions::PyValueError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PyType};

use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;
use crate::SchemaSerializer;

use super::extra::Extra;
use super::shared::{BuildSerializer, CombinedSerializer, TypeSerializer};

#[derive(Debug)]
pub struct PrebuiltSerializer {
serializer: Arc<CombinedSerializer>,
}

impl BuildSerializer for PrebuiltSerializer {
const EXPECTED_TYPE: &'static str = "prebuilt";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;

// note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
// because we don't want to fetch prebuilt serializers from parent classes
let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?;

// Ensure the class has completed its Pydantic setup
let is_complete: bool = class_dict
.get_as_req::<Bound<'_, PyBool>>(intern!(py, "__pydantic_complete__"))
.is_ok_and(|b| b.extract().unwrap_or(false));

if !is_complete {
return Err(PyValueError::new_err("Prebuilt serializer not found."));
}

// Retrieve the prebuilt validator if available
let prebuilt_serializer: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_serializer__"))?;
let schema_serializer: PyRef<SchemaSerializer> = prebuilt_serializer.extract()?;
let combined_serializer: Arc<CombinedSerializer> = schema_serializer.serializer.clone();

Ok(Self {
serializer: combined_serializer,
}
.into())
}
}

impl_py_gc_traverse!(PrebuiltSerializer { serializer });

impl TypeSerializer for PrebuiltSerializer {
fn to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
self.serializer.to_python(value, include, exclude, extra)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.serializer.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
&self,
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
self.serializer
.serde_serialize(value, serializer, include, exclude, extra)
}

fn get_name(&self) -> &str {
self.serializer.get_name()
}

fn retry_with_lax_check(&self) -> bool {
self.serializer.retry_with_lax_check()
}
}
14 changes: 13 additions & 1 deletion src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ combined_serializer! {
Function: super::type_serializers::function::FunctionPlainSerializer;
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
Fields: super::fields::GeneralFieldsSerializer;
// prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum
Prebuilt: super::prebuilt::PrebuiltSerializer;
}
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
Expand Down Expand Up @@ -195,7 +197,16 @@ impl CombinedSerializer {
}

let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?;
Self::find_serializer(type_.to_str()?, schema, config, definitions)
let type_ = type_.to_str()?;

// if we have a SchemaValidator on the type already, use it
if matches!(type_, "model" | "dataclass" | "typed-dict") {
if let Ok(prebuilt_serializer) = super::prebuilt::PrebuiltSerializer::build(schema, config, definitions) {
return Ok(prebuilt_serializer);
}
}

Self::find_serializer(type_, schema, config, definitions)
}
}

Expand All @@ -219,6 +230,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::None(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),
Expand Down
18 changes: 15 additions & 3 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::sync::Arc;

use enum_dispatch::enum_dispatch;
use jiter::{PartialMode, StringCacheMode};
Expand Down Expand Up @@ -52,6 +53,7 @@ mod model;
mod model_fields;
mod none;
mod nullable;
mod prebuilt;
mod set;
mod string;
mod time;
Expand Down Expand Up @@ -105,7 +107,7 @@ impl PySome {
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaValidator {
validator: CombinedValidator,
validator: Arc<CombinedValidator>,
definitions: Definitions<CombinedValidator>,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for cloudpickle support (see `__reduce__`).
Expand Down Expand Up @@ -146,7 +148,7 @@ impl SchemaValidator {
.get_as(intern!(py, "cache_strings"))?
.unwrap_or(StringCacheMode::All);
Ok(Self {
validator,
validator: Arc::new(validator),
definitions,
py_schema,
py_config,
Expand Down Expand Up @@ -455,7 +457,7 @@ impl<'py> SelfValidator<'py> {
};
let definitions = definitions_builder.finish()?;
Ok(SchemaValidator {
validator,
validator: Arc::new(validator),
definitions,
py_schema: py.None(),
py_config: None,
Expand Down Expand Up @@ -517,6 +519,14 @@ pub fn build_validator(
let dict = schema.downcast::<PyDict>()?;
let type_: Bound<'_, PyString> = dict.get_as_req(intern!(schema.py(), "type"))?;
let type_ = type_.to_str()?;

// if we have a SchemaValidator on the type already, use it
if matches!(type_, "model" | "dataclass" | "typed-dict") {
if let Ok(prebuilt_validator) = prebuilt::PrebuiltValidator::build(dict, config, definitions) {
return Ok(prebuilt_validator);
}
}

validator_match!(
type_,
dict,
Expand Down Expand Up @@ -763,6 +773,8 @@ pub enum CombinedValidator {
// input dependent
JsonOrPython(json_or_python::JsonOrPython),
Complex(complex::ComplexValidator),
// uses a reference to an existing SchemaValidator to reduce memory usage
Prebuilt(prebuilt::PrebuiltValidator),
}

/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
Expand Down
74 changes: 74 additions & 0 deletions src/validators/prebuilt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use std::sync::Arc;

use pyo3::exceptions::PyValueError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PyType};

use crate::errors::ValResult;
use crate::input::Input;
use crate::tools::SchemaDict;

use super::ValidationState;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, SchemaValidator, Validator};

#[derive(Debug)]
pub struct PrebuiltValidator {
validator: Arc<CombinedValidator>,
name: String,
}

impl BuildValidator for PrebuiltValidator {
const EXPECTED_TYPE: &'static str = "prebuilt";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let py = schema.py();
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;

// note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
// because we don't want to fetch prebuilt validators from parent classes
let class_dict: Bound<'_, PyDict> = class.getattr(intern!(py, "__dict__"))?.extract()?;

// Ensure the class has completed its Pydantic validation setup
let is_complete: bool = class_dict
.get_as_req::<Bound<'_, PyBool>>(intern!(py, "__pydantic_complete__"))
.is_ok_and(|b| b.extract().unwrap_or(false));

if !is_complete {
return Err(PyValueError::new_err("Prebuilt validator not found."));
}

// Retrieve the prebuilt validator if available
let prebuilt_validator: Bound<'_, PyAny> = class_dict.get_as_req(intern!(py, "__pydantic_validator__"))?;
let schema_validator: PyRef<SchemaValidator> = prebuilt_validator.extract()?;
let combined_validator: Arc<CombinedValidator> = schema_validator.validator.clone();
let name: String = class.getattr(intern!(py, "__name__"))?.extract()?;

Ok(Self {
validator: combined_validator,
name,
}
.into())
}
}

impl_py_gc_traverse!(PrebuiltValidator { validator });

impl Validator for PrebuiltValidator {
fn validate<'py>(
&self,
py: Python<'py>,
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
self.validator.validate(py, input, state)
}

fn get_name(&self) -> &str {
&self.name
}
}