@@ -917,7 +917,12 @@ def test_switch_mixture():
917917 i_vv = I_rv .clone ()
918918 i_vv .name = "i"
919919
920+ # When I_rv == True, X_rv flows through otherwise Y_rv does
920921 Z1_rv = pt .switch (I_rv , X_rv , Y_rv )
922+
923+ assert Z1_rv .eval ({I_rv : 0 }) > 5
924+ assert Z1_rv .eval ({I_rv : 1 }) < - 5
925+
921926 z_vv = Z1_rv .clone ()
922927 z_vv .name = "z1"
923928
@@ -935,7 +940,10 @@ def test_switch_mixture():
935940
936941 # building the identical graph but with a stack to check that mixture computations are identical
937942
938- Z2_rv = pt .stack ((X_rv , Y_rv ))[I_rv ]
943+ Z2_rv = pt .stack ((Y_rv , X_rv ))[I_rv ]
944+
945+ assert Z2_rv .eval ({I_rv : 0 }) > 5
946+ assert Z2_rv .eval ({I_rv : 1 }) < - 5
939947
940948 fgraph2 , _ , _ = construct_ir_fgraph ({Z2_rv : z_vv , I_rv : i_vv })
941949
@@ -949,8 +957,8 @@ def test_switch_mixture():
949957 # below should follow immediately from the equal_computations assertion above
950958 assert equal_computations ([z1_logp_combined ], [z2_logp_combined ])
951959
952- np .testing .assert_almost_equal (0.69049938 , z1_logp_combined .eval ({z_vv : - 10 , i_vv : 0 }))
953- np .testing .assert_almost_equal (0.69049938 , z2_logp_combined .eval ({z_vv : - 10 , i_vv : 0 }))
960+ np .testing .assert_almost_equal (0.69049938 , z1_logp_combined .eval ({z_vv : - 10 , i_vv : 1 }))
961+ np .testing .assert_almost_equal (0.69049938 , z2_logp_combined .eval ({z_vv : - 10 , i_vv : 1 }))
954962
955963
956964def test_ifelse_mixture_one_component ():
0 commit comments