11import itertools
22import sys
3- from functools import partial
3+ from functools import cached_property , partial
44from typing import Any , Callable , Dict , List , Optional , Tuple , Type
55
66from mypy .build import PRI_MED , PRI_MYPY
1919)
2020from mypy .types import Type as MypyType
2121
22- import mypy_django_plugin .transformers .orm_lookups
2322from mypy_django_plugin .config import DjangoPluginConfig
2423from mypy_django_plugin .django .context import DjangoContext
2524from mypy_django_plugin .exceptions import UnregisteredModelError
3130 manytomany ,
3231 manytoone ,
3332 meta ,
33+ orm_lookups ,
3434 querysets ,
3535 request ,
3636 settings ,
@@ -60,10 +60,6 @@ def transform_form_class(ctx: ClassDefContext) -> None:
6060 forms .make_meta_nested_class_inherit_from_any (ctx )
6161
6262
63- def add_new_manager_base_hook (ctx : ClassDefContext ) -> None :
64- helpers .add_new_manager_base (ctx .api , ctx .cls .fullname )
65-
66-
6763class NewSemanalDjangoPlugin (Plugin ):
6864 def __init__ (self , options : Options ) -> None :
6965 super ().__init__ (options )
@@ -83,15 +79,6 @@ def _get_current_queryset_bases(self) -> Dict[str, int]:
8379 else :
8480 return {}
8581
86- def _get_current_manager_bases (self ) -> Dict [str , int ]:
87- model_sym = self .lookup_fully_qualified (fullnames .MANAGER_CLASS_FULLNAME )
88- if model_sym is not None and isinstance (model_sym .node , TypeInfo ):
89- bases = helpers .get_django_metadata_bases (model_sym .node , "manager_bases" )
90- bases [fullnames .MANAGER_CLASS_FULLNAME ] = 1
91- return bases
92- else :
93- return {}
94-
9582 def _get_current_form_bases (self ) -> Dict [str , int ]:
9683 model_sym = self .lookup_fully_qualified (fullnames .BASEFORM_CLASS_FULLNAME )
9784 if model_sym is not None and isinstance (model_sym .node , TypeInfo ):
@@ -165,10 +152,6 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
165152 if fullname == "django.contrib.auth.get_user_model" :
166153 return partial (settings .get_user_model_hook , django_context = self .django_context )
167154
168- manager_bases = self ._get_current_manager_bases ()
169- if fullname in manager_bases :
170- return querysets .determine_proper_manager_type
171-
172155 info = self ._get_typeinfo_or_none (fullname )
173156 if info :
174157 if info .has_base (fullnames .FIELD_FULLNAME ):
@@ -177,8 +160,26 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
177160 if helpers .is_model_type (info ):
178161 return partial (init_create .redefine_and_typecheck_model_init , django_context = self .django_context )
179162
163+ if info .has_base (fullnames .BASE_MANAGER_CLASS_FULLNAME ):
164+ return querysets .determine_proper_manager_type
165+
180166 return None
181167
168+ @cached_property
169+ def manager_and_queryset_method_hooks (self ) -> Dict [str , Callable [[MethodContext ], MypyType ]]:
170+ typecheck_filtering_method = partial (orm_lookups .typecheck_queryset_filter , django_context = self .django_context )
171+ return {
172+ "values" : partial (querysets .extract_proper_type_queryset_values , django_context = self .django_context ),
173+ "values_list" : partial (
174+ querysets .extract_proper_type_queryset_values_list , django_context = self .django_context
175+ ),
176+ "annotate" : partial (querysets .extract_proper_type_queryset_annotate , django_context = self .django_context ),
177+ "create" : partial (init_create .redefine_and_typecheck_model_create , django_context = self .django_context ),
178+ "filter" : typecheck_filtering_method ,
179+ "get" : typecheck_filtering_method ,
180+ "exclude" : typecheck_filtering_method ,
181+ }
182+
182183 def get_method_hook (self , fullname : str ) -> Optional [Callable [[MethodContext ], MypyType ]]:
183184 class_fullname , _ , method_name = fullname .rpartition ("." )
184185 # Methods called very often -- short circuit for minor speed up
@@ -208,38 +209,17 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
208209 }
209210 return hooks .get (class_fullname )
210211
211- manager_classes = self ._get_current_manager_bases ()
212-
213- if method_name == "values" :
212+ if method_name in self .manager_and_queryset_method_hooks :
214213 info = self ._get_typeinfo_or_none (class_fullname )
215- if info and info .has_base (fullnames .QUERYSET_CLASS_FULLNAME ) or class_fullname in manager_classes :
216- return partial (querysets .extract_proper_type_queryset_values , django_context = self .django_context )
217-
218- elif method_name == "values_list" :
219- info = self ._get_typeinfo_or_none (class_fullname )
220- if info and info .has_base (fullnames .QUERYSET_CLASS_FULLNAME ) or class_fullname in manager_classes :
221- return partial (querysets .extract_proper_type_queryset_values_list , django_context = self .django_context )
222-
223- elif method_name == "annotate" :
224- info = self ._get_typeinfo_or_none (class_fullname )
225- if info and info .has_base (fullnames .QUERYSET_CLASS_FULLNAME ) or class_fullname in manager_classes :
226- return partial (querysets .extract_proper_type_queryset_annotate , django_context = self .django_context )
227-
214+ if info and helpers .has_any_of_bases (
215+ info , [fullnames .QUERYSET_CLASS_FULLNAME , fullnames .MANAGER_CLASS_FULLNAME ]
216+ ):
217+ return self .manager_and_queryset_method_hooks [method_name ]
228218 elif method_name == "get_field" :
229219 info = self ._get_typeinfo_or_none (class_fullname )
230220 if info and info .has_base (fullnames .OPTIONS_CLASS_FULLNAME ):
231221 return partial (meta .return_proper_field_type_from_get_field , django_context = self .django_context )
232222
233- elif method_name == "create" :
234- # We need `BASE_MANAGER_CLASS_FULLNAME` to check abstract models.
235- if class_fullname in manager_classes or class_fullname == fullnames .BASE_MANAGER_CLASS_FULLNAME :
236- return partial (init_create .redefine_and_typecheck_model_create , django_context = self .django_context )
237- elif method_name in {"filter" , "get" , "exclude" } and class_fullname in manager_classes :
238- return partial (
239- mypy_django_plugin .transformers .orm_lookups .typecheck_queryset_filter ,
240- django_context = self .django_context ,
241- )
242-
243223 return None
244224
245225 def get_customize_class_mro_hook (self , fullname : str ) -> Optional [Callable [[ClassDefContext ], None ]]:
@@ -262,10 +242,6 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte
262242 if sym is not None and isinstance (sym .node , TypeInfo ) and helpers .is_model_type (sym .node ):
263243 return partial (process_model_class , django_context = self .django_context )
264244
265- # Base class is a Manager class definition
266- if fullname in self ._get_current_manager_bases ():
267- return add_new_manager_base_hook
268-
269245 # Base class is a Form class definition
270246 if fullname in self ._get_current_form_bases ():
271247 return transform_form_class
0 commit comments