Skip to content

Commit 9207c5f

Browse files
mldwgpgammans
authored andcommitted
polymorphic accessors now use builtin caching from underlying fields
1 parent 3cf751d commit 9207c5f

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

polymorphic/models.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,15 @@ def __init__(self, *args, **kwargs):
192192
return
193193
self.__class__.polymorphic_super_sub_accessors_replaced = True
194194

195-
def create_accessor_function_for_model(model, accessor_name):
195+
def create_accessor_function_for_model(model, field):
196196
def accessor_function(self):
197-
objects = getattr(model, "_base_objects", model.objects)
198-
attr = objects.get(pk=self.pk)
199-
return attr
197+
try:
198+
rel_obj = field.get_cached_value(self)
199+
except KeyError:
200+
objects = getattr(model, "_base_objects", model.objects)
201+
rel_obj = objects.get(pk=self.pk)
202+
field.set_cached_value(self, rel_obj)
203+
return rel_obj
200204

201205
return accessor_function
202206

@@ -209,10 +213,14 @@ def accessor_function(self):
209213
type(orig_accessor),
210214
(ReverseOneToOneDescriptor, ForwardManyToOneDescriptor),
211215
):
216+
217+
field = orig_accessor.related \
218+
if isinstance(orig_accessor, ReverseOneToOneDescriptor) else orig_accessor.field
219+
212220
setattr(
213221
self.__class__,
214222
name,
215-
property(create_accessor_function_for_model(model, name)),
223+
property(create_accessor_function_for_model(model, field)),
216224
)
217225

218226
def _get_inheritance_relation_fields_and_models(self):

polymorphic/tests/test_orm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,29 @@ def test_parent_link_and_related_name(self):
965965
# test that we can delete the object
966966
t.delete()
967967

968+
def test_polymorphic__accessor_caching(self):
969+
blog_a = BlogA.objects.create(name="blog")
970+
971+
blog_base = BlogBase.objects.non_polymorphic().get(id=blog_a.id)
972+
blog_a = BlogA.objects.get(id=blog_a.id)
973+
974+
# test reverse accessor & check that we get back cached object on repeated access
975+
self.assertEqual(blog_base.bloga, blog_a)
976+
self.assertIs(blog_base.bloga, blog_base.bloga)
977+
cached_blog_a = blog_base.bloga
978+
979+
# test forward accessor & check that we get back cached object on repeated access
980+
self.assertEqual(blog_a.blogbase_ptr, blog_base)
981+
self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr)
982+
cached_blog_base = blog_a.blogbase_ptr
983+
984+
# check that refresh_from_db correctly clears cached related objects
985+
blog_base.refresh_from_db()
986+
blog_a.refresh_from_db()
987+
988+
self.assertIsNot(cached_blog_a, blog_base.bloga)
989+
self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr)
990+
968991
def test_polymorphic__aggregate(self):
969992
"""test ModelX___field syntax on aggregate (should work for annotate either)"""
970993

0 commit comments

Comments
 (0)