Skip to content

Conversation

kahfizulkifli
Copy link

In transformers-neuronx, it is crucial to ensure optimized machine learning pipelines, with frequent reshape and transpose sequences, have the same meaning (semantic equivalence) as the baseline pipeline. However, we found that there are cases where the usage of these layout transformation operators causes the tensors to not maintain the sharding specification between the baseline and distributed tensors.

In the gqa shard-over-batch feature for the QKV calculation, we note that the reshape operator at the end of the calculation splits the axes of these matrices from 2 dimension to 4 dimensions. Symbolizing these dimensions, initially they have axes (a, b). After the splitting of axes, they become (c, d, e, f) where c, d represents the splitting of axes from a and e, f represents the splitting of axes from b. Initially, the sharding dimension is at a, and after the splitting, the sharding dimension becomes at d.

However, when we concatenate the tensors across this sharding dimension d across multiple devices, the resulting tensors become different to the baseline. We propose to fix this issue by switching the sharding dimension from d to c to retain the sharding information. Then, we do a transpose by switching the first and second dimension.

Steps to reproduce the bug:

  1. Setup transformers-neuronx library on AWS Neuron trainium machine
  2. Download the Llama-3 pretrained model from Huggingface https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/tree/main
  3. Edit the config.json in the Llama-3 folder to use 1 layer ("num_attention_heads": 4, "num_hidden_layers": 1, "num_key_value_heads": 1)
  4. Run the test script provided in two modes: baseline and distributed with collectives layout BSH
# Baseline mode
python llama_driver.py run <model_path_folder> --tp_degree=1 --gqa=shard-over-batch --debug > tp_1.txt

# Distributed mode
python llama_driver.py run <model_path_folder> --tp_degree=2 --gqa=shard-over-batch --debug > tp_2.txt
  1. Check that the results of both scripts are different to one another

Here are some generated outputs for the baseline and distributed.

Baseline:

reshape.56
torch.Size([64, 4, 4, 1024])
tensor([[[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]],

        [[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]],

        [[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]],

        [[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]]])

Distributed:

reshape.71
torch.Size([64, 4, 4, 1024])
tensor([[[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]],

        [[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]]])

This bug is related to this bug we found over here in the all-to-all operation #108 , although this splitting of axes does not have to be preceded by an all-to-all operation. After we introduce the fix, by reshaping the tensors to the correct sharding dimension, we can see that the tensors are equal to the baseline.

A problem with this fix is that the tensors become different from the old version since it is impossible to force the sharding dimension to be along the second dimension. However, it seems that initially it is intended to be split across the 2nd dimension since later the QKV matrices will be multiplied by the KV caches which are already split along the 2nd dimension from the beginning. In general, it might be impossible to support this feature due to differences in sharding dimension between the KV caches and the QKV tensors.

Baseline:

transpose.57
torch.Size([64, 4, 4, 1024])
tensor([[[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]]])

Distributed:

transpose.72
torch.Size([64, 4, 4, 1024])
tensor([[[-0.0974,  0.2617,  0.2940,  ..., -0.0316,  0.1201,  0.3106],
         [ 0.0730,  0.1343,  0.0262,  ...,  0.4203, -0.0501,  0.0040],
         [-0.3880,  0.1703,  0.4177,  ...,  0.2850, -0.3323,  0.3105],
         [ 0.7463,  0.2157, -0.2138,  ...,  0.4804,  0.4542,  0.2629]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]],

        [[ 0.5616,  1.7017,  1.3909,  ..., -0.2786,  0.6289,  0.8052],
         [-0.2743,  0.9822,  0.2499,  ...,  0.9142, -0.3162,  0.1421],
         [-2.7863, -0.0689,  1.4579,  ..., -0.0177, -0.1165,  0.0734],
         [-0.6148, -0.3177,  0.5596,  ...,  1.0244,  0.9916,  0.7313]]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant