Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
This repository was archived by the owner on Mar 31, 2025. It is now read-only.

pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus #245

@a1302z

Description

@a1302z

Hi,
I've noticed a problem, where I'd like to ask for your expertise. I'm not entirely sure if it is an objax problem or rather a Jax problem under the hood, but as it is triggered by objax commands I'll post it here.

Description

In particular, when combining objax.Parallel and objax.functional.pmean (as done in this tutorial) I encounter problems with more than 2 GPUs (with 2 GPUs it works fine). It results in a deadlock situation, where nothing happens anymore. If I understand the tutorial correctly, the pmean is necessary to average the gradients of all cards.

Minimal reproducible example

import objax
import numpy as np
from objax.zoo.resnet_v2 import ResNet18
from jax import numpy as jnp, device_count
from tqdm import tqdm


if __name__ == "__main__":
    print(f"Num devices: {device_count()}")
    model = ResNet18(3, 1)
    opt = objax.optimizer.SGD(model.vars())

    @objax.Function.with_vars(model.vars())
    def loss(x, label):
        return objax.functional.loss.mean_squared_error(
            model(x, training=True), label
        ).mean()

    gv = objax.GradValues(loss, model.vars())

    train_vars = model.vars() + gv.vars() + opt.vars()

    @objax.Function.with_vars(train_vars)
    def train_op(
        image_batch,
        label_batch,
    ):

        grads, loss = gv(image_batch, label_batch)
        # grads = objax.functional.parallel.pmean(grads) # this line
        # loss = objax.functional.parallel.pmean(loss) # and this line
        loss = loss[0]
        opt(1e-3, grads)
        return loss, grads

    train_op = objax.Parallel(train_op, reduce=jnp.mean, vc=train_vars)

    with (train_vars).replicate():
        for _ in tqdm(range(10), total=10):
            data = jnp.array(np.random.randn(512, 3, 224, 224))
            label = jnp.zeros((512, 1))
            loss, grads = train_op(data, label)

Whenever you comment in the two lines with pmean the program gets stuck. However, if I understood it correctly, this is necessary to get the average of the gradients over all cards.

Error traces

As with most deadlock bugs you don't get an error stack trace. However, I have two clues that I've found so far. One is that if this is uncommented, the following appears:

2022-08-22 14:55:46.462557: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2022-08-22 14:55:48.543291: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:36] Thread is unstuck! Warning above was a false-positive. Perhaps the timeout is too short.

The other is that if I manually interrupt it with ctrl+c I got this lengthy stacktrace

Setup

We use 4 NVIDIA A40 GPUs with CUDA Version 11.7 (Driver Version 515.65.01), cudnn 8.2.1.32, jax version 0.3.15, objax version 1.6.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions