Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/tike/operators/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Wraps cupy operators in torch.autograd.Function"""

from .lamino import *

__all__ = [
LaminoFunction,
LaminoModule,
]
76 changes: 76 additions & 0 deletions src/tike/operators/torch/lamino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import cupy as cp
import torch

import tike.operators.cupy


class LaminoFunction(torch.autograd.Function):
"""The forward/adjoint laminography operations.

Parameters
----------
u : (N, N, N, 2) tensor float32
A (3 + 1)D tensor where the first dimensions are spatial dimensions and
the last dimension of len 2 is the real/imaginary components. Pytorch
doesn't presently have good complex-value support.
theta : (M, ) tensor float32
The rotation angles of the projections.
tilt : float
The laminography angle
output : (M, N, N, 2) float32
Projections through the volume at each rotation angle.

"""

@staticmethod
def forward(ctx, u, theta, tilt=cp.pi / 2):
ctx.n = u.shape[0]
ctx.tilt = tilt
ctx.save_for_backward(theta)
with tike.operators.cupy.Lamino(
n=ctx.n,
tilt=ctx.tilt,
eps=1e-6,
upsample=2,
) as operator:
output = operator.fwd(
u=cp.asarray(torch.view_as_complex(u).detach(),
dtype='complex64'),
theta=cp.asarray(theta, dtype='float32'),
)
output = torch.view_as_real(torch.as_tensor(output, device=u.device))
return output

@staticmethod
def backward(ctx, grad_output):
theta, = ctx.saved_tensors
with tike.operators.cupy.Lamino(
n=ctx.n,
tilt=ctx.tilt,
eps=1e-6,
upsample=2,
) as operator:
grad_u = operator.adj(
data=cp.asarray(torch.view_as_complex(grad_output),
dtype='complex64'),
theta=cp.asarray(theta, dtype='float32'),
) / grad_output.shape[0]
grad_u = torch.view_as_real(
torch.as_tensor(grad_u, device=grad_output.device))
grad_theta = grad_tilt = None
return grad_u, grad_theta, grad_tilt


class LaminoModule(torch.nn.Module):

def __init__(self, width):
super(LaminoModule, self).__init__()
self.width = width
self.weight = torch.nn.Parameter(
torch.zeros(width, width, width, 2, dtype=torch.float32))

def forward(self, theta, tilt=cp.pi / 2):
return LaminoFunction.apply(self.weight, theta, tilt)

def extra_repr(self):
return f'width={self.width}'
136 changes: 136 additions & 0 deletions tests/operators/torch/test_lamino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import lzma
import os
import pickle
import unittest

import cupy as cp
import numpy as np
import torch
from torch.nn.modules.loss import GaussianNLLLoss

from tike.operators.torch import LaminoFunction, LaminoModule


@unittest.skip('single precision is not enough to pass gradcheck')
def test_lamino_gradcheck(n=16, ntheta=8):

lamino = LaminoFunction.apply

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (
torch.randn(
n,
n,
n,
2,
dtype=torch.float32,
requires_grad=True,
device='cpu',
),
cp.pi * torch.randn(
ntheta,
dtype=torch.float32,
requires_grad=False,
device='cpu',
),
)
test = torch.autograd.gradcheck(
lamino,
input,
eps=1e-6,
atol=1e-4,
nondet_tol=1e-6,
)
print(test)


testdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))


class L2Loss(torch.nn.Module):

def forward(self, input, target):
return torch.mean(torch.square(torch.abs(input - target)))


class TestLaminoModel(unittest.TestCase):

def setUp(self):
"""Load a dataset for reconstruction."""
dataset_file = os.path.join(testdir, 'data/lamino_setup.pickle.lzma')
if not os.path.isfile(dataset_file):
self.create_dataset(dataset_file)
with lzma.open(dataset_file, 'rb') as file:
[
self.data,
self.original,
self.theta,
self.tilt,
] = pickle.load(file)

def test_lamino_model(self, num_epoch=32, device=0):

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

theta = torch.from_numpy(self.theta).type(torch.float32).to(device)
data = torch.view_as_real(
torch.from_numpy(self.data).type(torch.complex64)).to(device)
var = torch.ones(data.shape, dtype=torch.float32,
requires_grad=True).to(device)

model = LaminoModule(data.shape[1]).to(device)
lossf = GaussianNLLLoss().to(device)
optimizer = torch.optim.Adam(model.parameters())

loss_log = []
for epoch in range(num_epoch):
pred = model(theta, self.tilt)
loss = lossf(pred, data, var)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_log.append(loss.item())
print(f"loss: {loss_log[-1]:.3e} [{epoch:>5d}/{num_epoch:>5d}]")

obj = torch.view_as_complex(model.weight.cpu().detach()).numpy()

_save_lamino_result({'obj': obj, 'costs': loss_log}, 'torch')


def _save_lamino_result(result, algorithm):
try:
import matplotlib.pyplot as plt
fname = os.path.join(testdir, 'result', 'lamino', f'{algorithm}')
os.makedirs(fname, exist_ok=True)
plt.figure()
plt.title(algorithm)
plt.plot(result['costs'])
plt.semilogy()
plt.savefig(os.path.join(fname, 'convergence.svg'))
slice_id = int(35 / 128 * result['obj'].shape[0])
plt.imsave(
f'{fname}/{slice_id}-phase.png',
np.angle(result['obj'][slice_id]).astype('float32'),
# The output of np.angle is locked to (-pi, pi]
cmap=plt.cm.twilight,
vmin=-np.pi,
vmax=np.pi,
)
plt.imsave(
f'{fname}/{slice_id}-ampli.png',
np.abs(result['obj'][slice_id]).astype('float32'),
)
import skimage.io
skimage.io.imsave(
f'{fname}/phase.tiff',
np.angle(result['obj']).astype('float32'),
)
skimage.io.imsave(
f'{fname}/ampli.tiff',
np.abs(result['obj']).astype('float32'),
)

except ImportError:
pass
5 changes: 5 additions & 0 deletions tests/test_lamino.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ def _save_lamino_result(result, algorithm):
import matplotlib.pyplot as plt
fname = os.path.join(testdir, 'result', 'lamino', f'{algorithm}')
os.makedirs(fname, exist_ok=True)
plt.figure()
plt.title(algorithm)
plt.plot(result['cost'])
plt.semilogy()
plt.savefig(os.path.join(fname, 'convergence.svg'))
slice_id = int(35 / 128 * result['obj'].shape[0])
plt.imsave(
f'{fname}/{slice_id}-phase.png',
Expand Down