3333from pymc .initial_point import make_initial_point_fn
3434from pymc .pytensorf import compile_pymc
3535from pymc .smc .kernels import IMH
36- from pymc .testing import SeededTest
3736
3837
39- class TestSimulator ( SeededTest ) :
38+ class TestSimulator :
4039 @staticmethod
4140 def count_rvs (end_node ):
4241 return len (
@@ -60,7 +59,6 @@ def quantiles(x):
6059 return np .quantile (x , [0.25 , 0.5 , 0.75 ])
6160
6261 def setup_class (self ):
63- super ().setup_class ()
6462 self .data = np .random .normal (loc = 0 , scale = 1 , size = 1000 )
6563
6664 with pm .Model () as self .SMABC_test :
@@ -75,7 +73,7 @@ def setup_class(self):
7573 c = pm .Potential ("c" , pm .math .switch (a > 0 , 0 , - np .inf ))
7674 s = pm .Simulator ("s" , self .normal_sim , a , b , observed = self .data )
7775
78- def test_one_gaussian (self ):
76+ def test_one_gaussian (self , seeded_test ):
7977 assert self .count_rvs (self .SMABC_test .logp ()) == 1
8078
8179 with self .SMABC_test :
@@ -95,7 +93,7 @@ def test_one_gaussian(self):
9593 assert abs (self .data .std () - po_p ["s" ].std ()) < 0.10
9694
9795 @pytest .mark .parametrize ("floatX" , ["float32" , "float64" ])
98- def test_custom_dist_sum_stat (self , floatX ):
96+ def test_custom_dist_sum_stat (self , seeded_test , floatX ):
9997 with pytensor .config .change_flags (floatX = floatX ):
10098 with pm .Model () as m :
10199 a = pm .Normal ("a" , mu = 0 , sigma = 1 )
@@ -118,7 +116,7 @@ def test_custom_dist_sum_stat(self, floatX):
118116 pm .sample_smc (draws = 100 )
119117
120118 @pytest .mark .parametrize ("floatX" , ["float32" , "float64" ])
121- def test_custom_dist_sum_stat_scalar (self , floatX ):
119+ def test_custom_dist_sum_stat_scalar (self , seeded_test , floatX ):
122120 """
123121 Test that automatically wrapped functions cope well with scalar inputs
124122 """
@@ -149,22 +147,22 @@ def test_custom_dist_sum_stat_scalar(self, floatX):
149147 )
150148 assert self .count_rvs (m .logp ()) == 1
151149
152- def test_model_with_potential (self ):
150+ def test_model_with_potential (self , seeded_test ):
153151 assert self .count_rvs (self .SMABC_potential .logp ()) == 1
154152
155153 with self .SMABC_potential :
156154 trace = pm .sample_smc (draws = 100 , chains = 1 , return_inferencedata = False )
157155 assert np .all (trace ["a" ] >= 0 )
158156
159- def test_simulator_metropolis_mcmc (self ):
157+ def test_simulator_metropolis_mcmc (self , seeded_test ):
160158 with self .SMABC_test as m :
161159 step = pm .Metropolis ([m .rvs_to_values [m ["a" ]], m .rvs_to_values [m ["b" ]]])
162160 trace = pm .sample (step = step , return_inferencedata = False )
163161
164162 assert abs (self .data .mean () - trace ["a" ].mean ()) < 0.05
165163 assert abs (self .data .std () - trace ["b" ].mean ()) < 0.05
166164
167- def test_multiple_simulators (self ):
165+ def test_multiple_simulators (self , seeded_test ):
168166 true_a = 2
169167 true_b = - 2
170168
@@ -214,9 +212,9 @@ def test_multiple_simulators(self):
214212 assert abs (true_a - trace ["a" ].mean ()) < 0.05
215213 assert abs (true_b - trace ["b" ].mean ()) < 0.05
216214
217- def test_nested_simulators (self ):
215+ def test_nested_simulators (self , seeded_test ):
218216 true_a = 2
219- rng = self . get_random_state ( )
217+ rng = np . random . RandomState ( 20160911 )
220218 data = rng .normal (true_a , 0.1 , size = 1000 )
221219
222220 with pm .Model () as m :
@@ -244,7 +242,7 @@ def test_nested_simulators(self):
244242
245243 assert np .abs (true_a - trace ["sim1" ].mean ()) < 0.1
246244
247- def test_upstream_rngs_not_in_compiled_logp (self ):
245+ def test_upstream_rngs_not_in_compiled_logp (self , seeded_test ):
248246 smc = IMH (model = self .SMABC_test )
249247 smc .initialize_population ()
250248 smc ._initialize_kernel ()
@@ -263,7 +261,7 @@ def test_upstream_rngs_not_in_compiled_logp(self):
263261 ]
264262 assert len (shared_rng_vars ) == 1
265263
266- def test_simulator_error_msg (self ):
264+ def test_simulator_error_msg (self , seeded_test ):
267265 msg = "The distance metric not_real is not implemented"
268266 with pytest .raises (ValueError , match = msg ):
269267 with pm .Model () as m :
@@ -280,7 +278,7 @@ def test_simulator_error_msg(self):
280278 sim = pm .Simulator ("sim" , self .normal_sim , 0 , params = (1 ))
281279
282280 @pytest .mark .xfail (reason = "KL not refactored" )
283- def test_automatic_use_of_sort (self ):
281+ def test_automatic_use_of_sort (self , seeded_test ):
284282 with pm .Model () as model :
285283 s_k = pm .Simulator (
286284 "s_k" ,
@@ -292,7 +290,7 @@ def test_automatic_use_of_sort(self):
292290 )
293291 assert s_k .distribution .sum_stat is pm .distributions .simulator .identity
294292
295- def test_name_is_string_type (self ):
293+ def test_name_is_string_type (self , seeded_test ):
296294 with self .SMABC_potential :
297295 assert not self .SMABC_potential .name
298296 with warnings .catch_warnings ():
@@ -303,7 +301,7 @@ def test_name_is_string_type(self):
303301 trace = pm .sample_smc (draws = 10 , chains = 1 , return_inferencedata = False )
304302 assert isinstance (trace ._straces [0 ].name , str )
305303
306- def test_named_model (self ):
304+ def test_named_model (self , seeded_test ):
307305 # Named models used to fail with Simulator because the arguments to the
308306 # random fn used to be passed by name. This is no longer true.
309307 # https://github.com/pymc-devs/pymc/pull/4365#issuecomment-761221146
@@ -323,7 +321,7 @@ def test_named_model(self):
323321 @pytest .mark .parametrize ("mu" , [0 , np .arange (3 )], ids = str )
324322 @pytest .mark .parametrize ("sigma" , [1 , np .array ([1 , 2 , 5 ])], ids = str )
325323 @pytest .mark .parametrize ("size" , [None , 3 , (5 , 3 )], ids = str )
326- def test_simulator_moment (self , mu , sigma , size ):
324+ def test_simulator_moment (self , seeded_test , mu , sigma , size ):
327325 def normal_sim (rng , mu , sigma , size ):
328326 return rng .normal (mu , sigma , size = size )
329327
@@ -357,7 +355,7 @@ def normal_sim(rng, mu, sigma, size):
357355
358356 assert np .all (np .abs ((result - expected_sample_mean ) / expected_sample_mean_std ) < cutoff )
359357
360- def test_dist (self ):
358+ def test_dist (self , seeded_test ):
361359 x = pm .Simulator .dist (self .normal_sim , 0 , 1 , sum_stat = "sort" , shape = (3 ,))
362360 x = cloudpickle .loads (cloudpickle .dumps (x ))
363361
0 commit comments