From 4ebf3664e45177605c7d97f06ee35d15d56dd17f Mon Sep 17 00:00:00 2001 From: Rahul Solanki Date: Fri, 3 May 2024 01:02:20 +0000 Subject: [PATCH] handle weight sharing with init_on_device --- src/accelerate/big_modeling.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 94febb5d3dd..e27c79c9674 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -128,7 +128,14 @@ def register_empty_parameter(module, name, param): param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad - module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + # When we have a case of tensor2 = tensor1, it would call the set_attr + # of param, which in turn would call the register_parameter API. + # In this case, the new param is already on meta-device, since it was moved + # previously when it was initialized. Hence, when resetting, you can + # directly assign that tensor instead of re-init. If you re-init you would + # lose the relationship. + module._parameters[name] = param if param.device == device else \ + param_cls(module._parameters[name].to(device), **kwargs) def register_empty_buffer(module, name, buffer, persistent=True): old_register_buffer(module, name, buffer, persistent=persistent)