Skip to content

Commit 81d7be9

Browse files
committed
Jit some structure factor functions
1 parent 6d0172a commit 81d7be9

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

diffsims/utils/sim_utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import collections
2020
import math
2121

22-
import diffpy.structure
22+
from numba import njit
2323
import numpy as np
2424
from scipy.constants import h, m_e, e, c, pi, mu_0
2525

@@ -215,7 +215,7 @@ def get_vectorized_list_for_atomic_scattering_factors(
215215

216216
return coeffs, fcoords, occus, dwfactors
217217

218-
218+
@njit(fastmath=True)
219219
def get_atomic_scattering_factors(g_hkl_sq, coeffs, scattering_params):
220220
"""Calculate atomic scattering factors for n atoms.
221221
@@ -277,9 +277,32 @@ def _get_kinematical_structure_factor(
277277
# Set all atomic scattering factors to 1
278278
atomic_scattering_factor = np.ones((gspacing_squared.shape[0], coeffs.shape[0]))
279279

280+
return _numba_get_kinematical_structure_factor(
281+
structure.lattice.stdbase,
282+
structure.lattice.recbase,
283+
xyz,
284+
atomic_scattering_factor,
285+
occupancy,
286+
g_indices,
287+
gspacing_squared,
288+
dwfactors,
289+
)
290+
291+
292+
@njit(fastmath=True)
293+
def _numba_get_kinematical_structure_factor(
294+
stdbase: np.ndarray,
295+
recbase: np.ndarray,
296+
xyz: np.ndarray,
297+
atomic_scattering_factor: np.ndarray,
298+
occupancy: np.ndarray,
299+
g_indices: np.ndarray,
300+
g_spacing_squared: np.ndarray,
301+
dwfactors: np.ndarray,
302+
) -> np.ndarray:
280303
# Express the atom positions in the same reference frame as the
281304
# Miller indices
282-
mat = np.linalg.inv(np.dot(structure.lattice.stdbase, structure.lattice.recbase))
305+
mat = np.linalg.inv(np.dot(stdbase, recbase))
283306
xyz = np.dot(xyz, mat)
284307

285308
# Calculate the complex structure factor
@@ -288,11 +311,10 @@ def _get_kinematical_structure_factor(
288311
* occupancy
289312
* np.exp(
290313
2j * np.pi * np.dot(g_indices, xyz.T)
291-
- 0.25 * np.outer(gspacing_squared, dwfactors)
314+
- 0.25 * np.outer(g_spacing_squared, dwfactors)
292315
),
293316
axis=-1,
294317
)
295-
296318
return structure_factor
297319

298320

0 commit comments

Comments
 (0)