-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
Description
JAX does not play well with fork, which is the default we're using for linux OS and arm-based MacOS
import pymc as pm
N_OBSERVATIONS = 50
with pm.Model() as model:
mu = pm.Normal("mu")
sigma = pm.HalfNormal("sigma", sigma=0.5)
y = pm.Normal("y", mu=mu, sigma=sigma, shape=N_OBSERVATIONS)
prior_trace = pm.sample_prior_predictive(random_seed=100)
data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
pm.sample(compile_kwargs=dict(mode="JAX"), mp_ctx="forkserver") # fine
pm.sample(compile_kwargs=dict(mode="JAX")) # hangs forever
Wherever we're defaulting to fork, we should switch to forkserver/spawn instead (whichever is supported)
Relevant code:
pymc/pymc/sampling/parallel.py
Lines 437 to 450 in 268e13b
if mp_ctx is None or isinstance(mp_ctx, str): | |
# Closes issue https://github.com/pymc-devs/pymc/issues/3849 | |
# Related issue https://github.com/pymc-devs/pymc/issues/5339 | |
if mp_ctx is None and platform.system() == "Darwin": | |
if platform.processor() == "arm": | |
mp_ctx = "fork" | |
logger.debug( | |
"mp_ctx is set to 'fork' for MacOS with ARM architecture. " | |
+ "This might cause unexpected behavior with JAX, which is inherently multithreaded." | |
) | |
else: | |
mp_ctx = "forkserver" | |
mp_ctx = multiprocessing.get_context(mp_ctx) |
To find the backend that is being used something like this can be used:
from pytensor.compile.mode import get_mode
from pytensor.link.jax import JAXLinker
...
# Somewhere inside/downstream of pm.sample
mode = compile_kwargs.get("mode", None)
using_jax = isinstance(get_mode(mode).linker, JAXLinker)