Skip to content

Remove uses of OrderedDict #2046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0


from collections import OrderedDict
from collections.abc import Callable
from typing import Any, Protocol, runtime_checkable

Expand All @@ -15,7 +14,7 @@
ModelT: TypeAlias = Callable[P, Any]

Message: TypeAlias = dict[str, Any]
TraceT: TypeAlias = OrderedDict[str, Message]
TraceT: TypeAlias = dict[str, Message]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep TraceT as OrderedDict



@runtime_checkable
Expand Down
5 changes: 1 addition & 4 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict
from functools import partial
from typing import Callable, Optional

Expand Down Expand Up @@ -493,9 +492,7 @@ def g(*args, **kwargs):
dim_to_name = msg["infer"].get("dim_to_name")
to_funsor(
msg["value"],
dim_to_name=OrderedDict(
[(k, dim_to_name[k]) for k in sorted(dim_to_name)]
),
dim_to_name={k: dim_to_name[k] for k in sorted(dim_to_name)},
)
apply_stack(msg)

Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/funsor/discrete.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, defaultdict
from collections import defaultdict
import functools

from jax import random
Expand Down Expand Up @@ -35,7 +35,7 @@ def _get_support_value_contraction(funsor_dist, name, **kwargs):
@_get_support_value.register(funsor.delta.Delta)
def _get_support_value_delta(funsor_dist, name, **kwargs):
assert name in funsor_dist.fresh
return OrderedDict(funsor_dist.terms)[name][0]
return dict(funsor_dist.terms)[name][0]


def _sample_posterior(
Expand Down
36 changes: 15 additions & 21 deletions numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, namedtuple
from collections import namedtuple
from contextlib import ExitStack # python 3
from enum import Enum

Expand Down Expand Up @@ -71,8 +71,8 @@ class DimStack:
def __init__(self):
self._stack = [
StackFrame(
name_to_dim=OrderedDict(),
dim_to_name=OrderedDict(),
name_to_dim={},
dim_to_name={},
parents=(),
iter_parents=(),
keep=False,
Expand Down Expand Up @@ -238,7 +238,7 @@ def process_message(self, msg):

@staticmethod
def _get_name_to_dim(batch_names, name_to_dim=None, dim_type=DimType.LOCAL):
name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim.copy()
name_to_dim = {} if name_to_dim is None else name_to_dim.copy()

# interpret all names/dims as requests since we only run this function once
for name in batch_names:
Expand All @@ -256,7 +256,7 @@ def _get_name_to_dim(batch_names, name_to_dim=None, dim_type=DimType.LOCAL):
@classmethod # only depends on the global _DIM_STACK state, not self
def _pyro_to_data(cls, msg):
(funsor_value,) = msg["args"]
name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict())
name_to_dim = msg["kwargs"].setdefault("name_to_dim", {})
dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)

batch_names = tuple(funsor_value.inputs.keys())
Expand All @@ -270,7 +270,7 @@ def _pyro_to_data(cls, msg):

@staticmethod
def _get_dim_to_name(batch_shape, dim_to_name=None, dim_type=DimType.LOCAL):
dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name.copy()
dim_to_name = {} if dim_to_name is None else dim_to_name.copy()
batch_dim = len(batch_shape)

# interpret all names/dims as requests since we only run this function once
Expand All @@ -296,7 +296,7 @@ def _pyro_to_funsor(cls, msg):
else:
raw_value = msg["args"][0]
output = msg["kwargs"].setdefault("output", None)
dim_to_name = msg["kwargs"].setdefault("dim_to_name", OrderedDict())
dim_to_name = msg["kwargs"].setdefault("dim_to_name", {})
dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)

event_dim = len(output.shape) if output else 0
Expand Down Expand Up @@ -359,7 +359,7 @@ def __enter__(self):
saved_frame = self._saved_frames.pop()
name_to_dim, dim_to_name = saved_frame.name_to_dim, saved_frame.dim_to_name
else:
name_to_dim, dim_to_name = OrderedDict(), OrderedDict()
name_to_dim, dim_to_name = {}, {}

frame = StackFrame(
name_to_dim=name_to_dim,
Expand Down Expand Up @@ -490,18 +490,14 @@ def __init__(self, name, size, subsample_size=None, dim=None):
self.subsample_size = indices.shape[0]
self._indices = funsor.Tensor(
indices,
OrderedDict([(self.name, funsor.Bint[self.subsample_size])]),
{self.name: funsor.Bint[self.subsample_size]},
self.subsample_size,
)
super(plate, self).__init__(None)

def __enter__(self):
super().__enter__() # do this first to take care of globals recycling
name_to_dim = (
OrderedDict([(self.name, self.dim)])
if self.dim is not None
else OrderedDict()
)
name_to_dim = {self.name: self.dim} if self.dim is not None else {}
indices = to_data(
self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE
)
Expand Down Expand Up @@ -594,9 +590,7 @@ def process_message(self, msg):

size = msg["fn"].enumerate_support(expand=False).shape[0]
raw_value = jnp.arange(0, size)
funsor_value = funsor.Tensor(
raw_value, OrderedDict([(msg["name"], funsor.Bint[size])]), size
)
funsor_value = funsor.Tensor(raw_value, {msg["name"]: funsor.Bint[size]}, size)

msg["value"] = to_data(funsor_value)
msg["done"] = True
Expand Down Expand Up @@ -661,15 +655,15 @@ def to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL):
:param x: An object.
:param funsor.domains.Domain output: An optional output hint to uniquely
convert a data to a Funsor (e.g. when `x` is a string).
:param OrderedDict dim_to_name: An optional mapping from negative
:param dict dim_to_name: An optional mapping from negative
batch dimensions to name strings.
:param int dim_type: Either 0, 1, or 2. This optional argument indicates
a dimension should be treated as 'local', 'global', or 'visible',
which can be used to interact with the global :class:`DimStack`.
:return: A Funsor equivalent to `x`.
:rtype: funsor.terms.Funsor
"""
dim_to_name = OrderedDict() if dim_to_name is None else dim_to_name
dim_to_name = {} if dim_to_name is None else dim_to_name

initial_msg = {
"type": "to_funsor",
Expand All @@ -691,14 +685,14 @@ def to_data(x, name_to_dim=None, dim_type=DimType.LOCAL):
A primitive to extract a python object from a :class:`~funsor.terms.Funsor`.

:param ~funsor.terms.Funsor x: A funsor object
:param OrderedDict name_to_dim: An optional inputs hint which maps
:param dict name_to_dim: An optional inputs hint which maps
dimension names from `x` to dimension positions of the returned value.
:param int dim_type: Either 0, 1, or 2. This optional argument indicates
a dimension should be treated as 'local', 'global', or 'visible',
which can be used to interact with the global :class:`DimStack`.
:return: A non-funsor equivalent to `x`.
"""
name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim
name_to_dim = {} if name_to_dim is None else name_to_dim

initial_msg = {
"type": "to_data",
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, defaultdict
from collections import defaultdict
from contextlib import contextmanager
import functools
import re
Expand Down Expand Up @@ -221,7 +221,7 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):

dim_to_name = site["infer"]["dim_to_name"]

if all(dim == 1 for dim in log_prob.shape) and dim_to_name == OrderedDict():
if all(dim == 1 for dim in log_prob.shape) and dim_to_name == {}:
log_prob = log_prob.squeeze()

log_prob_factor = funsor.to_funsor(
Expand Down
20 changes: 10 additions & 10 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType, Union
from collections import defaultdict
from typing import Any, Callable, Union

import jax
from jax import random
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None

def _find_slps(
self, rng_key: jax.Array, *args: Any, **kwargs: Any
) -> dict[str, OrderedDictType]:
) -> dict[str, dict]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure but it seems that we need to use OrderedDict in this module.

"""
Discover the straight-line programs (SLPs) in the model by sampling from the prior.
This implementation assumes that all branching is done via discrete sampling sites
Expand All @@ -80,11 +80,11 @@ def _find_slps(

return branching_traces

def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType:
def _get_branching_trace(self, tr: dict[str, Any]) -> dict:
"""
Extract the sites from the trace that are annotated with `infer={"branching": True}`.
"""
branching_trace = OrderedDict()
branching_trace = {}
for site in tr.values():
if (
site["type"] == "sample"
Expand All @@ -109,7 +109,7 @@ def _get_branching_trace(self, tr: dict[str, Any]) -> OrderedDictType:
def _run_inference(
self,
rng_key: jax.Array,
branching_trace: OrderedDictType,
branching_trace: dict,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
Expand All @@ -120,7 +120,7 @@ def _combine_inferences(
self,
rng_key: jax.Array,
inferences: dict[str, Any],
branching_traces: dict[str, OrderedDictType],
branching_traces: dict[str, dict],
*args: Any,
**kwargs: Any,
) -> Union[DCCResult, SDVIResult]:
Expand All @@ -139,7 +139,7 @@ def run(
rng_key, subkey = random.split(rng_key)
branching_traces = self._find_slps(subkey, *args, **kwargs)

inferences = dict()
inferences = {}
for key, bt in branching_traces.items():
rng_key, subkey = random.split(rng_key)
inferences[key] = self._run_inference(subkey, bt, *args, **kwargs)
Expand Down Expand Up @@ -209,7 +209,7 @@ def __init__(
def _run_inference(
self,
rng_key: jax.Array,
branching_trace: OrderedDictType,
branching_trace: dict,
*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
Expand All @@ -227,7 +227,7 @@ def _combine_inferences( # type: ignore[override]
self,
rng_key: jax.Array,
samples: dict[str, Any],
branching_traces: dict[str, OrderedDictType],
branching_traces: dict[str, dict],
*args: Any,
**kwargs: Any,
) -> DCCResult:
Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/stochastic_support/sdvi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, OrderedDict as OrderedDictType
from typing import Any, Callable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
def _run_inference(
self,
rng_key: jax.Array,
branching_trace: OrderedDictType,
branching_trace: dict,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like above, maybe we need to use OrderedDict in this module.

*args: Any,
**kwargs: Any,
) -> RunInferenceResult:
Expand All @@ -121,7 +121,7 @@ def _combine_inferences( # type: ignore[override]
self,
rng_key: jax.Array,
guides: dict[str, tuple[AutoGuide, dict[str, Any]]],
branching_traces: dict[str, OrderedDictType],
branching_traces: dict[str, dict],
*args: Any,
**kwargs: Any,
) -> SDVIResult:
Expand Down
21 changes: 9 additions & 12 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
This provides a small set of utilities in NumPyro that are used to diagnose posterior samples.
"""

from collections import OrderedDict
from itertools import product
from typing import Union

Expand Down Expand Up @@ -271,17 +270,15 @@ def summary(
r_hat = split_gelman_rubin(value)
hpd_lower = "{:.1f}%".format(50 * (1 - prob))
hpd_upper = "{:.1f}%".format(50 * (1 + prob))
summary_dict[name] = OrderedDict(
[
("mean", mean),
("std", std),
("median", median),
(hpd_lower, hpd[0]),
(hpd_upper, hpd[1]),
("n_eff", n_eff),
("r_hat", r_hat),
]
)
summary_dict[name] = {
"mean": mean,
"std": std,
"median": median,
hpd_lower: hpd[0],
hpd_upper: hpd[1],
"n_eff": n_eff,
"r_hat": r_hat,
}
return summary_dict


Expand Down
7 changes: 3 additions & 4 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
# POSSIBILITY OF SUCH DAMAGE.


from collections import OrderedDict
from contextlib import contextmanager
import functools
import inspect
Expand Down Expand Up @@ -620,7 +619,7 @@ def __init__(
@staticmethod
def _broadcast_shape(
existing_shape: tuple[int, ...], new_shape: tuple[int, ...]
) -> tuple[tuple[int, ...], OrderedDict, OrderedDict]:
) -> tuple[tuple[int, ...], dict, dict]:
if len(new_shape) < len(existing_shape):
raise ValueError(
"Cannot broadcast distribution of shape {} to shape {}".format(
Expand All @@ -645,8 +644,8 @@ def _broadcast_shape(
)
return (
tuple(reversed(reversed_shape)),
OrderedDict(reversed(expanded_sizes)),
OrderedDict(interstitial_sizes),
dict(reversed(expanded_sizes)),
dict(interstitial_sizes),
)

@property
Expand Down
Loading
Loading