2727class Lexer :
2828 """JSONPath expression lexical scanner."""
2929
30- __slots__ = ("filter_depth" , "paren_stack" , "tokens" , "start" , "pos" , "query" )
30+ __slots__ = (
31+ "filter_depth" ,
32+ "func_call_stack" ,
33+ "bracket_stack" ,
34+ "tokens" ,
35+ "start" ,
36+ "pos" ,
37+ "query" ,
38+ )
3139
3240 def __init__ (self , query : str ) -> None :
3341 self .filter_depth = 0
3442 """Filter nesting level."""
3543
36- self .paren_stack : List [int ] = []
44+ self .func_call_stack : List [int ] = []
3745 """A running count of parentheses for each, possibly nested, function call.
3846
3947 If the stack is empty, we are not in a function call. Remember that
4048 function arguments can be arbitrarily nested in parentheses.
4149 """
4250
51+ self .bracket_stack : list [tuple [str , int ]] = []
52+ """A stack of opening (parentheses/bracket, index) pairs."""
53+
4354 self .tokens : List [Token ] = []
4455 """Tokens resulting from scanning a JSONPath expression."""
4556
@@ -133,7 +144,7 @@ def ignore_whitespace(self) -> bool:
133144
134145 def error (self , msg : str ) -> None :
135146 """Emit an error token."""
136- # better error messages.
147+ # TODO: better error messages.
137148 self .tokens .append (
138149 Token (
139150 TokenType .ERROR ,
@@ -179,6 +190,7 @@ def lex_segment(l: Lexer) -> Optional[StateFn]: # noqa: D103, PLR0911
179190
180191 if c == "[" :
181192 l .emit (TokenType .LBRACKET )
193+ l .bracket_stack .append ((c , l .pos - 1 ))
182194 return lex_inside_bracketed_segment
183195
184196 if l .filter_depth :
@@ -202,6 +214,7 @@ def lex_descendant_segment(l: Lexer) -> Optional[StateFn]: # noqa: D103
202214
203215 if c == "[" :
204216 l .emit (TokenType .LBRACKET )
217+ l .bracket_stack .append ((c , l .pos - 1 ))
205218 return lex_inside_bracketed_segment
206219
207220 l .backup ()
@@ -244,11 +257,17 @@ def lex_inside_bracketed_segment(l: Lexer) -> Optional[StateFn]: # noqa: PLR091
244257 c = l .next ()
245258
246259 if c == "]" :
260+ if not l .bracket_stack or l .bracket_stack [- 1 ][0 ] != "[" :
261+ l .backup ()
262+ l .error ("unbalanced brackets" )
263+ return None
264+
265+ l .bracket_stack .pop ()
247266 l .emit (TokenType .RBRACKET )
248267 return lex_segment
249268
250269 if c == "" :
251- l .error ("unclosed bracketed selection " )
270+ l .error ("unbalanced brackets " )
252271 return None
253272
254273 if c == "*" :
@@ -299,18 +318,14 @@ def lex_inside_filter(l: Lexer) -> Optional[StateFn]: # noqa: D103, PLR0915, PL
299318
300319 if c == "]" :
301320 l .filter_depth -= 1
302- if len (l .paren_stack ) == 1 :
303- l .error ("unbalanced parentheses" )
304- return None
305-
306321 l .backup ()
307322 return lex_inside_bracketed_segment
308323
309324 if c == "," :
310325 l .emit (TokenType .COMMA )
311326 # If we have unbalanced parens, we are inside a function call and a
312327 # comma separates arguments. Otherwise a comma separates selectors.
313- if l .paren_stack :
328+ if l .func_call_stack :
314329 continue
315330 l .filter_depth -= 1
316331 return lex_inside_bracketed_segment
@@ -323,19 +338,26 @@ def lex_inside_filter(l: Lexer) -> Optional[StateFn]: # noqa: D103, PLR0915, PL
323338
324339 if c == "(" :
325340 l .emit (TokenType .LPAREN )
341+ l .bracket_stack .append ((c , l .pos - 1 ))
326342 # Are we in a function call? If so, a function argument contains parens.
327- if l .paren_stack :
328- l .paren_stack [- 1 ] += 1
343+ if l .func_call_stack :
344+ l .func_call_stack [- 1 ] += 1
329345 continue
330346
331347 if c == ")" :
348+ if not l .bracket_stack or l .bracket_stack [- 1 ][0 ] != "(" :
349+ l .backup ()
350+ l .error ("unbalanced parentheses" )
351+ return None
352+
353+ l .bracket_stack .pop ()
332354 l .emit (TokenType .RPAREN )
333355 # Are we closing a function call or a parenthesized expression?
334- if l .paren_stack :
335- if l .paren_stack [- 1 ] == 1 :
336- l .paren_stack .pop ()
356+ if l .func_call_stack :
357+ if l .func_call_stack [- 1 ] == 1 :
358+ l .func_call_stack .pop ()
337359 else :
338- l .paren_stack [- 1 ] -= 1
360+ l .func_call_stack [- 1 ] -= 1
339361 continue
340362
341363 if c == "$" :
@@ -402,8 +424,9 @@ def lex_inside_filter(l: Lexer) -> Optional[StateFn]: # noqa: D103, PLR0915, PL
402424 l .emit (TokenType .INT )
403425 elif l .accept_match (RE_FUNCTION_NAME ) and l .peek () == "(" :
404426 # Keep track of parentheses for this function call.
405- l .paren_stack .append (1 )
427+ l .func_call_stack .append (1 )
406428 l .emit (TokenType .FUNCTION )
429+ l .bracket_stack .append (("(" , l .pos ))
407430 l .next ()
408431 l .ignore () # ignore LPAREN
409432 else :
@@ -486,6 +509,21 @@ def tokenize(query: str) -> List[Token]:
486509 lexer , tokens = lex (query )
487510 lexer .run ()
488511
512+ # Check for remaining opening brackets that have not been closes.
513+ if lexer .bracket_stack :
514+ ch , index = lexer .bracket_stack [0 ]
515+ msg = f"unbalanced { 'brackets' if ch == '[' else 'parentheses' } "
516+ raise JSONPathSyntaxError (
517+ msg ,
518+ token = Token (
519+ TokenType .ERROR ,
520+ lexer .query [index ],
521+ index ,
522+ lexer .query ,
523+ msg ,
524+ ),
525+ )
526+
489527 if tokens and tokens [- 1 ].type_ == TokenType .ERROR :
490528 raise JSONPathSyntaxError (tokens [- 1 ].message , token = tokens [- 1 ])
491529
0 commit comments