-
Notifications
You must be signed in to change notification settings - Fork 146
Handle slices in mlx_funcify_IncSubtensor
#1692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a bug in the MLX backend's IncSubtensor dispatch where it incorrectly assumed all indices would be integers. The fix adds logic to properly handle slice objects by converting their start/stop/step components from potentially symbolic values to actual integers while preserving None values.
Key Changes
- Added
get_slice_inthelper function to safely convert slice components to integers - Modified index processing to reconstruct slice objects with integer components
- Added comprehensive test coverage for various slice patterns (positive, negative, step-based, and full slices)
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| pytensor/link/mlx/dispatch/subtensor.py | Added get_slice_int helper and slice reconstruction logic to handle both integer and slice indices |
| tests/link/mlx/test_subtensor.py | Added four test cases covering different slice patterns to verify the fix |
| return None | ||
| try: | ||
| return int(element) | ||
| except Exception: |
Copilot
AI
Oct 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a bare except Exception is too broad. This should catch specific exceptions like TypeError or ValueError that would occur when trying to convert a non-integer value. The current implementation could mask unexpected errors.
| except Exception: | |
| except (TypeError, ValueError): |
|
So MLX is okay with The helper definitely doesn't ignore slices, I think you misread the error message. And I guess this would also apply to the regular Subtensor not just IncSubtensor? And AdvancedSubtensor as well, unless we typify the constant slices with integers? Because all those use numpy integers internally. This may be worth opening an issue with them, even if we patch here. |
Neither of these cases work: import mlx.core as mx
import numpy as np
x = mx.array([1, 2, 3])
x[np.int64(1)]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[42], line 4
2 import numpy as np
3 x = mx.array([1, 2, 3])
----> 4 x[np.int64(1)]
ValueError: Cannot index mlx array using the given type.x[np.int64(1):]
ValueError Traceback (most recent call last)
Cell In[43], line 4
2 import numpy as np
3 x = mx.array([1, 2, 3])
----> 4 x[np.int64(1):]
ValueError: Slice indices must be integers or None. |
|
So we should open an issue with them and I'm surprised any indexing worked before. We're casting constant numpy integers to python int? If so we should do something for the tipefy of slices |
|
This is how the Subtensor is working around it:
Should do the same thing in both Ops for consistency. The Subtensor approach seems simpler. Perhaps put in a helper just so we can document why this is needed for future developers |
|
More don't wanted help here #1702 |
Introduces normalize_indices_for_mlx to convert NumPy integer and floating types, MLX scalar arrays, and slice components to Python int/float for MLX compatibility. Updates all MLX subtensor dispatch functions to use this normalization, resolving issues with MLX's strict indexing requirements. Adds comprehensive tests for np.int64 indices and slices in subtensor and inc_subtensor operations, including advanced indexing scenarios.
Appended a newline to the end of subtensor.py and test_subtensor.py to conform with POSIX standards and improve code consistency.
|
The math test is kinda flaky, works sometimes others fail. Strange... |
|
Yeah we know it. I'll take a look to stop it but it need not block us |
| mx = pytest.importorskip("mlx.core") | ||
|
|
||
|
|
||
| def test_mlx_python_int_indexing(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would we be testing mlx directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah not at all this was internal for me, let me delete this file. I didn't wanna to push!
|
Issue in MLX: ml-explore/mlx#2710 |
| else: | ||
| return element | ||
|
|
||
| indices = indices_from_subtensor(ilist, idx_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this outside the helper, as it is only relevant for basic Subtensor/IncSubtensor. The advanced methods don't have an idx_list and don't need indices_from_subtensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| @mlx_funcify.register(AdvancedSubtensor1) | ||
| def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): | ||
| """MLX implementation of AdvancedSubtensor.""" | ||
| idx_list = getattr(op, "idx_list", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a thing for Advanced indexing
| idx_list = getattr(op, "idx_list", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| @mlx_funcify.register(Subtensor) | ||
| def mlx_funcify_Subtensor(op, node, **kwargs): | ||
| """MLX implementation of Subtensor.""" | ||
| idx_list = getattr(op, "idx_list", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not optional. Better be explicit to reduce confusion. Apply this and the other suggestion in all dispatches
| idx_list = getattr(op, "idx_list", None) | |
| idx_list = op.idx_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
tests/link/mlx/test_subtensor.py
Outdated
| ) | ||
|
|
||
| # Advanced indexing set with array indices | ||
| indices = [np.int64(0), np.int64(2)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually advanced indexing? To be sure make one of the indices a vector array [0, 1, 2, 3]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can probably reuse the same sort of indices from the test just above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch corrected.
Simplifies index normalization logic in MLX subtensor dispatch functions by separating basic and advanced indexing cases. Updates the advanced incsubtensor test to use vector array indices and a matching value shape for improved coverage.
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (84.37%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1692 +/- ##
==========================================
+ Coverage 81.61% 81.69% +0.08%
==========================================
Files 242 246 +4
Lines 53537 53655 +118
Branches 9433 9443 +10
==========================================
+ Hits 43695 43836 +141
+ Misses 7366 7334 -32
- Partials 2476 2485 +9
🚀 New features to boost your workflow:
|
| ) | ||
|
|
||
| # Advanced indexing set with vector array indices | ||
| indices = np.array([0, 1, 2, 3], dtype=np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not testing the "issue of having a scalar np.int64 or slice with an np.int64 entry". My earlier suggestion was to have an array + one of those. The array is what forces it to be "Advanced"
Description
The MLX dispatch for
IncSubtensorwas assuming that the indexes would always be integers, but they can actually be either integers or slices. This PR adds logic to handle the slice case.Related Issue
Incsubtensorfails on slices #1690mlx#1350Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1692.org.readthedocs.build/en/1692/