diff --git a/wavemix/__init__.py b/wavemix/__init__.py index 414262b..ce86895 100644 --- a/wavemix/__init__.py +++ b/wavemix/__init__.py @@ -8,8 +8,6 @@ from einops.layers.torch import Rearrange -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1): """ 1D synthesis filter bank of an image tensor """ @@ -57,6 +55,7 @@ def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1): return y + def reflect(x, minx, maxx): """Reflect the values in matrix *x* about the scalar values *minx* and *maxx*. Hence a vector *x* containing a long linearly increasing series is @@ -74,6 +73,7 @@ def reflect(x, minx, maxx): out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx return np.array(out, dtype=x.dtype) + def mode_to_int(mode): if mode == 'zero': return 0 @@ -92,6 +92,7 @@ def mode_to_int(mode): else: raise ValueError("Unkown pad type: {}".format(mode)) + def int_to_mode(mode): if mode == 0: return 'zero' @@ -110,6 +111,7 @@ def int_to_mode(mode): else: raise ValueError("Unkown pad type: {}".format(mode)) + def afb1d(x, h0, h1, mode='zero', dim=-1): """ 1D analysis filter bank (along one dimension only) of an image Inputs: @@ -160,7 +162,7 @@ def afb1d(x, h0, h1, mode='zero', dim=-1): N += 1 x = roll(x, -L2, dim=d) pad = (L-1, 0) if d == 2 else (0, L-1) - lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + lohi = F.conv2d(x, h.to(x.device), padding=pad, stride=s, groups=C) N2 = N//2 if d == 2: lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2] @@ -181,18 +183,17 @@ def afb1d(x, h0, h1, mode='zero', dim=-1): x = F.pad(x, pad) pad = (p//2, 0) if d == 2 else (0, p//2) # Calculate the high and lowpass - lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + lohi = F.conv2d(x, h.to(x.device), padding=pad, stride=s, groups=C) elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic': pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0) x = mypad(x, pad=pad, mode=mode) - lohi = F.conv2d(x, h, stride=s, groups=C) + lohi = F.conv2d(x, h.to(x.device), stride=s, groups=C) else: raise ValueError("Unkown pad type: {}".format(mode)) return lohi - class AFB2D(Function): """ Does a single level 2d wavelet decomposition of an input. Does separate row and column filtering by two calls to @@ -245,7 +246,7 @@ def backward(ctx, low, highs): return dx, None, None, None, None, None -def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device): +def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device='cpu'): """ Prepares the filters to be of the right form for the afb2d function. In particular, makes the tensors the right shape. It takes mirror images of @@ -274,7 +275,7 @@ def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device): return h0_col, h1_col, h0_row, h1_row -def prep_filt_afb1d(h0, h1, device=device): +def prep_filt_afb1d(h0, h1, device='cpu'): """ Prepares the filters to be of the right form for the afb2d function. In particular, makes the tensors the right shape. It takes mirror images of @@ -293,6 +294,7 @@ def prep_filt_afb1d(h0, h1, device=device): h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1)) return h0, h1 + class DWTForward(nn.Module): """ Performs a 2d DWT Forward decomposition of an image Args: @@ -358,12 +360,20 @@ def forward(self, x): return ll, yh + from numpy.lib.function_base import hamming - -xf1 = DWTForward(J=1, mode='zero', wave='db1').to(device) -xf2 = DWTForward(J=2, mode='zero', wave='db1').to(device) -xf3 = DWTForward(J=3, mode='zero', wave='db1').to(device) -xf4 = DWTForward(J=4, mode='zero', wave='db1').to(device) + + +def get_dwt_filters(level, mode='zero', wave='db1'): + xf = [] + for j in range(1,level+1,1): + xf.append(DWTForward(J=j, mode=mode, wave=wave)) + + if level == 1: + xf = xf[0] + + return xf + class Level1Waveblock(nn.Module): def __init__( @@ -388,14 +398,15 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1 = get_dwt_filters(level=1) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) + Y1, Yh = self.xf1(x) x = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) @@ -405,6 +416,7 @@ def forward(self, x): return x + class Level2Waveblock(nn.Module): def __init__( self, @@ -435,15 +447,16 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1, self.xf2 = get_dwt_filters(level=2) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) - Y2, Yh = xf2(x) + Y1, Yh = self.xf1(x) + Y2, Yh = self.xf2(x) x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) @@ -499,16 +512,17 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1, self.xf2, self.xf3 = get_dwt_filters(level=3) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) - Y2, Yh = xf2(x) - Y3, Yh = xf3(x) + Y1, Yh = self.xf1(x) + Y2, Yh = self.xf2(x) + Y3, Yh = self.xf3(x) x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) @@ -582,17 +596,18 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1, self.xf2, self.xf3, self.xf4 = get_dwt_filters(level=4) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) - Y2, Yh = xf2(x) - Y3, Yh = xf3(x) - Y4, Yh = xf4(x) + Y1, Yh = self.xf1(x) + Y2, Yh = self.xf2(x) + Y3, Yh = self.xf3(x) + Y4, Yh = self.xf4(x) x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4))) diff --git a/wavemix/classification.py b/wavemix/classification.py index 4353f0e..cb864e7 100644 --- a/wavemix/classification.py +++ b/wavemix/classification.py @@ -2,6 +2,7 @@ import torch.nn as nn from einops.layers.torch import Rearrange + class WaveMix(nn.Module): def __init__( self, @@ -22,14 +23,14 @@ def __init__( self.layers = nn.ModuleList([]) for _ in range(depth): - if level == 4: - self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) - elif level == 3: - self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) - elif level == 2: - self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) - else: - self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + if level == 4: + self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + elif level == 3: + self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + elif level == 2: + self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + else: + self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(1), @@ -49,8 +50,7 @@ def __init__( nn.Conv2d(int(final_dim/2), final_dim, patch_size, patch_size), nn.GELU(), nn.BatchNorm2d(final_dim) - ) - + ) def forward(self, img): x = self.conv(img)