Skip to content

Commit f51ef7f

Browse files
committed
Fix device hardcoding and duplicate function in pipeline.py
- Use model_device instead of hardcoded cuda:0 for multi-GPU compatibility - Define _materialize_meta_tensors once before loop to avoid duplication - Improves maintainability and correctness in multi-device environments Signed-off-by: ronantakizawa <[email protected]>
1 parent c6c8549 commit f51ef7f

File tree

1 file changed

+16
-39
lines changed

1 file changed

+16
-39
lines changed

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def __call__(
9191
# prepare intermediates cache
9292
activations = IntermediatesCache.from_dataloader(dataloader, model_device)
9393

94+
# Define helper function to materialize meta tensors once
95+
# Fixes "Tensor.item() on meta tensors" error when using device offloading
96+
def _materialize_meta_tensors(obj):
97+
if isinstance(obj, torch.Tensor) and obj.is_meta:
98+
return torch.zeros_like(obj, device=model_device)
99+
elif isinstance(obj, dict):
100+
return {
101+
k: _materialize_meta_tensors(v)
102+
for k, v in obj.items()
103+
}
104+
elif isinstance(obj, (list, tuple)):
105+
return type(obj)(
106+
[_materialize_meta_tensors(x) for x in obj]
107+
)
108+
return obj
109+
94110
for subgraph_index, subgraph in enumerate(subgraphs):
95111
# prepare tqdm description texts
96112
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
@@ -101,26 +117,6 @@ def __call__(
101117
# do a preliminary pass to trigger modifier hooks
102118
for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc):
103119
inputs = activations.fetch(batch_idx, subgraph.input_names)
104-
105-
# PATCH: Materialize meta tensors before traced code
106-
# Fixes "Tensor.item() on meta tensors" error
107-
def _materialize_meta_tensors(obj):
108-
if isinstance(obj, torch.Tensor) and obj.is_meta:
109-
device = torch.device(
110-
"cuda:0" if torch.cuda.is_available() else "cpu"
111-
)
112-
return torch.zeros_like(obj, device=device)
113-
elif isinstance(obj, dict):
114-
return {
115-
k: _materialize_meta_tensors(v)
116-
for k, v in obj.items()
117-
}
118-
elif isinstance(obj, (list, tuple)):
119-
return type(obj)(
120-
[_materialize_meta_tensors(x) for x in obj]
121-
)
122-
return obj
123-
124120
inputs = _materialize_meta_tensors(inputs)
125121
subgraph.forward(model, **inputs)
126122

@@ -131,25 +127,6 @@ def _materialize_meta_tensors(obj):
131127
with HooksMixin.disable_hooks():
132128
for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc):
133129
inputs = activations.fetch(batch_idx, subgraph.input_names)
134-
135-
# PATCH: Materialize meta tensors (same as above)
136-
def _materialize_meta_tensors(obj):
137-
if isinstance(obj, torch.Tensor) and obj.is_meta:
138-
device = torch.device(
139-
"cuda:0" if torch.cuda.is_available() else "cpu"
140-
)
141-
return torch.zeros_like(obj, device=device)
142-
elif isinstance(obj, dict):
143-
return {
144-
k: _materialize_meta_tensors(v)
145-
for k, v in obj.items()
146-
}
147-
elif isinstance(obj, (list, tuple)):
148-
return type(obj)(
149-
[_materialize_meta_tensors(x) for x in obj]
150-
)
151-
return obj
152-
153130
inputs = _materialize_meta_tensors(inputs)
154131
output = subgraph.forward(model, **inputs)
155132

0 commit comments

Comments
 (0)