Skip to content
82 changes: 64 additions & 18 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import cupy as cp
import cupyx.scipy.ndimage
import numpy as np
import scipy.signal

import tike.linalg
import tike.random
Expand Down Expand Up @@ -85,6 +86,12 @@ class ProbeOptions:
probe_support_degree: float = 2.5
"""Degree of the supergaussian defining the probe support; zero or greater."""

weights_smooth: float = 0.5
"""Relative weight of the polynomial term in variable probe smoothing. [0.0, 1.0]"""

weights_smooth_order: int = 0
"""Highest degree of variable probe smoothing polynomial."""

def copy_to_device(self):
"""Copy to the current GPU memory."""
if self.v is not None:
Expand Down Expand Up @@ -136,7 +143,7 @@ def get_varying_probe(shared_probe, eigen_probe=None, weights=None):
return shared_probe.copy()


def _constrain_variable_probe1(variable_probe, weights):
def _constrain_variable_probe1(probe, variable_probe, weights):
"""Help use the thread pool with constrain_variable_probe"""

# Normalize variable probes
Expand All @@ -146,11 +153,18 @@ def _constrain_variable_probe1(variable_probe, weights):
weights[..., 1:, :probes_with_modes] *= vnorm[..., 0, 0]

# Orthogonalize variable probes

# Make variable probes orthogonal to main probe too
all_probes = cp.concatenate(
[probe[..., :variable_probe.shape[-3], :, :], variable_probe],
axis=-4,
)

variable_probe = tike.linalg.orthogonalize_gs(
variable_probe,
all_probes,
axis=(-2, -1),
N=-4,
)
)[..., 1:, :, :, :]

# Compute probe energy in order to sort probes by energy
power = tike.linalg.norm(
Expand Down Expand Up @@ -187,7 +201,12 @@ def _constrain_variable_probe2(variable_probe, weights, power):
return variable_probe, weights


def constrain_variable_probe(comm, variable_probe, weights):
def _split(m, x, dtype):
return cp.asarray(x[m], dtype=dtype)


def constrain_variable_probe(comm, probe, variable_probe, weights,
probe_options):
"""Add the following constraints to variable probe weights

1. Remove outliars from weights
Expand All @@ -196,13 +215,35 @@ def constrain_variable_probe(comm, variable_probe, weights):
4. Normalize the variable probes so the energy is contained in the weight

"""
# TODO: No smoothing of variable probe weights yet because the weights are
# not stored consecutively in device memory. Smoothing would require either
# sorting and synchronizing the weights with the host OR implementing
# smoothing of non-gridded data with splines using device-local data only.
reorder = np.argsort(np.concatenate(comm.order))
hweights = comm.pool.gather(
weights,
axis=-3,
)[reorder].get()

if probe_options.weights_smooth_order > 0:
logger.info("Smoothing variable probe weights with "
f"{probe_options.weights_smooth_order} order polynomials.")
# Fit the 1st order and higher variable probe weights to a polynomial
modelx = np.arange(len(hweights))
for i in range(1, hweights.shape[-2]):
fitted_weights = np.polynomial.Polynomial.fit(
x=modelx,
y=hweights[..., i, 0],
deg=probe_options.weights_smooth_order,
)(modelx)
hweights[
..., i, 0] = (1.0 - probe_options.weights_smooth) * hweights[
..., i, 0] + probe_options.weights_smooth * fitted_weights

# Force weights to have zero mean
# hweights[..., 1:, :] -= np.mean(hweights[..., 1:, :], axis=-3, keepdims=True)

weights = comm.pool.map(_split, comm.order, x=hweights, dtype='float32')

variable_probe, weights, power = zip(*comm.pool.map(
_constrain_variable_probe1,
probe,
variable_probe,
weights,
))
Expand Down Expand Up @@ -238,7 +279,7 @@ def _get_update(R, eigen_probe, weights, *, c, m):
)


def _get_d(patches, diff, eigen_probe, update, *, β, c, m):
def _get_d(patches, diff, eigen_probe, update, weights, *, β, c, m):
eigen_probe[..., c - 1:c, m:m + 1, :, :] += β * update / tike.linalg.mnorm(
update,
axis=(-2, -1),
Expand Down Expand Up @@ -354,6 +395,7 @@ def update_eigen_probe(
diff,
eigen_probe,
update,
weights,
β=β,
c=c,
m=m,
Expand Down Expand Up @@ -438,6 +480,14 @@ def add_modes_random_phase(probe, nmodes):
return all_modes


def _pyramid():
rank = 0
while True:
for i in zip(range(rank+1), range(rank, -1, -1)):
yield i
rank = rank + 1


def add_modes_cartesian_hermite(probe, nmodes: int):
"""Create more probes from a 2D Cartesian Hermite basis functions.

Expand Down Expand Up @@ -469,9 +519,6 @@ def add_modes_cartesian_hermite(probe, nmodes: int):
raise ValueError(f"probe is incorrect shape is should be "
" (..., 1, W, new_probes) not {probe.shape}.")

M = int(np.ceil(np.sqrt(nmodes)))
N = int(np.ceil(nmodes / M))

X, Y = np.meshgrid(
np.arange(probe.shape[-2]) - (probe.shape[-2] // 2 - 1),
np.arange(probe.shape[-1]) - (probe.shape[-2] // 2 - 1),
Expand Down Expand Up @@ -518,8 +565,7 @@ def add_modes_cartesian_hermite(probe, nmodes: int):

# Create basis
new_probes = list()
for nii in range(N):
for mii in range(M):
for nii, mii in _pyramid():

basis = ((X - cenx)**mii) * ((Y - ceny)**nii) * probe

Expand Down Expand Up @@ -664,10 +710,10 @@ def orthogonalize_eig(x):
# descending order.
vectors = vectors[..., ::-1].swapaxes(-1, -2)
result = (vectors @ x.reshape(*x.shape[:-2], -1)).reshape(*x.shape)
assert np.all(
np.diff(tike.linalg.norm(result, axis=(-2, -1), keepdims=False),
axis=-1) <= 0
), f"Power of the orthogonalized probes should be monotonically decreasing! {val}"
# assert np.all(
# np.diff(tike.linalg.norm(result, axis=(-2, -1), keepdims=False),
# axis=-1) <= 0
# ), f"Power of the orthogonalized probes should be monotonically decreasing! {val}"
return result


Expand Down
29 changes: 21 additions & 8 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def iterate(self, num_iter: int) -> None:
parameters=self._device_parameters,
)

if self._device_parameters.object_options.clip_magnitude:
if self._device_parameters.object_options and self._device_parameters.object_options.clip_magnitude:
self._device_parameters.psi = self.comm.pool.map(
_clip_magnitude,
self._device_parameters.psi,
Expand Down Expand Up @@ -558,11 +558,13 @@ def append_new_data(
new_scan,
axis=0,
)
assert len(self._device_parameters.scan[0]) == len(self.data[0])
self.comm.order = self.comm.pool.map(
_order_join,
self.comm.order,
order,
)
assert len(self.comm.order[0]) == len(self.data[0])

# Rebatch on each device
self.batches = self.comm.pool.map(
Expand All @@ -574,22 +576,33 @@ def append_new_data(

if self._device_parameters.eigen_weights is not None:
self._device_parameters.eigen_weights = self.comm.pool.map(
cp.pad,
_pad_weights,
self._device_parameters.eigen_weights,
pad_width=(
(0, len(new_scan)), # position
(0, 0), # eigen
(0, 0), # shared
),
mode='mean',
new_scan,
)
assert len(self._device_parameters.eigen_weights[0]) == len(
self.data[0])

if self._device_parameters.position_options is not None:
self._device_parameters.position_options = self.comm.pool.map(
PositionOptions.append,
self._device_parameters.position_options,
new_scan,
)
assert len(self._device_parameters.position_options[0].initial_scan
== len(self.data[0]))


def _pad_weights(weights, new_scan):
return cp.pad(
weights,
pad_width=(
(0, len(new_scan)), # position
(0, 0), # eigen
(0, 0), # shared
),
mode='mean',
)


def _order_join(a, b):
Expand Down
93 changes: 50 additions & 43 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,34 @@ def lstsq_grad(
eigen_probe = parameters.eigen_probe
eigen_weights = parameters.eigen_weights

if probe_options and probe_options.orthogonality_constraint:
probe = comm.pool.map(tike.ptycho.probe.orthogonalize_eig, probe)

if object_options:
psi = comm.pool.map(positivity_constraint,
psi,
r=object_options.positivity_constraint)

psi = comm.pool.map(smoothness_constraint,
psi,
a=object_options.smoothness_constraint)

if eigen_probe is not None:
eigen_probe, eigen_weights = tike.ptycho.probe.constrain_variable_probe(
comm,
probe,
eigen_probe,
eigen_weights,
probe_options,
)

if eigen_probe is None:
beigen_probe = [None] * comm.pool.num_workers
else:
beigen_probe = eigen_probe

preconditioner = [None] * comm.pool.num_workers
if object_options is not None:
preconditioner = [None] * comm.pool.num_workers
for n in range(len(batches[0])):
bscan = comm.pool.map(tike.opt.get_batch, scan, batches, n=n)
preconditioner = comm.pool.map(
Expand All @@ -92,19 +113,18 @@ def lstsq_grad(
preconditioner = comm.pool.allreduce(preconditioner)
# Use a rolling average of this preconditioner and the previous
# preconditioner
if object_options.preconditioner is None:
object_options.preconditioner = preconditioner
else:
object_options.preconditioner = comm.pool.map(
if object_options.preconditioner is not None:
preconditioner = comm.pool.map(
cp.add,
object_options.preconditioner,
preconditioner,
)
object_options.preconditioner = comm.pool.map(
preconditioner = comm.pool.map(
cp.divide,
object_options.preconditioner,
preconditioner,
[2] * comm.pool.num_workers,
)
object_options.preconditioner = preconditioner

if algorithm_options.batch_method == 'cluster_compact':
object_options.combined_update = cp.zeros_like(psi[0])
Expand Down Expand Up @@ -176,7 +196,7 @@ def lstsq_grad(
probe_options is not None,
bposition_options,
num_batch=algorithm_options.num_batch,
psi_update_denominator=object_options.preconditioner,
psi_update_denominator=preconditioner,
object_options=object_options,
probe_options=probe_options,
algorithm_options=algorithm_options,
Expand Down Expand Up @@ -225,25 +245,6 @@ def lstsq_grad(
psi[0] = psi[0] + dpsi
psi = comm.pool.bcast([psi[0]])

if probe_options and probe_options.orthogonality_constraint:
probe = comm.pool.map(tike.ptycho.probe.orthogonalize_eig, probe)

if object_options:
psi = comm.pool.map(positivity_constraint,
psi,
r=object_options.positivity_constraint)

psi = comm.pool.map(smoothness_constraint,
psi,
a=object_options.smoothness_constraint)

if eigen_probe is not None:
eigen_probe, eigen_weights = tike.ptycho.probe.constrain_variable_probe(
comm,
beigen_probe,
eigen_weights,
)

algorithm_options.costs.append(batch_cost)
parameters.probe = probe
parameters.psi = psi
Expand Down Expand Up @@ -418,23 +419,27 @@ def _update_nearplane(
recover_probe=recover_probe,
)))
if comm.use_mpi:
weighted_step_psi[0] = comm.Allreduce_mean(
weighted_step_psi,
axis=-5,
)[..., 0, 0, 0]
weighted_step_probe[0] = comm.Allreduce_mean(
weighted_step_probe,
axis=-5,
)
if recover_psi:
weighted_step_psi[0] = comm.Allreduce_mean(
weighted_step_psi,
axis=-5,
)[..., 0, 0, 0]
if recover_probe:
weighted_step_probe[0] = comm.Allreduce_mean(
weighted_step_probe,
axis=-5,
)
else:
weighted_step_psi[0] = comm.pool.reduce_mean(
weighted_step_psi,
axis=-5,
)[..., 0, 0, 0]
weighted_step_probe[0] = comm.pool.reduce_mean(
weighted_step_probe,
axis=-5,
)
if recover_psi:
weighted_step_psi[0] = comm.pool.reduce_mean(
weighted_step_psi,
axis=-5,
)[..., 0, 0, 0]
if recover_probe:
weighted_step_probe[0] = comm.pool.reduce_mean(
weighted_step_probe,
axis=-5,
)

if m == 0 and recover_probe and eigen_weights[0] is not None:
logger.info('Updating eigen probes')
Expand Down Expand Up @@ -732,6 +737,8 @@ def _get_nearplane_steps(diff, dOP, dPO, A1, A4, recover_psi, recover_probe):

# (27b) Object update
weighted_step_psi = cp.mean(step, keepdims=True, axis=-5)
else:
weighted_step_psi = None

if recover_probe:
step = 0.9 * cp.maximum(0, x2[..., None, None].real)
Expand Down
Loading