Skip to content

Remove uses of OrderedDict #2046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

sid-kap
Copy link

@sid-kap sid-kap commented Jul 16, 2025

Since python 3.6 dict has the same behavior as OrderedDict. Numpyro requires python version >=3.9, so it should be fine to use dict instead of OrderedDict.

@fehiepsi
Copy link
Member

Thanks @sid-kap! Do you think jax-ml/jax#24398 is still relevant for the changes?

@sid-kap
Copy link
Author

sid-kap commented Jul 16, 2025

I think it should be fine because dict now keeps the keys ordered, like OrderedDict

@fehiepsi
Copy link
Member

yes, dict in python keeps key order, but not under jit compiling

import jax

x = {"b":1,"a":2}
jax.jit(lambda x: x)(x)

returns a new order.

@sid-kap
Copy link
Author

sid-kap commented Jul 17, 2025

Ah you're right, my bad! Thanks for the clarification! It looks like jax handles OrderedDict correctly:

import jax
x = OrderedDict([("b", 1), ("a", 2)])
jax.jit(lambda x: x)(x)

returns

OrderedDict([('b', Array(1, dtype=int32, weak_type=True)), ('a', Array(2, dtype=int32, weak_type=True))])

So any function that may be jitted should be kept as OrderedDict.

I'll read through the code again and see if there are any functions that are guaranteed to not be jitted, but if not it probably makes sense to close this.

@fehiepsi
Copy link
Member

Looking at the changes, I think most are valid. We just need to keep TraceT as-is because order in trace is important for inference algorithms.

@@ -15,7 +14,7 @@
ModelT: TypeAlias = Callable[P, Any]

Message: TypeAlias = dict[str, Any]
TraceT: TypeAlias = OrderedDict[str, Message]
TraceT: TypeAlias = dict[str, Message]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep TraceT as OrderedDict

@@ -61,7 +61,7 @@ def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None

def _find_slps(
self, rng_key: jax.Array, *args: Any, **kwargs: Any
) -> dict[str, OrderedDictType]:
) -> dict[str, dict]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure but it seems that we need to use OrderedDict in this module.

@@ -98,7 +98,7 @@ def __init__(
def _run_inference(
self,
rng_key: jax.Array,
branching_trace: OrderedDictType,
branching_trace: dict,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like above, maybe we need to use OrderedDict in this module.

@@ -155,19 +154,19 @@ class trace(Messenger):

>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace()
>>> pp.pprint(exec_trace) # doctest: +SKIP
OrderedDict([('a',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this

@@ -174,7 +173,7 @@ def identity(x, *args, **kwargs):
def cached_by(outer_fn, *keys):
# Restrict cache size to prevent ref cycles.
max_size = 8
outer_fn._cache = getattr(outer_fn, "_cache", OrderedDict())
outer_fn._cache = getattr(outer_fn, "_cache", {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use OrderedDict here I guess.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants