Skip to content

Commit b47f1a3

Browse files
authored
Add conversion options
Differential Revision: D83087813 Pull Request resolved: #3051
1 parent f92b898 commit b47f1a3

File tree

1 file changed

+10
-3
lines changed
  • torchao/prototype/tensor_conversion

1 file changed

+10
-3
lines changed

torchao/prototype/tensor_conversion/api.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,16 @@ def _find_tied_params(model):
124124

125125

126126
def _convert_model_for_aarch64(
127-
model, *, tensor_type="auto", intx_packing_format="opaque_torchao_auto"
127+
model,
128+
*,
129+
tensor_type="auto",
130+
intx_packing_format="opaque_torchao_auto",
131+
convert_tied_embedding=True,
132+
convert_linear=True,
128133
):
129-
module_name_to_tied_param = _find_tied_params(model)
134+
module_name_to_tied_param = (
135+
_find_tied_params(model) if convert_tied_embedding else {}
136+
)
130137

131138
# Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor
132139
for name, module in model.named_modules():
@@ -138,7 +145,7 @@ def _convert_model_for_aarch64(
138145
print("Skipping converting nn.Embedding {name} because it is not tied")
139146
continue
140147

141-
if not isinstance(module, nn.Linear):
148+
if not (convert_linear and isinstance(module, nn.Linear)):
142149
continue
143150

144151
weight = module.weight

0 commit comments

Comments
 (0)