Skip to content

Add support for torch.export models #1498

@tolleybot

Description

@tolleybot

TorchSharp currently lacks support for loading and executing PyTorch models exported via torch.export (ExportedProgram format, .pt2 files). While TorchSharp supports TorchScript (JIT) models via torch.jit.load(), it cannot load the newer torch.export format introduced in PyTorch 2.x. This is becoming critical as the PyTorch ecosystem transitions from ONNX to torch.export for model deployment.

torch.export is PyTorch 2.x's modern, recommended approach for model export, using TorchDynamo for symbolic tracing and producing optimized single-graph representations with only ATen-level operations.

import torch
import torch.nn as nn

# Define a simple model
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

# Export the model
model = MyModule()
example_inputs = (torch.randn(1, 10),)
exported_program = torch.export.export(model, example_inputs)

# Save to .pt2 file
torch.export.save(exported_program, "model.pt2")

# Load and run the exported model
loaded_program = torch.export.load("model.pt2")
result = loaded_program.module()(torch.randn(1, 10))

Describe the solution you'd like

Add support for loading and executing .pt2 ExportedProgram models in TorchSharp. The solution should include:

  1. Loading API - Ability to load AOTInductor-compiled .pt2 files:
// New namespace for torch.export
using var exportedProgram = torch.export.load("model.pt2");
  1. Execution API - Ability to run forward pass (inference only):
var inputTensor = torch.randn(1, 10);
var result = exportedProgram.run(inputTensor);
  1. Implementation approach:
    • Use PyTorch C++ API's torch::inductor::AOTIModelPackageLoader (requires LibTorch 2.9+)
    • Leverage existing TorchSharp infrastructure (three-layer architecture: C# → PInvoke → C++ wrapper → LibTorch)
    • Inference-only API (no training, parameter updates, or device movement)
    • Models must be compiled with torch._inductor.aoti_compile_and_package() in Python

Technical considerations:

  • LibTorch 2.9+ includes torch::inductor::AOTIModelPackageLoader for loading AOTInductor-compiled models
  • .pt2 files from torch.export.save() are Python-only and cannot be loaded in C++
  • Only .pt2 files from torch._inductor.aoti_compile_and_package() work in C++
  • Models are compiled for specific device (CPU/CUDA) at build time
  • Can follow similar implementation pattern as ScriptModule (see src/TorchSharp/JIT/ScriptModule.cs)

Environment Information:

  • OS: All platforms (Windows, Linux, MacOS)
  • Package Type: All packages (torchsharp-cpu, torchsharp-cuda-windows, torchsharp-cuda-linux)
  • Version: Latest (currently 0.105.x) and future versions

Additional context

Why this matters:

  • torch.export is the recommended path forward for PyTorch model deployment (PyTorch 2.x+)
  • Without this support, .NET developers cannot consume the latest PyTorch model formats
  • TorchScript (JIT) is becoming legacy; torch.export is the future

Related Issues:

Comparison: TorchScript vs torch.export:

Feature TorchScript (JIT) torch.export
File Format .pt .pt2
PyTorch Era 1.x 2.x+ (current)
Status in TorchSharp ✅ Supported ❌ Not supported
Use Case Research Production deployment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions