Skip to content

Commit 1243f8c

Browse files
add unit tests for two_level and irt graphical simulators
1 parent b75c5f5 commit 1243f8c

File tree

2 files changed

+144
-1
lines changed

2 files changed

+144
-1
lines changed

tests/test_simulators/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,24 @@ def single_level_simulator():
254254
from bayesflow.experimental.graphical_simulator.example_simulators import single_level
255255

256256
return single_level()
257+
258+
259+
@pytest.fixture()
260+
def two_level_simulator():
261+
from bayesflow.experimental.graphical_simulator.example_simulators import two_level
262+
263+
return two_level()
264+
265+
266+
@pytest.fixture()
267+
def two_level_repeated_roots_simulator():
268+
from bayesflow.experimental.graphical_simulator.example_simulators import two_level_repeated_roots
269+
270+
return two_level_repeated_roots()
271+
272+
273+
@pytest.fixture()
274+
def irt_simulator():
275+
from bayesflow.experimental.graphical_simulator.example_simulators import irt
276+
277+
return irt()

tests/test_simulators/test_graphical_simulator.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,137 @@
44

55

66
def test_single_level_simulator(single_level_simulator):
7+
# prior -> likelihood
78
assert isinstance(single_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
89
assert isinstance(single_level_simulator.sample(5), dict)
910

10-
samples = single_level_simulator.sample((12,))
11+
samples = single_level_simulator.sample(12)
1112
expected_keys = ["N", "beta", "sigma", "x", "y"]
1213

1314
assert set(samples.keys()) == set(expected_keys)
1415
assert 5 <= samples["N"] < 15
16+
17+
# prior node
1518
assert np.shape(samples["beta"]) == (12, 2) # num_samples, beta_dim
1619
assert np.shape(samples["sigma"]) == (12, 1) # num_samples, sigma_dim
20+
21+
# likelihood node
1722
assert np.shape(samples["x"]) == (12, samples["N"])
1823
assert np.shape(samples["y"]) == (12, samples["N"])
24+
25+
26+
def test_two_level_simulator(two_level_simulator):
27+
# hypers
28+
# |
29+
# locals shared
30+
# \ /
31+
# \ /
32+
# y
33+
34+
assert isinstance(two_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
35+
assert isinstance(two_level_simulator.sample(5), dict)
36+
37+
samples = two_level_simulator.sample(15)
38+
expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"]
39+
40+
assert set(samples.keys()) == set(expected_keys)
41+
42+
# hypers node
43+
assert np.shape(samples["hyper_mean"]) == (15, 1)
44+
assert np.shape(samples["hyper_std"]) == (15, 1)
45+
46+
# locals node
47+
assert np.shape(samples["local_mean"]) == (15, 6, 1)
48+
49+
# shared node
50+
assert np.shape(samples["shared_std"]) == (15, 1)
51+
52+
# y node
53+
assert np.shape(samples["y"]) == (15, 6, 10, 1)
54+
55+
56+
def test_two_level_repeated_roots_simulator(two_level_repeated_roots_simulator):
57+
# hypers
58+
# |
59+
# locals shared
60+
# \ /
61+
# \ /
62+
# y
63+
64+
simulator = two_level_repeated_roots_simulator
65+
assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
66+
assert isinstance(simulator.sample(5), dict)
67+
68+
samples = simulator.sample(15)
69+
expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"]
70+
71+
assert set(samples.keys()) == set(expected_keys)
72+
73+
# hypers node
74+
assert np.shape(samples["hyper_mean"]) == (15, 5, 1)
75+
assert np.shape(samples["hyper_std"]) == (15, 5, 1)
76+
77+
# locals node
78+
assert np.shape(samples["local_mean"]) == (15, 5, 6, 1)
79+
80+
# shared node
81+
assert np.shape(samples["shared_std"]) == (15, 1)
82+
83+
# y node
84+
assert np.shape(samples["y"]) == (15, 5, 6, 10, 1)
85+
86+
87+
def test_irt_simulator(irt_simulator):
88+
# schools
89+
# / \
90+
# exams students
91+
# | |
92+
# questions |
93+
# \ /
94+
# observations
95+
96+
assert isinstance(irt_simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
97+
assert isinstance(irt_simulator.sample(5), dict)
98+
99+
samples = irt_simulator.sample(22)
100+
expected_keys = [
101+
"mu_exam_mean",
102+
"sigma_exam_mean",
103+
"mu_exam_std",
104+
"sigma_exam_std",
105+
"exam_mean",
106+
"exam_std",
107+
"question_difficulty",
108+
"student_ability",
109+
"obs",
110+
"num_exams", # np.random.randint(2, 4)
111+
"num_questions", # np.random.randint(10, 21)
112+
"num_students", # np.random.randint(100, 201)
113+
]
114+
115+
assert set(samples.keys()) == set(expected_keys)
116+
117+
# schools node
118+
assert np.shape(samples["mu_exam_mean"]) == (22, 1)
119+
assert np.shape(samples["sigma_exam_mean"]) == (22, 1)
120+
assert np.shape(samples["mu_exam_std"]) == (22, 1)
121+
assert np.shape(samples["sigma_exam_std"]) == (22, 1)
122+
123+
# exams node
124+
assert np.shape(samples["exam_mean"]) == (22, samples["num_exams"], 1)
125+
assert np.shape(samples["exam_std"]) == (22, samples["num_exams"], 1)
126+
127+
# questions node
128+
assert np.shape(samples["question_difficulty"]) == (22, samples["num_exams"], samples["num_questions"], 1)
129+
130+
# students node
131+
assert np.shape(samples["student_ability"]) == (22, samples["num_students"], 1)
132+
133+
# observations node
134+
assert np.shape(samples["obs"]) == (
135+
22,
136+
samples["num_exams"],
137+
samples["num_students"],
138+
samples["num_questions"],
139+
1,
140+
)

0 commit comments

Comments
 (0)