@@ -657,22 +657,22 @@ class DivergenceBarColumn(BarColumn):
657
657
def __init__ (self , * args , diverging_color = "red" , ** kwargs ):
658
658
from matplotlib .colors import to_rgb
659
659
660
- self .diverging_color = diverging_color
661
- self .diverging_rgb = [int (x * 255 ) for x in to_rgb (self .diverging_color )]
660
+ self .failing_color = diverging_color
661
+ self .failing_rgb = [int (x * 255 ) for x in to_rgb (self .failing_color )]
662
662
663
663
super ().__init__ (* args , ** kwargs )
664
664
665
- self .non_diverging_style = self .complete_style
666
- self .non_diverging_finished_style = self .finished_style
665
+ self .default_complete_style = self .complete_style
666
+ self .default_finished_style = self .finished_style
667
667
668
668
def callbacks (self , task : "Task" ):
669
- divergences = task .fields .get ("divergences" , 0 )
670
- if isinstance (divergences , float | int ) and divergences > 0 :
671
- self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
672
- self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
669
+ if task .fields ["failing" ]:
670
+ self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
671
+ self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
673
672
else :
674
- self .complete_style = self .non_diverging_style
675
- self .finished_style = self .non_diverging_finished_style
673
+ # Recovered from failing yay
674
+ self .complete_style = self .default_complete_style
675
+ self .finished_style = self .default_finished_style
676
676
677
677
678
678
class ProgressBarManager :
@@ -794,6 +794,7 @@ def _initialize_tasks(self):
794
794
chain_idx = 0 ,
795
795
sampling_speed = 0 ,
796
796
speed_unit = "draws/s" ,
797
+ failing = False ,
797
798
** {stat : value [0 ] for stat , value in self .progress_stats .items ()},
798
799
)
799
800
]
@@ -808,6 +809,7 @@ def _initialize_tasks(self):
808
809
chain_idx = chain_idx ,
809
810
sampling_speed = 0 ,
810
811
speed_unit = "draws/s" ,
812
+ failing = False ,
811
813
** {stat : value [chain_idx ] for stat , value in self .progress_stats .items ()},
812
814
)
813
815
for chain_idx in range (self .chains )
@@ -829,16 +831,22 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
829
831
self .divergences += 1
830
832
831
833
if self .full_stats :
834
+ failing = False
835
+ all_step_stats = {}
836
+
832
837
# TODO: Index by chain already?
833
838
chain_progress_stats = [
834
839
update_states_fn (step_stats )
835
840
for update_states_fn , step_stats in zip (
836
841
self .update_stats_functions , stats , strict = True
837
842
)
838
843
]
839
- all_step_stats = {}
840
844
for step_stats in chain_progress_stats :
841
845
for key , val in step_stats .items ():
846
+ if key == "failing" :
847
+ failing |= val
848
+ continue
849
+
842
850
if key in all_step_stats :
843
851
continue
844
852
count = (
@@ -849,6 +857,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
849
857
all_step_stats [key ] = val
850
858
851
859
else :
860
+ failing = False
852
861
all_step_stats = {}
853
862
854
863
# more_updates = (
@@ -863,6 +872,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
863
872
draws = draw ,
864
873
sampling_speed = speed ,
865
874
speed_unit = unit ,
875
+ failing = failing ,
866
876
** all_step_stats ,
867
877
)
868
878
0 commit comments