Incorrect splitting of axes to retain sharding of tensors #109
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In transformers-neuronx, it is crucial to ensure optimized machine learning pipelines, with frequent reshape and transpose sequences, have the same meaning (semantic equivalence) as the baseline pipeline. However, we found that there are cases where the usage of these layout transformation operators causes the tensors to not maintain the sharding specification between the baseline and distributed tensors.
In the gqa shard-over-batch feature for the QKV calculation, we note that the reshape operator at the end of the calculation splits the axes of these matrices from 2 dimension to 4 dimensions. Symbolizing these dimensions, initially they have axes (a, b). After the splitting of axes, they become (c, d, e, f) where c, d represents the splitting of axes from a and e, f represents the splitting of axes from b. Initially, the sharding dimension is at a, and after the splitting, the sharding dimension becomes at d.
However, when we concatenate the tensors across this sharding dimension d across multiple devices, the resulting tensors become different to the baseline. We propose to fix this issue by switching the sharding dimension from d to c to retain the sharding information. Then, we do a transpose by switching the first and second dimension.
Steps to reproduce the bug:
Here are some generated outputs for the baseline and distributed.
Baseline:
Distributed:
This bug is related to this bug we found over here in the all-to-all operation #108 , although this splitting of axes does not have to be preceded by an all-to-all operation. After we introduce the fix, by reshaping the tensors to the correct sharding dimension, we can see that the tensors are equal to the baseline.
A problem with this fix is that the tensors become different from the old version since it is impossible to force the sharding dimension to be along the second dimension. However, it seems that initially it is intended to be split across the 2nd dimension since later the QKV matrices will be multiplied by the KV caches which are already split along the 2nd dimension from the beginning. In general, it might be impossible to support this feature due to differences in sharding dimension between the KV caches and the QKV tensors.
Baseline:
Distributed: