@@ -48,6 +48,7 @@ fn resolve_expression(
4848 local_variables : vec ! [ ] ,
4949 predicate_arguments : vec ! [ ] ,
5050 predicate_argument_types : vec ! [ ] ,
51+ predicates_allowed : false ,
5152 } ;
5253
5354 let new_expr = match node. kind ( ) {
@@ -882,42 +883,55 @@ impl Expression {
882883 return Self :: Invalid ;
883884 } ;
884885
885- let Some ( function) = function else {
886- // Check sub expressions anyway
887- sub_expr. for_each ( |n| {
888- Self :: from_expression_node ( n. clone ( ) , ctx) ;
889- } ) ;
890- assert ! ( ctx. diag. has_errors( ) ) ;
891- return Self :: Invalid ;
892- } ;
893- let LookupResult :: Callable ( function) = function else {
894- // Check sub expressions anyway
895- sub_expr. for_each ( |n| {
896- Self :: from_expression_node ( n. clone ( ) , ctx) ;
897- } ) ;
898- ctx. diag . push_error ( "The expression is not a function" . into ( ) , & node) ;
899- return Self :: Invalid ;
886+ let ( function, array_base_ty) = match function {
887+ Some ( LookupResult :: Callable ( function) ) => {
888+ let base = match & function {
889+ LookupResultCallable :: MemberFunction { base, member, .. } => match * * member {
890+ LookupResultCallable :: Callable ( Callable :: Builtin (
891+ BuiltinFunction :: ArrayAny | BuiltinFunction :: ArrayAll ,
892+ ) ) => match base. ty ( ) {
893+ Type :: Array ( ty) => {
894+ // only set predicates to be allowed if we are in one of the hardcoded array member functions
895+ ctx. predicates_allowed = true ;
896+ Some ( ( * ty) . clone ( ) )
897+ }
898+ _ => unreachable ! ( ) , // you won't have access to these member functions if the base is not an array
899+ } ,
900+ _ => None ,
901+ } ,
902+ _ => None ,
903+ } ;
904+
905+ ( function, base)
906+ }
907+ Some ( _) => {
908+ // Check sub expressions anyway
909+ sub_expr. for_each ( |n| {
910+ Self :: from_expression_node ( n. clone ( ) , ctx) ;
911+ } ) ;
912+ ctx. diag . push_error ( "The expression is not a function" . into ( ) , & node) ;
913+ ctx. predicates_allowed = false ;
914+ return Self :: Invalid ;
915+ }
916+ None => {
917+ // Check sub expressions anyway
918+ sub_expr. for_each ( |n| {
919+ Self :: from_expression_node ( n. clone ( ) , ctx) ;
920+ } ) ;
921+ assert ! ( ctx. diag. has_errors( ) ) ;
922+ return Self :: Invalid ;
923+ }
900924 } ;
901925
902926 // dirty hack to supply the type of the predicate argument for array member functions,
903927 // we check if we are dealing with an array builtin, then we push the type to the context,
904928 // to be popped at the end of this function after all the predicate's expression is resolved
905929 let mut should_pop_predicate_args = false ;
906- match & function {
907- LookupResultCallable :: MemberFunction { base, member, .. } => match * * member {
908- LookupResultCallable :: Callable ( Callable :: Builtin (
909- BuiltinFunction :: ArrayAny | BuiltinFunction :: ArrayAll ,
910- ) ) => {
911- let ty = match base. ty ( ) {
912- Type :: Array ( ty) => ( * ty) . clone ( ) ,
913- _ => unreachable ! ( ) ,
914- } ;
915-
916- should_pop_predicate_args = true ;
917- ctx. predicate_argument_types . push ( ty) ;
918- }
919- _ => ( ) ,
920- } ,
930+ match ( ctx. predicates_allowed , array_base_ty) {
931+ ( true , Some ( ty) ) => {
932+ should_pop_predicate_args = true ;
933+ ctx. predicate_argument_types . push ( ty) ;
934+ }
921935 _ => ( ) ,
922936 }
923937
@@ -990,10 +1004,13 @@ impl Expression {
9901004 }
9911005 } ;
9921006
1007+ // if we pushed the predicate argument type, we pop it now
9931008 if should_pop_predicate_args {
9941009 ctx. predicate_argument_types . pop ( ) ;
9951010 }
9961011
1012+ ctx. predicates_allowed = false ;
1013+
9971014 Expression :: FunctionCall { function, arguments, source_location : Some ( source_location) }
9981015 }
9991016
@@ -1289,6 +1306,14 @@ impl Expression {
12891306 }
12901307
12911308 fn from_predicate_node ( node : syntax_nodes:: Predicate , ctx : & mut LookupCtx ) -> Expression {
1309+ if !ctx. predicates_allowed {
1310+ ctx. diag . push_error (
1311+ "Predicate expressions are not permitted outside of array builtin function arguments" . to_string ( ) ,
1312+ & node,
1313+ ) ;
1314+ return Expression :: Invalid ;
1315+ }
1316+
12921317 let arg_name = node. DeclaredIdentifier ( ) . to_smolstr ( ) ;
12931318
12941319 ctx. predicate_arguments . push ( arg_name. clone ( ) ) ;
@@ -1299,6 +1324,7 @@ impl Expression {
12991324 format ! ( "Predicate expression must be of type bool, but is {}" , ty) ,
13001325 & node. Expression ( ) ,
13011326 ) ;
1327+ return Expression :: Invalid ;
13021328 }
13031329
13041330 ctx. predicate_arguments . pop ( ) ;
@@ -1746,6 +1772,7 @@ fn resolve_two_way_bindings(
17461772 local_variables : vec ! [ ] ,
17471773 predicate_arguments : vec ! [ ] ,
17481774 predicate_argument_types : vec ! [ ] ,
1775+ predicates_allowed : false ,
17491776 } ;
17501777
17511778 binding. expression = Expression :: Invalid ;
0 commit comments