@@ -133,7 +133,7 @@ def check_stat_dtype(self, step, idata):
133133 continue
134134 assert idata .sample_stats [stat ].dtype == np .dtype (dtype )
135135
136- def step_continuous (self , step_fn , draws ):
136+ def step_continuous (self , step_fn , draws , chains = 1 , tune = 1000 ):
137137 start , model , (mu , C ) = mv_simple ()
138138 unc = np .diag (C ) ** 0.5
139139 check = (("x" , np .mean , mu , unc / 10 ), ("x" , np .std , unc , unc / 10 ))
@@ -143,14 +143,19 @@ def step_continuous(self, step_fn, draws):
143143 with warnings .catch_warnings ():
144144 warnings .filterwarnings ("ignore" , "More chains .* than draws .*" , UserWarning )
145145 idata = pm .sample (
146- tune = 1000 ,
146+ tune = tune ,
147147 draws = draws ,
148- chains = 1 ,
148+ chains = chains ,
149149 step = step ,
150150 initvals = start ,
151151 model = model ,
152152 random_seed = 1 ,
153+ discard_tuned_samples = False ,
153154 )
155+ assert idata .warmup_posterior .sizes ["chain" ] == chains
156+ assert idata .warmup_posterior .sizes ["draw" ] == tune
157+ assert idata .posterior .sizes ["chain" ] == chains
158+ assert idata .posterior .sizes ["draw" ] == draws
154159 self .check_stat (check , idata , step .__class__ .__name__ )
155160 self .check_stat_dtype (idata , step )
156161
0 commit comments