-
Notifications
You must be signed in to change notification settings - Fork 13k
CANN: fix RoPE cache issue on multi-device #15629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
RoPE cache only needs to be computed once per token. However, in multi-device scenarios, not every device starts computation from layer 0, which may lead to unallocated memory issues and precision errors. This commit records the first layer of each device to avoid the above issues.
ggml/src/ggml-cann/aclnn_ops.cpp
Outdated
// get first layer in current device. | ||
int layer = 0; | ||
const char* dash = std::strchr(dst->name, '-'); | ||
if (dash) { | ||
layer = std::strtol(dash + 1, nullptr, 10); | ||
} | ||
|
||
// remember the first layer. | ||
if(ctx.rope_cache.first_layer == -1) | ||
ctx.rope_cache.first_layer = layer; | ||
|
||
int64_t theta_scale_length = ne00 / 2; | ||
// only init cache when freq_factors is not null or first layer. | ||
// dash == nullptr means we are in test-backend-ops | ||
if(dash != nullptr && src2 == nullptr && layer != ctx.rope_cache.first_layer) { | ||
// use cache. | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really hacky. Can you improve without making assumptions about the tensor names? Maybe create the cache based on the input parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve tried, but during the decode stage, it’s not possible to determine based on the shape and data of position, because all position lengths are the same, and position itself, as well as running position->data, are the same too. The only difference is the data inside position, but copying data from the device to the host is not a good approach. Do you have any good suggestions for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I’ve come up with a method: during the forward computation in ggml_cgraph, add a marker when encountering the first RoPE operator and perform the cache calculation. The subsequent RoPE operators would then skip the computation. This way, we can avoid parsing the tensor’s name. I will try this way.
The current scenario is that when computing the sine/cosine for the rope, it is calculated only for the first layer on each device, and other layers are reused. It is necessary to identify which layer is the first layer on the current device. However, currently there is no way to obtain the layer number within the device backend, so it can only be inferred from the name. Would it be possible to store the layer number information directly in the tensor? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This is an excellent refactor that removes the previous hacky implementation! We’ll do another refactor later for the case where |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIU the cache will be computed for every graph_compute()
call on the first rope operation and then will be reused for all remaining rope operations in the current graph. I think this assumes that all ropes are the same in all layers. In general this is not guaranteed and will likely cause problems in the future.
As far as I know, apart from freq_factors, the RoPE cache used for each token only depends on the position and some hyperparameters. So far, I haven’t observed any cases where the cache differs across layers. Could you clarify in what situations the RoPE cache would be layer-dependent? Thanks! |
For example Gemma3n uses different Lines 10635 to 10639 in f15d515
Also, this simple test program would likely not run correctly if it was offloaded to the CANN backend (currently it always runs on the CPU): Lines 156 to 177 in f15d515
Generally, we want to keep the code generic and not make assumptions about the application. |
@ggerganov Thank you for the reminder. All unsafe caches have been removed, and only the parts that can be determined through parameters to remain unchanged are cached. |
LGTM. However, for transformer models, removing the sin/cos cache in ROPE leads to a performance drop compared to before. We’ll need to explore a more elegant way to determine positional information in the future to ensure the cache check remains precise. |
* CANN: fix RoPE cache issue on multi-device RoPE cache only needs to be computed once per token. However, in multi-device scenarios, not every device starts computation from layer 0, which may lead to unallocated memory issues and precision errors. This commit records the first layer of each device to avoid the above issues. * CANN: Optimize first-layer detection method * CANN: Remove trailing whitespace * CANN: Only cache the data that can be determined as unchanged through the parameters. * CANN: Update function comment
RoPE cache only needs to be computed once per token. However, in multi-device scenarios, not every device starts computation from layer 0, which may lead to unallocated memory issues and precision errors.
This commit records the first layer of each device to avoid the above issues.
Update
To avoid the RoPE cache being overly coupled to a specific model, we currently only cache those entries that can be determined, from the input parameters, not to undergo any transformation.
./bin/test-backend-ops test -b CANN0 -o ROPE
Testing 3 devices
Backend 1/3: CANN0
Device description: Ascend910B4
Device memory: 30196 MB (29802 MB free)
11837/11837 tests passed
Backend CANN0: OK
Backend 2/3: CANN1
Skipping
Backend 3/3: CPU
Skipping
3/3 backends passed
OK
Make sure to read the contributing guidelines before submitting a PR