Skip to content

Commit 4035c16

Browse files
rename examples in graphical_simulator.example_simulators
1 parent 55b6dfd commit 4035c16

File tree

7 files changed

+39
-84
lines changed

7 files changed

+39
-84
lines changed
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from .single_level import single_level
2-
from .two_level import two_level
3-
from .two_level_repeated_roots import two_level_repeated_roots
4-
from .irt import irt
1+
from .single_level_simulator import single_level_simulator
2+
from .two_level_simulator import two_level_simulator
3+
from .crossed_design_irt_simulator import crossed_design_irt_simulator

bayesflow/experimental/graphical_simulator/example_simulators/irt.py renamed to bayesflow/experimental/graphical_simulator/example_simulators/crossed_design_irt_simulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from ..graphical_simulator import GraphicalSimulator
44

55

6-
def irt():
6+
def crossed_design_irt_simulator():
77
r"""
8-
Item Response Theory (IRT) model implemented as a graphical simultor.
8+
Item Response Theory (IRT) model implemented as a graphical simulator.
99
1010
schools
1111
/ \

bayesflow/experimental/graphical_simulator/example_simulators/single_level.py renamed to bayesflow/experimental/graphical_simulator/example_simulators/single_level_simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ..graphical_simulator import GraphicalSimulator
44

55

6-
def single_level():
6+
def single_level_simulator():
77
"""
88
Simple single-level simulator that implements the same model as in
99
https://bayesflow.org/main/_examples/Linear_Regression_Starter.html

bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

bayesflow/experimental/graphical_simulator/example_simulators/two_level.py renamed to bayesflow/experimental/graphical_simulator/example_simulators/two_level_simulator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ..graphical_simulator import GraphicalSimulator
44

55

6-
def two_level():
6+
def two_level_simulator(repeated_roots=False):
77
r"""
88
Simple hierarchical model with two levels of parameters: hyperparameters
99
and local parameters, along with a shared parameter:
@@ -15,6 +15,10 @@ def two_level():
1515
\ /
1616
y
1717
18+
Parameters
19+
----------
20+
repeated_roots : bool, default false.
21+
1822
"""
1923

2024
def sample_hypers():
@@ -39,7 +43,11 @@ def sample_y(local_mean, shared_std):
3943
return {"y": y}
4044

4145
simulator = GraphicalSimulator()
42-
simulator.add_node("hypers", sample_fn=sample_hypers)
46+
47+
if not repeated_roots:
48+
simulator.add_node("hypers", sample_fn=sample_hypers)
49+
else:
50+
simulator.add_node("hypers", sample_fn=sample_hypers, reps=5)
4351

4452
simulator.add_node(
4553
"locals",

tests/test_simulators/conftest.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,27 +251,27 @@ def simulator(request):
251251

252252
@pytest.fixture()
253253
def single_level_simulator():
254-
from bayesflow.experimental.graphical_simulator.example_simulators import single_level
254+
from bayesflow.experimental.graphical_simulator.example_simulators import single_level_simulator
255255

256-
return single_level()
256+
return single_level_simulator()
257257

258258

259259
@pytest.fixture()
260260
def two_level_simulator():
261-
from bayesflow.experimental.graphical_simulator.example_simulators import two_level
261+
from bayesflow.experimental.graphical_simulator.example_simulators import two_level_simulator
262262

263-
return two_level()
263+
return two_level_simulator()
264264

265265

266266
@pytest.fixture()
267267
def two_level_repeated_roots_simulator():
268-
from bayesflow.experimental.graphical_simulator.example_simulators import two_level_repeated_roots
268+
from bayesflow.experimental.graphical_simulator.example_simulators import two_level_simulator
269269

270-
return two_level_repeated_roots()
270+
return two_level_simulator(repeated_roots=True)
271271

272272

273273
@pytest.fixture()
274-
def irt_simulator():
275-
from bayesflow.experimental.graphical_simulator.example_simulators import irt
274+
def crossed_design_irt_simulator():
275+
from bayesflow.experimental.graphical_simulator.example_simulators import crossed_design_irt_simulator
276276

277-
return irt()
277+
return crossed_design_irt_simulator()

tests/test_simulators/test_graphical_simulator.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
def test_single_level_simulator(single_level_simulator):
77
# prior -> likelihood
8-
assert isinstance(single_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
9-
assert isinstance(single_level_simulator.sample(5), dict)
108

11-
samples = single_level_simulator.sample(12)
9+
simulator = single_level_simulator
10+
assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
11+
assert isinstance(simulator.sample(5), dict)
12+
13+
samples = simulator.sample(12)
1214
expected_keys = ["N", "beta", "sigma", "x", "y"]
1315

1416
assert set(samples.keys()) == set(expected_keys)
@@ -31,10 +33,11 @@ def test_two_level_simulator(two_level_simulator):
3133
# \ /
3234
# y
3335

34-
assert isinstance(two_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
35-
assert isinstance(two_level_simulator.sample(5), dict)
36+
simulator = two_level_simulator
37+
assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
38+
assert isinstance(simulator.sample(5), dict)
3639

37-
samples = two_level_simulator.sample(15)
40+
samples = simulator.sample(15)
3841
expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"]
3942

4043
assert set(samples.keys()) == set(expected_keys)
@@ -84,7 +87,7 @@ def test_two_level_repeated_roots_simulator(two_level_repeated_roots_simulator):
8487
assert np.shape(samples["y"]) == (15, 5, 6, 10, 1)
8588

8689

87-
def test_irt_simulator(irt_simulator):
90+
def test_crossed_design_irt_simulator(crossed_design_irt_simulator):
8891
# schools
8992
# / \
9093
# exams students
@@ -93,10 +96,11 @@ def test_irt_simulator(irt_simulator):
9396
# \ /
9497
# observations
9598

96-
assert isinstance(irt_simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
97-
assert isinstance(irt_simulator.sample(5), dict)
99+
simulator = crossed_design_irt_simulator
100+
assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator)
101+
assert isinstance(simulator.sample(5), dict)
98102

99-
samples = irt_simulator.sample(22)
103+
samples = simulator.sample(22)
100104
expected_keys = [
101105
"mu_exam_mean",
102106
"sigma_exam_mean",

0 commit comments

Comments
 (0)