1111from pytensor .graph .op import get_test_value
1212from pytensor .graph .replace import clone_replace
1313from pytensor .graph .traversal import explicit_graph_inputs
14+ from pytensor .graph .type import HasShape
1415from pytensor .graph .utils import MissingInputError , TestValueError
1516from pytensor .scan .op import Scan , ScanInfo
1617from pytensor .scan .utils import expand_empty , safe_new , until
@@ -706,6 +707,12 @@ def wrap_into_list(x):
706707 sit_sot_inner_outputs = []
707708 sit_sot_rightOrder = []
708709
710+ n_untraced_sit_sot_outs = 0
711+ untraced_sit_sot_scan_inputs = []
712+ untraced_sit_sot_inner_inputs = []
713+ untraced_sit_sot_inner_outputs = []
714+ untraced_sit_sot_rightOrder = []
715+
709716 # go through outputs picking up time slices as needed
710717 for i , init_out in enumerate (outs_info ):
711718 # Note that our convention dictates that if an output uses
@@ -741,17 +748,35 @@ def wrap_into_list(x):
741748 # We need now to allocate space for storing the output and copy
742749 # the initial state over. We do this using the expand function
743750 # defined in scan utils
744- sit_sot_scan_inputs .append (
745- expand_empty (
746- shape_padleft (actual_arg ),
747- actual_n_steps ,
751+ if isinstance (actual_arg .type , HasShape ):
752+ sit_sot_scan_inputs .append (
753+ expand_empty (
754+ shape_padleft (actual_arg ),
755+ actual_n_steps ,
756+ )
748757 )
749- )
758+ sit_sot_inner_slices .append (actual_arg )
759+
760+ sit_sot_inner_inputs .append (arg )
761+ sit_sot_rightOrder .append (i )
762+ n_sit_sot += 1
763+ else :
764+ # Assume variables without shape cannot be stacked (e.g., RNG variables)
765+ # Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
766+ from pytensor .tensor .random .type import RandomType
750767
751- sit_sot_inner_slices .append (actual_arg )
752- sit_sot_inner_inputs .append (arg )
753- sit_sot_rightOrder .append (i )
754- n_sit_sot += 1
768+ if not isinstance (arg .type , RandomType ):
769+ warnings .warn (
770+ (
771+ f"Output { actual_arg } (index { i } ) with type { actual_arg .type } will be treated as untraced variable in scan. "
772+ "Only the last value will be returned, not the entire sequence."
773+ ),
774+ UserWarning ,
775+ )
776+ untraced_sit_sot_scan_inputs .append (actual_arg )
777+ untraced_sit_sot_inner_inputs .append (arg )
778+ n_untraced_sit_sot_outs += 1
779+ untraced_sit_sot_rightOrder .append (i )
755780
756781 elif init_out .get ("taps" , None ):
757782 if np .any (np .array (init_out .get ("taps" , [])) > 0 ):
@@ -802,9 +827,10 @@ def wrap_into_list(x):
802827 # a map); in that case we do not have to do anything ..
803828
804829 # Re-order args
805- max_mit_sot = np .max ([- 1 , * mit_sot_rightOrder ]) + 1
806- max_sit_sot = np .max ([- 1 , * sit_sot_rightOrder ]) + 1
807- n_elems = np .max ([max_mit_sot , max_sit_sot ])
830+ max_mit_sot = max (mit_sot_rightOrder , default = - 1 ) + 1
831+ max_sit_sot = max (sit_sot_rightOrder , default = - 1 ) + 1
832+ max_untraced_sit_sot_outs = max (untraced_sit_sot_rightOrder , default = - 1 ) + 1
833+ n_elems = np .max ((max_mit_sot , max_sit_sot , max_untraced_sit_sot_outs ))
808834 _ordered_args = [[] for x in range (n_elems )]
809835 offset = 0
810836 for idx in range (n_mit_sot ):
@@ -825,6 +851,11 @@ def wrap_into_list(x):
825851 else :
826852 _ordered_args [sit_sot_rightOrder [idx ]] = [sit_sot_inner_inputs [idx ]]
827853
854+ for idx in range (n_untraced_sit_sot_outs ):
855+ _ordered_args [untraced_sit_sot_rightOrder [idx ]] = [
856+ untraced_sit_sot_inner_inputs [idx ]
857+ ]
858+
828859 ordered_args = list (chain .from_iterable (_ordered_args ))
829860 if single_step_requested :
830861 args = inner_slices + ordered_args + non_seqs
@@ -842,6 +873,11 @@ def wrap_into_list(x):
842873 raw_inner_outputs = fn (* args )
843874
844875 condition , outputs , updates = get_updates_and_outputs (raw_inner_outputs )
876+ if updates :
877+ warnings .warn (
878+ "Updates functionality in Scan are deprecated. Use explicit outputs_info and build shared update expressions manually, even for RNGs." ,
879+ DeprecationWarning , # Only meant for developers for now, not users. Switch to FutureWarning later, before removing.
880+ )
845881 if condition is not None :
846882 as_while = True
847883 else :
@@ -883,6 +919,8 @@ def wrap_into_list(x):
883919 fake_outputs = clone_replace (
884920 outputs , replace = dict (zip (non_seqs , fake_nonseqs , strict = True ))
885921 )
922+ # TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
923+ # to find implicit inputs in a way that reduces the size of the inner function
886924 known_inputs = [* args , * fake_nonseqs ]
887925 extra_inputs = [
888926 x for x in explicit_graph_inputs (fake_outputs ) if x not in known_inputs
@@ -939,18 +977,19 @@ def wrap_into_list(x):
939977 if "taps" in out and out ["taps" ] != [- 1 ]:
940978 mit_sot_inner_outputs .append (outputs [i ])
941979
942- # Step 5.2 Outputs with tap equal to -1
980+ # Step 5.2 Outputs with tap equal to -1 (traced and untraced)
943981 for i , out in enumerate (outs_info ):
944982 if "taps" in out and out ["taps" ] == [- 1 ]:
945- sit_sot_inner_outputs .append (outputs [i ])
983+ output = outputs [i ]
984+ if isinstance (output .type , HasShape ):
985+ sit_sot_inner_outputs .append (output )
986+ else :
987+ untraced_sit_sot_inner_outputs .append (output )
946988
947989 # Step 5.3 Outputs that correspond to update rules of shared variables
948- inner_replacements = {}
949- n_shared_outs = 0
950- shared_scan_inputs = []
951- shared_inner_inputs = []
952- shared_inner_outputs = []
990+ # This whole special logic for shared variables is deprecated
953991 sit_sot_shared = []
992+ inner_replacements = {}
954993 no_update_shared_inputs = []
955994 for input in dummy_inputs :
956995 if not isinstance (input .variable , SharedVariable ):
@@ -1003,10 +1042,10 @@ def wrap_into_list(x):
10031042 sit_sot_shared .append (input .variable )
10041043
10051044 else :
1006- shared_inner_inputs .append (new_var )
1007- shared_scan_inputs .append (input .variable )
1008- shared_inner_outputs .append (input .update )
1009- n_shared_outs += 1
1045+ untraced_sit_sot_inner_inputs .append (new_var )
1046+ untraced_sit_sot_scan_inputs .append (input .variable )
1047+ untraced_sit_sot_inner_outputs .append (input .update )
1048+ n_untraced_sit_sot_outs += 1
10101049 else :
10111050 no_update_shared_inputs .append (input )
10121051
@@ -1071,7 +1110,7 @@ def wrap_into_list(x):
10711110 + mit_mot_inner_inputs
10721111 + mit_sot_inner_inputs
10731112 + sit_sot_inner_inputs
1074- + shared_inner_inputs
1113+ + untraced_sit_sot_inner_inputs
10751114 + other_shared_inner_args
10761115 + other_inner_args
10771116 )
@@ -1081,7 +1120,7 @@ def wrap_into_list(x):
10811120 + mit_sot_inner_outputs
10821121 + sit_sot_inner_outputs
10831122 + nit_sot_inner_outputs
1084- + shared_inner_outputs
1123+ + untraced_sit_sot_inner_outputs
10851124 )
10861125 if condition is not None :
10871126 inner_outs .append (condition )
@@ -1101,7 +1140,7 @@ def wrap_into_list(x):
11011140 mit_mot_out_slices = tuple (tuple (v ) for v in mit_mot_out_slices ),
11021141 mit_sot_in_slices = tuple (tuple (v ) for v in mit_sot_tap_array ),
11031142 sit_sot_in_slices = tuple ((- 1 ,) for x in range (n_sit_sot )),
1104- n_shared_outs = n_shared_outs ,
1143+ n_untraced_sit_sot_outs = n_untraced_sit_sot_outs ,
11051144 n_nit_sot = n_nit_sot ,
11061145 n_non_seqs = len (other_shared_inner_args ) + len (other_inner_args ),
11071146 as_while = as_while ,
@@ -1127,7 +1166,7 @@ def wrap_into_list(x):
11271166 + mit_mot_scan_inputs
11281167 + mit_sot_scan_inputs
11291168 + sit_sot_scan_inputs
1130- + shared_scan_inputs
1169+ + untraced_sit_sot_scan_inputs
11311170 + [actual_n_steps for x in range (n_nit_sot )]
11321171 + other_shared_scan_args
11331172 + other_scan_args
@@ -1173,13 +1212,28 @@ def remove_dimensions(outs, offsets=None):
11731212 nit_sot_outs = remove_dimensions (scan_outs [offset : offset + n_nit_sot ])
11741213
11751214 offset += n_nit_sot
1176- for idx , update_rule in enumerate (scan_outs [offset : offset + n_shared_outs ]):
1177- update_map [shared_scan_inputs [idx ]] = update_rule
11781215
1179- _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs
1216+ # Legacy support for explicit untraced sit_sot and those built with update dictionary
1217+ # Switch to n_untraced_sit_sot_outs after deprecation period
1218+ n_explicit_untraced_sit_sot_outs = len (untraced_sit_sot_rightOrder )
1219+ untraced_sit_sot_outs = scan_outs [
1220+ offset : offset + n_explicit_untraced_sit_sot_outs
1221+ ]
1222+
1223+ # Legacy support: map shared outputs to their updates
1224+ offset += n_explicit_untraced_sit_sot_outs
1225+ for idx , update_rule in enumerate (scan_outs [offset :]):
1226+ update_map [untraced_sit_sot_scan_inputs [idx ]] = update_rule
1227+
1228+ _scan_out_list = mit_sot_outs + sit_sot_outs + nit_sot_outs + untraced_sit_sot_outs
11801229 # Step 10. I need to reorder the outputs to be in the order expected by
11811230 # the user
1182- rightOrder = mit_sot_rightOrder + sit_sot_rightOrder + nit_sot_rightOrder
1231+ rightOrder = (
1232+ mit_sot_rightOrder
1233+ + sit_sot_rightOrder
1234+ + untraced_sit_sot_rightOrder
1235+ + nit_sot_rightOrder
1236+ )
11831237 scan_out_list = [None ] * len (rightOrder )
11841238 for idx , pos in enumerate (rightOrder ):
11851239 if pos >= 0 :
0 commit comments