Skip to content

Commit aa4c0ca

Browse files
committed
add manual patching to mkl_fft
1 parent 142b483 commit aa4c0ca

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

mkl_fft/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444

4545
import mkl_fft.interfaces # isort: skip
4646

47+
from ._patch_numpy import mkl_fft, patch_numpy_fft, restore_numpy_fft, is_patched
48+
4749
__all__ = [
4850
"fft",
4951
"ifft",
@@ -60,6 +62,10 @@
6062
"rfftn",
6163
"irfftn",
6264
"interfaces",
65+
"mkl_fft",
66+
"patch_numpy_fft"
67+
"restore_numpy_fft",
68+
"is_patched",
6369
]
6470

6571
del _init_helper

mkl_fft/_patch_numpy.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
"""
28+
Define function for patching NumPy with MKL-based NumPy interface.
29+
"""
30+
31+
import numpy as np
32+
import mkl_fft.interfaces.numpy_fft as _nfft
33+
34+
from contextlib import ContextDecorator
35+
36+
from threading import local as threading_local
37+
38+
_tls = threading_local()
39+
40+
41+
class _Patch():
42+
"""
43+
Internal object for patching NumPy with mkl_fft interfaces.
44+
"""
45+
_is_patched = False
46+
__patched_functions__ = _nfft.__all__
47+
_restore_dict = {}
48+
49+
def _register_func(self, name, func):
50+
if name not in self.__patched_functions__:
51+
raise ValueError("%s not an mkl_fft function." % name)
52+
f = getattr(np.fft, name)
53+
self._restore_dict[name] = f
54+
setattr(np.fft, name, func)
55+
56+
def _restore_func(self, name):
57+
if name not in self.__patched_functions__:
58+
raise ValueError("%s not an mkl_fft function." % name)
59+
try:
60+
val = self._restore_dict[name]
61+
except KeyError:
62+
print("failed to restore")
63+
return
64+
else:
65+
print("found and restoring...")
66+
setattr(np.fft, name, val)
67+
68+
def restore(self):
69+
for name in self._restore_dict.keys():
70+
self._restore_func(name)
71+
self._is_patched = False
72+
73+
def do_patch(self):
74+
for f in self.__patched_functions__:
75+
self._register_func(f, getattr(_nfft, f))
76+
self._is_patched = True
77+
78+
def is_patched(self):
79+
return self._is_patched
80+
81+
82+
def _initialize_tls():
83+
_tls.patch = _Patch()
84+
_tls.initialized = True
85+
86+
87+
def _is_tls_initialized():
88+
return (getattr(_tls, "initialized", None) is not None) and (_tls.initialized == True)
89+
90+
91+
def patch_numpy_fft(verbose=False):
92+
if verbose:
93+
print(
94+
"Now patching NumPy FFT submodule with mkl_fft NumPy interface. "
95+
"Please direct bug reports to https://github.com/IntelPython/mkl_fft"
96+
)
97+
if not _is_tls_initialized():
98+
_initialize_tls()
99+
_tls.patch.do_patch()
100+
101+
102+
def restore_numpy_fft(verbose=False):
103+
if verbose:
104+
print("Now restoring original NumPy FFT submodule.")
105+
if not _is_tls_initialized():
106+
_initialize_tls()
107+
_tls.patch.restore()
108+
109+
110+
def is_patched():
111+
if not _is_tls_initialized():
112+
_initialize_tls()
113+
return _tls.patch.is_patched()
114+
115+
116+
class mkl_fft(ContextDecorator):
117+
def __enter__(self):
118+
patch_numpy_fft()
119+
return self
120+
121+
def __exit__(self, *exc):
122+
restore_numpy_fft()
123+
return False

0 commit comments

Comments
 (0)