3333import contextvars
3434import operator
3535import os
36+ from numbers import Number
3637
3738import mkl
3839import numpy as np
@@ -156,30 +157,65 @@ def _check_plan(plan):
156157 )
157158
158159
159- def _check_overwrite_x ( overwrite_x ):
160- if overwrite_x :
161- raise NotImplementedError (
162- "Overwriting the content of `x` is currently not supported"
163- )
160+ # copied from scipy.fft._pocketfft.helper
161+ # https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
162+ def _iterable_of_int ( x , name = None ):
163+ if isinstance ( x , Number ):
164+ x = ( x , )
164165
166+ try :
167+ x = [operator .index (a ) for a in x ]
168+ except TypeError as e :
169+ name = name or "value"
170+ raise ValueError (
171+ f"{ name } must be a scalar or iterable of integers"
172+ ) from e
165173
166- def _cook_nd_args (x , s = None , axes = None , invreal = False ):
167- if s is None :
168- shapeless = True
169- if axes is None :
170- s = list (x .shape )
171- else :
172- s = np .take (x .shape , axes )
174+ return x
175+
176+
177+ # copied and modified from scipy.fft._pocketfft.helper
178+ # https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
179+ def _init_nd_shape_and_axes (x , shape , axes , invreal = False ):
180+ noshape = shape is None
181+ noaxes = axes is None
182+
183+ if not noaxes :
184+ axes = _iterable_of_int (axes , "axes" )
185+ axes = [a + x .ndim if a < 0 else a for a in axes ]
186+
187+ if any (a >= x .ndim or a < 0 for a in axes ):
188+ raise ValueError ("axes exceeds dimensionality of input" )
189+ if len (set (axes )) != len (axes ):
190+ raise ValueError ("all axes must be unique" )
191+
192+ if not noshape :
193+ shape = _iterable_of_int (shape , "shape" )
194+
195+ if axes and len (axes ) != len (shape ):
196+ raise ValueError (
197+ "when given, axes and shape arguments"
198+ " have to be of the same length"
199+ )
200+ if noaxes :
201+ if len (shape ) > x .ndim :
202+ raise ValueError ("shape requires more axes than are present" )
203+ axes = range (x .ndim - len (shape ), x .ndim )
204+
205+ shape = [x .shape [a ] if s == - 1 else s for s , a in zip (shape , axes )]
206+ elif noaxes :
207+ shape = list (x .shape )
208+ axes = range (x .ndim )
173209 else :
174- shapeless = False
175- s = list ( s )
176- if axes is None :
177- axes = list ( range ( - len ( s ), 0 ))
178- if len ( s ) != len ( axes ):
179- raise ValueError ( "Shape and axes have different lengths." )
180- if invreal and shapeless :
181- s [ - 1 ] = ( x . shape [ axes [ - 1 ]] - 1 ) * 2
182- return s , axes
210+ shape = [ x . shape [ a ] for a in axes ]
211+
212+ if noshape and invreal :
213+ shape [ - 1 ] = ( x . shape [ axes [ - 1 ]] - 1 ) * 2
214+
215+ if any ( s < 1 for s in shape ):
216+ raise ValueError ( f"invalid number of data points ( { shape } ) specified" )
217+
218+ return tuple ( shape ), list ( axes )
183219
184220
185221def _validate_input (x ):
@@ -301,7 +337,7 @@ def fftn(
301337 """
302338 _check_plan (plan )
303339 x = _validate_input (x )
304- s , axes = _cook_nd_args (x , s , axes )
340+ s , axes = _init_nd_shape_and_axes (x , s , axes )
305341 fsc = _compute_fwd_scale (norm , s , x .shape )
306342
307343 with _Workers (workers ):
@@ -328,7 +364,7 @@ def ifftn(
328364 """
329365 _check_plan (plan )
330366 x = _validate_input (x )
331- s , axes = _cook_nd_args (x , s , axes )
367+ s , axes = _init_nd_shape_and_axes (x , s , axes )
332368 fsc = _compute_fwd_scale (norm , s , x .shape )
333369
334370 with _Workers (workers ):
@@ -345,17 +381,13 @@ def rfft(
345381
346382 For full documentation refer to `scipy.fft.rfft`.
347383
348- Limitation
349- -----------
350- The kwarg `overwrite_x` is only supported with its default value.
351-
352384 """
353385 _check_plan (plan )
354- _check_overwrite_x (overwrite_x )
355386 x = _validate_input (x )
356387 fsc = _compute_fwd_scale (norm , n , x .shape [axis ])
357388
358389 with _Workers (workers ):
390+ # Note: overwrite_x is not utilized
359391 return mkl_fft .rfft (x , n = n , axis = axis , fwd_scale = fsc )
360392
361393
@@ -367,17 +399,13 @@ def irfft(
367399
368400 For full documentation refer to `scipy.fft.irfft`.
369401
370- Limitation
371- -----------
372- The kwarg `overwrite_x` is only supported with its default value.
373-
374402 """
375403 _check_plan (plan )
376- _check_overwrite_x (overwrite_x )
377404 x = _validate_input (x )
378405 fsc = _compute_fwd_scale (norm , n , 2 * (x .shape [axis ] - 1 ))
379406
380407 with _Workers (workers ):
408+ # Note: overwrite_x is not utilized
381409 return mkl_fft .irfft (x , n = n , axis = axis , fwd_scale = fsc )
382410
383411
@@ -396,10 +424,6 @@ def rfft2(
396424
397425 For full documentation refer to `scipy.fft.rfft2`.
398426
399- Limitation
400- -----------
401- The kwarg `overwrite_x` is only supported with its default value.
402-
403427 """
404428 return rfftn (
405429 x ,
@@ -427,10 +451,6 @@ def irfft2(
427451
428452 For full documentation refer to `scipy.fft.irfft2`.
429453
430- Limitation
431- -----------
432- The kwarg `overwrite_x` is only supported with its default value.
433-
434454 """
435455 return irfftn (
436456 x ,
@@ -458,18 +478,14 @@ def rfftn(
458478
459479 For full documentation refer to `scipy.fft.rfftn`.
460480
461- Limitation
462- -----------
463- The kwarg `overwrite_x` is only supported with its default value.
464-
465481 """
466482 _check_plan (plan )
467- _check_overwrite_x (overwrite_x )
468483 x = _validate_input (x )
469- s , axes = _cook_nd_args (x , s , axes )
484+ s , axes = _init_nd_shape_and_axes (x , s , axes )
470485 fsc = _compute_fwd_scale (norm , s , x .shape )
471486
472487 with _Workers (workers ):
488+ # Note: overwrite_x is not utilized
473489 return mkl_fft .rfftn (x , s , axes , fwd_scale = fsc )
474490
475491
@@ -488,18 +504,14 @@ def irfftn(
488504
489505 For full documentation refer to `scipy.fft.irfftn`.
490506
491- Limitation
492- -----------
493- The kwarg `overwrite_x` is only supported with its default value.
494-
495507 """
496508 _check_plan (plan )
497- _check_overwrite_x (overwrite_x )
498509 x = _validate_input (x )
499- s , axes = _cook_nd_args (x , s , axes , invreal = True )
510+ s , axes = _init_nd_shape_and_axes (x , s , axes , invreal = True )
500511 fsc = _compute_fwd_scale (norm , s , x .shape )
501512
502513 with _Workers (workers ):
514+ # Note: overwrite_x is not utilized
503515 return mkl_fft .irfftn (x , s , axes , fwd_scale = fsc )
504516
505517
@@ -512,20 +524,16 @@ def hfft(
512524
513525 For full documentation refer to `scipy.fft.hfft`.
514526
515- Limitation
516- -----------
517- The kwarg `overwrite_x` is only supported with its default value.
518-
519527 """
520528 _check_plan (plan )
521- _check_overwrite_x (overwrite_x )
522529 x = _validate_input (x )
523530 norm = _swap_direction (norm )
524531 x = np .array (x , copy = True )
525532 np .conjugate (x , out = x )
526533 fsc = _compute_fwd_scale (norm , n , 2 * (x .shape [axis ] - 1 ))
527534
528535 with _Workers (workers ):
536+ # Note: overwrite_x is not utilized
529537 return mkl_fft .irfft (x , n = n , axis = axis , fwd_scale = fsc )
530538
531539
@@ -537,18 +545,14 @@ def ihfft(
537545
538546 For full documentation refer to `scipy.fft.ihfft`.
539547
540- Limitation
541- -----------
542- The kwarg `overwrite_x` is only supported with its default value.
543-
544548 """
545549 _check_plan (plan )
546- _check_overwrite_x (overwrite_x )
547550 x = _validate_input (x )
548551 norm = _swap_direction (norm )
549552 fsc = _compute_fwd_scale (norm , n , x .shape [axis ])
550553
551554 with _Workers (workers ):
555+ # Note: overwrite_x is not utilized
552556 result = mkl_fft .rfft (x , n = n , axis = axis , fwd_scale = fsc )
553557
554558 np .conjugate (result , out = result )
@@ -570,10 +574,6 @@ def hfft2(
570574
571575 For full documentation refer to `scipy.fft.hfft2`.
572576
573- Limitation
574- -----------
575- The kwarg `overwrite_x` is only supported with its default value.
576-
577577 """
578578 return hfftn (
579579 x ,
@@ -601,10 +601,6 @@ def ihfft2(
601601
602602 For full documentation refer to `scipy.fft.ihfft2`.
603603
604- Limitation
605- -----------
606- The kwarg `overwrite_x` is only supported with its default value.
607-
608604 """
609605 return ihfftn (
610606 x ,
@@ -633,21 +629,17 @@ def hfftn(
633629
634630 For full documentation refer to `scipy.fft.hfftn`.
635631
636- Limitation
637- -----------
638- The kwarg `overwrite_x` is only supported with its default value.
639-
640632 """
641633 _check_plan (plan )
642- _check_overwrite_x (overwrite_x )
643634 x = _validate_input (x )
644635 norm = _swap_direction (norm )
645636 x = np .array (x , copy = True )
646637 np .conjugate (x , out = x )
647- s , axes = _cook_nd_args (x , s , axes , invreal = True )
638+ s , axes = _init_nd_shape_and_axes (x , s , axes , invreal = True )
648639 fsc = _compute_fwd_scale (norm , s , x .shape )
649640
650641 with _Workers (workers ):
642+ # Note: overwrite_x is not utilized
651643 return mkl_fft .irfftn (x , s , axes , fwd_scale = fsc )
652644
653645
@@ -666,19 +658,15 @@ def ihfftn(
666658
667659 For full documentation refer to `scipy.fft.ihfftn`.
668660
669- Limitation
670- -----------
671- The kwarg `overwrite_x` is only supported with its default value.
672-
673661 """
674662 _check_plan (plan )
675- _check_overwrite_x (overwrite_x )
676663 x = _validate_input (x )
677664 norm = _swap_direction (norm )
678- s , axes = _cook_nd_args (x , s , axes )
665+ s , axes = _init_nd_shape_and_axes (x , s , axes )
679666 fsc = _compute_fwd_scale (norm , s , x .shape )
680667
681668 with _Workers (workers ):
669+ # Note: overwrite_x is not utilized
682670 result = mkl_fft .rfftn (x , s , axes , fwd_scale = fsc )
683671
684672 np .conjugate (result , out = result )
0 commit comments