-
Notifications
You must be signed in to change notification settings - Fork 29
Rmsnorm #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Rmsnorm #136
Changes from all commits
9b74876
5d499d4
3499a64
30cfabb
360519c
fee8470
8f90620
4834c2a
e1785a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,8 @@ package-lock.json | |
| .mypy_cache | ||
| node_modules | ||
|
|
||
| .venv/* | ||
|
|
||
| compile_commands.json | ||
|
|
||
| docs/_autosummary | ||
|
|
||
diaconuccalin marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| from typing import Dict, List, Tuple | ||
|
|
||
| import numpy as np | ||
|
|
||
| from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation | ||
|
|
||
|
|
||
| class _PowTemplate(NodeTemplate): | ||
|
|
||
| def alignToContext(self, ctxt: NetworkContext, | ||
| operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: | ||
| # Get input and output tensors | ||
| data_in = ctxt.lookup(operatorRepresentation['data_in']) | ||
| exponent = ctxt.lookup(operatorRepresentation['exponent']) | ||
| data_out = ctxt.lookup(operatorRepresentation['data_out']) | ||
|
|
||
| # Get data type (fp32) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for not noticing this earlier, but maybe remove hard-coded data type from the template (in this comment, but in the template as well). In the conv template, there is a good example on how to make data type specification dynamic - see line 19. Please change it for the sqrt template as well. |
||
| data_type = data_in._type.typeName | ||
| operatorRepresentation['data_type'] = data_type | ||
|
|
||
| # Get type width dynamically (e.g., 32, 64) | ||
| type_width = data_in._type.referencedType.typeWidth | ||
| operatorRepresentation['type_width'] = type_width | ||
|
|
||
| # Calculate size | ||
| input_size = int(np.prod(data_in.shape)) | ||
| exponent_size = int(np.prod(exponent.shape)) | ||
| operatorRepresentation['size'] = input_size | ||
|
|
||
| # Check if exponent is scalar (broadcasting) | ||
| if exponent_size == 1: | ||
| operatorRepresentation['is_scalar'] = True | ||
| # Get the full variable name with prefix | ||
| exponent_name = operatorRepresentation['exponent'] | ||
| operatorRepresentation['exponent_scalar'] = f"DeeployNetwork_{exponent_name}[0]" | ||
| else: | ||
| # Since currently the kernel only supports equally sized base-exponent data, | ||
| # for non-scalar, let's add a size check here (length of data_in should be equal to exponent length). | ||
| if input_size != exponent_size: | ||
| raise ValueError(f"Pow operator mismatch: input size ({input_size}) " | ||
| f"must equal exponent size ({exponent_size}) for non-scalar exponents.") | ||
|
|
||
| operatorRepresentation['is_scalar'] = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since currently the kernel only supports equally sized base-exponent data, for non-scalar, let's add a size check here (length of data_in should be equal to exponent length). |
||
| operatorRepresentation['exponent_scalar'] = "NULL" | ||
|
|
||
| return ctxt, operatorRepresentation, [] | ||
|
|
||
|
|
||
| referenceTemplate = _PowTemplate(""" | ||
| // Pow (Name: ${nodeName}, Op: ${nodeOp}) | ||
| % if is_scalar: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great that you kept 2 versions, one for scalar, and one for vector exponents. It's not full broadcasting yet, but it still gives us a little futureproofing. |
||
| Pow_fp${type_width}_scalar_fp${type_width}(${data_in}, ${exponent_scalar}, ${data_out}, ${size}); | ||
| % else: | ||
| Pow_fp${type_width}_fp${type_width}_fp${type_width}(${data_in}, ${exponent}, ${data_out}, ${size}); | ||
| % endif | ||
| """) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| from typing import Dict, List, Tuple | ||
|
|
||
| import numpy as np | ||
|
|
||
| from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation | ||
|
|
||
|
|
||
| class _SqrtTemplate(NodeTemplate): | ||
|
|
||
| def alignToContext(self, ctxt: NetworkContext, | ||
| operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: | ||
| # Get input and output tensors | ||
| data_in = ctxt.lookup(operatorRepresentation['data_in']) | ||
| data_out = ctxt.lookup(operatorRepresentation['data_out']) | ||
|
|
||
| # Get data type (fp32) | ||
| data_type = data_in._type.typeName | ||
| operatorRepresentation['data_type'] = data_type | ||
|
|
||
| type_width = data_in._type.referencedType.typeWidth | ||
| operatorRepresentation['type_width'] = type_width | ||
|
|
||
| # Calculate size | ||
| operatorRepresentation['size'] = int(np.prod(data_in.shape)) | ||
|
|
||
| return ctxt, operatorRepresentation, [] | ||
|
|
||
|
|
||
| referenceTemplate = _SqrtTemplate(""" | ||
| // Sqrt (Name: ${nodeName}, Op: ${nodeOp}) | ||
| Sqrt_fp${type_width}_fp${type_width}(${data_in}, ${data_out}, ${size}); | ||
| """) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| deeploy_test_generator:� | ||
| 3 | ||
| data_in | ||
| exponentdata_outPow_Vector_Test"Powtest_float_pow_vectorZ! | ||
| data_in | ||
| Z" | ||
| exponent | ||
| b" | ||
| data_out | ||
| B |
diaconuccalin marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna | ||
| * | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| /* | ||
| * This file implements the element-wise binary power operation. | ||
| */ | ||
|
|
||
| #ifndef __DEEPLOY_MATH_POW_KERNEL_HEADER_ | ||
| #define __DEEPLOY_MATH_POW_KERNEL_HEADER_ | ||
|
|
||
| #include "DeeployBasicMath.h" | ||
|
|
||
| void Pow_fp32_fp32_fp32(const float32_t *__restrict__ data_in, | ||
| const float32_t *__restrict__ exponent, | ||
| float32_t *__restrict__ data_out, int32_t size); | ||
|
|
||
| void Pow_fp32_scalar_fp32(const float32_t *__restrict__ data_in, | ||
| float32_t exponent, float32_t *__restrict__ data_out, | ||
| int32_t size); | ||
|
|
||
| #endif |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: 2020 ETH Zurich and University of Bologna | ||
| * | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #ifndef __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_ | ||
| #define __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_ | ||
|
|
||
| #include "DeeployBasicMath.h" | ||
|
|
||
| /* | ||
| * Square root operation - computes sqrt for each element | ||
| */ | ||
|
|
||
| /******************************************************************************/ | ||
| /* Sqrt */ | ||
| /******************************************************************************/ | ||
|
|
||
| void Sqrt_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size); | ||
|
|
||
| #endif //__DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_ |
Uh oh!
There was an error while loading. Please reload this page.