diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index c5304e05..6d437f28 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -41,6 +41,7 @@ import cupy as cp import cupyx.scipy.ndimage import numpy as np +import scipy.signal import tike.linalg import tike.random @@ -85,6 +86,12 @@ class ProbeOptions: probe_support_degree: float = 2.5 """Degree of the supergaussian defining the probe support; zero or greater.""" + weights_smooth: float = 0.5 + """Relative weight of the polynomial term in variable probe smoothing. [0.0, 1.0]""" + + weights_smooth_order: int = 0 + """Highest degree of variable probe smoothing polynomial.""" + def copy_to_device(self): """Copy to the current GPU memory.""" if self.v is not None: @@ -136,7 +143,7 @@ def get_varying_probe(shared_probe, eigen_probe=None, weights=None): return shared_probe.copy() -def _constrain_variable_probe1(variable_probe, weights): +def _constrain_variable_probe1(probe, variable_probe, weights): """Help use the thread pool with constrain_variable_probe""" # Normalize variable probes @@ -146,11 +153,18 @@ def _constrain_variable_probe1(variable_probe, weights): weights[..., 1:, :probes_with_modes] *= vnorm[..., 0, 0] # Orthogonalize variable probes + + # Make variable probes orthogonal to main probe too + all_probes = cp.concatenate( + [probe[..., :variable_probe.shape[-3], :, :], variable_probe], + axis=-4, + ) + variable_probe = tike.linalg.orthogonalize_gs( - variable_probe, + all_probes, axis=(-2, -1), N=-4, - ) + )[..., 1:, :, :, :] # Compute probe energy in order to sort probes by energy power = tike.linalg.norm( @@ -187,7 +201,12 @@ def _constrain_variable_probe2(variable_probe, weights, power): return variable_probe, weights -def constrain_variable_probe(comm, variable_probe, weights): +def _split(m, x, dtype): + return cp.asarray(x[m], dtype=dtype) + + +def constrain_variable_probe(comm, probe, variable_probe, weights, + probe_options): """Add the following constraints to variable probe weights 1. Remove outliars from weights @@ -196,13 +215,35 @@ def constrain_variable_probe(comm, variable_probe, weights): 4. Normalize the variable probes so the energy is contained in the weight """ - # TODO: No smoothing of variable probe weights yet because the weights are - # not stored consecutively in device memory. Smoothing would require either - # sorting and synchronizing the weights with the host OR implementing - # smoothing of non-gridded data with splines using device-local data only. + reorder = np.argsort(np.concatenate(comm.order)) + hweights = comm.pool.gather( + weights, + axis=-3, + )[reorder].get() + + if probe_options.weights_smooth_order > 0: + logger.info("Smoothing variable probe weights with " + f"{probe_options.weights_smooth_order} order polynomials.") + # Fit the 1st order and higher variable probe weights to a polynomial + modelx = np.arange(len(hweights)) + for i in range(1, hweights.shape[-2]): + fitted_weights = np.polynomial.Polynomial.fit( + x=modelx, + y=hweights[..., i, 0], + deg=probe_options.weights_smooth_order, + )(modelx) + hweights[ + ..., i, 0] = (1.0 - probe_options.weights_smooth) * hweights[ + ..., i, 0] + probe_options.weights_smooth * fitted_weights + + # Force weights to have zero mean + # hweights[..., 1:, :] -= np.mean(hweights[..., 1:, :], axis=-3, keepdims=True) + + weights = comm.pool.map(_split, comm.order, x=hweights, dtype='float32') variable_probe, weights, power = zip(*comm.pool.map( _constrain_variable_probe1, + probe, variable_probe, weights, )) @@ -238,7 +279,7 @@ def _get_update(R, eigen_probe, weights, *, c, m): ) -def _get_d(patches, diff, eigen_probe, update, *, β, c, m): +def _get_d(patches, diff, eigen_probe, update, weights, *, β, c, m): eigen_probe[..., c - 1:c, m:m + 1, :, :] += β * update / tike.linalg.mnorm( update, axis=(-2, -1), @@ -354,6 +395,7 @@ def update_eigen_probe( diff, eigen_probe, update, + weights, β=β, c=c, m=m, @@ -438,6 +480,14 @@ def add_modes_random_phase(probe, nmodes): return all_modes +def _pyramid(): + rank = 0 + while True: + for i in zip(range(rank+1), range(rank, -1, -1)): + yield i + rank = rank + 1 + + def add_modes_cartesian_hermite(probe, nmodes: int): """Create more probes from a 2D Cartesian Hermite basis functions. @@ -469,9 +519,6 @@ def add_modes_cartesian_hermite(probe, nmodes: int): raise ValueError(f"probe is incorrect shape is should be " " (..., 1, W, new_probes) not {probe.shape}.") - M = int(np.ceil(np.sqrt(nmodes))) - N = int(np.ceil(nmodes / M)) - X, Y = np.meshgrid( np.arange(probe.shape[-2]) - (probe.shape[-2] // 2 - 1), np.arange(probe.shape[-1]) - (probe.shape[-2] // 2 - 1), @@ -518,8 +565,7 @@ def add_modes_cartesian_hermite(probe, nmodes: int): # Create basis new_probes = list() - for nii in range(N): - for mii in range(M): + for nii, mii in _pyramid(): basis = ((X - cenx)**mii) * ((Y - ceny)**nii) * probe @@ -664,10 +710,10 @@ def orthogonalize_eig(x): # descending order. vectors = vectors[..., ::-1].swapaxes(-1, -2) result = (vectors @ x.reshape(*x.shape[:-2], -1)).reshape(*x.shape) - assert np.all( - np.diff(tike.linalg.norm(result, axis=(-2, -1), keepdims=False), - axis=-1) <= 0 - ), f"Power of the orthogonalized probes should be monotonically decreasing! {val}" + # assert np.all( + # np.diff(tike.linalg.norm(result, axis=(-2, -1), keepdims=False), + # axis=-1) <= 0 + # ), f"Power of the orthogonalized probes should be monotonically decreasing! {val}" return result diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 1a4cb592..f0a3796c 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -404,7 +404,7 @@ def iterate(self, num_iter: int) -> None: parameters=self._device_parameters, ) - if self._device_parameters.object_options.clip_magnitude: + if self._device_parameters.object_options and self._device_parameters.object_options.clip_magnitude: self._device_parameters.psi = self.comm.pool.map( _clip_magnitude, self._device_parameters.psi, @@ -558,11 +558,13 @@ def append_new_data( new_scan, axis=0, ) + assert len(self._device_parameters.scan[0]) == len(self.data[0]) self.comm.order = self.comm.pool.map( _order_join, self.comm.order, order, ) + assert len(self.comm.order[0]) == len(self.data[0]) # Rebatch on each device self.batches = self.comm.pool.map( @@ -574,15 +576,12 @@ def append_new_data( if self._device_parameters.eigen_weights is not None: self._device_parameters.eigen_weights = self.comm.pool.map( - cp.pad, + _pad_weights, self._device_parameters.eigen_weights, - pad_width=( - (0, len(new_scan)), # position - (0, 0), # eigen - (0, 0), # shared - ), - mode='mean', + new_scan, ) + assert len(self._device_parameters.eigen_weights[0]) == len( + self.data[0]) if self._device_parameters.position_options is not None: self._device_parameters.position_options = self.comm.pool.map( @@ -590,6 +589,20 @@ def append_new_data( self._device_parameters.position_options, new_scan, ) + assert len(self._device_parameters.position_options[0].initial_scan + == len(self.data[0])) + + +def _pad_weights(weights, new_scan): + return cp.pad( + weights, + pad_width=( + (0, len(new_scan)), # position + (0, 0), # eigen + (0, 0), # shared + ), + mode='mean', + ) def _order_join(a, b): diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 3929e194..3feebeab 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -69,13 +69,34 @@ def lstsq_grad( eigen_probe = parameters.eigen_probe eigen_weights = parameters.eigen_weights + if probe_options and probe_options.orthogonality_constraint: + probe = comm.pool.map(tike.ptycho.probe.orthogonalize_eig, probe) + + if object_options: + psi = comm.pool.map(positivity_constraint, + psi, + r=object_options.positivity_constraint) + + psi = comm.pool.map(smoothness_constraint, + psi, + a=object_options.smoothness_constraint) + + if eigen_probe is not None: + eigen_probe, eigen_weights = tike.ptycho.probe.constrain_variable_probe( + comm, + probe, + eigen_probe, + eigen_weights, + probe_options, + ) + if eigen_probe is None: beigen_probe = [None] * comm.pool.num_workers else: beigen_probe = eigen_probe + preconditioner = [None] * comm.pool.num_workers if object_options is not None: - preconditioner = [None] * comm.pool.num_workers for n in range(len(batches[0])): bscan = comm.pool.map(tike.opt.get_batch, scan, batches, n=n) preconditioner = comm.pool.map( @@ -92,19 +113,18 @@ def lstsq_grad( preconditioner = comm.pool.allreduce(preconditioner) # Use a rolling average of this preconditioner and the previous # preconditioner - if object_options.preconditioner is None: - object_options.preconditioner = preconditioner - else: - object_options.preconditioner = comm.pool.map( + if object_options.preconditioner is not None: + preconditioner = comm.pool.map( cp.add, object_options.preconditioner, preconditioner, ) - object_options.preconditioner = comm.pool.map( + preconditioner = comm.pool.map( cp.divide, - object_options.preconditioner, + preconditioner, [2] * comm.pool.num_workers, ) + object_options.preconditioner = preconditioner if algorithm_options.batch_method == 'cluster_compact': object_options.combined_update = cp.zeros_like(psi[0]) @@ -176,7 +196,7 @@ def lstsq_grad( probe_options is not None, bposition_options, num_batch=algorithm_options.num_batch, - psi_update_denominator=object_options.preconditioner, + psi_update_denominator=preconditioner, object_options=object_options, probe_options=probe_options, algorithm_options=algorithm_options, @@ -225,25 +245,6 @@ def lstsq_grad( psi[0] = psi[0] + dpsi psi = comm.pool.bcast([psi[0]]) - if probe_options and probe_options.orthogonality_constraint: - probe = comm.pool.map(tike.ptycho.probe.orthogonalize_eig, probe) - - if object_options: - psi = comm.pool.map(positivity_constraint, - psi, - r=object_options.positivity_constraint) - - psi = comm.pool.map(smoothness_constraint, - psi, - a=object_options.smoothness_constraint) - - if eigen_probe is not None: - eigen_probe, eigen_weights = tike.ptycho.probe.constrain_variable_probe( - comm, - beigen_probe, - eigen_weights, - ) - algorithm_options.costs.append(batch_cost) parameters.probe = probe parameters.psi = psi @@ -418,23 +419,27 @@ def _update_nearplane( recover_probe=recover_probe, ))) if comm.use_mpi: - weighted_step_psi[0] = comm.Allreduce_mean( - weighted_step_psi, - axis=-5, - )[..., 0, 0, 0] - weighted_step_probe[0] = comm.Allreduce_mean( - weighted_step_probe, - axis=-5, - ) + if recover_psi: + weighted_step_psi[0] = comm.Allreduce_mean( + weighted_step_psi, + axis=-5, + )[..., 0, 0, 0] + if recover_probe: + weighted_step_probe[0] = comm.Allreduce_mean( + weighted_step_probe, + axis=-5, + ) else: - weighted_step_psi[0] = comm.pool.reduce_mean( - weighted_step_psi, - axis=-5, - )[..., 0, 0, 0] - weighted_step_probe[0] = comm.pool.reduce_mean( - weighted_step_probe, - axis=-5, - ) + if recover_psi: + weighted_step_psi[0] = comm.pool.reduce_mean( + weighted_step_psi, + axis=-5, + )[..., 0, 0, 0] + if recover_probe: + weighted_step_probe[0] = comm.pool.reduce_mean( + weighted_step_probe, + axis=-5, + ) if m == 0 and recover_probe and eigen_weights[0] is not None: logger.info('Updating eigen probes') @@ -732,6 +737,8 @@ def _get_nearplane_steps(diff, dOP, dPO, A1, A4, recover_psi, recover_probe): # (27b) Object update weighted_step_psi = cp.mean(step, keepdims=True, axis=-5) + else: + weighted_step_psi = None if recover_probe: step = 0.9 * cp.maximum(0, x2[..., None, None].real) diff --git a/src/tike/view.py b/src/tike/view.py index 1350ec3d..918c8615 100644 --- a/src/tike/view.py +++ b/src/tike/view.py @@ -525,10 +525,19 @@ def plot_cost_convergence(costs, times): ax2 = ax1.twiny() color = 'red' - ax2.set_xlabel('wall-time [s]', color=color) + ax2.loglog() ax2.plot(np.cumsum(times), costs, color=color, alpha=alpha) ax2.tick_params(axis='x', labelcolor=color) + try: + import humanize + def humanize_time(tick_val, tick_pos): + return humanize.naturaldelta(tick_val) + ax2.xaxis.set_major_formatter(humanize_time) + ax2.set_xlabel('wall-time', color=color) + except ImportError: + ax2.set_xlabel('wall-time [s]', color=color) + return ax1, ax2 diff --git a/tests/matlab/variable_intensity_input0.mat b/tests/matlab/variable_intensity_input0.mat new file mode 100644 index 00000000..e38d472e Binary files /dev/null and b/tests/matlab/variable_intensity_input0.mat differ diff --git a/tests/matlab/variable_intensity_input1.mat b/tests/matlab/variable_intensity_input1.mat new file mode 100644 index 00000000..c9af3fad Binary files /dev/null and b/tests/matlab/variable_intensity_input1.mat differ diff --git a/tests/matlab/variable_intensity_output.mat b/tests/matlab/variable_intensity_output.mat new file mode 100644 index 00000000..068ef7c9 Binary files /dev/null and b/tests/matlab/variable_intensity_output.mat differ diff --git a/tests/matlab/variable_intensity_test.m b/tests/matlab/variable_intensity_test.m new file mode 100644 index 00000000..8dbf0b57 --- /dev/null +++ b/tests/matlab/variable_intensity_test.m @@ -0,0 +1,32 @@ +import engines.GPU.GPU_wrapper.* + +pd = makedist('Normal'); + +probe = cast(random(pd, [236, 236]) + 1i * random(pd, [236, 236]), 'single'); +obj_proj_tmp = cast(random(pd, [236, 236, 120]) + 1i * random(pd, [236, 236, 120]), 'single'); +chi_tmp = cast(random(pd, [236, 236, 120]) + 1i * random(pd, [236, 236, 120]), 'single'); +g_ind_tmp = cast(1:120, 'int32'); +probe_evolution = cast(random(pd, [120, 1]), 'single'); +kk = 1; + +save('variable_intensity_input0.mat', 'probe', 'obj_proj_tmp', '-v7.3'); +save('variable_intensity_input1.mat', 'chi_tmp', 'g_ind_tmp', 'probe_evolution', '-v7.3'); + +% correction to account for variable intensity +mean_probe = probe(:,:,kk,1); +% compare P*0 and chi to estimate best update of the intensity +[nom, denom] = get_coefs_intensity(chi_tmp, mean_probe, obj_proj_tmp); + +probe_evolution(g_ind_tmp,1) = probe_evolution(g_ind_tmp,1) + 0.1 * squeeze(Ggather(sum2(nom)./ sum2(denom))); + +save('variable_intensity_output.mat', 'probe_evolution', 'nom', 'denom', '-v7.3'); + +function [nom1, denom1] = get_coefs_intensity(xi, P, O) + OP = O.*P; + nom1 = real(conj(OP) .* xi); + denom1 = abs(OP).^2; +end + +function x = sum2(x) + x = sum(sum(x,1),2); +end \ No newline at end of file diff --git a/tests/matlab/variable_intensity_test.py b/tests/matlab/variable_intensity_test.py new file mode 100644 index 00000000..6174aeac --- /dev/null +++ b/tests/matlab/variable_intensity_test.py @@ -0,0 +1,54 @@ +import os.path +import cupy as cp +import numpy as np +import matplotlib.pyplot as plt +import h5py + +_dir = os.path.dirname(__file__) + +def test_variable_intensity(): + """Test that the variable intensity coefficient update is consistent.""" + + from tike.ptycho.solvers.lstsq import _get_coefs_intensity + + + with ( + h5py.File(os.path.join(_dir, 'variable_intensity_input0.mat'), 'r') as input0, + h5py.File(os.path.join(_dir, 'variable_intensity_input1.mat'), 'r') as input1, + h5py.File(os.path.join(_dir, 'variable_intensity_output.mat'), 'r') as output, + ): + + P = input0['probe'][...][None, None, None].view('complex64') + O = input0['obj_proj_tmp'][...][:, None, None].view('complex64') + xi = input1['chi_tmp'][...][:, None, None].view('complex64') + weights = input1['probe_evolution'][...].transpose()[..., None] + m = 0 + + ref_weights = output['probe_evolution'][...].transpose()[..., None] + + assert weights.shape == (120, 1, 1) + assert ref_weights.shape == (120, 1, 1) + assert xi.shape == (120, 1, 1, 236, 236) + assert P.shape == (1, 1, 1, 236, 236) + assert O.shape == (120, 1, 1, 236, 236) + + new_weights = cp.asnumpy(_get_coefs_intensity( + cp.asarray(weights), + cp.asarray(xi), + cp.asarray(P), + cp.asarray(O), + m, + )) + + plt.figure() + width = 0.5 + x = np.arange(120) + plt.bar(x - width, ref_weights[:,0,0], width) + plt.bar(x, new_weights[:,0,0], width) + plt.legend(['ptychoshelves', 'tike']) + plt.savefig(os.path.join(_dir, 'variable_intensity.svg')) + + np.testing.assert_allclose( + ref_weights, + new_weights, + ) diff --git a/tests/ptycho/test_ptycho.py b/tests/ptycho/test_ptycho.py index fd354d6e..8ab47c66 100644 --- a/tests/ptycho/test_ptycho.py +++ b/tests/ptycho/test_ptycho.py @@ -635,7 +635,6 @@ def _save_ptycho_result(result, algorithm): result.algorithm_options.costs, result.algorithm_options.times, ) - ax2.set_xlim(0, 20) fig.suptitle(algorithm) fig.tight_layout() plt.savefig(os.path.join(fname, 'convergence.svg'))