Skip to content

Commit 3297370

Browse files
Track divergences in HMC stats
1 parent 6748186 commit 3297370

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

pymc/step_methods/hmc/base_hmc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
275275
stats: dict[str, Any] = {
276276
"tune": self.tune,
277277
"diverging": diverging,
278+
"divergences": self.divergences,
278279
"perf_counter_diff": perf_end - perf_start,
279280
"process_time_diff": process_end - process_start,
280281
"perf_counter_start": perf_start,

pymc/step_methods/hmc/hmc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,12 @@ def _progressbar_config(n_chains=1):
221221

222222
return columns, stats
223223

224-
def _make_progressbar_update_functions(self):
224+
@staticmethod
225+
def _make_progressbar_update_functions():
225226
def update_stats(stats):
226-
divergences = self.divergences
227227
return {key: stats[key] for key in ("n_steps",)} | {
228-
"failing": divergences > 0,
229-
"divergences": divergences,
228+
"failing": stats["divergences"] > 0,
229+
"divergences": stats["divergences"],
230230
}
231231

232232
return (update_stats,)

pymc/step_methods/hmc/nuts.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class NUTS(BaseHMC):
115115
"step_size_bar": (np.float64, []),
116116
"tree_size": (np.float64, []),
117117
"diverging": (bool, []),
118+
"divergences": (int, []),
118119
"energy_error": (np.float64, []),
119120
"energy": (np.float64, []),
120121
"max_energy_error": (np.float64, []),
@@ -247,12 +248,12 @@ def _progressbar_config(n_chains=1):
247248

248249
return columns, stats
249250

250-
def _make_update_stats_functions(self):
251+
@staticmethod
252+
def _make_update_stats_functions():
251253
def update_stats(stats):
252-
divergences = self.divergences
253254
return {key: stats[key] for key in ("step_size", "tree_size")} | {
254-
"failing": divergences > 0,
255-
"divergences": divergences,
255+
"failing": stats["divergences"] > 0,
256+
"divergences": stats["divergences"],
256257
}
257258

258259
return (update_stats,)

0 commit comments

Comments
 (0)