66import torch .nn .functional as F
77from einops import rearrange , repeat
88
9- from dalle_pytorch .cache import Cached
10-
119# helpers
1210
1311def exists (val ):
@@ -41,6 +39,8 @@ def apply_rotary_emb(freqs, t):
4139 return torch .cat ((t , t_right ), dim = - 1 )
4240
4341def apply_pos_emb (pos_emb , qkv ):
42+ n = qkv [0 ].shape [- 2 ]
43+ pos_emb = pos_emb [..., :n , :]
4444 return tuple (map (lambda t : apply_rotary_emb (pos_emb , t ), qkv ))
4545
4646# classes
@@ -65,30 +65,24 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
6565 def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
6666 b , n , _ , h , device = * x .shape , self .heads , x .device
6767 softmax = torch .softmax if not self .stable else stable_softmax
68+ using_cache = exists (cache ) and cache_key in cache
6869
69- qkv_key = f'{ cache_key } _qkv'
70- if exists (cache ) and qkv_key in cache :
71- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
72- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
70+ qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
71+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
7372
74- if exists (rotary_pos_emb ):
75- q , k , v = apply_pos_emb (rotary_pos_emb [..., n - 1 :n , :], (q , k , v )) # FIXME: Fix rotary index here
73+ if exists (rotary_pos_emb ):
74+ if using_cache :
75+ rotary_pos_emb = rotary_pos_emb [..., n - 1 :, :] # FIXME: Fix rotary index here
76+ q , k , v = apply_pos_emb (rotary_pos_emb , (q , k , v ))
7677
77- q *= self .scale
78+ q = q * self .scale
7879
79- k_top , v_top = cache [qkv_key ]
80+ if using_cache :
81+ k_top , v_top = cache [cache_key ]
8082 k = torch .cat ([k_top , k ], dim = - 2 )
8183 v = torch .cat ([v_top , v ], dim = - 2 )
82- else :
83- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
84- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
85-
86- if exists (rotary_pos_emb ):
87- q , k , v = apply_pos_emb (rotary_pos_emb [..., :n , :], (q , k , v ))
88-
89- q *= self .scale
9084 if exists (cache ):
91- cache [qkv_key ] = ( k , v )
85+ cache [cache_key ] = k , v
9286
9387 dots = q @ k .swapaxes (- 1 , - 2 )
9488 mask_value = max_neg_value (dots )
@@ -98,17 +92,16 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9892 dots .masked_fill_ (~ mask , mask_value )
9993 del mask
10094
101- # if self.causal: # TODO:
102- # i, j = dots.shape[-2:]
103- # mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
104- # dots.masked_fill_(mask, mask_value)
95+ if self .causal and not using_cache : # causality is naturally enforced if we run the cached inference
96+ i , j = dots .shape [- 2 :]
97+ mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
98+ dots .masked_fill_ (mask , mask_value )
10599
106100 attn = softmax (dots , dim = - 1 )
107101
108102 out = attn @ v
109103 out = rearrange (out , 'b h n d -> b n (h d)' )
110104 out = self .to_out (out )
111-
112105 return out
113106
114107# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
@@ -128,14 +121,14 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
128121
129122 self .stable = stable
130123
131- self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
124+ self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
132125
133- self .to_out = Cached ( nn .Sequential (
126+ self .to_out = nn .Sequential (
134127 nn .Linear (inner_dim , dim ),
135128 nn .Dropout (dropout )
136- ))
129+ )
137130
138- def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
131+ def forward (self , x , mask = None , rotary_pos_emb = None ):
139132 b , n , _ , h , img_size , kernel_size , dilation , seq_len , device = * x .shape , self .heads , self .image_size , self .kernel_size , self .dilation , self .seq_len , x .device
140133 softmax = torch .softmax if not self .stable else stable_softmax
141134
@@ -152,7 +145,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
152145
153146 # derive query / keys / values
154147
155- qkv = self .to_qkv (x , cache = cache , cache_key = f' { cache_key } _qkv' ).chunk (3 , dim = - 1 )
148+ qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
156149 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
157150
158151 if exists (rotary_pos_emb ):
@@ -229,7 +222,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
229222 out = torch .cat ((out_text , out_image ), dim = 1 )
230223
231224 out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
232- out = self .to_out (out , cache = cache , cache_key = f' { cache_key } _out' )
225+ out = self .to_out (out )
233226 return out [:, :n ]
234227
235228# sparse axial causal attention
@@ -248,14 +241,14 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
248241
249242 self .stable = stable
250243
251- self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
244+ self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
252245
253- self .to_out = Cached ( nn .Sequential (
246+ self .to_out = nn .Sequential (
254247 nn .Linear (inner_dim , dim ),
255248 nn .Dropout (dropout )
256- ))
249+ )
257250
258- def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
251+ def forward (self , x , mask = None , rotary_pos_emb = None ):
259252 b , n , _ , h , img_size , axis , seq_len , device = * x .shape , self .heads , self .image_size , self .axis , self .seq_len , x .device
260253 softmax = torch .softmax if not self .stable else stable_softmax
261254
@@ -272,7 +265,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
272265
273266 # derive queries / keys / values
274267
275- qkv = self .to_qkv (x , cache = cache , cache_key = f' { cache_key } _qkv' ).chunk (3 , dim = - 1 )
268+ qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
276269 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
277270
278271 if exists (rotary_pos_emb ):
@@ -284,15 +277,15 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
284277
285278 # text attention
286279
287- dots_text = q_text @ k_text . swapaxes ( - 1 , - 2 )
280+ dots_text = einsum ( 'b i d, b j d -> b i j' , q_text , k_text )
288281 mask_value = max_neg_value (dots_text )
289282
290283 i , j = dots_text .shape [- 2 :]
291284 text_causal_mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
292285 dots_text .masked_fill_ (text_causal_mask , mask_value )
293286
294287 attn_text = softmax (dots_text , dim = - 1 )
295- out_text = attn_text @ v_text
288+ out_text = einsum ( 'b i j, b j d -> b i d' , attn_text , v_text )
296289
297290 # image attention
298291
@@ -305,8 +298,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
305298
306299 # similarity
307300
308- dots_image_to_image = q_img @ k_img . swapaxes ( - 1 , - 2 )
309- dots_image_to_text = q_img @ k_text [:, None ]. swapaxes ( - 1 , - 2 )
301+ dots_image_to_image = einsum ( 'b x i d, b x j d -> b x i j' , q_img , k_img )
302+ dots_image_to_text = einsum ( 'b x i d, b j d -> b x i j' , q_img , k_text )
310303
311304 dots = torch .cat ((dots_image_to_text , dots_image_to_image ), dim = - 1 )
312305
@@ -329,8 +322,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
329322
330323 attn_image_to_text , attn_image_to_image = attn [..., :text_len ], attn [..., text_len :]
331324
332- out_image_to_image = attn_image_to_image @ v_img
333- out_image_to_text = attn_image_to_text @ v_text [:, None ]
325+ out_image_to_image = einsum ( 'b x i j, b x j d -> b x i d' , attn_image_to_image , v_img )
326+ out_image_to_text = einsum ( 'b x i j, b j d -> b x i d' , attn_image_to_text , v_text )
334327
335328 out_image = out_image_to_image + out_image_to_text
336329
@@ -343,7 +336,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
343336 out = torch .cat ((out_text , out_image ), dim = 1 )
344337
345338 out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
346- out = self .to_out (out , cache = cache , cache_key = f' { cache_key } _out' )
339+ out = self .to_out (out )
347340 return out [:, :n ]
348341
349342# microsoft sparse attention CUDA kernel
0 commit comments