-
Notifications
You must be signed in to change notification settings - Fork 267
Remove uses of OrderedDict #2046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Thanks @sid-kap! Do you think jax-ml/jax#24398 is still relevant for the changes? |
I think it should be fine because |
yes, dict in python keeps key order, but not under jit compiling
returns a new order. |
Ah you're right, my bad! Thanks for the clarification! It looks like jax handles OrderedDict correctly: import jax
x = OrderedDict([("b", 1), ("a", 2)])
jax.jit(lambda x: x)(x) returns
So any function that may be jitted should be kept as OrderedDict. I'll read through the code again and see if there are any functions that are guaranteed to not be jitted, but if not it probably makes sense to close this. |
Looking at the changes, I think most are valid. We just need to keep TraceT as-is because order in trace is important for inference algorithms. |
@@ -15,7 +14,7 @@ | |||
ModelT: TypeAlias = Callable[P, Any] | |||
|
|||
Message: TypeAlias = dict[str, Any] | |||
TraceT: TypeAlias = OrderedDict[str, Message] | |||
TraceT: TypeAlias = dict[str, Message] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep TraceT as OrderedDict
@@ -61,7 +61,7 @@ def __init__(self, model: Callable, num_slp_samples: int, max_slps: int) -> None | |||
|
|||
def _find_slps( | |||
self, rng_key: jax.Array, *args: Any, **kwargs: Any | |||
) -> dict[str, OrderedDictType]: | |||
) -> dict[str, dict]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure but it seems that we need to use OrderedDict in this module.
@@ -98,7 +98,7 @@ def __init__( | |||
def _run_inference( | |||
self, | |||
rng_key: jax.Array, | |||
branching_trace: OrderedDictType, | |||
branching_trace: dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like above, maybe we need to use OrderedDict in this module.
@@ -155,19 +154,19 @@ class trace(Messenger): | |||
|
|||
>>> exec_trace = trace(seed(model, random.PRNGKey(0))).get_trace() | |||
>>> pp.pprint(exec_trace) # doctest: +SKIP | |||
OrderedDict([('a', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert this
@@ -174,7 +173,7 @@ def identity(x, *args, **kwargs): | |||
def cached_by(outer_fn, *keys): | |||
# Restrict cache size to prevent ref cycles. | |||
max_size = 8 | |||
outer_fn._cache = getattr(outer_fn, "_cache", OrderedDict()) | |||
outer_fn._cache = getattr(outer_fn, "_cache", {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use OrderedDict here I guess.
Since python 3.6 dict has the same behavior as OrderedDict. Numpyro requires python version >=3.9, so it should be fine to use dict instead of OrderedDict.