Skip to content

Commit 7b9da64

Browse files
committed
feat: Allow to define model_factory for mapping when registering mapping
1 parent 0e91f49 commit 7b9da64

File tree

3 files changed

+159
-9
lines changed

3 files changed

+159
-9
lines changed

README.md

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class Address:
149149
number: int
150150
zip_code: int
151151
city: str
152-
152+
153153
class PersonInfo:
154154
def __init__(self, name: str, age: int, address: Address):
155155
self.name = name
@@ -181,6 +181,43 @@ print("Target public_info.address is same as source address: ", address is publi
181181
* [TortoiseORM](https://github.com/tortoise/tortoise-orm)
182182
* [SQLAlchemy](https://www.sqlalchemy.org/)
183183

184+
## Complex mapping support
185+
186+
Support to pass a model factory method or cunstructor when registering a mapping.
187+
This allows for mapping to value objects or other complex types.
188+
189+
```python
190+
class SourceEnum(Enum):
191+
VALUE1 = "value1"
192+
VALUE2 = "value2"
193+
VALUE3 = "value3"
194+
195+
class NameEnum(Enum):
196+
VALUE1 = 1
197+
VALUE2 = 2
198+
VALUE3 = 3
199+
200+
class ValueEnum(Enum):
201+
A = "value1"
202+
B = "value2"
203+
C = "value3"
204+
205+
206+
class ValueObject:
207+
value: str
208+
209+
def __init__(self, value: Union[float, int, Decimal]):
210+
self.value = str(value)
211+
212+
mapper.add(SourceEnum, NameEnum, model_factory=lambda x: NameEnum[x.name])
213+
mapper.map(SourceEnum.VALUE1) # NameEnum.VALUE1
214+
215+
mapper.add(ValueEnum, SourceEnum, model_factory=lambda x: SourceEnum(x.value))
216+
mapper.map(ValueEnum.B) # SourceEnum.VALUE2
217+
218+
mapper.to(ValueObject).map(Decimal("42"), model_factory=ValueObject) # ValueObject(42)
219+
```
220+
184221
## Pydantic/FastAPI Support
185222
Out of the box Pydantic models support:
186223
```python
@@ -273,7 +310,7 @@ class PublicUserInfo(Base):
273310
id = Column(Integer, primary_key=True)
274311
public_name = Column(String)
275312
hobbies = Column(String)
276-
313+
277314
obj = UserInfo(
278315
id=2,
279316
full_name="Danny DeVito",
@@ -304,7 +341,7 @@ class TargetClass:
304341
def __init__(self, **kwargs):
305342
self.name = kwargs["name"]
306343
self.age = kwargs["age"]
307-
344+
308345
@staticmethod
309346
def get_fields(cls):
310347
return ["name", "age"]
@@ -358,7 +395,7 @@ T = TypeVar("T")
358395

359396
def class_has_fields_property(target_cls: Type[T]) -> bool:
360397
return callable(getattr(target_cls, "fields", None))
361-
398+
362399
mapper.add_spec(class_has_fields_property, lambda t: getattr(t, "fields")())
363400

364401
target_obj = mapper.to(TargetClass).map(source_obj)

automapper/mapper.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def map(
6262
skip_none_values: bool = False,
6363
fields_mapping: FieldsMap = None,
6464
use_deepcopy: bool = True,
65+
model_factory: Optional[Callable[[S], T]] = None,
6566
) -> T:
6667
"""Produces output object mapped from source object and custom arguments.
6768
@@ -72,6 +73,9 @@ def map(
7273
Specify dictionary in format {"field_name": value_object}. Defaults to None.
7374
use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object.
7475
Defaults to True.
76+
model_factory (Callable, optional): Custom factory funtion
77+
factory function that is used to create the target_obj. Called with the
78+
source as parameter. Mutually exclusive with fields_mapping. Defaults to None
7579
7680
Raises:
7781
CircularReferenceError: Circular references in `source class` object are not allowed yet.
@@ -86,13 +90,14 @@ def map(
8690
skip_none_values=skip_none_values,
8791
custom_mapping=fields_mapping,
8892
use_deepcopy=use_deepcopy,
93+
model_factory=model_factory,
8994
)
9095

9196

9297
class Mapper:
9398
def __init__(self) -> None:
9499
"""Initializes internal containers"""
95-
self._mappings: Dict[Type[S], Tuple[T, FieldsMap]] = {} # type: ignore [valid-type]
100+
self._mappings: Dict[Type[S], Tuple[T, FieldsMap, Callable[[S], [T]]]] = {} # type: ignore [valid-type]
96101
self._class_specs: Dict[Type[T], SpecFunction[T]] = {} # type: ignore [valid-type]
97102
self._classifier_specs: Dict[ # type: ignore [valid-type]
98103
ClassifierFunction[T], SpecFunction[T]
@@ -147,6 +152,7 @@ def add(
147152
target_cls: Type[T],
148153
override: bool = False,
149154
fields_mapping: FieldsMap = None,
155+
model_factory: Optional[Callable[[S], T]] = None,
150156
) -> None:
151157
"""Adds mapping between object of `source class` to an object of `target class`.
152158
@@ -156,7 +162,12 @@ def add(
156162
override (bool, optional): Override existing `source class` mapping to use new `target class`.
157163
Defaults to False.
158164
fields_mapping (FieldsMap, optional): Custom mapping.
159-
Specify dictionary in format {"field_name": value_object}. Defaults to None.
165+
Specify dictionary in format {"field_name": value_objecture_obj}.
166+
Can take a lamdba funtion as argument, that will get the source_cls
167+
as argument. Defaults to None.
168+
model_factory (Callable, optional): Custom factory funtion
169+
factory function that is used to create the target_obj. Called with the
170+
source as parameter. Mutually exclusive with fields_mapping. Defaults to None
160171
161172
Raises:
162173
DuplicatedRegistrationError: Same mapping for `source class` was added.
@@ -168,7 +179,7 @@ def add(
168179
raise DuplicatedRegistrationError(
169180
f"source_cls {source_cls} was already added for mapping"
170181
)
171-
self._mappings[source_cls] = (target_cls, fields_mapping)
182+
self._mappings[source_cls] = (target_cls, fields_mapping, model_factory)
172183

173184
def map(
174185
self,
@@ -201,7 +212,7 @@ def map(
201212
raise MappingError(f"Missing mapping type for input type {obj_type}")
202213
obj_type_prefix = f"{obj_type.__name__}."
203214

204-
target_cls, target_cls_field_mappings = self._mappings[obj_type]
215+
target_cls, target_cls_field_mappings, target_cls_model_factory= self._mappings[obj_type]
205216

206217
common_fields_mapping = fields_mapping
207218
if target_cls_field_mappings:
@@ -221,13 +232,16 @@ def map(
221232
**fields_mapping,
222233
} # merge two dict into one, fields_mapping has priority
223234

235+
236+
224237
return self._map_common(
225238
obj,
226239
target_cls,
227240
set(),
228241
skip_none_values=skip_none_values,
229242
custom_mapping=common_fields_mapping,
230243
use_deepcopy=use_deepcopy,
244+
model_factory=target_cls_model_factory,
231245
)
232246

233247
def _get_fields(self, target_cls: Type[T]) -> Iterable[str]:
@@ -257,7 +271,7 @@ def _map_subobject(
257271
raise CircularReferenceError()
258272

259273
if type(obj) in self._mappings:
260-
target_cls, _ = self._mappings[type(obj)]
274+
target_cls, _, _ = self._mappings[type(obj)]
261275
result: Any = self._map_common(
262276
obj, target_cls, _visited_stack, skip_none_values=skip_none_values
263277
)
@@ -297,6 +311,7 @@ def _map_common(
297311
skip_none_values: bool = False,
298312
custom_mapping: FieldsMap = None,
299313
use_deepcopy: bool = True,
314+
model_factory: Optional[Callable[[S], T]] = None,
300315
) -> T:
301316
"""Produces output object mapped from source object and custom arguments.
302317
@@ -309,6 +324,9 @@ def _map_common(
309324
Specify dictionary in format {"field_name": value_object}. Defaults to None.
310325
use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object.
311326
Defaults to True.
327+
model_factory (Callable, optional): Custom factory funtion
328+
factory function that is used to create the target_obj. Called with the
329+
source as parameter. Mutually exclusive with fields_mapping. Defaults to None
312330
313331
Raises:
314332
CircularReferenceError: Circular references in `source class` object are not allowed yet.
@@ -320,10 +338,25 @@ def _map_common(
320338

321339
if obj_id in _visited_stack:
322340
raise CircularReferenceError()
341+
342+
target_cls_fields_mapping = None
343+
if type(obj) in self._mappings:
344+
_, target_cls_fields_mapping, a = self._mappings[type(obj)]
345+
346+
if model_factory is not None and target_cls_fields_mapping:
347+
raise ValueError(
348+
"Cannot specify both model_factory and fields_mapping. "
349+
"Use one of them to customize mapping."
350+
)
351+
352+
if model_factory is not None and callable(model_factory):
353+
return model_factory(obj)
354+
323355
_visited_stack.add(obj_id)
324356

325357
target_cls_fields = self._get_fields(target_cls)
326358

359+
327360
mapped_values: Dict[str, Any] = {}
328361
for field_name in target_cls_fields:
329362
value_found, value = _try_get_field_value(field_name, obj, custom_mapping)

tests/test_model_factory_mapping.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from decimal import Decimal
2+
from enum import Enum
3+
from typing import Union
4+
from unittest import TestCase
5+
6+
import pytest
7+
8+
from automapper import create_mapper
9+
10+
11+
class SourceEnum(Enum):
12+
VALUE1 = "value1"
13+
VALUE2 = "value2"
14+
VALUE3 = "value3"
15+
16+
class NameEnum(Enum):
17+
VALUE1 = 1
18+
VALUE2 = 2
19+
VALUE3 = 3
20+
21+
class ValueEnum(Enum):
22+
A = "value1"
23+
B = "value2"
24+
C = "value3"
25+
26+
class ValueObject:
27+
value: str
28+
29+
def __init__(self, value: Union[float, int, Decimal]):
30+
self.value = str(value)
31+
32+
def __repr__(self):
33+
return f"ValueObject(value={self.value})"
34+
35+
def __str__(self):
36+
return f"ValueObject(value={self.value})"
37+
38+
class AutomapperModelFactoryTest(TestCase):
39+
def setUp(self) -> None:
40+
self.mapper = create_mapper()
41+
42+
def test_map__with_registered_lambda_factory(self):
43+
self.mapper.add(SourceEnum, NameEnum, model_factory=lambda x: NameEnum[x.name])
44+
self.mapper.add(ValueEnum, SourceEnum, model_factory=lambda x: SourceEnum(x.value))
45+
46+
self.assertEqual(self.mapper.map(SourceEnum.VALUE3), NameEnum.VALUE3)
47+
self.assertEqual(self.mapper.map(ValueEnum.B), SourceEnum.VALUE2)
48+
49+
50+
def test_map__with_lambda_factory(self):
51+
name_enum = self.mapper.to(NameEnum).map(SourceEnum.VALUE3, model_factory=lambda x: NameEnum[x.name])
52+
value_enum = self.mapper.to(SourceEnum).map(ValueEnum.B, model_factory=lambda x: SourceEnum(x.value))
53+
54+
self.assertEqual(name_enum, NameEnum.VALUE3)
55+
self.assertEqual(value_enum, SourceEnum.VALUE2)
56+
57+
58+
def test_map__with_registered_constructor_factory(self):
59+
self.mapper.add(Decimal, ValueObject, model_factory=ValueObject) # pyright: ignore[reportArgumentType]
60+
61+
self.assertEqual(self.mapper.map(Decimal("42")).value, ValueObject(42).value)
62+
63+
64+
def test_map__with_constructor_factory(self):
65+
result = self.mapper.to(ValueObject).map(Decimal("42"), model_factory=ValueObject) # pyright: ignore[reportArgumentType]
66+
67+
print(result)
68+
self.assertEqual(result.value, ValueObject(42).value)
69+
70+
71+
def test_map__with_factory_and_fields_mapping_raises_error(self):
72+
self.mapper.add(
73+
ValueEnum,
74+
ValueObject,
75+
model_factory=lambda s: ValueObject(int(s.value)),
76+
fields_mapping={"value": lambda x: x.value}
77+
)
78+
79+
with pytest.raises(ValueError):
80+
self.mapper.map(ValueEnum.A)

0 commit comments

Comments
 (0)