-
Notifications
You must be signed in to change notification settings - Fork 428
Closed
Description
Hi, I'm using aimet_torch version 2.7.0.
While performing QAT training on sim.model, I encountered cases where min > max, which caused the computed scale to become negative.
To work around this, I only performed PTQ and saved the state_dict of sim.model after compute_encodings. Then I modified the internal aimet_torch code as shown below and reloaded the model for QAT. After this change, the issue no longer occurred.
- In
aimet_torch/v2/quantization/affine/quantizer.py->AffineQuantizerBase
...
def _is_min_max_quantizer(self):
# return "min" in self._parameters and "max" in self._parameters
return "min" in self._parameters
@property
def max(self):
return self.min + torch.nn.functional.softplus(self.delta)
@max.setter
def max(self, value):
softplus_delta = value - self.min
self.delta.data = softplus_delta + torch.log(-torch.expm1(-softplus_delta))
...
def _reparametrize_to_min_max(self):
# pylint: disable=attribute-defined-outside-init
if self._is_min_max_quantizer():
return
is_initialized = self.is_initialized()
self.register_quantization_parameter(
"min", nn.Parameter(-torch.ones(self.shape))
)
self.register_quantization_parameter(
"delta", nn.Parameter(torch.ones(self.shape))
)
# self.register_quantization_parameter(
# "max", nn.Parameter(torch.ones(self.shape))
# )
...- Load PTQ sim.model
# create dummy sim instance
...
model_state_dict = torch.load(model_path, map_location=device)
sim.model.load_state_dict(model_state_dict, strict=False)
for key, value in model_state_dict.items():
if key.endswith(".max"):
parts = key.split(".")
target = sim.model
for p in parts[:-1]:
target = getattr(target, p)
setattr(target, "max", value)
...
# perform QATIs there any official way to ensure min < max during QAT without modifying aimet_torch like this?
If not, are there any plans to handle this case in future versions?
Thanks!
bwkim71
Metadata
Metadata
Assignees
Labels
No labels