Description
Hi there,
I noticed that negative index can not used for permute_dim when using Jax.numpy namespace.
Although documentation about permute_dims did mention that
"axes (Tuple[int, ...]) – tuple containing a permutation of (0, 1, ..., N-1) where N is the number of axes (dimensions) of x"
However, indexing page says api standard compatible with negative indexing.
Right now, I normalized index before calling permute_dims. However, I do hope your team can make it more clear by whatever measurements. For example, just simply Noting out the non-negative requirement, vs using negative index common people will do, if this is made by designing.
By the way, it will be appreciated if you can share your design philosophy since I am working on a project involves array manipulations. Is that because permute_dim is more like an internal function?