B & S transpose issue in BSH collectives_layout MLP output #106
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description of PR
In our testing with the latest transformers-neuronx version 1ade6d7 we found a similar case to 69d039d in the MLP output calculation, where using the Llama-3 model, we noted that the outputs between --tp_degree=1 and --tp_degree=2 --collectives_layout="BSH" are highly different to each other.
Steps to reproduce the bug:
Here are some sample generated output and logit outputs that we got from running the two modes
Output and logits for the baseline
Output and logits for the TP 2 with collectives layout BSH
In 69d039d the output of the dot is split into (s, b, h) through a reshape operator, and then converted into (b, s, h) through a transpose of (1, 0, 2), getting the result of (b, s, h).
We propose a similar modification in the BSH collectives_layout feature, where the initial output of the dot is mapped to (s * b, h) and not (b * s, h). Therefore, we need to convert (s * b, h) => (b, s, h).
We tested the fix and the outputs of the --tp_degree=2 --collectives_layout="BSH" are equivalent to the --tp_degree=1 mode.
Your insights are very much appreciated. We will continue following up this issue until it is resolved.
Credits to @wenboqian for providing initial direction to detecting and fixing the bug