Skip to content

Commit fc83b9f

Browse files
committed
fix: Handle initial zero gradients better
1 parent d545424 commit fc83b9f

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

src/adapt_strategy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
352352
state
353353
.grad
354354
.iter()
355-
.map(|&grad| grad.abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT))
355+
.map(|&grad| grad.abs().clamp(LOWER_LIMIT, UPPER_LIMIT).recip())
356356
.map(|var| if var.is_finite() { Some(var) } else { Some(1.) }),
357357
);
358358
}

src/cpu_potential.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,16 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
165165
}
166166
self.update_potential_gradient(&mut state)
167167
.map_err(|e| NutsError::LogpFailure(Box::new(e)))?;
168-
Ok(state)
168+
if state
169+
.grad
170+
.iter()
171+
.cloned()
172+
.any(|val| (val == 0f64) | !val.is_finite())
173+
{
174+
Err(NutsError::BadInitGrad())
175+
} else {
176+
Ok(state)
177+
}
169178
}
170179

171180
fn randomize_momentum<R: rand::Rng + ?Sized>(&self, state: &mut Self::State, rng: &mut R) {

src/nuts.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ use crate::math::logaddexp;
1313
#[cfg(feature = "arrow")]
1414
use crate::SamplerArgs;
1515

16+
#[non_exhaustive]
1617
#[derive(Error, Debug)]
1718
pub enum NutsError {
1819
#[error("Logp function returned error: {0}")]
1920
LogpFailure(Box<dyn std::error::Error + Send + Sync>),
2021

2122
#[error("Could not serialize sample stats")]
2223
SerializeFailure(),
24+
25+
#[error("Could not initialize state because of bad initial gradient.")]
26+
BadInitGrad(),
2327
}
2428

2529
pub type Result<T> = std::result::Result<T, NutsError>;

0 commit comments

Comments
 (0)