@@ -62,6 +62,7 @@ def map(
62
62
skip_none_values : bool = False ,
63
63
fields_mapping : FieldsMap = None ,
64
64
use_deepcopy : bool = True ,
65
+ model_factory : Optional [Callable [[S ], T ]] = None ,
65
66
) -> T :
66
67
"""Produces output object mapped from source object and custom arguments.
67
68
@@ -72,6 +73,9 @@ def map(
72
73
Specify dictionary in format {"field_name": value_object}. Defaults to None.
73
74
use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object.
74
75
Defaults to True.
76
+ model_factory (Callable, optional): Custom factory funtion
77
+ factory function that is used to create the target_obj. Called with the
78
+ source as parameter. Mutually exclusive with fields_mapping. Defaults to None
75
79
76
80
Raises:
77
81
CircularReferenceError: Circular references in `source class` object are not allowed yet.
@@ -86,13 +90,14 @@ def map(
86
90
skip_none_values = skip_none_values ,
87
91
custom_mapping = fields_mapping ,
88
92
use_deepcopy = use_deepcopy ,
93
+ model_factory = model_factory ,
89
94
)
90
95
91
96
92
97
class Mapper :
93
98
def __init__ (self ) -> None :
94
99
"""Initializes internal containers"""
95
- self ._mappings : Dict [Type [S ], Tuple [T , FieldsMap ]] = {} # type: ignore [valid-type]
100
+ self ._mappings : Dict [Type [S ], Tuple [T , FieldsMap , Callable [[ S ], [ T ]] ]] = {} # type: ignore [valid-type]
96
101
self ._class_specs : Dict [Type [T ], SpecFunction [T ]] = {} # type: ignore [valid-type]
97
102
self ._classifier_specs : Dict [ # type: ignore [valid-type]
98
103
ClassifierFunction [T ], SpecFunction [T ]
@@ -147,6 +152,7 @@ def add(
147
152
target_cls : Type [T ],
148
153
override : bool = False ,
149
154
fields_mapping : FieldsMap = None ,
155
+ model_factory : Optional [Callable [[S ], T ]] = None ,
150
156
) -> None :
151
157
"""Adds mapping between object of `source class` to an object of `target class`.
152
158
@@ -156,7 +162,12 @@ def add(
156
162
override (bool, optional): Override existing `source class` mapping to use new `target class`.
157
163
Defaults to False.
158
164
fields_mapping (FieldsMap, optional): Custom mapping.
159
- Specify dictionary in format {"field_name": value_object}. Defaults to None.
165
+ Specify dictionary in format {"field_name": value_objecture_obj}.
166
+ Can take a lamdba funtion as argument, that will get the source_cls
167
+ as argument. Defaults to None.
168
+ model_factory (Callable, optional): Custom factory funtion
169
+ factory function that is used to create the target_obj. Called with the
170
+ source as parameter. Mutually exclusive with fields_mapping. Defaults to None
160
171
161
172
Raises:
162
173
DuplicatedRegistrationError: Same mapping for `source class` was added.
@@ -168,7 +179,7 @@ def add(
168
179
raise DuplicatedRegistrationError (
169
180
f"source_cls { source_cls } was already added for mapping"
170
181
)
171
- self ._mappings [source_cls ] = (target_cls , fields_mapping )
182
+ self ._mappings [source_cls ] = (target_cls , fields_mapping , model_factory )
172
183
173
184
def map (
174
185
self ,
@@ -201,7 +212,7 @@ def map(
201
212
raise MappingError (f"Missing mapping type for input type { obj_type } " )
202
213
obj_type_prefix = f"{ obj_type .__name__ } ."
203
214
204
- target_cls , target_cls_field_mappings = self ._mappings [obj_type ]
215
+ target_cls , target_cls_field_mappings , target_cls_model_factory = self ._mappings [obj_type ]
205
216
206
217
common_fields_mapping = fields_mapping
207
218
if target_cls_field_mappings :
@@ -221,13 +232,16 @@ def map(
221
232
** fields_mapping ,
222
233
} # merge two dict into one, fields_mapping has priority
223
234
235
+
236
+
224
237
return self ._map_common (
225
238
obj ,
226
239
target_cls ,
227
240
set (),
228
241
skip_none_values = skip_none_values ,
229
242
custom_mapping = common_fields_mapping ,
230
243
use_deepcopy = use_deepcopy ,
244
+ model_factory = target_cls_model_factory ,
231
245
)
232
246
233
247
def _get_fields (self , target_cls : Type [T ]) -> Iterable [str ]:
@@ -257,7 +271,7 @@ def _map_subobject(
257
271
raise CircularReferenceError ()
258
272
259
273
if type (obj ) in self ._mappings :
260
- target_cls , _ = self ._mappings [type (obj )]
274
+ target_cls , _ , _ = self ._mappings [type (obj )]
261
275
result : Any = self ._map_common (
262
276
obj , target_cls , _visited_stack , skip_none_values = skip_none_values
263
277
)
@@ -297,6 +311,7 @@ def _map_common(
297
311
skip_none_values : bool = False ,
298
312
custom_mapping : FieldsMap = None ,
299
313
use_deepcopy : bool = True ,
314
+ model_factory : Optional [Callable [[S ], T ]] = None ,
300
315
) -> T :
301
316
"""Produces output object mapped from source object and custom arguments.
302
317
@@ -309,6 +324,9 @@ def _map_common(
309
324
Specify dictionary in format {"field_name": value_object}. Defaults to None.
310
325
use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object.
311
326
Defaults to True.
327
+ model_factory (Callable, optional): Custom factory funtion
328
+ factory function that is used to create the target_obj. Called with the
329
+ source as parameter. Mutually exclusive with fields_mapping. Defaults to None
312
330
313
331
Raises:
314
332
CircularReferenceError: Circular references in `source class` object are not allowed yet.
@@ -320,10 +338,25 @@ def _map_common(
320
338
321
339
if obj_id in _visited_stack :
322
340
raise CircularReferenceError ()
341
+
342
+ target_cls_fields_mapping = None
343
+ if type (obj ) in self ._mappings :
344
+ _ , target_cls_fields_mapping , a = self ._mappings [type (obj )]
345
+
346
+ if model_factory is not None and target_cls_fields_mapping :
347
+ raise ValueError (
348
+ "Cannot specify both model_factory and fields_mapping. "
349
+ "Use one of them to customize mapping."
350
+ )
351
+
352
+ if model_factory is not None and callable (model_factory ):
353
+ return model_factory (obj )
354
+
323
355
_visited_stack .add (obj_id )
324
356
325
357
target_cls_fields = self ._get_fields (target_cls )
326
358
359
+
327
360
mapped_values : Dict [str , Any ] = {}
328
361
for field_name in target_cls_fields :
329
362
value_found , value = _try_get_field_value (field_name , obj , custom_mapping )
0 commit comments