diff --git a/src/torch_onnx_models/onnx_passes/_fold_transpose.py b/src/torch_onnx_models/onnx_passes/_fold_transpose.py index dfecb73..82b65bf 100644 --- a/src/torch_onnx_models/onnx_passes/_fold_transpose.py +++ b/src/torch_onnx_models/onnx_passes/_fold_transpose.py @@ -32,11 +32,22 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: perm = node.attributes.get_ints("perm", reversed(range(len(shape)))) # Create a lazy transposed tensor - torch_tensor = initializer.raw - if isinstance(torch_tensor, torch.Tensor): + if isinstance(initializer.raw, torch.Tensor): - def tensor_func(tensor=torch_tensor): + def tensor_func(tensor=initializer.raw): return tensor_adapters.TorchTensor(tensor.permute(*perm), name=name) + + elif isinstance(initializer, ir.LazyTensor): + # Unwrap the lazy tensor to get the underlying torch tensor + def tensor_func(tensor=initializer): + while isinstance(tensor, ir.LazyTensor): + tensor = tensor.raw() + torch_tensor = tensor.raw + if not isinstance(torch_tensor, torch.Tensor): + torch_tensor = torch.from_numpy(tensor.numpy()) + return tensor_adapters.TorchTensor( + torch_tensor.permute(*perm), name=name + ) else: def tensor_func(tensor=initializer):