-
Notifications
You must be signed in to change notification settings - Fork 69
added sample_filter_outputs utility and accompanying simple tests #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
fd87691
5b064d4
d142a91
9e78bae
46149ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari | |
states, covs = outputs[:4], outputs[4:] | ||
|
||
state_names = [ | ||
"filtered_state", | ||
"predicted_state", | ||
"predicted_observed_state", | ||
"smoothed_state", | ||
"filtered_states", | ||
"predicted_states", | ||
"predicted_observed_states", | ||
"smoothed_states", | ||
] | ||
cov_names = [ | ||
"filtered_covariance", | ||
"predicted_covariance", | ||
"predicted_observed_covariance", | ||
"smoothed_covariance", | ||
"filtered_covariances", | ||
"predicted_covariances", | ||
"predicted_observed_covariances", | ||
"smoothed_covariances", | ||
] | ||
|
||
with mod: | ||
|
@@ -939,7 +939,7 @@ def build_statespace_graph( | |
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances] | ||
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs) | ||
|
||
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"] | ||
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"] | ||
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None | ||
|
||
SequenceMvNormal( | ||
|
@@ -1678,6 +1678,93 @@ def sample_statespace_matrices( | |
|
||
return matrix_idata | ||
|
||
def sample_filter_outputs( | ||
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs | ||
): | ||
if isinstance(filter_output_names, str): | ||
filter_output_names = [filter_output_names] | ||
|
||
compile_kwargs = kwargs.pop("compile_kwargs", {}) | ||
compile_kwargs.setdefault("mode", self.mode) | ||
|
||
with pm.Model(coords=self.coords) as m: | ||
self._build_dummy_graph() | ||
self._insert_random_variables() | ||
|
||
if self.data_names: | ||
for name in self.data_names: | ||
pm.Data(**self._exog_data_info[name]) | ||
|
||
self._insert_data_variables() | ||
|
||
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() | ||
data = self._fit_data | ||
|
||
obs_coords = m.coords.get(OBS_STATE_DIM, None) | ||
|
||
data, nan_mask = register_data_with_pymc( | ||
data, | ||
n_obs=self.ssm.k_endog, | ||
obs_coords=obs_coords, | ||
register_data=True, | ||
) | ||
|
||
filter_outputs = self.kalman_filter.build_graph( | ||
data, | ||
x0, | ||
P0, | ||
c, | ||
d, | ||
T, | ||
Z, | ||
R, | ||
H, | ||
Q, | ||
) | ||
|
||
smoother_outputs = self.kalman_smoother.build_graph( | ||
T, R, Q, filter_outputs[0], filter_outputs[3] | ||
) | ||
|
||
# Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph() | ||
# filter_output_dims_mapping = {} | ||
# for k in FILTER_OUTPUT_DIMS.keys(): | ||
# filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops! Sorry about this. That was a careless oversight. I will clean that up right away! |
||
|
||
all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs) | ||
# This excludes observed states and observed covariances from the filter outputs | ||
all_filter_outputs = [ | ||
output for output in all_filter_outputs if output.name in FILTER_OUTPUT_DIMS | ||
] | ||
|
||
if filter_output_names is None: | ||
filter_output_names = all_filter_outputs | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
unknown_filter_output_names = np.setdiff1d( | ||
filter_output_names, [x.name for x in all_filter_outputs] | ||
) | ||
if unknown_filter_output_names.size > 0: | ||
raise ValueError( | ||
f"{unknown_filter_output_names} not a valid filter output name!" | ||
) | ||
filter_output_names = [ | ||
x for x in all_filter_outputs if x.name in filter_output_names | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the input validation up to the top, so we fail quickly without doing any work if the user passes invalid names |
||
|
||
for output in filter_output_names: | ||
dims = FILTER_OUTPUT_DIMS[output.name] | ||
pm.Deterministic(output.name, output, dims=dims) | ||
|
||
frozen_model = freeze_dims_and_data(m) | ||
with frozen_model: | ||
idata_filter = pm.sample_posterior_predictive( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: no need for an intermediate variable here, just directly return |
||
idata if group == "posterior" else idata.prior, | ||
var_names=[x.name for x in frozen_model.deterministics], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just use |
||
compile_kwargs=compile_kwargs, | ||
**kwargs, | ||
) | ||
return idata_filter | ||
|
||
@staticmethod | ||
def _validate_forecast_args( | ||
time_index: pd.RangeIndex | pd.DatetimeIndex, | ||
|
Uh oh!
There was an error while loading. Please reload this page.