@@ -91,6 +91,22 @@ def __call__(
9191 # prepare intermediates cache
9292 activations = IntermediatesCache .from_dataloader (dataloader , model_device )
9393
94+ # Define helper function to materialize meta tensors once
95+ # Fixes "Tensor.item() on meta tensors" error when using device offloading
96+ def _materialize_meta_tensors (obj ):
97+ if isinstance (obj , torch .Tensor ) and obj .is_meta :
98+ return torch .zeros_like (obj , device = model_device )
99+ elif isinstance (obj , dict ):
100+ return {
101+ k : _materialize_meta_tensors (v )
102+ for k , v in obj .items ()
103+ }
104+ elif isinstance (obj , (list , tuple )):
105+ return type (obj )(
106+ [_materialize_meta_tensors (x ) for x in obj ]
107+ )
108+ return obj
109+
94110 for subgraph_index , subgraph in enumerate (subgraphs ):
95111 # prepare tqdm description texts
96112 calib_desc = f"({ subgraph_index + 1 } /{ num_subgraphs } ): Calibrating"
@@ -101,26 +117,6 @@ def __call__(
101117 # do a preliminary pass to trigger modifier hooks
102118 for batch_idx in tqdm (range (len (dataloader )), desc = calib_desc ):
103119 inputs = activations .fetch (batch_idx , subgraph .input_names )
104-
105- # PATCH: Materialize meta tensors before traced code
106- # Fixes "Tensor.item() on meta tensors" error
107- def _materialize_meta_tensors (obj ):
108- if isinstance (obj , torch .Tensor ) and obj .is_meta :
109- device = torch .device (
110- "cuda:0" if torch .cuda .is_available () else "cpu"
111- )
112- return torch .zeros_like (obj , device = device )
113- elif isinstance (obj , dict ):
114- return {
115- k : _materialize_meta_tensors (v )
116- for k , v in obj .items ()
117- }
118- elif isinstance (obj , (list , tuple )):
119- return type (obj )(
120- [_materialize_meta_tensors (x ) for x in obj ]
121- )
122- return obj
123-
124120 inputs = _materialize_meta_tensors (inputs )
125121 subgraph .forward (model , ** inputs )
126122
@@ -131,25 +127,6 @@ def _materialize_meta_tensors(obj):
131127 with HooksMixin .disable_hooks ():
132128 for batch_idx in tqdm (range (len (dataloader )), desc = prop_desc ):
133129 inputs = activations .fetch (batch_idx , subgraph .input_names )
134-
135- # PATCH: Materialize meta tensors (same as above)
136- def _materialize_meta_tensors (obj ):
137- if isinstance (obj , torch .Tensor ) and obj .is_meta :
138- device = torch .device (
139- "cuda:0" if torch .cuda .is_available () else "cpu"
140- )
141- return torch .zeros_like (obj , device = device )
142- elif isinstance (obj , dict ):
143- return {
144- k : _materialize_meta_tensors (v )
145- for k , v in obj .items ()
146- }
147- elif isinstance (obj , (list , tuple )):
148- return type (obj )(
149- [_materialize_meta_tensors (x ) for x in obj ]
150- )
151- return obj
152-
153130 inputs = _materialize_meta_tensors (inputs )
154131 output = subgraph .forward (model , ** inputs )
155132
0 commit comments