Skip to content

Commit d736a9b

Browse files
committed
feat: Add more sample stats about divergences
1 parent fc83b9f commit d736a9b

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

src/cpu_potential.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
116116
let div_info = DivergenceInfo {
117117
logp_function_error: Some(Box::new(logp_error)),
118118
start_location: Some(start.q.clone()),
119+
start_gradient: Some(start.grad.clone()),
120+
start_momentum: Some(start.p.clone()),
119121
end_location: None,
120122
start_idx_in_trajectory: Some(start.idx_in_trajectory),
121123
end_idx_in_trajectory: None,
@@ -142,7 +144,9 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
142144
let divergence_info = DivergenceInfo {
143145
logp_function_error: None,
144146
start_location: Some(start.q.clone()),
147+
start_gradient: Some(start.grad.clone()),
145148
end_location: Some(out.q.clone()),
149+
start_momentum: Some(out.p.clone()),
146150
start_idx_in_trajectory: Some(start.index_in_trajectory()),
147151
end_idx_in_trajectory: Some(out.index_in_trajectory()),
148152
energy_error: Some(energy_error),

src/cpu_state.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ impl crate::nuts::State for State {
243243
out.copy_from_slice(&self.grad);
244244
}
245245

246+
fn write_momentum(&self, out: &mut [f64]) {
247+
out.copy_from_slice(&self.p);
248+
}
249+
246250
fn energy(&self) -> f64 {
247251
self.kinetic_energy + self.potential_energy
248252
}

src/nuts.rs

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use arrow2::array::{MutableFixedSizeListArray, TryPush};
1+
use arrow2::array::{MutableFixedSizeListArray, MutableUtf8Array, TryPush};
22
#[cfg(feature = "arrow")]
33
use arrow2::{
44
array::{MutableArray, MutableBooleanArray, MutablePrimitiveArray, StructArray},
@@ -37,7 +37,9 @@ pub type Result<T> = std::result::Result<T, NutsError>;
3737
/// failed)
3838
#[derive(Debug)]
3939
pub struct DivergenceInfo {
40+
pub start_momentum: Option<Box<[f64]>>,
4041
pub start_location: Option<Box<[f64]>>,
42+
pub start_gradient: Option<Box<[f64]>>,
4143
pub end_location: Option<Box<[f64]>>,
4244
pub energy_error: Option<f64>,
4345
pub end_idx_in_trajectory: Option<i64>,
@@ -152,6 +154,9 @@ pub trait State: Clone + Debug {
152154
/// Write the gradient stored in the state to a different location
153155
fn write_gradient(&self, out: &mut [f64]);
154156

157+
/// Write the momentum in the state to a different location
158+
fn write_momentum(&self, out: &mut [f64]);
159+
155160
/// Compute the termination criterion for NUTS
156161
fn is_turning(&self, other: &Self) -> bool;
157162

@@ -523,6 +528,11 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
523528
hamiltonian: <H::Stats as ArrowRow>::Builder,
524529
adapt: <A::Stats as ArrowRow>::Builder,
525530
diverging: MutableBooleanArray,
531+
divergence_start: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
532+
divergence_start_grad: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
533+
divergence_end: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
534+
divergence_momentum: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
535+
divergence_msg: Option<MutableUtf8Array<i64>>,
526536
}
527537

528538
#[cfg(feature = "arrow")]
@@ -548,6 +558,40 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
548558
None
549559
};
550560

561+
let (div_start, div_start_grad, div_end, div_mom, div_msg) = if settings.store_divergences {
562+
let start_location_prim = MutablePrimitiveArray::new();
563+
let start_location_list =
564+
MutableFixedSizeListArray::new_with_field(start_location_prim, "item", false, dim);
565+
566+
let start_grad_prim = MutablePrimitiveArray::new();
567+
let start_grad_list =
568+
MutableFixedSizeListArray::new_with_field(start_grad_prim, "item", false, dim);
569+
570+
let end_location_prim = MutablePrimitiveArray::new();
571+
let end_location_list =
572+
MutableFixedSizeListArray::new_with_field(end_location_prim, "item", false, dim);
573+
574+
let momentum_location_prim = MutablePrimitiveArray::new();
575+
let momentum_location_list = MutableFixedSizeListArray::new_with_field(
576+
momentum_location_prim,
577+
"item",
578+
false,
579+
dim,
580+
);
581+
582+
let msg_list = MutableUtf8Array::new();
583+
584+
(
585+
Some(start_location_list),
586+
Some(start_grad_list),
587+
Some(end_location_list),
588+
Some(momentum_location_list),
589+
Some(msg_list),
590+
)
591+
} else {
592+
(None, None, None, None, None)
593+
};
594+
551595
Self {
552596
depth: MutablePrimitiveArray::with_capacity(capacity),
553597
maxdepth_reached: MutableBooleanArray::with_capacity(capacity),
@@ -561,6 +605,11 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
561605
hamiltonian: <H::Stats as ArrowRow>::new_builder(dim, settings),
562606
adapt: <A::Stats as ArrowRow>::new_builder(dim, settings),
563607
diverging: MutableBooleanArray::with_capacity(capacity),
608+
divergence_start: div_start,
609+
divergence_start_grad: div_start_grad,
610+
divergence_end: div_end,
611+
divergence_momentum: div_mom,
612+
divergence_msg: div_msg,
564613
}
565614
}
566615
}
@@ -601,6 +650,58 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
601650
.unwrap();
602651
}
603652

653+
let info_option = value.divergence_info();
654+
if let Some(div_start) = self.divergence_start.as_mut() {
655+
div_start
656+
.try_push(info_option.and_then(|info| {
657+
info.start_location
658+
.as_ref()
659+
.map(|vals| vals.iter().map(|&x| Some(x)))
660+
}))
661+
.unwrap();
662+
}
663+
664+
let info_option = value.divergence_info();
665+
if let Some(div_grad) = self.divergence_start_grad.as_mut() {
666+
div_grad
667+
.try_push(info_option.and_then(|info| {
668+
info.start_gradient
669+
.as_ref()
670+
.map(|vals| vals.iter().map(|&x| Some(x)))
671+
}))
672+
.unwrap();
673+
}
674+
675+
if let Some(div_end) = self.divergence_end.as_mut() {
676+
div_end
677+
.try_push(info_option.and_then(|info| {
678+
info.end_location
679+
.as_ref()
680+
.map(|vals| vals.iter().map(|&x| Some(x)))
681+
}))
682+
.unwrap();
683+
}
684+
685+
if let Some(div_mom) = self.divergence_momentum.as_mut() {
686+
div_mom
687+
.try_push(info_option.and_then(|info| {
688+
info.start_momentum
689+
.as_ref()
690+
.map(|vals| vals.iter().map(|&x| Some(x)))
691+
}))
692+
.unwrap();
693+
}
694+
695+
if let Some(div_msg) = self.divergence_msg.as_mut() {
696+
div_msg
697+
.try_push(info_option.and_then(|info| {
698+
info.logp_function_error
699+
.as_ref()
700+
.map(|err| format!("{}", err))
701+
}))
702+
.unwrap();
703+
}
704+
604705
self.hamiltonian.append_value(&value.potential_stats);
605706
self.adapt.append_value(&value.strategy_stats);
606707
}
@@ -655,6 +756,51 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
655756
arrays.push(unconstrained.as_box());
656757
}
657758

759+
if let Some(mut div_start) = self.divergence_start.take() {
760+
fields.push(Field::new(
761+
"divergence_start",
762+
div_start.data_type().clone(),
763+
true,
764+
));
765+
arrays.push(div_start.as_box());
766+
}
767+
768+
if let Some(mut div_start_grad) = self.divergence_start_grad.take() {
769+
fields.push(Field::new(
770+
"divergence_start_gradient",
771+
div_start_grad.data_type().clone(),
772+
true,
773+
));
774+
arrays.push(div_start_grad.as_box());
775+
}
776+
777+
if let Some(mut div_end) = self.divergence_end.take() {
778+
fields.push(Field::new(
779+
"divergence_end",
780+
div_end.data_type().clone(),
781+
true,
782+
));
783+
arrays.push(div_end.as_box());
784+
}
785+
786+
if let Some(mut div_mom) = self.divergence_momentum.take() {
787+
fields.push(Field::new(
788+
"divergence_momentum",
789+
div_mom.data_type().clone(),
790+
true,
791+
));
792+
arrays.push(div_mom.as_box());
793+
}
794+
795+
if let Some(mut div_msg) = self.divergence_msg.take() {
796+
fields.push(Field::new(
797+
"divergence_message",
798+
div_msg.data_type().clone(),
799+
true,
800+
));
801+
arrays.push(div_msg.as_box());
802+
}
803+
658804
Some(StructArray::new(DataType::Struct(fields), arrays, None))
659805
}
660806
}

0 commit comments

Comments
 (0)