Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyComplex;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};
use pyo3::types::{
PyByteArray, PyBytes, PyComplex, PyDict, PyFrozenSet, PyIterator, PyList, PyModule, PySet, PyString, PyTuple,
};

use pyo3::IntoPyObjectExt;
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
Expand Down Expand Up @@ -241,6 +242,10 @@ pub(crate) fn infer_to_python_known(
let complex_str = type_serializers::complex::complex_to_str(v);
complex_str.into_py_any(py)?
}
ObType::Module => {
let v = value.downcast::<PyModule>()?;
v.name()?.into()
}
ObType::Path => value.str()?.into_py_any(py)?,
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.unbind(),
ObType::Unknown => {
Expand Down Expand Up @@ -554,6 +559,11 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
.map_err(py_err_se_err)?;
serializer.serialize_str(&s)
}
ObType::Module => {
let v = value.downcast::<PyModule>().map_err(py_err_se_err)?;
let s: PyBackedStr = v.name().and_then(|name| name.extract()).map_err(py_err_se_err)?;
serializer.serialize_str(&s)
}
ObType::Unknown => {
if let Some(fallback) = extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
Expand Down Expand Up @@ -678,6 +688,10 @@ pub(crate) fn infer_json_key_known<'a>(
let v = key.downcast::<PyComplex>()?;
Ok(type_serializers::complex::complex_to_str(v).into())
}
ObType::Module => {
let v = key.downcast::<PyModule>()?;
Ok(Cow::Owned(v.name()?.to_string_lossy().into_owned()))
}
ObType::Pattern => Ok(Cow::Owned(
key.getattr(intern!(key.py(), "pattern"))?
.str()?
Expand Down
12 changes: 9 additions & 3 deletions src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pyo3::prelude::*;
use pyo3::sync::PyOnceLock;
use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyComplex, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt,
PyIterator, PyList, PyNone, PySet, PyString, PyTime, PyTuple, PyType,
PyIterator, PyList, PyModule, PyNone, PySet, PyString, PyTime, PyTuple, PyType,
};
use pyo3::{intern, PyTypeInfo};

Expand Down Expand Up @@ -48,6 +48,7 @@ pub struct ObTypeLookup {
pattern_object: Py<PyAny>,
// uuid type
uuid_object: Py<PyAny>,
module_object: usize,
complex: usize,
}

Expand Down Expand Up @@ -87,6 +88,7 @@ impl ObTypeLookup {
path_object: py.import("pathlib").unwrap().getattr("Path").unwrap().unbind(),
pattern_object: py.import("re").unwrap().getattr("Pattern").unwrap().unbind(),
uuid_object: py.import("uuid").unwrap().getattr("UUID").unwrap().unbind(),
module_object: PyModule::type_object_raw(py) as usize,
complex: PyComplex::type_object_raw(py) as usize,
}
}
Expand Down Expand Up @@ -157,8 +159,9 @@ impl ObTypeLookup {
ObType::Path => self.path_object.as_ptr() as usize == ob_type,
ObType::Pattern => self.path_object.as_ptr() as usize == ob_type,
ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type,
ObType::Unknown => false,
ObType::Complex => self.complex == ob_type,
ObType::Module => self.module_object == ob_type,
ObType::Unknown => false,
};

if ans {
Expand Down Expand Up @@ -241,6 +244,8 @@ impl ObTypeLookup {
ObType::Complex
} else if ob_type == self.uuid_object.as_ptr() as usize {
ObType::Uuid
} else if ob_type == self.module_object {
ObType::Module
} else if is_pydantic_serializable(op_value) {
ObType::PydanticSerializable
} else if is_dataclass(op_value) {
Expand Down Expand Up @@ -414,9 +419,10 @@ pub enum ObType {
Pattern,
// Uuid
Uuid,
Complex,
Module,
// unknown type
Unknown,
Complex,
}

impl PartialEq for ObType {
Expand Down
24 changes: 21 additions & 3 deletions tests/serializers/test_infer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,46 @@
import os
from enum import Enum

from pydantic_core import SchemaSerializer, core_schema


# serializing enum calls methods in serializers::infer
def test_infer_to_python():
def test_infer_complex_to_python():
class MyEnum(Enum):
complex_ = complex(1, 2)

v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
assert v.to_python(MyEnum.complex_, mode='json') == '1+2j'


def test_infer_serialize():
def test_infer_complex_serialize():
class MyEnum(Enum):
complex_ = complex(1, 2)

v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
assert v.to_json(MyEnum.complex_) == b'"1+2j"'


def test_infer_json_key():
def test_infer_complex_json_key():
class MyEnum(Enum):
complex_ = {complex(1, 2): 1}

v = SchemaSerializer(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))
assert v.to_json(MyEnum.complex_) == b'{"1+2j":1}'


def test_infer_module_type():
v = SchemaSerializer(core_schema.any_schema())
assert v.to_python(os) is os
assert v.to_json(os).decode('utf-8') == '"os"'
assert v.to_python(os, serialize_as_any=True) is os
assert v.to_json(os, serialize_as_any=True).decode('utf-8') == '"os"'

v_as_key = SchemaSerializer(
core_schema.dict_schema(keys_schema=core_schema.any_schema(), values_schema=core_schema.any_schema())
)

assert v_as_key.to_python({os: 1}) == {os: 1}
assert v_as_key.to_json({os: 1}).decode('utf-8') == '{"os":1}'
assert v_as_key.to_python({os: 1}, serialize_as_any=True) == {os: 1}
assert v_as_key.to_json({os: 1}, serialize_as_any=True).decode('utf-8') == '{"os":1}'
Loading