Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dist/
*.gz
*-ubyte
*.pth
*.pt
*.onnx
*.npz
onnx/*
Expand Down
45 changes: 5 additions & 40 deletions DeepQuant/CustomForwards/Activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,28 @@
#
# Federico Brancasi <[email protected]>


import torch.nn as nn
from torch import Tensor
from brevitas.nn.quant_layer import QuantNonLinearActLayer
from torch import Tensor


class InnerForwardImplWrapperActivation(nn.Module):
"""
A small wrapper around the activation function of a Brevitas QuantActivation layer.

This wrapper exposes the original activation function as a standalone submodule
so that FX tracing can display it as a separate node.
"""
class WrapperActivation(nn.Module):
"""Expose inner activation so FX sees it as a leaf."""

def __init__(self, actImpl: nn.Module) -> None:
"""
Args:
act_impl: The original activation function module (e.g. an instance of nn.ReLU).
"""
super().__init__()
self.actImpl = actImpl

def forward(self, quantInput: Tensor) -> Tensor:
"""
Applies the wrapped activation function.

Args:
quant_input: Input tensor after input quantization.

Returns:
Output tensor after applying the activation.
"""
return self.actImpl(quantInput)


def quantActivationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
"""
Unrolled forward pass for a Brevitas QuantActivation layer.

Steps:
1) Apply self.input_quant to the input.
2) Apply the activation function via the wrapped activation implementation.
3) Apply self.act_quant to the activation output.

Args:
self: The QuantNonLinearActLayer instance.
inp: The input tensor.

Returns:
Output tensor after applying activation and output quantization.
"""
def activationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
"""Unroll input→act→output quant steps."""
quantInput = self.input_quant(inp) if self.input_quant is not None else inp
# Use the wrapped activation if available; otherwise pass through.
if hasattr(self, "wrappedActImpl"):
output = self.wrappedActImpl(quantInput)
else:
output = quantInput
import IPython; IPython.embed()
quantOutput = self.act_quant(output) if self.act_quant is not None else output
return quantOutput
75 changes: 0 additions & 75 deletions DeepQuant/CustomForwards/Linear.py

This file was deleted.

32 changes: 4 additions & 28 deletions DeepQuant/CustomForwards/MultiHeadAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,22 @@
#
# Federico Brancasi <[email protected]>


import math

import torch
import torch.nn.functional as F
from torch import Tensor
from brevitas.nn.quant_mha import QuantMultiheadAttention
from torch import Tensor


def unrolledQuantMhaForward(
def mhaForward(
self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor
) -> Tensor:
"""
Export-friendly forward that explicitly unrolls the multi-head logic.

Steps:
1) Q, K, V projections
2) Reshapes & permutes for multi-head
3) Scales queries
4) Applies softmax and intermediate quantizations
5) Out projection

Args:
self: The QuantMultiheadAttention instance.
query: The query tensor of shape [sequence_len, batch_size, embed_dim].
key: The key tensor, same shape as query.
value: The value tensor, same shape as query.

Returns:
A torch.Tensor of shape [sequence_len, batch_size, embed_dim]
after the unrolled MHA steps.
"""
# 1) Q, K, V projections
"""Explicit, export-friendly MHA forward."""
qOut = self.q_proj(query)
kOut = self.k_proj(key)
vOut = self.v_proj(value)

# 2) Multi-head reshape
seqLen, batchSize, embedDim = qOut.shape
headDim = embedDim // self.num_heads

Expand All @@ -60,11 +39,9 @@ def unrolledQuantMhaForward(
.reshape(batchSize * self.num_heads, seqLen, headDim)
)

# 3) Scale queries, then quantize
qScaled = qOut / math.sqrt(headDim)
qScaled = self.q_scaled_quant(qScaled)

# 4) Transpose + quantize K, compute attention weights
k_t = kOut.transpose(-2, -1)
k_t = self.k_transposed_quant(k_t)

Expand All @@ -73,7 +50,6 @@ def unrolledQuantMhaForward(
attnWeights = F.softmax(attnWeights, dim=-1)
attnWeights = self.attn_output_weights_quant(attnWeights)

# 5) Quantize V, multiply, reshape back, and final out projection
vOut = self.v_quant(vOut)
attnOutput = torch.bmm(attnWeights, vOut)

Expand Down
36 changes: 36 additions & 0 deletions DeepQuant/CustomForwards/WBIOL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2025 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Federico Brancasi <[email protected]>

import torch.nn as nn
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer
from torch import Tensor


class WrapperWBIOL(nn.Module):
"""Expose `inner_forward_impl` as a standalone submodule."""

def __init__(self, innerForwardImpl: nn.Module) -> None:
super().__init__()
self.innerForwardImpl = innerForwardImpl

def forward(
self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor
) -> Tensor:
return self.innerForwardImpl(quantInput, quantWeight, quantBias)


def WBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor:
"""Quant-in → quant-weight/bias → matmul → quant-out."""
quantInput = self.input_quant(inp)
quantWeight = self.weight_quant(self.weight)

quantBias = None
if self.bias is not None:
quantBias = self.bias_quant(self.bias, quantInput, quantWeight)

output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias)
quantOutput = self.output_quant(output)
return quantOutput
110 changes: 0 additions & 110 deletions DeepQuant/CustomTracer.py

This file was deleted.

Loading