55import textwrap
66import tokenize
77import warnings
8- from ast import PyCF_ONLY_AST as _AST_FLAG
98from bisect import bisect_right
9+ from types import CodeType
1010from types import FrameType
1111from typing import Iterator
1212from typing import List
1818import py
1919
2020from _pytest .compat import overload
21+ from _pytest .compat import TYPE_CHECKING
22+
23+ if TYPE_CHECKING :
24+ from typing_extensions import Literal
2125
2226
2327class Source :
@@ -121,7 +125,7 @@ def getstatement(self, lineno: int) -> "Source":
121125 start , end = self .getstatementrange (lineno )
122126 return self [start :end ]
123127
124- def getstatementrange (self , lineno : int ):
128+ def getstatementrange (self , lineno : int ) -> Tuple [ int , int ] :
125129 """ return (start, end) tuple which spans the minimal
126130 statement region which containing the given lineno.
127131 """
@@ -159,14 +163,36 @@ def isparseable(self, deindent: bool = True) -> bool:
159163 def __str__ (self ) -> str :
160164 return "\n " .join (self .lines )
161165
166+ @overload
162167 def compile (
163168 self ,
164- filename = None ,
165- mode = "exec" ,
169+ filename : Optional [str ] = ...,
170+ mode : str = ...,
171+ flag : "Literal[0]" = ...,
172+ dont_inherit : int = ...,
173+ _genframe : Optional [FrameType ] = ...,
174+ ) -> CodeType :
175+ raise NotImplementedError ()
176+
177+ @overload # noqa: F811
178+ def compile ( # noqa: F811
179+ self ,
180+ filename : Optional [str ] = ...,
181+ mode : str = ...,
182+ flag : int = ...,
183+ dont_inherit : int = ...,
184+ _genframe : Optional [FrameType ] = ...,
185+ ) -> Union [CodeType , ast .AST ]:
186+ raise NotImplementedError ()
187+
188+ def compile ( # noqa: F811
189+ self ,
190+ filename : Optional [str ] = None ,
191+ mode : str = "exec" ,
166192 flag : int = 0 ,
167193 dont_inherit : int = 0 ,
168194 _genframe : Optional [FrameType ] = None ,
169- ):
195+ ) -> Union [ CodeType , ast . AST ] :
170196 """ return compiled code object. if filename is None
171197 invent an artificial filename which displays
172198 the source/line position of the caller frame.
@@ -196,8 +222,10 @@ def compile(
196222 newex .text = ex .text
197223 raise newex
198224 else :
199- if flag & _AST_FLAG :
225+ if flag & ast .PyCF_ONLY_AST :
226+ assert isinstance (co , ast .AST )
200227 return co
228+ assert isinstance (co , CodeType )
201229 lines = [(x + "\n " ) for x in self .lines ]
202230 # Type ignored because linecache.cache is private.
203231 linecache .cache [filename ] = (1 , None , lines , filename ) # type: ignore
@@ -209,22 +237,52 @@ def compile(
209237#
210238
211239
212- def compile_ (source , filename = None , mode = "exec" , flags : int = 0 , dont_inherit : int = 0 ):
240+ @overload
241+ def compile_ (
242+ source : Union [str , bytes , ast .mod , ast .AST ],
243+ filename : Optional [str ] = ...,
244+ mode : str = ...,
245+ flags : "Literal[0]" = ...,
246+ dont_inherit : int = ...,
247+ ) -> CodeType :
248+ raise NotImplementedError ()
249+
250+
251+ @overload # noqa: F811
252+ def compile_ ( # noqa: F811
253+ source : Union [str , bytes , ast .mod , ast .AST ],
254+ filename : Optional [str ] = ...,
255+ mode : str = ...,
256+ flags : int = ...,
257+ dont_inherit : int = ...,
258+ ) -> Union [CodeType , ast .AST ]:
259+ raise NotImplementedError ()
260+
261+
262+ def compile_ ( # noqa: F811
263+ source : Union [str , bytes , ast .mod , ast .AST ],
264+ filename : Optional [str ] = None ,
265+ mode : str = "exec" ,
266+ flags : int = 0 ,
267+ dont_inherit : int = 0 ,
268+ ) -> Union [CodeType , ast .AST ]:
213269 """ compile the given source to a raw code object,
214270 and maintain an internal cache which allows later
215271 retrieval of the source code for the code object
216272 and any recursively created code objects.
217273 """
218274 if isinstance (source , ast .AST ):
219275 # XXX should Source support having AST?
220- return compile (source , filename , mode , flags , dont_inherit )
276+ assert filename is not None
277+ co = compile (source , filename , mode , flags , dont_inherit )
278+ assert isinstance (co , (CodeType , ast .AST ))
279+ return co
221280 _genframe = sys ._getframe (1 ) # the caller
222281 s = Source (source )
223- co = s .compile (filename , mode , flags , _genframe = _genframe )
224- return co
282+ return s .compile (filename , mode , flags , _genframe = _genframe )
225283
226284
227- def getfslineno (obj ):
285+ def getfslineno (obj ) -> Tuple [ Union [ str , py . path . local ], int ] :
228286 """ Return source location (path, lineno) for the given object.
229287 If the source cannot be determined return ("", -1).
230288
@@ -321,7 +379,7 @@ def getstatementrange_ast(
321379 # don't produce duplicate warnings when compiling source to find ast
322380 with warnings .catch_warnings ():
323381 warnings .simplefilter ("ignore" )
324- astnode = compile (content , "source" , "exec" , _AST_FLAG )
382+ astnode = ast . parse (content , "source" , "exec" )
325383
326384 start , end = get_statement_startend2 (lineno , astnode )
327385 # we need to correct the end:
0 commit comments