-
Notifications
You must be signed in to change notification settings - Fork 109
Enhance symbolic arithmetics #2779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
27f7b50 to
5ec8a66
Compare
kiya00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @shino16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enhances symbolic arithmetic support by implementing rounding operations (ceil, floor, trunc, round), bitwise shift operations, and improving where to support all-NumberProxy inputs. The main motivation is to enable math.ceil support for clang.arange. The implementation includes proper type promotion handling to ensure rounding operations return integers for number-only inputs while preserving dtypes for tensor inputs.
Key changes:
- Added rounding and bitwise operations with proper type promotion semantics
- Extended prologue to test both
intandfloattypes of NumberProxy values via symbolic value caching - Changed
__pos__methods to returnselffor both NumberProxy and TensorProxy, aligning with Python/PyTorch behavior
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| thunder/tests/test_elementwise.py | Refactored tests to use parametrization for binary/unary ops and added symbolic value caching tests to verify type-based cache behavior |
| thunder/executors/pythonex.py | Registered rounding and bitwise operations with appropriate backing functions, added where implementation for all-NumberProxy inputs |
| thunder/executors/nvfuserex_impl.py | Added decorator to ensure rounding operations return int type for number-only inputs in nvFuser execution |
| thunder/core/utils.py | Changed type promotion for NUMBER_TO_INT kind to use promoted type instead of always returning int |
| thunder/core/proxies.py | Updated NumberProxy and TensorProxy __pos__ to return self, implemented bitwise shift operations, fixed type annotations |
| thunder/core/prims.py | Updated rounding operations to use INT_FOR_NUMBER output dtype kind, enabled where for all-NumberProxy inputs |
| thunder/core/jit_ext.py | Added type constraint check for symbolic value caching |
| thunder/clang/init.py | Updated rounding operations to use NUMBER_TO_INT type promotion, added method_name to bitwise shift operations |
Comments suppressed due to low confidence (1)
thunder/core/prims.py:1
- Lines 1950-1952 duplicate the description of 'operations, like abs, map complex numbers to floats (COMPLEX_TO_FLOAT)'. Remove the duplicate sentence on lines 1951-1952 and update line 1953 to reference 'four behaviors' without the redundant explanation.
from enum import auto, Enum
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
kshitij12345
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just few minor comments, thanks @shino16
Closes #2736. The main purpose is to implement
math.ceil, which is needed for supportingclang.arange(ref).To test it against both
intandfloat, I added a logic to test the type of NumberProxy's underlying values in prologue.I took a quick side step to implement all ops tested in
test_elementwise.py. This involved the following patches.Rounding ops:
ceil,floor,trunc,roundpythonex, backed withmath.ceil/floor/trunc,builtins.round.Bitwise ops:
bitwise_not,bitwise_left_shift,bitwise_right_shiftpythonex, backed withoperator.inv/lshift/rshiftNumberProxy.__lshift__to the correspondingclangop.floor_dividewhere, whichfloor_dividedepends on, to support all-NumberProxy inputspos(as opposed toneg)pos. I changedNumberProxy.__pos__andTensorProxy.__pos__to justreturn self, which aligns with Python and PyTorch (tryfor x in (1000, torch.randn(3)): assert +x is x).