1- use arrow2:: array:: { MutableFixedSizeListArray , TryPush } ;
1+ use arrow2:: array:: { MutableFixedSizeListArray , MutableUtf8Array , TryPush } ;
22#[ cfg( feature = "arrow" ) ]
33use arrow2:: {
44 array:: { MutableArray , MutableBooleanArray , MutablePrimitiveArray , StructArray } ,
@@ -37,7 +37,9 @@ pub type Result<T> = std::result::Result<T, NutsError>;
3737/// failed)
3838#[ derive( Debug ) ]
3939pub struct DivergenceInfo {
40+ pub start_momentum : Option < Box < [ f64 ] > > ,
4041 pub start_location : Option < Box < [ f64 ] > > ,
42+ pub start_gradient : Option < Box < [ f64 ] > > ,
4143 pub end_location : Option < Box < [ f64 ] > > ,
4244 pub energy_error : Option < f64 > ,
4345 pub end_idx_in_trajectory : Option < i64 > ,
@@ -152,6 +154,9 @@ pub trait State: Clone + Debug {
152154 /// Write the gradient stored in the state to a different location
153155 fn write_gradient ( & self , out : & mut [ f64 ] ) ;
154156
157+ /// Write the momentum in the state to a different location
158+ fn write_momentum ( & self , out : & mut [ f64 ] ) ;
159+
155160 /// Compute the termination criterion for NUTS
156161 fn is_turning ( & self , other : & Self ) -> bool ;
157162
@@ -523,6 +528,11 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
523528 hamiltonian : <H :: Stats as ArrowRow >:: Builder ,
524529 adapt : <A :: Stats as ArrowRow >:: Builder ,
525530 diverging : MutableBooleanArray ,
531+ divergence_start : Option < MutableFixedSizeListArray < MutablePrimitiveArray < f64 > > > ,
532+ divergence_start_grad : Option < MutableFixedSizeListArray < MutablePrimitiveArray < f64 > > > ,
533+ divergence_end : Option < MutableFixedSizeListArray < MutablePrimitiveArray < f64 > > > ,
534+ divergence_momentum : Option < MutableFixedSizeListArray < MutablePrimitiveArray < f64 > > > ,
535+ divergence_msg : Option < MutableUtf8Array < i64 > > ,
526536}
527537
528538#[ cfg( feature = "arrow" ) ]
@@ -548,6 +558,40 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
548558 None
549559 } ;
550560
561+ let ( div_start, div_start_grad, div_end, div_mom, div_msg) = if settings. store_divergences {
562+ let start_location_prim = MutablePrimitiveArray :: new ( ) ;
563+ let start_location_list =
564+ MutableFixedSizeListArray :: new_with_field ( start_location_prim, "item" , false , dim) ;
565+
566+ let start_grad_prim = MutablePrimitiveArray :: new ( ) ;
567+ let start_grad_list =
568+ MutableFixedSizeListArray :: new_with_field ( start_grad_prim, "item" , false , dim) ;
569+
570+ let end_location_prim = MutablePrimitiveArray :: new ( ) ;
571+ let end_location_list =
572+ MutableFixedSizeListArray :: new_with_field ( end_location_prim, "item" , false , dim) ;
573+
574+ let momentum_location_prim = MutablePrimitiveArray :: new ( ) ;
575+ let momentum_location_list = MutableFixedSizeListArray :: new_with_field (
576+ momentum_location_prim,
577+ "item" ,
578+ false ,
579+ dim,
580+ ) ;
581+
582+ let msg_list = MutableUtf8Array :: new ( ) ;
583+
584+ (
585+ Some ( start_location_list) ,
586+ Some ( start_grad_list) ,
587+ Some ( end_location_list) ,
588+ Some ( momentum_location_list) ,
589+ Some ( msg_list) ,
590+ )
591+ } else {
592+ ( None , None , None , None , None )
593+ } ;
594+
551595 Self {
552596 depth : MutablePrimitiveArray :: with_capacity ( capacity) ,
553597 maxdepth_reached : MutableBooleanArray :: with_capacity ( capacity) ,
@@ -561,6 +605,11 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
561605 hamiltonian : <H :: Stats as ArrowRow >:: new_builder ( dim, settings) ,
562606 adapt : <A :: Stats as ArrowRow >:: new_builder ( dim, settings) ,
563607 diverging : MutableBooleanArray :: with_capacity ( capacity) ,
608+ divergence_start : div_start,
609+ divergence_start_grad : div_start_grad,
610+ divergence_end : div_end,
611+ divergence_momentum : div_mom,
612+ divergence_msg : div_msg,
564613 }
565614 }
566615}
@@ -601,6 +650,58 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
601650 . unwrap ( ) ;
602651 }
603652
653+ let info_option = value. divergence_info ( ) ;
654+ if let Some ( div_start) = self . divergence_start . as_mut ( ) {
655+ div_start
656+ . try_push ( info_option. and_then ( |info| {
657+ info. start_location
658+ . as_ref ( )
659+ . map ( |vals| vals. iter ( ) . map ( |& x| Some ( x) ) )
660+ } ) )
661+ . unwrap ( ) ;
662+ }
663+
664+ let info_option = value. divergence_info ( ) ;
665+ if let Some ( div_grad) = self . divergence_start_grad . as_mut ( ) {
666+ div_grad
667+ . try_push ( info_option. and_then ( |info| {
668+ info. start_gradient
669+ . as_ref ( )
670+ . map ( |vals| vals. iter ( ) . map ( |& x| Some ( x) ) )
671+ } ) )
672+ . unwrap ( ) ;
673+ }
674+
675+ if let Some ( div_end) = self . divergence_end . as_mut ( ) {
676+ div_end
677+ . try_push ( info_option. and_then ( |info| {
678+ info. end_location
679+ . as_ref ( )
680+ . map ( |vals| vals. iter ( ) . map ( |& x| Some ( x) ) )
681+ } ) )
682+ . unwrap ( ) ;
683+ }
684+
685+ if let Some ( div_mom) = self . divergence_momentum . as_mut ( ) {
686+ div_mom
687+ . try_push ( info_option. and_then ( |info| {
688+ info. start_momentum
689+ . as_ref ( )
690+ . map ( |vals| vals. iter ( ) . map ( |& x| Some ( x) ) )
691+ } ) )
692+ . unwrap ( ) ;
693+ }
694+
695+ if let Some ( div_msg) = self . divergence_msg . as_mut ( ) {
696+ div_msg
697+ . try_push ( info_option. and_then ( |info| {
698+ info. logp_function_error
699+ . as_ref ( )
700+ . map ( |err| format ! ( "{}" , err) )
701+ } ) )
702+ . unwrap ( ) ;
703+ }
704+
604705 self . hamiltonian . append_value ( & value. potential_stats ) ;
605706 self . adapt . append_value ( & value. strategy_stats ) ;
606707 }
@@ -655,6 +756,51 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
655756 arrays. push ( unconstrained. as_box ( ) ) ;
656757 }
657758
759+ if let Some ( mut div_start) = self . divergence_start . take ( ) {
760+ fields. push ( Field :: new (
761+ "divergence_start" ,
762+ div_start. data_type ( ) . clone ( ) ,
763+ true ,
764+ ) ) ;
765+ arrays. push ( div_start. as_box ( ) ) ;
766+ }
767+
768+ if let Some ( mut div_start_grad) = self . divergence_start_grad . take ( ) {
769+ fields. push ( Field :: new (
770+ "divergence_start_gradient" ,
771+ div_start_grad. data_type ( ) . clone ( ) ,
772+ true ,
773+ ) ) ;
774+ arrays. push ( div_start_grad. as_box ( ) ) ;
775+ }
776+
777+ if let Some ( mut div_end) = self . divergence_end . take ( ) {
778+ fields. push ( Field :: new (
779+ "divergence_end" ,
780+ div_end. data_type ( ) . clone ( ) ,
781+ true ,
782+ ) ) ;
783+ arrays. push ( div_end. as_box ( ) ) ;
784+ }
785+
786+ if let Some ( mut div_mom) = self . divergence_momentum . take ( ) {
787+ fields. push ( Field :: new (
788+ "divergence_momentum" ,
789+ div_mom. data_type ( ) . clone ( ) ,
790+ true ,
791+ ) ) ;
792+ arrays. push ( div_mom. as_box ( ) ) ;
793+ }
794+
795+ if let Some ( mut div_msg) = self . divergence_msg . take ( ) {
796+ fields. push ( Field :: new (
797+ "divergence_message" ,
798+ div_msg. data_type ( ) . clone ( ) ,
799+ true ,
800+ ) ) ;
801+ arrays. push ( div_msg. as_box ( ) ) ;
802+ }
803+
658804 Some ( StructArray :: new ( DataType :: Struct ( fields) , arrays, None ) )
659805 }
660806}
0 commit comments