@@ -69,7 +69,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
69
69
if not parent_nodes :
70
70
# root node: generate independent samples
71
71
node_samples = [
72
- {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sampling_fn (sampling_fn , {})
72
+ {"__batch_idx" : batch_idx , f"__{ node } _idx" : i } | self ._call_sample_fn (sampling_fn , {})
73
73
for i in range (1 , reps + 1 )
74
74
]
75
75
else :
@@ -86,7 +86,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]:
86
86
[
87
87
index_entries
88
88
| {f"__{ node } _idx" : i }
89
- | self ._call_sampling_fn (sampling_fn , sampling_fn_input )
89
+ | self ._call_sample_fn (sampling_fn , sampling_fn_input )
90
90
for i in range (1 , reps + 1 )
91
91
]
92
92
)
@@ -169,12 +169,12 @@ def _output_shape(self, samples, variable):
169
169
170
170
return tuple (output_shape )
171
171
172
- def _call_sampling_fn (self , sampling_fn , args ):
173
- signature = inspect .signature (sampling_fn )
172
+ def _call_sample_fn (self , sample_fn , args ):
173
+ signature = inspect .signature (sample_fn )
174
174
fn_args = signature .parameters
175
175
accepted_args = {k : v for k , v in args .items () if k in fn_args }
176
176
177
- return sampling_fn (** accepted_args )
177
+ return sample_fn (** accepted_args )
178
178
179
179
180
180
def sorted_ancestors (graph , node ):
0 commit comments