-
Notifications
You must be signed in to change notification settings - Fork 27
Rmsnorm #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Rmsnorm #136
Conversation
|
Important Review skippedReview was skipped as selected files did not have any reviewable changes. 💤 Files selected but had no reviewable changes (3)
You can disable this status message by setting the 📝 WalkthroughSummary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings. WalkthroughThe pull request introduces Pow and Sqrt operation support to the generic platform. Changes include adding new parser classes, layer definitions, bindings, templates with kernel selection logic, and C kernel implementations for float32 operations. CI tests for these new operations are also added. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (6)
Deeploy/DeeployTypes.py (1)
327-349: Docstring matches behavior; minor cleanup possible in visited setThe updated “live ancestors” wording matches the breadth‑first walk over the alias graph and better explains what’s being checked. One small implementation nit:
visited = set(self.name)builds a set of characters rather than a set of buffer names; using{self.name}would make the intent clearer and avoid mixing types invisited, even though it doesn’t currently break correctness.TargetLibraries/Generic/src/Sqrt_fp32.c (1)
1-13: Elementwise fp32 sqrt kernel looks correctThe
Sqrt_fp32_fp32implementation is straightforward and type‑consistent withfloat32_t/int32_t, doing an elementwisesqrtfover the input range. Assumingsqrtfis declared via the transitive includes fromDeeployBasicMath.h, there are no correctness issues here.TargetLibraries/Generic/src/Pow_fp16.c (1)
1-26: Pow_fp16 implementation is correct for integer exponents; consider faster exponentiationThe kernel correctly handles zero and negative integer exponents and writes elementwise
base^exponentintodata_out. For typical small exponents this is fine, but the linearfor (j = 0; j < exp; j++)loop makes runtime proportional to |exponent|. If you expect larger exponents or care about worst‑case latency, consider switching to exponentiation‑by‑squaring on a promotedfloataccumulator for better performance and numerical behavior, while preserving thefloat16_tI/O interface.Deeploy/Targets/Generic/Layers.py (1)
230-240: PowLayer/SqrtLayer wiring is minimal and consistent with existing layersThe new
PowLayerandSqrtLayerclasses correctly follow the existing pattern of thinONNXLayerwrappers around mappers. For current usage this is sufficient. If accurate op‑count reporting or explicit broadcasting for Pow becomes important, you may later want to overridecomputeOps(e.g., proportional to tensor size) and, if needed,computeShapessimilar toAddLayer/MulLayer.Deeploy/Targets/Generic/Parsers.py (1)
1967-2001: Duplicate PowParser/SqrtParser definitions and mismatched exponent fieldThere are two separate definitions of
PowParserandSqrtParserin this file: one here and another at lines 2813–2869. The later definitions override these ones at import time, so this block is effectively dead code and also:
- Triggers lints (
PowParser/SqrtParserredefinition, undefinedConstantBufferon Line 1990).- Uses
exponent_valueinstead ofexponent, which doesn’t matchFloatPowTemplate.alignToContext, whereexponentis expected andexponent_valueis derived there.To avoid confusion and static-analysis noise, I’d consolidate to a single implementation (the newer one) and delete this earlier block entirely. A minimal fix would look like:
-class PowParser(NodeParser): - ... - - -class SqrtParser(NodeParser): - ... -leaving only the final
PowParser/SqrtParserdefinitions at the bottom of the file.Also applies to: 2003-2023
Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py (1)
14-28: FloatSqrtTemplate matches kernels; consider removing unuseddata_outThe template and
alignToContextcorrectly:
- Infer
data_typefromdata_inand- Dispatch to
Sqrt_fp32_fp32/Sqrt_fp16_fp16with the right arguments.The only nit is that
data_out = ctxt.lookup(operatorRepresentation['data_out'])is never used inalignToContext; you can safely drop that line to quiet Ruff and keep the function minimal.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
.github/workflows/ci-platform-generic.yml(1 hunks)Deeploy/DeeployTypes.py(3 hunks)Deeploy/Targets/Generic/Bindings.py(2 hunks)Deeploy/Targets/Generic/Layers.py(1 hunks)Deeploy/Targets/Generic/Parsers.py(2 hunks)Deeploy/Targets/Generic/Platform.py(3 hunks)Deeploy/Targets/Generic/Templates/FloatPowTemplate.py(1 hunks)Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py(1 hunks)TargetLibraries/Generic/inc/DeeployBasicMath.h(1 hunks)TargetLibraries/Generic/inc/kernel/Pow.h(1 hunks)TargetLibraries/Generic/inc/kernel/Sqrt.h(1 hunks)TargetLibraries/Generic/inc/types.h(1 hunks)TargetLibraries/Generic/src/Pow_fp16.c(1 hunks)TargetLibraries/Generic/src/Pow_fp32.c(1 hunks)TargetLibraries/Generic/src/Sqrt_fp16.c(1 hunks)TargetLibraries/Generic/src/Sqrt_fp32.c(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
TargetLibraries/Generic/inc/kernel/Sqrt.h (3)
TargetLibraries/Generic/src/Sqrt_fp32.c (1)
Sqrt_fp32_fp32(9-13)DeeployTest/testUtils/dmaUtils.py (1)
size(72-73)TargetLibraries/Generic/src/Sqrt_fp16.c (1)
Sqrt_fp16_fp16(9-13)
TargetLibraries/Generic/inc/kernel/Pow.h (2)
TargetLibraries/Generic/src/Pow_fp32.c (1)
Pow_fp32_int32_fp32(9-27)TargetLibraries/Generic/src/Pow_fp16.c (1)
Pow_fp16_int32_fp16(8-26)
Deeploy/Targets/Generic/Layers.py (1)
Deeploy/DeeployTypes.py (2)
ONNXLayer(1819-2147)NodeMapper(1660-1816)
Deeploy/Targets/Generic/Parsers.py (1)
Deeploy/Targets/Snitch/Parsers.py (3)
parseNode(15-26)parseNodeCtxt(28-42)parseNodeCtxt(60-74)
Deeploy/Targets/Generic/Templates/FloatPowTemplate.py (3)
Deeploy/DeeployTypes.py (2)
NetworkContext(508-1020)NodeTemplate(87-229)Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py (1)
alignToContext(14-28)Deeploy/AbstractDataTypes.py (1)
typeName(312-313)
Deeploy/Targets/Generic/Platform.py (3)
Deeploy/Targets/Generic/Layers.py (2)
PowLayer(230-233)SqrtLayer(236-239)Deeploy/Targets/Generic/Parsers.py (4)
PowParser(1967-2000)PowParser(2814-2846)SqrtParser(2003-2023)SqrtParser(2849-2869)Deeploy/DeeployTypes.py (1)
NodeMapper(1660-1816)
Deeploy/Targets/Generic/Bindings.py (2)
Deeploy/CommonExtensions/DataTypes.py (2)
float16_t(67-71)float32_t(74-78)Deeploy/DeeployTypes.py (2)
CodeTransformation(2290-2324)NodeBinding(1512-1657)
🪛 Ruff (0.14.5)
Deeploy/Targets/Generic/Parsers.py
1978-1978: Unused method argument: channels_first
(ARG002)
1990-1990: Undefined name ConstantBuffer
(F821)
1995-1996: Prefer TypeError exception for invalid type
(TRY004)
1995-1996: Avoid specifying long messages outside the exception class
(TRY003)
2014-2014: Unused method argument: channels_first
(ARG002)
2814-2814: Redefinition of unused PowParser from line 1967
(F811)
2825-2825: Unused method argument: channels_first
(ARG002)
2849-2849: Redefinition of unused SqrtParser from line 2003
(F811)
2860-2860: Unused method argument: channels_first
(ARG002)
Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py
19-19: Local variable data_out is assigned to but never used
Remove assignment to unused variable data_out
(F841)
Deeploy/Targets/Generic/Templates/FloatPowTemplate.py
19-19: Local variable data_out is assigned to but never used
Remove assignment to unused variable data_out
(F841)
30-30: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (9)
.github/workflows/ci-platform-generic.yml (1)
76-81: New generic-kernels tests are wired correctly in CI listThe added Pow/Sqrt/RMSNorm tests fit the existing naming pattern and placement in the float test block; no workflow logic changes. Please just confirm the test targets are defined under those exact names so CI doesn’t fail on a typo.
Deeploy/DeeployTypes.py (1)
2565-2589: _mapNode refactor cleanly decouples engine selection from layer constructionThe new
_mapNodethat iterates engines and callsengine.Mapping[node.op](node)looks correct and keepsNetworkContainer._bindLayerssimpler by returning either anONNXLayeror side‑effect‑only mappings (e.g., Constant hoisting). Theisinstance(layer, ONNXLayer)guard in_bindLayersensures only real layers are scheduled, while Constant mappings still run via their lambda. This preserves multi‑engine support and fails fast with a clearRuntimeErrorwhen no mapping exists.TargetLibraries/Generic/inc/DeeployBasicMath.h (1)
47-55: Pow/Sqrt kernel includes are correctly integratedIncluding
kernel/Pow.handkernel/Sqrt.halongside the other basic math kernels is consistent and sufficient to expose the new operations to users ofDeeployBasicMath.h.TargetLibraries/Generic/src/Sqrt_fp16.c (1)
1-13: fp16 sqrt kernel is consistent with the fp32 path
Sqrt_fp16_fp16mirrors the fp32 implementation, applyingsqrtfelementwise and relying on thefloat16_ttypedef for the actual storage type. This is a reasonable, simple implementation for FP16 support and aligns with the new type definition.TargetLibraries/Generic/inc/types.h (1)
13-21: float16_t typedef is sensible and keeps non-FP16 platforms buildingDefining
float16_tas_Float16when compiler support is detected, and otherwise aliasing it tofloat, gives the new Pow/Sqrt FP16 kernels a consistent type while preserving buildability on targets without native FP16. The surrounding comments clearly document this fallback behavior.TargetLibraries/Generic/inc/kernel/Sqrt.h (1)
1-24: Sqrt kernel header matches implementationsThe include guard, DeeployBasicMath dependency, and fp32/fp16 prototypes are consistent with the corresponding C kernels; no issues from a correctness or integration perspective.
TargetLibraries/Generic/inc/kernel/Pow.h (1)
1-25: Pow kernel header is consistent with C implementationsPrototypes and guard are well-formed and match the Pow_fp32/Pow_fp16 C kernels; nothing blocking here.
Deeploy/Targets/Generic/Bindings.py (1)
10-11: Pow/Sqrt bindings are wired consistently with templates and typesThe new
BasicPowBindings/BasicSqrtBindingscorrectly:
- Use
float32_t/float16_tpointer types for inputs/outputs.- Bind to
FloatPowTemplate.referenceTemplateandFloatSqrtTemplate.referenceTemplate.- Reuse
DummyCheckerandBasicTransformerin line with nearby float ops.Once the Pow parser/template exponent checks are tightened as discussed, these bindings look sound.
Also applies to: 18-22, 121-133
Deeploy/Targets/Generic/Platform.py (1)
10-17: Pow/Sqrt integration into Generic platform is coherentThe new imports,
PowMapper/SqrtMapperdefinitions, and'Pow'/'Sqrt'entries inGenericMappingline up correctly with:
BasicPowBindings/BasicSqrtBindings,PowLayer/SqrtLayer, and- The Pow/Sqrt kernels exposed via
DeeployBasicMath.h.Assuming
DeeployBasicMath.hnow includes the newkernel/Pow.handkernel/Sqrt.h, the end‑to‑end wiring looks correct.Please double‑check that
DeeployBasicMath.hactually includes the new Pow/Sqrt kernel headers so generated code has the necessary prototypes.Also applies to: 20-22, 27-29, 56-57, 104-105
Deeploy/Targets/Generic/Parsers.py
Outdated
| ############################ | ||
|
|
||
|
|
||
| class PowParser(NodeParser): | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def parseNode(self, node: gs.Node) -> bool: | ||
| return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1 | ||
|
|
||
| def parseNodeCtxt(self, | ||
| ctxt: NetworkContext, | ||
| node: gs.Node, | ||
| channels_first: bool = True) -> Tuple[NetworkContext, bool]: | ||
|
|
||
| data_in = ctxt.lookup(node.inputs[0].name) | ||
| exponent = node.inputs[1] | ||
| data_out = ctxt.lookup(node.outputs[0].name) | ||
|
|
||
| self.operatorRepresentation['data_in'] = data_in.name | ||
| self.operatorRepresentation['data_out'] = data_out.name | ||
|
|
||
| # Check if exponent is a constant | ||
| if isinstance(exponent, gs.Constant): | ||
| exp_value = float(exponent.values) | ||
| self.operatorRepresentation['exponent'] = exp_value | ||
| self.operatorRepresentation['is_constant_exp'] = True | ||
| else: | ||
| exp_tensor = ctxt.lookup(exponent.name) | ||
| self.operatorRepresentation['exponent'] = exp_tensor.name | ||
| self.operatorRepresentation['is_constant_exp'] = False | ||
|
|
||
| self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) | ||
|
|
||
| return ctxt, True | ||
|
|
||
|
|
||
| class SqrtParser(NodeParser): | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def parseNode(self, node: gs.Node) -> bool: | ||
| return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1 | ||
|
|
||
| def parseNodeCtxt(self, | ||
| ctxt: NetworkContext, | ||
| node: gs.Node, | ||
| channels_first: bool = True) -> Tuple[NetworkContext, bool]: | ||
|
|
||
| data_in = ctxt.lookup(node.inputs[0].name) | ||
| data_out = ctxt.lookup(node.outputs[0].name) | ||
|
|
||
| self.operatorRepresentation['data_in'] = data_in.name | ||
| self.operatorRepresentation['data_out'] = data_out.name | ||
| self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) | ||
|
|
||
| return ctxt, True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Tighten PowParser/SqrtParser to the supported “scalar constant integer exponent” subset
With the current wiring:
PowParser.parseNodeCtxtstores a float exponent forgs.Constantinputs, or a tensor name otherwise.FloatPowTemplate.alignToContextthen casts that float toint, only rejecting tensor exponents if the value is a string.
Given the C kernels (Pow_fp32_int32_fp32/Pow_fp16_int32_fp16) only support integer exponents, this means any non‑integer constant exponent will be silently truncated before codegen, which is a functional divergence from general ONNX Pow.
I’d recommend tightening PowParser here to enforce what the backend actually supports:
- Require the exponent input to be a scalar
gs.Constant. - Extract its scalar value and store it as a Python number in
operatorRepresentation['exponent']. - Reject any tensor/broadcast exponents up front.
For example:
- data_in = ctxt.lookup(node.inputs[0].name)
- exponent = node.inputs[1]
+ data_in = ctxt.lookup(node.inputs[0].name)
+ exponent = node.inputs[1]
data_out = ctxt.lookup(node.outputs[0].name)
@@
- self.operatorRepresentation['data_in'] = data_in.name
- self.operatorRepresentation['data_out'] = data_out.name
-
- # Check if exponent is a constant
- if isinstance(exponent, gs.Constant):
- exp_value = float(exponent.values)
- self.operatorRepresentation['exponent'] = exp_value
- self.operatorRepresentation['is_constant_exp'] = True
- else:
- exp_tensor = ctxt.lookup(exponent.name)
- self.operatorRepresentation['exponent'] = exp_tensor.name
- self.operatorRepresentation['is_constant_exp'] = False
+ self.operatorRepresentation['data_in'] = data_in.name
+ self.operatorRepresentation['data_out'] = data_out.name
+
+ # Only scalar constant exponents are supported
+ if not isinstance(exponent, gs.Constant):
+ raise TypeError(f"Pow: exponent input for node {node.name} must be a scalar Constant")
+
+ exp_vals = np.asarray(exponent.values).reshape(-1)
+ if exp_vals.size != 1:
+ raise ValueError(f"Pow: only scalar exponents are supported, got shape {exp_vals.shape}")
+
+ self.operatorRepresentation['exponent'] = float(exp_vals[0])
@@
- self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
-
- return ctxt, True
+ self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
+ return ctxt, TrueYou can keep SqrtParser here and drop the earlier duplicate, as mentioned in the previous comment.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ############################ | |
| class PowParser(NodeParser): | |
| def __init__(self): | |
| super().__init__() | |
| def parseNode(self, node: gs.Node) -> bool: | |
| return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1 | |
| def parseNodeCtxt(self, | |
| ctxt: NetworkContext, | |
| node: gs.Node, | |
| channels_first: bool = True) -> Tuple[NetworkContext, bool]: | |
| data_in = ctxt.lookup(node.inputs[0].name) | |
| exponent = node.inputs[1] | |
| data_out = ctxt.lookup(node.outputs[0].name) | |
| self.operatorRepresentation['data_in'] = data_in.name | |
| self.operatorRepresentation['data_out'] = data_out.name | |
| # Check if exponent is a constant | |
| if isinstance(exponent, gs.Constant): | |
| exp_value = float(exponent.values) | |
| self.operatorRepresentation['exponent'] = exp_value | |
| self.operatorRepresentation['is_constant_exp'] = True | |
| else: | |
| exp_tensor = ctxt.lookup(exponent.name) | |
| self.operatorRepresentation['exponent'] = exp_tensor.name | |
| self.operatorRepresentation['is_constant_exp'] = False | |
| self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) | |
| return ctxt, True | |
| class SqrtParser(NodeParser): | |
| def __init__(self): | |
| super().__init__() | |
| def parseNode(self, node: gs.Node) -> bool: | |
| return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1 | |
| def parseNodeCtxt(self, | |
| ctxt: NetworkContext, | |
| node: gs.Node, | |
| channels_first: bool = True) -> Tuple[NetworkContext, bool]: | |
| data_in = ctxt.lookup(node.inputs[0].name) | |
| data_out = ctxt.lookup(node.outputs[0].name) | |
| self.operatorRepresentation['data_in'] = data_in.name | |
| self.operatorRepresentation['data_out'] = data_out.name | |
| self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) | |
| return ctxt, True | |
| ############################ | |
| class PowParser(NodeParser): | |
| def __init__(self): | |
| super().__init__() | |
| def parseNode(self, node: gs.Node) -> bool: | |
| return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1 | |
| def parseNodeCtxt(self, | |
| ctxt: NetworkContext, | |
| node: gs.Node, | |
| channels_first: bool = True) -> Tuple[NetworkContext, bool]: | |
| data_in = ctxt.lookup(node.inputs[0].name) | |
| exponent = node.inputs[1] | |
| data_out = ctxt.lookup(node.outputs[0].name) | |
| self.operatorRepresentation['data_in'] = data_in.name | |
| self.operatorRepresentation['data_out'] = data_out.name | |
| # Only scalar constant exponents are supported | |
| if not isinstance(exponent, gs.Constant): | |
| raise TypeError(f"Pow: exponent input for node {node.name} must be a scalar Constant") | |
| exp_vals = np.asarray(exponent.values).reshape(-1) | |
| if exp_vals.size != 1: | |
| raise ValueError(f"Pow: only scalar exponents are supported, got shape {exp_vals.shape}") | |
| self.operatorRepresentation['exponent'] = float(exp_vals[0]) | |
| self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) | |
| return ctxt, True | |
| class SqrtParser(NodeParser): | |
| def __init__(self): | |
| super().__init__() | |
| def parseNode(self, node: gs.Node) -> bool: | |
| return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1 | |
| def parseNodeCtxt(self, | |
| ctxt: NetworkContext, | |
| node: gs.Node, | |
| channels_first: bool = True) -> Tuple[NetworkContext, bool]: | |
| data_in = ctxt.lookup(node.inputs[0].name) | |
| data_out = ctxt.lookup(node.outputs[0].name) | |
| self.operatorRepresentation['data_in'] = data_in.name | |
| self.operatorRepresentation['data_out'] = data_out.name | |
| self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) | |
| return ctxt, True |
🧰 Tools
🪛 Ruff (0.14.5)
2814-2814: Redefinition of unused PowParser from line 1967
(F811)
2825-2825: Unused method argument: channels_first
(ARG002)
2849-2849: Redefinition of unused SqrtParser from line 2003
(F811)
2860-2860: Unused method argument: channels_first
(ARG002)
| def alignToContext(self, ctxt: NetworkContext, | ||
| operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: | ||
|
|
||
| # Get input and output tensors | ||
| data_in = ctxt.lookup(operatorRepresentation['data_in']) | ||
| data_out = ctxt.lookup(operatorRepresentation['data_out']) | ||
|
|
||
| # Get data type (fp32 or fp16) | ||
| data_type = data_in._type.typeName | ||
| operatorRepresentation['data_type'] = data_type | ||
|
|
||
| # Exponent must be a constant integer | ||
| if 'exponent' in operatorRepresentation: | ||
| exponent_input = operatorRepresentation['exponent'] | ||
| if isinstance(exponent_input, str): | ||
| # It's a tensor name - not supported for integer exponent version | ||
| raise ValueError("Tensor exponent not supported. Use constant integer exponent.") | ||
| else: | ||
| # Convert to integer | ||
| operatorRepresentation['exponent_value'] = int(exponent_input) | ||
|
|
||
| # Calculate size | ||
| operatorRepresentation['size'] = int(np.prod(data_in.shape)) | ||
|
|
||
| return ctxt, operatorRepresentation, [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce scalar integer exponent and remove unused local in FloatPowTemplate
Two things here:
- Exponent semantics / silent truncation
alignToContext currently does:
- Accepts any numeric
operatorRepresentation['exponent']. - Casts it directly to
intwithout checking integer‑ness. - Rejects only string exponents (tensors) at runtime.
Combined with the C kernels’ int32_t exponent parameter, this means a constant exponent of 2.7 will be silently interpreted as 2, which is a correctness bug for general Pow.
To make the limitation explicit and safe, I’d strongly suggest:
- Requiring
exponentto be present and non‑string (i.e., constant, not a tensor). - Validating that it’s integer‑valued before casting.
- Raising a clear error otherwise.
For example:
- # Exponent must be a constant integer
- if 'exponent' in operatorRepresentation:
- exponent_input = operatorRepresentation['exponent']
- if isinstance(exponent_input, str):
- # It's a tensor name - not supported for integer exponent version
- raise ValueError("Tensor exponent not supported. Use constant integer exponent.")
- else:
- # Convert to integer
- operatorRepresentation['exponent_value'] = int(exponent_input)
+ # Exponent must be a scalar constant integer
+ exponent_input = operatorRepresentation.get('exponent')
+ if exponent_input is None:
+ raise ValueError("Pow: missing 'exponent' in operatorRepresentation")
+ if isinstance(exponent_input, str):
+ # Tensor exponents are not supported by the integer-exponent kernel
+ raise TypeError("Pow: tensor exponents are not supported; exponent must be a scalar constant")
+
+ exp_float = float(exponent_input)
+ if not exp_float.is_integer():
+ raise ValueError(f"Pow: only integer exponents are supported, got {exponent_input!r}")
+
+ operatorRepresentation['exponent_value'] = int(exp_float)This aligns codegen behavior with the actual kernel capabilities and avoids silent truncation.
- Minor: unused
data_outlocal
data_out = ctxt.lookup(operatorRepresentation['data_out']) is never used. You can safely drop that line to satisfy Ruff and keep alignToContext minimal.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.5)
19-19: Local variable data_out is assigned to but never used
Remove assignment to unused variable data_out
(F841)
30-30: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In Deeploy/Targets/Generic/Templates/FloatPowTemplate.py around lines 14–38,
alignToContext currently looks up an unused data_out and silently truncates
non-integer numeric exponents; remove the unused data_out = ctxt.lookup(...)
line, require that 'exponent' exists and is not a tensor (string), validate that
the provided constant exponent is integer-valued (e.g., reject floats that are
not whole numbers) before casting to int, store the integer in
operatorRepresentation['exponent_value'], and raise a clear ValueError if the
exponent is missing, a tensor, or not integer-valued.
| #include "DeeployBasicMath.h" | ||
|
|
||
| void Pow_fp32_int32_fp32(float32_t *data_in, int32_t exponent, | ||
| float32_t *data_out, int32_t size) { | ||
| for (int i = 0; i < size; i++) { | ||
| float32_t result = 1.0f; | ||
| int32_t exp = exponent; | ||
| float32_t base = data_in[i]; | ||
|
|
||
| if (exp < 0) { | ||
| base = 1.0f / base; | ||
| exp = -exp; | ||
| } | ||
|
|
||
| for (int32_t j = 0; j < exp; j++) { | ||
| result *= base; | ||
| } | ||
|
|
||
| data_out[i] = result; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kernel assumes integer scalar exponents; ensure front‑end enforces this
This implementation is correct for integer exponents (including negatives), but it fundamentally implements base**int(exponent) via repeated multiplication. When wired through the current Python template, any non-integer constant exponent will be truncated to an integer before reaching this kernel, silently diverging from ONNX Pow semantics.
Please make sure the Python side (parser/template) explicitly:
- Restricts Pow to scalar, constant, integer exponents, and
- Raises a clear error if the exponent is non-integer or a tensor,
so users don’t get silent wrong results for general Pow nodes.
🤖 Prompt for AI Agents
In TargetLibraries/Generic/src/Pow_fp32.c around lines 7 to 27, the C kernel
implements power by repeated multiplication and thus only supports integer
scalar exponents; update the Python front-end (parser/template) to validate Pow
nodes before emitting this kernel: require the exponent to be a scalar constant
integer (not a tensor and not a non-integer), and if that check fails raise a
clear, descriptive exception (e.g., "Pow operator only supported for scalar
integer constant exponents in this backend") so users are not silently given
truncated/non-ONNX-compliant results.
The current implementation does not support true FP16 arithmetic. Instead, data is cast to FP32 internally. These tests are being removed to avoid misleading results until native half-precision support is implemented.
diaconuccalin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job overall, but there are some changes that we need to address.
Most of them concern the following:
- Remove integer exponent enforcement for Pow
- Remove constant exponent enforcement for Pow
- Remove all traces of FP16 version for Generic, since the compiler for this platform doesn't support this format (as we talked privately, we will use it directly in Snitch, since here it would only help us create the proper infrastructure, like binding and parser, but we've already done it with FP32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have 2 versions each of PowParser and SqrtParser. Please only keep one for each. I see the Sqrt ones are identical, and for the Pow operation, the first one (above the #### line) looks cleaner
Deeploy/Targets/Generic/Parsers.py
Outdated
| return ctxt, False | ||
|
|
||
|
|
||
| ############################ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this comment line as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the changes on this page should be reverted, it looks like a leftover from a rebase
| void Pow_fp32_int32_fp32(float32_t *data_in, int32_t exponent, | ||
| float32_t *data_out, int32_t size); | ||
|
|
||
| void Pow_fp16_int32_fp16(float16_t *data_in, int32_t exponent, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed, I think it's ok to remove the fp16 version from the generic platform, since there is no compiler support for this data format.
|
|
||
| void Sqrt_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size); | ||
|
|
||
| void Sqrt_fp16_fp16(float16_t *data_in, float16_t *data_out, int32_t size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, we can remove the fp16 for generic because of the lack of compiler support,
| operatorRepresentation['data_type'] = data_type | ||
|
|
||
| # Exponent must be a constant integer | ||
| if 'exponent' in operatorRepresentation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point, we should be certain that the exponent is in the opRep, no need for this check.
| self.operatorRepresentation['data_out'] = data_out.name | ||
|
|
||
| # Extract exponent value from the constant tensor | ||
| if isinstance(exponent_tensor, ConstantBuffer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, let's remove the constant enforcement
|
|
||
| #include "DeeployBasicMath.h" | ||
|
|
||
| void Pow_fp32_int32_fp32(float32_t *data_in, int32_t exponent, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For removing the enforcement of constant and int values, we will need to change the data type of exponent to const float *__restrict__
| float32_t *data_out, int32_t size) { | ||
| for (int i = 0; i < size; i++) { | ||
| float32_t result = 1.0f; | ||
| int32_t exp = exponent; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will also have to update the kernel for float exponent support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please create a new test to better test the RMSNorm function. Right now, pretty much all values go to either -1 or 1 because the function normalizes on the last axis and we only have a single element on that one (the input dimension is 1024x1). Let's do a test with something like a 128x128 input, so it's not too big, but has more than one element on the last axis.
Let's also use a value a little more interesting than 1 for the weight, to check that operation as well.
This commit addresses code review feedback: - Refactor Pow kernel to use 'powf' from math.h to support floating-point exponents. - Update PowParser to allow tensor exponents instead of forcing constants. - Remove Generic FP16 support and revert types.h changes. - Remove duplicate PowParser/SqrtParser classes. - Enhance RMSNorm tests with larger shapes and non-trivial weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py (1)
12-13: Remove unused variable lookup.The
data_outvariable is looked up but never used inalignToContext. Since the template only needsdata_into infer the data type and compute size, you can safely remove this line.Apply this diff:
# Get input and output tensors data_in = ctxt.lookup(operatorRepresentation['data_in']) - data_out = ctxt.lookup(operatorRepresentation['data_out']) # Get data type (fp32)Deeploy/Targets/Generic/Templates/FloatPowTemplate.py (1)
12-14: Remove unused variable lookup.The
data_outvariable is looked up but never used inalignToContext. The method only needsdata_inandexponentto infer the data type and compute sizes.Apply this diff:
# Get input and output tensors data_in = ctxt.lookup(operatorRepresentation['data_in']) exponent = ctxt.lookup(operatorRepresentation['exponent']) - data_out = ctxt.lookup(operatorRepresentation['data_out']) # Get data type (fp32)Deeploy/Targets/Generic/Bindings.py (1)
121-129: Consider more specific type checkers for Pow and Sqrt.The bindings use
DummyCheckerwhich provides minimal type validation. While this may be intentional for flexibility, you might want to define dedicatedPowCheckerandSqrtCheckerclasses (similar toAddChecker,MulChecker, etc.) to provide more specific type validation for these operations.This can be deferred if the current approach aligns with the project's type-checking strategy. The bindings are otherwise correctly structured.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
.github/workflows/ci-platform-generic.yml(1 hunks)Deeploy/Targets/Generic/Bindings.py(2 hunks)Deeploy/Targets/Generic/Parsers.py(3 hunks)Deeploy/Targets/Generic/Templates/FloatPowTemplate.py(1 hunks)Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py(1 hunks)TargetLibraries/Generic/inc/kernel/Pow.h(1 hunks)TargetLibraries/Generic/inc/kernel/Sqrt.h(1 hunks)TargetLibraries/Generic/src/Pow_fp32.c(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- .github/workflows/ci-platform-generic.yml
🧰 Additional context used
🧬 Code graph analysis (5)
TargetLibraries/Generic/inc/kernel/Sqrt.h (2)
TargetLibraries/Generic/src/Sqrt_fp32.c (1)
Sqrt_fp32_fp32(9-13)DeeployTest/testUtils/dmaUtils.py (1)
size(72-73)
Deeploy/Targets/Generic/Bindings.py (3)
Deeploy/CommonExtensions/DataTypes.py (1)
float32_t(74-78)Deeploy/DeeployTypes.py (2)
CodeTransformation(2290-2324)NodeBinding(1512-1657)Deeploy/AbstractDataTypes.py (1)
PointerClass(536-559)
TargetLibraries/Generic/inc/kernel/Pow.h (2)
TargetLibraries/Generic/src/Pow_fp32.c (2)
Pow_fp32_fp32_fp32(10-17)Pow_fp32_scalar_fp32(19-26)DeeployTest/testUtils/dmaUtils.py (1)
size(72-73)
Deeploy/Targets/Generic/Parsers.py (1)
Deeploy/DeeployTypes.py (7)
NetworkContext(508-1020)NodeParser(1023-1198)VariableBuffer(232-360)ConstantBuffer(393-430)parseNode(1033-1048)inputs(2503-2520)parseNodeCtxt(1051-1076)
Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py (1)
Deeploy/Targets/Generic/Templates/FloatPowTemplate.py (1)
alignToContext(9-34)
🪛 Ruff (0.14.6)
Deeploy/Targets/Generic/Templates/FloatPowTemplate.py
14-14: Local variable data_out is assigned to but never used
Remove assignment to unused variable data_out
(F841)
Deeploy/Targets/Generic/Parsers.py
1978-1978: Unused method argument: channels_first
(ARG002)
1995-1996: Prefer TypeError exception for invalid type
(TRY004)
1995-1996: Avoid specifying long messages outside the exception class
(TRY003)
2799-2799: Unused method argument: channels_first
(ARG002)
Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py
13-13: Local variable data_out is assigned to but never used
Remove assignment to unused variable data_out
(F841)
🔇 Additional comments (6)
TargetLibraries/Generic/inc/kernel/Sqrt.h (1)
20-20: LGTM!The function signature is correct for an element-wise square root operation. The naming convention follows the pattern seen in other kernels and the parameters are appropriate.
TargetLibraries/Generic/inc/kernel/Pow.h (1)
16-24: LGTM!Both function signatures correctly use
float32_tfor the exponent parameter(s), which allows the kernels to support general floating-point exponents viapowf. The const and restrict qualifiers are appropriate.TargetLibraries/Generic/src/Pow_fp32.c (1)
10-26: LGTM!Both kernel implementations correctly use
powfwhich supports general floating-point exponents. The array-based and scalar-based variants are implemented appropriately for broadcasting scenarios.Deeploy/Targets/Generic/Parsers.py (2)
2788-2808: LGTM!The
SqrtParserimplementation is straightforward and correct for a unary square root operation. It properly extracts the input/output tensors and computes the size.Note: The
channels_firstparameter is unused (flagged by static analysis), but this is likely required by theNodeParserinterface.
1990-1996: Incorrect exponent handling: casting to int loses precision and enforcing constants limits functionality.There are two critical issues here:
Integer casting loses precision: Line 1991 casts the exponent to
int, but the C kernelPow_fp32_fp32_fp32andPow_fp32_scalar_fp32usepowfwhich supports floating-point exponents. For example, an exponent of2.5would be silently truncated to2, producing incorrect results.Constant enforcement is too restrictive: Lines 1994-1996 reject non-constant (variable tensor) exponents, but this unnecessarily limits the operator's functionality. Per past review feedback and the ONNX Pow specification, variable exponents should be supported.
Apply this diff to support float exponents and remove constant enforcement:
# Extract exponent value from the constant tensor if isinstance(exponent_tensor, ConstantBuffer): - exp_value = int(exponent_tensor.values.flatten()[0]) - self.operatorRepresentation['exponent_value'] = exp_value - else: - # Tensor exponent not supported - raise ValueError(f"Node {node.name}: Exponent must be a constant. " - f"Variable tensor exponents are not supported.") + exp_value = float(exponent_tensor.values.flatten()[0]) + self.operatorRepresentation['exponent_value'] = exp_value + # Variable tensor exponents are now supported via the array-based kernelBased on learnings from past reviews requesting float exponent support and removal of constant enforcement.
Likely an incorrect or invalid review comment.
Deeploy/Targets/Generic/Templates/FloatPowTemplate.py (1)
25-34: LGTM!The scalar broadcasting logic is well-implemented. The template correctly distinguishes between scalar and array exponents, selecting the appropriate kernel (
Pow_fp32_scalar_fp32vsPow_fp32_fp32_fp32) and constructing the proper variable reference for scalar exponents.
Describe the intent of your PR here.
This PR adds support for RMSNorm (Root Mean Square Normalization) operation to the Deeploy framework's Generic platform. RMSNorm is a critical normalization technique used in modern Transformer architectures and large language models. To enable RMSNorm deployment on embedded systems, this PR implements the necessary mathematical primitives (Pow and Sqrt operations) and integrates them into Deeploy's compilation pipeline.
The implementation follows Deeploy's operator decomposition approach, where RMSNorm is constructed from basic mathematical operations rather than as a monolithic kernel. This design provides flexibility and maintainability while supporting both float32 and float16 precision for resource-constrained embedded devices.
Added
Pow (Power) operation support
FloatPowTemplate.py: Mako template for C code generationPow_fp32.cKernel implementations for both precisionskernel/Pow.h: Kernel interface definitionsSqrt (Square Root) operation support
FloatSqrtTemplate.py: Mako template for C code generationSqrt_fp32.c: Kernel implementationskernel/Sqrt.h: Kernel interface definitionsComprehensive test suites
testFloatPow: Pow operator tests with ONNX models and reference datatestFloatSqrt: Sqrt operator teststestFloatRMSNorm: End-to-end RMSNorm tests demonstrating operator compositionChanged
Framework integration files
Deeploy/Targets/Generic/Parsers.py: Added PowParser and SqrtParser for ONNX graph parsingDeeploy/Targets/Generic/Layers.py: Added corresponding Layer classes for both operationsDeeploy/Targets/Generic/Bindings.py: Added type checking and binding registrationDeeploy/Targets/Generic/Platform.py: Registered new operations in platform mappingRuntime library headers
TargetLibraries/Generic/inc/DeeployBasicMath.h: Extended with Pow and Sqrt function declarationsTargetLibraries/Generic/inc/types.h: Updated type definitions for consistencyCI/CD configuration
.github/workflows/ci-platform-generic.yml: Updated to include new test cases in automated testing pipelineFixed
PR Merge Checklist
develcommit and pointing todevel.CHANGELOG.mdfile has been updated.