File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed
torchao/prototype/tensor_conversion Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -124,9 +124,16 @@ def _find_tied_params(model):
124
124
125
125
126
126
def _convert_model_for_aarch64 (
127
- model , * , tensor_type = "auto" , intx_packing_format = "opaque_torchao_auto"
127
+ model ,
128
+ * ,
129
+ tensor_type = "auto" ,
130
+ intx_packing_format = "opaque_torchao_auto" ,
131
+ convert_tied_embedding = True ,
132
+ convert_linear = True ,
128
133
):
129
- module_name_to_tied_param = _find_tied_params (model )
134
+ module_name_to_tied_param = (
135
+ _find_tied_params (model ) if convert_tied_embedding else {}
136
+ )
130
137
131
138
# Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor
132
139
for name , module in model .named_modules ():
@@ -138,7 +145,7 @@ def _convert_model_for_aarch64(
138
145
print ("Skipping converting nn.Embedding {name} because it is not tied" )
139
146
continue
140
147
141
- if not isinstance (module , nn .Linear ):
148
+ if not ( convert_linear and isinstance (module , nn .Linear ) ):
142
149
continue
143
150
144
151
weight = module .weight
You can’t perform that action at this time.
0 commit comments