Skip to content

Commit c4696d2

Browse files
authored
Merge pull request #185 from IntelPython/revisit_overwrite_x
Revisit usage of `overwrite_x` parameter
2 parents 2f7f485 + d9f9725 commit c4696d2

File tree

10 files changed

+201
-578
lines changed

10 files changed

+201
-578
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Changed
1313
* Replaced `fwd_scale` parameter with `norm` in `mkl_fft` [gh-189](https://github.com/IntelPython/mkl_fft/pull/189)
14+
* Dropped support for `scipy.fftpack` interface [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
15+
* Dropped support for `overwrite_x` parameter in `mkl_fft` [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
1416

1517
### Fixed
18+
* Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
1619

1720
## [2.0.0] - 2025-06-03
1821

@@ -27,8 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2730
* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs [gh-181](https://github.com/IntelPython/mkl_fft/pull/181)
2831

2932
### Fixed
30-
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
31-
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
33+
* Fixed a bug in `mkl_fft.interfaces.numpy.fftn` when an empty tuple is passed for `axes` [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
34+
* Fixed a bug for a case when a zero-size array is passed to `mkl_fft.interfaces.numpy.fftn` [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
3235
* Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180)
3336
* Fixed issues with `set_workers` function in SciPy interface `mkl_fft.interfaces.scipy_fft` [gh-183](https://github.com/IntelPython/mkl_fft/pull/183)
3437

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ While using the interfaces module is the recommended way to leverage `mk_fft`, o
5151

5252
### complex-to-complex (c2c) transforms:
5353

54-
`fft(x, n=None, axis=-1, overwrite_x=False, norm=None, out=None)` - 1D FFT, similar to `scipy.fft.fft`
54+
`fft(x, n=None, axis=-1, norm=None, out=None)` - 1D FFT, similar to `scipy.fft.fft`
5555

56-
`fft2(x, s=None, axes=(-2, -1), overwrite_x=False, norm=None, out=None)` - 2D FFT, similar to `scipy.fft.fft2`
56+
`fft2(x, s=None, axes=(-2, -1), norm=None, out=None)` - 2D FFT, similar to `scipy.fft.fft2`
5757

58-
`fftn(x, s=None, axes=None, overwrite_x=False, norm=None, out=None)` - ND FFT, similar to `scipy.fft.fftn`
58+
`fftn(x, s=None, axes=None, norm=None, out=None)` - ND FFT, similar to `scipy.fft.fftn`
5959

6060
and similar inverse FFT (`ifft*`) functions.
6161

mkl_fft/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
rfft2,
4040
rfftn,
4141
)
42-
from ._pydfti import irfftpack, rfftpack # pylint: disable=no-name-in-module
4342
from ._version import __version__
4443

4544
import mkl_fft.interfaces # isort: skip
@@ -51,8 +50,6 @@
5150
"ifft2",
5251
"fftn",
5352
"ifftn",
54-
"rfftpack",
55-
"irfftpack",
5653
"rfft",
5754
"irfft",
5855
"rfft2",

mkl_fft/_fft_utils.py

Lines changed: 94 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -258,23 +258,43 @@ def _iter_fftnd(
258258
axes=None,
259259
out=None,
260260
direction=+1,
261-
overwrite_x=False,
262-
scale_function=lambda n, ind: 1.0,
261+
scale_function=lambda ind: 1.0,
263262
):
264263
a = np.asarray(a)
265264
s, axes = _init_nd_shape_and_axes(a, s, axes)
266-
ovwr = overwrite_x
267-
for ii in reversed(range(len(axes))):
265+
266+
# Combine the two, but in reverse, to end with the first axis given.
267+
axes_and_s = list(zip(axes, s))[::-1]
268+
# We try to use in-place calculations where possible, which is
269+
# everywhere except when the size changes after the first FFT.
270+
size_changes = [axis for axis, n in axes_and_s[1:] if a.shape[axis] != n]
271+
272+
# If there are any size changes, we cannot use out
273+
res = None if size_changes else out
274+
for ind, (axis, n) in enumerate(axes_and_s):
275+
if axis in size_changes:
276+
if axis == size_changes[-1]:
277+
# Last size change, so any output should now be OK
278+
# (an error will be raised if not), and if no output is
279+
# required, we want a freshly allocated array of the right size.
280+
res = out
281+
elif res is not None and n < res.shape[axis]:
282+
# For an intermediate step where we return fewer elements, we
283+
# can use a smaller view of the previous array.
284+
res = res[(slice(None),) * axis + (slice(n),)]
285+
else:
286+
# If we need more elements, we cannot use res.
287+
res = None
268288
a = _c2c_fft1d_impl(
269289
a,
270-
n=s[ii],
271-
axis=axes[ii],
272-
overwrite_x=ovwr,
290+
n=n,
291+
axis=axis,
273292
direction=direction,
274-
fsc=scale_function(s[ii], ii),
275-
out=out,
293+
fsc=scale_function(ind),
294+
out=res,
276295
)
277-
ovwr = True
296+
# Default output for next iteration.
297+
res = a
278298
return a
279299

280300

@@ -356,7 +376,6 @@ def _c2c_fftnd_impl(
356376
x,
357377
s=None,
358378
axes=None,
359-
overwrite_x=False,
360379
direction=+1,
361380
fsc=1.0,
362381
out=None,
@@ -381,7 +400,6 @@ def _c2c_fftnd_impl(
381400
if _direct:
382401
return _direct_fftnd(
383402
x,
384-
overwrite_x=overwrite_x,
385403
direction=direction,
386404
fsc=fsc,
387405
out=out,
@@ -399,11 +417,7 @@ def _c2c_fftnd_impl(
399417
x,
400418
axes,
401419
_direct_fftnd,
402-
{
403-
"overwrite_x": overwrite_x,
404-
"direction": direction,
405-
"fsc": fsc,
406-
},
420+
{"direction": direction, "fsc": fsc},
407421
res,
408422
)
409423
else:
@@ -414,97 +428,122 @@ def _c2c_fftnd_impl(
414428
axes=axes,
415429
out=out,
416430
direction=direction,
417-
overwrite_x=overwrite_x,
418-
scale_function=lambda n, i: fsc if i == 0 else 1.0,
431+
scale_function=lambda i: fsc if i == 0 else 1.0,
419432
)
420433

421434

422435
def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
423436
a = np.asarray(x)
424437
no_trim = (s is None) and (axes is None)
425438
s, axes = _cook_nd_args(a, s, axes)
439+
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
426440
la = axes[-1]
441+
427442
# trim array, so that rfft avoids doing unnecessary computations
428443
if not no_trim:
429444
a = _trim_array(a, s, axes)
445+
446+
# last axis is not included since we calculate r2c FFT separately
447+
# and not in the loop
448+
axes_and_s = list(zip(axes, s))[-2::-1]
449+
size_changes = [axis for axis, n in axes_and_s if a.shape[axis] != n]
450+
res = None if size_changes else out
451+
430452
# r2c along last axis
431-
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=out)
453+
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=res)
454+
res = a
432455
if len(s) > 1:
433-
if not no_trim:
434-
ss = list(s)
435-
ss[-1] = a.shape[la]
436-
a = _pad_array(a, tuple(ss), axes)
456+
437457
len_axes = len(axes)
438458
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
459+
if not no_trim:
460+
ss = list(s)
461+
ss[-1] = a.shape[la]
462+
a = _pad_array(a, tuple(ss), axes)
439463
# a series of ND c2c FFTs along last axis
440464
ss, aa = _remove_axis(s, axes, -1)
441-
ind = [
442-
slice(None, None, 1),
443-
] * len(s)
465+
ind = [slice(None, None, 1)] * len(s)
444466
for ii in range(a.shape[la]):
445467
ind[la] = ii
446468
tind = tuple(ind)
447469
a_inp = a[tind]
448-
res = out[tind] if out is not None else None
449-
a_res = _c2c_fftnd_impl(
450-
a_inp, s=ss, axes=aa, overwrite_x=True, direction=1, out=res
451-
)
452-
if a_res is not a_inp:
453-
a[tind] = a_res # copy in place
470+
res = out[tind] if out is not None else a_inp
471+
_ = _c2c_fftnd_impl(a_inp, s=ss, axes=aa, direction=1, out=res)
472+
if out is not None:
473+
a = out
454474
else:
475+
# another size_changes check is needed if there are repeated axes
476+
# of last axis, since since FFT changes the shape along last axis
477+
size_changes = [
478+
axis for axis, n in axes_and_s if a.shape[axis] != n
479+
]
480+
455481
# a series of 1D c2c FFTs along all axes except last
456-
for ii in range(len(axes) - 2, -1, -1):
457-
a = _c2c_fft1d_impl(a, s[ii], axes[ii], overwrite_x=True)
482+
for axis, n in axes_and_s:
483+
if axis in size_changes:
484+
if axis == size_changes[-1]:
485+
res = out
486+
elif res is not None and n < res.shape[axis]:
487+
res = res[(slice(None),) * axis + (slice(n),)]
488+
else:
489+
res = None
490+
a = _c2c_fft1d_impl(a, n, axis, out=res)
491+
res = a
458492
return a
459493

460494

461495
def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
462496
a = np.asarray(x)
463497
no_trim = (s is None) and (axes is None)
464498
s, axes = _cook_nd_args(a, s, axes, invreal=True)
499+
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
465500
la = axes[-1]
466501
if not no_trim:
467502
a = _trim_array(a, s, axes)
468503
if len(s) > 1:
469-
if not no_trim:
470-
a = _pad_array(a, s, axes)
471-
ovr_x = True if _datacopied(a, x) else False
472504
len_axes = len(axes)
473505
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
506+
if not no_trim:
507+
a = _pad_array(a, s, axes)
474508
# a series of ND c2c FFTs along last axis
475509
# due to need to write into a, we must copy
476-
if not ovr_x:
477-
a = a.copy()
478-
ovr_x = True
510+
a = a if _datacopied(a, x) else a.copy()
479511
if not np.issubdtype(a.dtype, np.complexfloating):
480512
# complex output will be copied to input, copy is needed
481513
if a.dtype == np.float32:
482514
a = a.astype(np.complex64)
483515
else:
484516
a = a.astype(np.complex128)
485-
ovr_x = True
486517
ss, aa = _remove_axis(s, axes, -1)
487-
ind = [
488-
slice(None, None, 1),
489-
] * len(s)
518+
ind = [slice(None, None, 1)] * len(s)
490519
for ii in range(a.shape[la]):
491520
ind[la] = ii
492521
tind = tuple(ind)
493522
a_inp = a[tind]
494523
# out has real dtype and cannot be used in intermediate steps
495-
a_res = _c2c_fftnd_impl(
496-
a_inp, s=ss, axes=aa, overwrite_x=True, direction=-1
524+
# ss and aa are reversed since np.irfftn uses forward order but
525+
# np.ifftn uses reverse order see numpy-gh-28950
526+
_ = _c2c_fftnd_impl(
527+
a_inp, s=ss[::-1], axes=aa[::-1], out=a_inp, direction=-1
497528
)
498-
if a_res is not a_inp:
499-
a[tind] = a_res # copy in place
500529
else:
501530
# a series of 1D c2c FFTs along all axes except last
502-
for ii in range(len(axes) - 1):
503-
# out has real dtype and cannot be used in intermediate steps
504-
a = _c2c_fft1d_impl(
505-
a, s[ii], axes[ii], overwrite_x=ovr_x, direction=-1
506-
)
507-
ovr_x = True
531+
# forward order, see numpy-gh-28950
532+
axes_and_s = list(zip(axes, s))[:-1]
533+
size_changes = [
534+
axis for axis, n in axes_and_s[1:] if a.shape[axis] != n
535+
]
536+
# out has real dtype cannot be used for intermediate steps
537+
res = None
538+
for axis, n in axes_and_s:
539+
if axis in size_changes:
540+
if res is not None and n < res.shape[axis]:
541+
# pylint: disable=unsubscriptable-object
542+
res = res[(slice(None),) * axis + (slice(n),)]
543+
else:
544+
res = None
545+
a = _c2c_fft1d_impl(a, n, axis, out=res, direction=-1)
546+
res = a
508547
# c2r along last axis
509548
a = _c2r_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=out)
510549
return a

mkl_fft/_mkl_fft.py

Lines changed: 22 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -50,78 +50,32 @@
5050
]
5151

5252

53-
def fft(x, n=None, axis=-1, norm=None, out=None, overwrite_x=False):
53+
def fft(x, n=None, axis=-1, norm=None, out=None):
5454
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
55-
return _c2c_fft1d_impl(
56-
x,
57-
n=n,
58-
axis=axis,
59-
out=out,
60-
overwrite_x=overwrite_x,
61-
direction=+1,
62-
fsc=fsc,
63-
)
64-
65-
66-
def ifft(x, n=None, axis=-1, norm=None, out=None, overwrite_x=False):
55+
return _c2c_fft1d_impl(x, n=n, axis=axis, out=out, direction=+1, fsc=fsc)
56+
57+
58+
def ifft(x, n=None, axis=-1, norm=None, out=None):
6759
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
68-
return _c2c_fft1d_impl(
69-
x,
70-
n=n,
71-
axis=axis,
72-
out=out,
73-
overwrite_x=overwrite_x,
74-
direction=-1,
75-
fsc=fsc,
76-
)
77-
78-
79-
def fft2(x, s=None, axes=(-2, -1), norm=None, out=None, overwrite_x=False):
80-
return fftn(
81-
x,
82-
s=s,
83-
axes=axes,
84-
norm=norm,
85-
out=out,
86-
overwrite_x=overwrite_x,
87-
)
88-
89-
90-
def ifft2(x, s=None, axes=(-2, -1), norm=None, out=None, overwrite_x=False):
91-
return ifftn(
92-
x,
93-
s=s,
94-
axes=axes,
95-
norm=norm,
96-
out=out,
97-
overwrite_x=overwrite_x,
98-
)
99-
100-
101-
def fftn(x, s=None, axes=None, norm=None, out=None, overwrite_x=False):
60+
return _c2c_fft1d_impl(x, n=n, axis=axis, out=out, direction=-1, fsc=fsc)
61+
62+
63+
def fft2(x, s=None, axes=(-2, -1), norm=None, out=None):
64+
return fftn(x, s=s, axes=axes, norm=norm, out=out)
65+
66+
67+
def ifft2(x, s=None, axes=(-2, -1), norm=None, out=None):
68+
return ifftn(x, s=s, axes=axes, norm=norm, out=out)
69+
70+
71+
def fftn(x, s=None, axes=None, norm=None, out=None):
10272
fsc = _compute_fwd_scale(norm, s, x.shape)
103-
return _c2c_fftnd_impl(
104-
x,
105-
s=s,
106-
axes=axes,
107-
out=out,
108-
overwrite_x=overwrite_x,
109-
direction=+1,
110-
fsc=fsc,
111-
)
112-
113-
114-
def ifftn(x, s=None, axes=None, norm=None, out=None, overwrite_x=False):
73+
return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=+1, fsc=fsc)
74+
75+
76+
def ifftn(x, s=None, axes=None, norm=None, out=None):
11577
fsc = _compute_fwd_scale(norm, s, x.shape)
116-
return _c2c_fftnd_impl(
117-
x,
118-
s=s,
119-
axes=axes,
120-
out=out,
121-
overwrite_x=overwrite_x,
122-
direction=-1,
123-
fsc=fsc,
124-
)
78+
return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=-1, fsc=fsc)
12579

12680

12781
def rfft(x, n=None, axis=-1, norm=None, out=None):

0 commit comments

Comments
 (0)