Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Oct 24, 2025

Description

The MLX dispatch for IncSubtensor was assuming that the indexes would always be integers, but they can actually be either integers or slices. This PR adds logic to handle the slice case.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1692.org.readthedocs.build/en/1692/

@jessegrabowski jessegrabowski added bug Something isn't working mlx labels Oct 24, 2025
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a bug in the MLX backend's IncSubtensor dispatch where it incorrectly assumed all indices would be integers. The fix adds logic to properly handle slice objects by converting their start/stop/step components from potentially symbolic values to actual integers while preserving None values.

Key Changes

  • Added get_slice_int helper function to safely convert slice components to integers
  • Modified index processing to reconstruct slice objects with integer components
  • Added comprehensive test coverage for various slice patterns (positive, negative, step-based, and full slices)

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
pytensor/link/mlx/dispatch/subtensor.py Added get_slice_int helper and slice reconstruction logic to handle both integer and slice indices
tests/link/mlx/test_subtensor.py Added four test cases covering different slice patterns to verify the fix

return None
try:
return int(element)
except Exception:
Copy link

Copilot AI Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bare except Exception is too broad. This should catch specific exceptions like TypeError or ValueError that would occur when trying to convert a non-integer value. The current implementation could mask unexpected errors.

Suggested change
except Exception:
except (TypeError, ValueError):

Copilot uses AI. Check for mistakes.
@ricardoV94
Copy link
Member

So MLX is okay with x[np.int64(1)], but not x[np.int64(1):]?

The helper definitely doesn't ignore slices, I think you misread the error message.

And I guess this would also apply to the regular Subtensor not just IncSubtensor? And AdvancedSubtensor as well, unless we typify the constant slices with integers? Because all those use numpy integers internally.

This may be worth opening an issue with them, even if we patch here.

@jessegrabowski
Copy link
Member Author

So MLX is okay with x[np.int64(1)], but not x[np.int64(1):]?

Neither of these cases work:

import mlx.core as mx
import numpy as np
x = mx.array([1, 2, 3])
x[np.int64(1)]

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[42], line 4
      2 import numpy as np
      3 x = mx.array([1, 2, 3])
----> 4 x[np.int64(1)]

ValueError: Cannot index mlx array using the given type.
x[np.int64(1):]

ValueError                                Traceback (most recent call last)
Cell In[43], line 4
      2 import numpy as np
      3 x = mx.array([1, 2, 3])
----> 4 x[np.int64(1):]

ValueError: Slice indices must be integers or None.

@ricardoV94
Copy link
Member

So we should open an issue with them and I'm surprised any indexing worked before. We're casting constant numpy integers to python int? If so we should do something for the tipefy of slices

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 27, 2025

This is how the Subtensor is working around it:

indices = indices_from_subtensor([int(element) for element in ilists], idx_list)

Should do the same thing in both Ops for consistency. The Subtensor approach seems simpler.

Perhaps put in a helper just so we can document why this is needed for future developers

@cetagostini
Copy link
Contributor

More don't wanted help here #1702

@jessegrabowski @ricardoV94

Introduces normalize_indices_for_mlx to convert NumPy integer and floating types, MLX scalar arrays, and slice components to Python int/float for MLX compatibility. Updates all MLX subtensor dispatch functions to use this normalization, resolving issues with MLX's strict indexing requirements. Adds comprehensive tests for np.int64 indices and slices in subtensor and inc_subtensor operations, including advanced indexing scenarios.
Appended a newline to the end of subtensor.py and test_subtensor.py to conform with POSIX standards and improve code consistency.
@cetagostini
Copy link
Contributor

The math test is kinda flaky, works sometimes others fail. Strange...

@ricardoV94
Copy link
Member

Yeah we know it. I'll take a look to stop it but it need not block us

mx = pytest.importorskip("mlx.core")


def test_mlx_python_int_indexing():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would we be testing mlx directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah not at all this was internal for me, let me delete this file. I didn't wanna to push!

@cetagostini
Copy link
Contributor

Issue in MLX: ml-explore/mlx#2710

else:
return element

indices = indices_from_subtensor(ilist, idx_list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this outside the helper, as it is only relevant for basic Subtensor/IncSubtensor. The advanced methods don't have an idx_list and don't need indices_from_subtensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
"""MLX implementation of AdvancedSubtensor."""
idx_list = getattr(op, "idx_list", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a thing for Advanced indexing

Suggested change
idx_list = getattr(op, "idx_list", None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
"""MLX implementation of Subtensor."""
idx_list = getattr(op, "idx_list", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not optional. Better be explicit to reduce confusion. Apply this and the other suggestion in all dispatches

Suggested change
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

)

# Advanced indexing set with array indices
indices = [np.int64(0), np.int64(2)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually advanced indexing? To be sure make one of the indices a vector array [0, 1, 2, 3]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably reuse the same sort of indices from the test just above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch corrected.

Simplifies index normalization logic in MLX subtensor dispatch functions by separating basic and advanced indexing cases. Updates the advanced incsubtensor test to use vector array indices and a matching value shape for improved coverage.
@codecov
Copy link

codecov bot commented Oct 30, 2025

Codecov Report

❌ Patch coverage is 84.37500% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.69%. Comparing base (17c675a) to head (449c2df).
⚠️ Report is 11 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/mlx/dispatch/subtensor.py 84.37% 3 Missing and 2 partials ⚠️

❌ Your patch check has failed because the patch coverage (84.37%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1692      +/-   ##
==========================================
+ Coverage   81.61%   81.69%   +0.08%     
==========================================
  Files         242      246       +4     
  Lines       53537    53655     +118     
  Branches     9433     9443      +10     
==========================================
+ Hits        43695    43836     +141     
+ Misses       7366     7334      -32     
- Partials     2476     2485       +9     
Files with missing lines Coverage Δ
pytensor/link/mlx/dispatch/subtensor.py 87.50% <84.37%> (-6.35%) ⬇️

... and 16 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

)

# Advanced indexing set with vector array indices
indices = np.array([0, 1, 2, 3], dtype=np.int64)
Copy link
Member

@ricardoV94 ricardoV94 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not testing the "issue of having a scalar np.int64 or slice with an np.int64 entry". My earlier suggestion was to have an array + one of those. The array is what forces it to be "Advanced"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working mlx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MLX Incsubtensor fails on slices

3 participants