File tree Expand file tree Collapse file tree 1 file changed +15
-1
lines changed
src/compressed_tensors/utils Expand file tree Collapse file tree 1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change 8686 "offloaded_dispatch" ,
8787 "disable_offloading" ,
8888 "remove_dispatch" ,
89+ "cast_to_device" ,
8990]
9091
9192
@@ -169,6 +170,19 @@ def update_parameter_data(
169170""" Candidates for Upstreaming """
170171
171172
173+ def cast_to_device (device_spec : Union [int , torch .device ]) -> torch .device :
174+ """
175+ Convert an integer device index or torch.device into a torch.device object.
176+
177+ :param device_spec: Device index (int) or torch.device object.
178+ Negative integers map to CPU.
179+ :return: torch.device corresponding to the given device specification.
180+ """
181+ if isinstance (device_spec , int ):
182+ return torch .device (f"cuda:{ device_spec } " if device_spec >= 0 else "cpu" )
183+ return device_spec
184+
185+
172186def get_execution_device (module : torch .nn .Module ) -> torch .device :
173187 """
174188 Get the device which inputs should be moved to before module execution.
@@ -179,7 +193,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device:
179193 """
180194 for submodule in module .modules ():
181195 if has_offloaded_params (submodule ):
182- return submodule ._hf_hook .execution_device
196+ return cast_to_device ( submodule ._hf_hook .execution_device )
183197
184198 param = next (submodule .parameters (recurse = False ), None )
185199 if param is not None :
You can’t perform that action at this time.
0 commit comments