Skip to content

Commit 210b7c5

Browse files
committed
Abstract special behavior of NUTS divergences in ProgressBar
Every step sampler can now decide whether sampling is failing or not by setting "failing" in the returned update dict
1 parent 7823727 commit 210b7c5

File tree

4 files changed

+34
-13
lines changed

4 files changed

+34
-13
lines changed

pymc/step_methods/hmc/base_hmc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(
184184

185185
self._step_rand = step_rand
186186
self._num_divs_sample = 0
187+
self.divergences = 0
187188

188189
@abstractmethod
189190
def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
@@ -266,11 +267,15 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
266267
divergence_info=info_store,
267268
)
268269

270+
diverging = bool(hmc_step.divergence_info)
271+
if not self.tune:
272+
self.divergences += diverging
269273
self.iter_count += 1
270274

271275
stats: dict[str, Any] = {
272276
"tune": self.tune,
273-
"diverging": bool(hmc_step.divergence_info),
277+
"diverging": diverging,
278+
"divergences": self.divergences,
274279
"perf_counter_diff": perf_end - perf_start,
275280
"process_time_diff": process_end - process_start,
276281
"perf_counter_start": perf_start,
@@ -288,6 +293,8 @@ def reset_tuning(self, start=None):
288293
self.reset(start=None)
289294

290295
def reset(self, start=None):
296+
self.iter_count = 0
297+
self.divergences = 0
291298
self.tune = True
292299
self.potential.reset()
293300

pymc/step_methods/hmc/hmc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class HamiltonianMC(BaseHMC):
5555
"accept": (np.float64, []),
5656
"diverging": (bool, []),
5757
"energy_error": (np.float64, []),
58+
"divergences": (np.int64, []),
5859
"energy": (np.float64, []),
5960
"path_length": (np.float64, []),
6061
"accepted": (bool, []),

pymc/step_methods/hmc/nuts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class NUTS(BaseHMC):
115115
"step_size_bar": (np.float64, []),
116116
"tree_size": (np.float64, []),
117117
"diverging": (bool, []),
118+
"divergences": (np.int64, []),
118119
"energy_error": (np.float64, []),
119120
"energy": (np.float64, []),
120121
"max_energy_error": (np.float64, []),
@@ -250,7 +251,9 @@ def _progressbar_config(n_chains=1):
250251
@staticmethod
251252
def _make_update_stats_functions():
252253
def update_stats(stats):
253-
return {key: stats[key] for key in ("diverging", "step_size", "tree_size")}
254+
return {key: stats[key] for key in ("divergences", "step_size", "tree_size")} | {
255+
"failing": stats["divergences"] > 0
256+
}
254257

255258
return (update_stats,)
256259

pymc/util.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -657,22 +657,22 @@ class DivergenceBarColumn(BarColumn):
657657
def __init__(self, *args, diverging_color="red", **kwargs):
658658
from matplotlib.colors import to_rgb
659659

660-
self.diverging_color = diverging_color
661-
self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)]
660+
self.failing_color = diverging_color
661+
self.failing_rgb = [int(x * 255) for x in to_rgb(self.failing_color)]
662662

663663
super().__init__(*args, **kwargs)
664664

665-
self.non_diverging_style = self.complete_style
666-
self.non_diverging_finished_style = self.finished_style
665+
self.default_complete_style = self.complete_style
666+
self.default_finished_style = self.finished_style
667667

668668
def callbacks(self, task: "Task"):
669-
divergences = task.fields.get("divergences", 0)
670-
if isinstance(divergences, float | int) and divergences > 0:
671-
self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
672-
self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
669+
if task.fields["failing"]:
670+
self.complete_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb))
671+
self.finished_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb))
673672
else:
674-
self.complete_style = self.non_diverging_style
675-
self.finished_style = self.non_diverging_finished_style
673+
# Recovered from failing yay
674+
self.complete_style = self.default_complete_style
675+
self.finished_style = self.default_finished_style
676676

677677

678678
class ProgressBarManager:
@@ -794,6 +794,7 @@ def _initialize_tasks(self):
794794
chain_idx=0,
795795
sampling_speed=0,
796796
speed_unit="draws/s",
797+
failing=False,
797798
**{stat: value[0] for stat, value in self.progress_stats.items()},
798799
)
799800
]
@@ -808,6 +809,7 @@ def _initialize_tasks(self):
808809
chain_idx=chain_idx,
809810
sampling_speed=0,
810811
speed_unit="draws/s",
812+
failing=False,
811813
**{stat: value[chain_idx] for stat, value in self.progress_stats.items()},
812814
)
813815
for chain_idx in range(self.chains)
@@ -829,16 +831,22 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
829831
self.divergences += 1
830832

831833
if self.full_stats:
834+
failing = False
835+
all_step_stats = {}
836+
832837
# TODO: Index by chain already?
833838
chain_progress_stats = [
834839
update_states_fn(step_stats)
835840
for update_states_fn, step_stats in zip(
836841
self.update_stats_functions, stats, strict=True
837842
)
838843
]
839-
all_step_stats = {}
840844
for step_stats in chain_progress_stats:
841845
for key, val in step_stats.items():
846+
if key == "failing":
847+
failing |= val
848+
continue
849+
842850
if key in all_step_stats:
843851
continue
844852
count = (
@@ -849,6 +857,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
849857
all_step_stats[key] = val
850858

851859
else:
860+
failing = False
852861
all_step_stats = {}
853862

854863
# more_updates = (
@@ -863,6 +872,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
863872
draws=draw,
864873
sampling_speed=speed,
865874
speed_unit=unit,
875+
failing=failing,
866876
**all_step_stats,
867877
)
868878

0 commit comments

Comments
 (0)