diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 03174d0f..f27280e2 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -118,6 +118,8 @@ class PosteriorGSSMFiltered(NamedTuple): :param marginal_loglik: marginal log likelihood, $p(y_{1:T} \mid u_{1:T})$ :param filtered_means: array of filtered means $\mathbb{E}[z_t \mid y_{1:t}, u_{1:t}]$ :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$ + :param predicted_means: array of predicted means $\mathbb{E}[z_t \mid y_{1:t-1}, u_{1:t-1}]$ + :param predicted_covariances: array of predicted covariances $\mathrm{Cov}[z_t \mid y_{1:t-1}, u_{1:t-1}]$ """ marginal_loglik: Union[Scalar, Float[Array, " ntime"]] @@ -504,12 +506,12 @@ def _step(carry, t): # Predict the next state pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F, B, b, Q, u) - return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov) + return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov, carry[1], carry[2]) # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) - (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) - return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs) + (ll, _, _), (filtered_means, filtered_covs, predicted_means, predicted_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) + return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs, predicted_means=predicted_means, predicted_covariances=predicted_covs) @preprocess_args