2121_LOGGER = logging .getLogger (LOGGER_PATH + ".trigger" )
2222
2323
24+ STATE_RE = re .compile (r"[a-zA-Z]\w*\.[a-zA-Z]\w*$" )
25+
26+
2427def dt_now ():
2528 """Return current time."""
2629 return dt .datetime .now ()
@@ -101,32 +104,57 @@ async def wait_until(
101104 await asyncio .sleep (timeout )
102105 return {"trigger_type" : "timeout" }
103106 return {"trigger_type" : "none" }
104- state_trig_ident = None
105- state_trig_expr = None
107+ state_trig_ident = set ()
108+ state_trig_ident_any = set ()
109+ state_trig_eval = None
106110 event_trig_expr = None
107111 exc = None
108112 notify_q = asyncio .Queue (0 )
109113 if state_trigger is not None :
110- state_trig_expr = AstEval (
111- f"{ ast_ctx .name } state_trigger" ,
112- ast_ctx .get_global_ctx (),
113- logger_name = ast_ctx .get_logger_name (),
114- )
115- Function .install_ast_funcs (state_trig_expr )
116- state_trig_expr .parse (state_trigger )
117- exc = state_trig_expr .get_exception_obj ()
118- if exc is not None :
119- raise exc
114+ state_trig = []
115+ if isinstance (state_trigger , str ):
116+ state_trigger = [state_trigger ]
117+ elif isinstance (state_trigger , set ):
118+ state_trigger = list (state_trigger )
120119 #
121- # check straight away to see if the condition is met (to avoid race conditions)
120+ # separate out the entries that are just state var names, which mean trigger
121+ # on any change (no expr)
122122 #
123- state_trig_ok = await state_trig_expr .eval ()
124- exc = state_trig_expr .get_exception_obj ()
125- if exc is not None :
126- raise exc
127- if state_trig_ok :
128- return {"trigger_type" : "state" }
129- state_trig_ident = await state_trig_expr .get_names ()
123+ for trig in state_trigger :
124+ if STATE_RE .match (trig ):
125+ state_trig_ident_any .add (trig )
126+ else :
127+ state_trig .append (trig )
128+
129+ if len (state_trig ) > 0 :
130+ if len (state_trig ) == 1 :
131+ state_trig_expr = state_trig [0 ]
132+ else :
133+ state_trig_expr = f"any([{ ', ' .join (state_trig )} ])"
134+ state_trig_eval = AstEval (
135+ f"{ ast_ctx .name } state_trigger" ,
136+ ast_ctx .get_global_ctx (),
137+ logger_name = ast_ctx .get_logger_name (),
138+ )
139+ Function .install_ast_funcs (state_trig_eval )
140+ state_trig_eval .parse (state_trig_expr )
141+ state_trig_ident = await state_trig_eval .get_names ()
142+ exc = state_trig_eval .get_exception_obj ()
143+ if exc is not None :
144+ raise exc
145+
146+ state_trig_ident .update (state_trig_ident_any )
147+ if state_trig_eval :
148+ #
149+ # check straight away to see if the condition is met (to avoid race conditions)
150+ #
151+ state_trig_ok = await state_trig_eval .eval (State .notify_var_get (state_trig_ident , {}))
152+ exc = state_trig_eval .get_exception_obj ()
153+ if exc is not None :
154+ raise exc
155+ if state_trig_ok :
156+ return {"trigger_type" : "state" }
157+
130158 _LOGGER .debug (
131159 "trigger %s wait_until: watching vars %s" , ast_ctx .name , state_trig_ident ,
132160 )
@@ -145,7 +173,7 @@ async def wait_until(
145173 event_trig_expr .parse (event_trigger [1 ])
146174 exc = event_trig_expr .get_exception_obj ()
147175 if exc is not None :
148- if state_trig_ident :
176+ if len ( state_trig_ident ) > 0 :
149177 State .notify_del (state_trig_ident , notify_q )
150178 raise exc
151179 Event .notify_add (event_trigger [0 ], notify_q )
@@ -191,11 +219,19 @@ async def wait_until(
191219 ret ["trigger_time" ] = time_next
192220 break
193221 if notify_type == "state" :
194- new_vars = notify_info [0 ] if notify_info else None
195- state_trig_ok = await state_trig_expr .eval (new_vars )
196- exc = state_trig_expr .get_exception_obj ()
197- if exc is not None :
198- break
222+ if notify_info :
223+ new_vars , func_args = notify_info
224+ else :
225+ new_vars , func_args = None , {}
226+
227+ state_trig_ok = False
228+ if func_args .get ("var_name" , "" ) in state_trig_ident_any :
229+ state_trig_ok = True
230+ elif state_trig_eval :
231+ state_trig_ok = await state_trig_eval .eval (new_vars )
232+ exc = state_trig_eval .get_exception_obj ()
233+ if exc is not None :
234+ break
199235 if state_trig_ok :
200236 ret = notify_info [1 ] if notify_info else None
201237 break
@@ -215,7 +251,7 @@ async def wait_until(
215251 "trigger %s wait_until got unexpected queue message %s" , ast_ctx .name , notify_type ,
216252 )
217253
218- if state_trig_ident :
254+ if len ( state_trig_ident ) > 0 :
219255 State .notify_del (state_trig_ident , notify_q )
220256 if event_trigger is not None :
221257 Event .notify_del (event_trigger [0 ], notify_q )
@@ -454,7 +490,9 @@ def __init__(
454490 self .active_expr = None
455491 self .state_active_ident = None
456492 self .state_trig_expr = None
493+ self .state_trig_eval = None
457494 self .state_trig_ident = None
495+ self .state_trig_ident_any = set ()
458496 self .event_trig_expr = None
459497 self .have_trigger = False
460498 self .setup_ok = False
@@ -481,15 +519,36 @@ def __init__(
481519 self .run_on_startup = True
482520
483521 if self .state_trigger is not None :
484- self .state_trig_expr = AstEval (
485- f"{ self .name } @state_trigger()" , self .global_ctx , logger_name = self .name
486- )
487- Function .install_ast_funcs (self .state_trig_expr )
488- self .state_trig_expr .parse (self .state_trigger )
489- exc = self .state_trig_expr .get_exception_long ()
490- if exc is not None :
491- self .state_trig_expr .get_logger ().error (exc )
492- return
522+ state_trig = []
523+ for triggers in self .state_trigger :
524+ if isinstance (triggers , str ):
525+ triggers = [triggers ]
526+ elif isinstance (triggers , set ):
527+ triggers = list (triggers )
528+ #
529+ # separate out the entries that are just state var names, which mean trigger
530+ # on any change (no expr)
531+ #
532+ for trig in triggers :
533+ if STATE_RE .match (trig ):
534+ self .state_trig_ident_any .add (trig )
535+ else :
536+ state_trig .append (trig )
537+
538+ if len (state_trig ) > 0 :
539+ if len (state_trig ) == 1 :
540+ self .state_trig_expr = state_trig [0 ]
541+ else :
542+ self .state_trig_expr = f"any([{ ', ' .join (state_trig )} ])"
543+ self .state_trig_eval = AstEval (
544+ f"{ self .name } @state_trigger()" , self .global_ctx , logger_name = self .name
545+ )
546+ Function .install_ast_funcs (self .state_trig_eval )
547+ self .state_trig_eval .parse (self .state_trig_expr )
548+ exc = self .state_trig_eval .get_exception_long ()
549+ if exc is not None :
550+ self .state_trig_eval .get_logger ().error (exc )
551+ return
493552 self .have_trigger = True
494553
495554 if self .event_trigger is not None :
@@ -530,7 +589,10 @@ async def trigger_watch(self):
530589 try :
531590
532591 if self .state_trigger is not None :
533- self .state_trig_ident = await self .state_trig_expr .get_names ()
592+ self .state_trig_ident = set ()
593+ if self .state_trig_eval :
594+ self .state_trig_ident = await self .state_trig_eval .get_names ()
595+ self .state_trig_ident .update (self .state_trig_ident_any )
534596 _LOGGER .debug ("trigger %s: watching vars %s" , self .name , self .state_trig_ident )
535597 if len (self .state_trig_ident ) > 0 :
536598 State .notify_add (self .state_trig_ident , self .notify_q )
@@ -587,11 +649,14 @@ async def trigger_watch(self):
587649 if notify_type == "state" :
588650 new_vars , func_args = notify_info
589651
590- if self .state_trig_expr :
591- trig_ok = await self .state_trig_expr .eval (new_vars )
592- exc = self .state_trig_expr .get_exception_long ()
593- if exc is not None :
594- self .state_trig_expr .get_logger ().error (exc )
652+ if func_args ["var_name" ] not in self .state_trig_ident_any :
653+ if self .state_trig_eval :
654+ trig_ok = await self .state_trig_eval .eval (new_vars )
655+ exc = self .state_trig_eval .get_exception_long ()
656+ if exc is not None :
657+ self .state_trig_eval .get_logger ().error (exc )
658+ trig_ok = False
659+ else :
595660 trig_ok = False
596661
597662 elif notify_type == "event" :
0 commit comments