77
88from fastprogress .fastprogress import progress_bar
99from functools import partial
10- from jax import jit , vmap
10+ from jax import jit , tree , vmap
1111from jax .tree_util import tree_map
1212from jaxtyping import Array , Float
1313from tensorflow_probability .substrates .jax .distributions import MultivariateNormalFullCovariance as MVN
1414from typing import Any , Optional , Tuple , Union , runtime_checkable
15- from typing_extensions import Protocol
15+ from typing_extensions import Protocol
1616
1717from dynamax .ssm import SSM
1818from dynamax .linear_gaussian_ssm .inference import lgssm_joint_sample , lgssm_filter , lgssm_smoother , lgssm_posterior_sample
@@ -206,7 +206,7 @@ def sample(self,
206206 key : PRNGKeyT ,
207207 num_timesteps : int ,
208208 inputs : Optional [Float [Array , "num_timesteps input_dim" ]] = None ) \
209- -> Tuple [Float [Array , "num_timesteps state_dim" ],
209+ -> Tuple [Float [Array , "num_timesteps state_dim" ],
210210 Float [Array , "num_timesteps emission_dim" ]]:
211211 """Sample from the model.
212212
@@ -357,7 +357,7 @@ def forecast(self,
357357 input_weights = params .emissions .input_weights ,
358358 cov = 1e8 * jnp .ones (self .emission_dim )) # ignore dummy observatiosn
359359 )
360-
360+
361361 dummy_emissions = jnp .zeros ((num_forecast_timesteps , self .emission_dim ))
362362 forecast_inputs = forecast_inputs if forecast_inputs is not None else \
363363 jnp .zeros ((num_forecast_timesteps , 0 ))
@@ -367,7 +367,7 @@ def forecast(self,
367367 H = params .emissions .weights
368368 b = params .emissions .bias
369369 R = params .emissions .cov if params .emissions .cov .ndim == 2 else jnp .diag (params .emissions .cov )
370-
370+
371371 forecast_emissions = forecast_states .filtered_means @ H .T + b
372372 forecast_emissions_cov = H @ forecast_states .filtered_covariances @ H .T + R
373373 return forecast_states .filtered_means , \
@@ -643,6 +643,47 @@ def m_step(self,
643643 )
644644 return params , m_step_state
645645
646+ def _check_params (self , params : ParamsLGSSM , num_timesteps : int ) -> ParamsLGSSM :
647+ """Replace None parameters with zeros."""
648+ dynamics , emissions = params .dynamics , params .emissions
649+ is_inhomogeneous = dynamics .weights .ndim == 3
650+
651+ def _zeros_if_none (x , shape ):
652+ if x is None :
653+ return jnp .zeros (shape )
654+ return x
655+
656+ shape_prefix = ()
657+ if is_inhomogeneous :
658+ shape_prefix = (num_timesteps - 1 ,)
659+
660+ clean_dynamics = ParamsLGSSMDynamics (
661+ weights = dynamics .weights ,
662+ bias = _zeros_if_none (dynamics .bias , shape = shape_prefix + (self .state_dim ,)),
663+ input_weights = _zeros_if_none (
664+ dynamics .input_weights , shape = shape_prefix + (self .state_dim , self .input_dim )
665+ ),
666+ cov = dynamics .cov
667+ )
668+ shape_prefix = ()
669+ if is_inhomogeneous :
670+ shape_prefix = (num_timesteps ,)
671+
672+ clean_emissions = ParamsLGSSMEmissions (
673+ weights = emissions .weights ,
674+ bias = _zeros_if_none (emissions .bias , shape = shape_prefix + (self .emission_dim ,)),
675+ input_weights = _zeros_if_none (
676+ emissions .input_weights , shape = shape_prefix + (self .emission_dim , self .input_dim )
677+ ),
678+ cov = emissions .cov
679+ )
680+ return ParamsLGSSM (
681+ initial = params .initial ,
682+ dynamics = clean_dynamics ,
683+ emissions = clean_emissions ,
684+ )
685+
686+
646687 def fit_blocked_gibbs (self ,
647688 key : PRNGKeyT ,
648689 initial_params : ParamsLGSSM ,
@@ -654,7 +695,8 @@ def fit_blocked_gibbs(self,
654695
655696 Args:
656697 key: random number key.
657- initial_params: starting parameters.
698+ initial_params: starting parameters. Include a leading time axis for
699+ the dynamics and emissions parameters in inhomogeneous models.
658700 sample_size: how many samples to draw.
659701 emissions: set of observation sequences.
660702 inputs: optional set of input sequences.
@@ -664,67 +706,97 @@ def fit_blocked_gibbs(self,
664706 """
665707 num_timesteps = len (emissions )
666708
709+ # Inhomogeneous models have a leading time dimension.
710+ is_inhomogeneous = initial_params .dynamics .weights .ndim == 3
711+
667712 if inputs is None :
668713 inputs = jnp .zeros ((num_timesteps , 0 ))
669714
715+ initial_params = self ._check_params (initial_params , num_timesteps )
716+
670717 def sufficient_stats_from_sample (states ):
671718 """Convert samples of states to sufficient statistics."""
672719 inputs_joint = jnp .concatenate ((inputs , jnp .ones ((num_timesteps , 1 ))), axis = 1 )
673720 # Let xn[t] = x[t+1] for t = 0...T-2
674- x , xp , xn = states , states [:- 1 ], states [1 :]
675- u , up = inputs_joint , inputs_joint [:- 1 ]
721+ x , xn = states , states [1 :]
722+ u = inputs_joint
723+ # Let z[t] = [x[t], u[t]] for t = 0...T-1
724+ z = jnp .concatenate ([x , u ], axis = - 1 )
725+ # Let zp[t] = [x[t], u[t]] for t = 0...T-2
726+ zp = z [:- 1 ]
676727 y = emissions
677728
678729 init_stats = (x [0 ], jnp .outer (x [0 ], x [0 ]), 1 )
679730
680731 # Quantities for the dynamics distribution
681- # Let zp[t] = [x[t], u[t]] for t = 0...T-2
682- sum_zpzpT = jnp .block ([[xp .T @ xp , xp .T @ up ], [up .T @ xp , up .T @ up ]])
683- sum_zpxnT = jnp .block ([[xp .T @ xn ], [up .T @ xn ]])
684- sum_xnxnT = xn .T @ xn
685- dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , num_timesteps - 1 )
732+ sum_zpzpT = jnp .einsum ('ti,tj->tij' , zp , zp )
733+ sum_zpxnT = jnp .einsum ('ti,tj->tij' , zp , xn )
734+ sum_xnxnT = jnp .einsum ('ti,tj->tij' , xn , xn )
735+ z_is_observed = jnp .ones (num_timesteps - 1 )
736+ # The dynamics stats have a leading time dimension.
737+ dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , z_is_observed )
686738 if not self .has_dynamics_bias :
687- dynamics_stats = (sum_zpzpT [:- 1 , :- 1 ], sum_zpxnT [:- 1 , :], sum_xnxnT ,
688- num_timesteps - 1 )
739+ dynamics_stats = (sum_zpzpT [:, : - 1 , :- 1 ], sum_zpxnT [:, :- 1 , :], sum_xnxnT ,
740+ z_is_observed )
689741
690742 # Quantities for the emissions
691- # Let z[t] = [x[t], u[t]] for t = 0...T-1
692- sum_zzT = jnp .block ([[x .T @ x , x .T @ u ], [u .T @ x , u .T @ u ]])
693- sum_zyT = jnp .block ([[x .T @ y ], [u .T @ y ]])
694- sum_yyT = y .T @ y
695- emission_stats = (sum_zzT , sum_zyT , sum_yyT , num_timesteps )
743+ sum_zzT = jnp .einsum ('ti,tj->tij' , z , z )
744+ sum_zyT = jnp .einsum ('ti,tj->tij' , z , y )
745+ sum_yyT = jnp .einsum ('ti,tj->tij' , y , y )
746+ y_is_observed = jnp .ones (num_timesteps )
747+ # The emissions stats have a leading time dimension.
748+ emission_stats = (sum_zzT , sum_zyT , sum_yyT , y_is_observed )
696749 if not self .has_emissions_bias :
697- emission_stats = (sum_zzT [:- 1 , :- 1 ], sum_zyT [:- 1 , :], sum_yyT , num_timesteps )
750+ emission_stats = (sum_zzT [:, : - 1 , :- 1 ], sum_zyT [:, : - 1 , :], sum_yyT , y_is_observed )
698751
699752 return init_stats , dynamics_stats , emission_stats
700753
701- def lgssm_params_sample (rng , stats ):
702- """Sample parameters of the model given sufficient statistics from observed states and emissions."""
703- init_stats , dynamics_stats , emission_stats = stats
704- rngs = iter (jr .split (rng , 3 ))
705-
706- # Sample the initial params
754+ def _sample_initial_params (rng , init_stats ):
707755 initial_posterior = niw_posterior_update (self .initial_prior , init_stats )
708- S , m = initial_posterior .sample (seed = next (rngs ))
756+ S , m = initial_posterior .sample (seed = rng )
757+ return ParamsLGSSMInitial (mean = m , cov = S )
709758
710- # Sample the dynamics params
759+ def _sample_dynamics_params ( rng , dynamics_stats ):
711760 dynamics_posterior = mniw_posterior_update (self .dynamics_prior , dynamics_stats )
712- Q , FB = dynamics_posterior .sample (seed = next ( rngs ) )
761+ Q , FB = dynamics_posterior .sample (seed = rng )
713762 F = FB [:, :self .state_dim ]
714763 B , b = (FB [:, self .state_dim :- 1 ], FB [:, - 1 ]) if self .has_dynamics_bias \
715764 else (FB [:, self .state_dim :], jnp .zeros (self .state_dim ))
765+ return ParamsLGSSMDynamics (weights = F , bias = b , input_weights = B , cov = Q )
716766
717- # Sample the emission params
767+ def _sample_emission_params ( rng , emission_stats ):
718768 emission_posterior = mniw_posterior_update (self .emission_prior , emission_stats )
719- R , HD = emission_posterior .sample (seed = next ( rngs ) )
769+ R , HD = emission_posterior .sample (seed = rng )
720770 H = HD [:, :self .state_dim ]
721771 D , d = (HD [:, self .state_dim :- 1 ], HD [:, - 1 ]) if self .has_emissions_bias \
722772 else (HD [:, self .state_dim :], jnp .zeros (self .emission_dim ))
773+ return ParamsLGSSMEmissions (weights = H , bias = d , input_weights = D , cov = R )
774+
775+ def lgssm_params_sample (rng , stats ):
776+ """Sample parameters of the model given sufficient statistics from observed states and emissions."""
777+ init_stats , dynamics_stats , emission_stats = stats
778+ rngs = iter (jr .split (rng , 3 ))
779+
780+ # Sample the initial params
781+ initial_params = _sample_initial_params (next (rngs ), init_stats )
782+
783+ # Sample the dynamics and emission params.
784+ if not is_inhomogeneous :
785+ # Aggregate summary statistics across time for homogeneous model.
786+ dynamics_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), dynamics_stats )
787+ emission_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), emission_stats )
788+ dynamics_params = _sample_dynamics_params (next (rngs ), dynamics_stats )
789+ emission_params = _sample_emission_params (next (rngs ), emission_stats )
790+ else :
791+ keys_dynamics = jr .split (next (rngs ), num_timesteps - 1 )
792+ keys_emission = jr .split (next (rngs ), num_timesteps )
793+ dynamics_params = vmap (_sample_dynamics_params )(keys_dynamics , dynamics_stats )
794+ emission_params = vmap (_sample_emission_params )(keys_emission , emission_stats )
723795
724796 params = ParamsLGSSM (
725- initial = ParamsLGSSMInitial ( mean = m , cov = S ) ,
726- dynamics = ParamsLGSSMDynamics ( weights = F , bias = b , input_weights = B , cov = Q ) ,
727- emissions = ParamsLGSSMEmissions ( weights = H , bias = d , input_weights = D , cov = R )
797+ initial = initial_params ,
798+ dynamics = dynamics_params ,
799+ emissions = emission_params ,
728800 )
729801 return params
730802
0 commit comments