@@ -10,4 +10,62 @@ def __init__(self, device="cpu", torch_dtype=torch.float32):
10
10
11
11
def load (self , model : torch .nn .Module , state_dict_lora , alpha = 1.0 ):
12
12
lora_prefix , model_resource = self .loader .match (model , state_dict_lora )
13
- self .loader .load (model , state_dict_lora , lora_prefix , alpha = alpha , model_resource = model_resource )
13
+ self .loader .load (model , state_dict_lora , lora_prefix , alpha = alpha , model_resource = model_resource )
14
+
15
+ class LoraMerger (torch .nn .Module ):
16
+ def __init__ (self , dim ):
17
+ super ().__init__ ()
18
+ self .weight_base = torch .nn .Parameter (torch .randn ((dim ,)))
19
+ self .weight_lora = torch .nn .Parameter (torch .randn ((dim ,)))
20
+ self .weight_cross = torch .nn .Parameter (torch .randn ((dim ,)))
21
+ self .weight_out = torch .nn .Parameter (torch .ones ((dim ,)))
22
+ self .bias = torch .nn .Parameter (torch .randn ((dim ,)))
23
+ self .activation = torch .nn .Sigmoid ()
24
+ self .norm_base = torch .nn .LayerNorm (dim , eps = 1e-5 )
25
+ self .norm_lora = torch .nn .LayerNorm (dim , eps = 1e-5 )
26
+
27
+ def forward (self , base_output , lora_outputs ):
28
+ norm_base_output = self .norm_base (base_output )
29
+ norm_lora_outputs = self .norm_lora (lora_outputs )
30
+ gate = self .activation (
31
+ norm_base_output * self .weight_base \
32
+ + norm_lora_outputs * self .weight_lora \
33
+ + norm_base_output * norm_lora_outputs * self .weight_cross + self .bias
34
+ )
35
+ output = base_output + (self .weight_out * gate * lora_outputs ).sum (dim = 0 )
36
+ return output
37
+
38
+ class LoraPatcher (torch .nn .Module ):
39
+ def __init__ (self , lora_patterns = None ):
40
+ super ().__init__ ()
41
+ if lora_patterns is None :
42
+ lora_patterns = self .default_lora_patterns ()
43
+ model_dict = {}
44
+ for lora_pattern in lora_patterns :
45
+ name , dim = lora_pattern ["name" ], lora_pattern ["dim" ]
46
+ model_dict [name .replace ("." , "___" )] = LoraMerger (dim )
47
+ self .model_dict = torch .nn .ModuleDict (model_dict )
48
+
49
+ def default_lora_patterns (self ):
50
+ lora_patterns = []
51
+ lora_dict = {
52
+ "attn.a_to_qkv" : 9216 , "attn.a_to_out" : 3072 , "ff_a.0" : 12288 , "ff_a.2" : 3072 , "norm1_a.linear" : 18432 ,
53
+ "attn.b_to_qkv" : 9216 , "attn.b_to_out" : 3072 , "ff_b.0" : 12288 , "ff_b.2" : 3072 , "norm1_b.linear" : 18432 ,
54
+ }
55
+ for i in range (19 ):
56
+ for suffix in lora_dict :
57
+ lora_patterns .append ({
58
+ "name" : f"blocks.{ i } .{ suffix } " ,
59
+ "dim" : lora_dict [suffix ]
60
+ })
61
+ lora_dict = {"to_qkv_mlp" : 21504 , "proj_out" : 3072 , "norm.linear" : 9216 }
62
+ for i in range (38 ):
63
+ for suffix in lora_dict :
64
+ lora_patterns .append ({
65
+ "name" : f"single_blocks.{ i } .{ suffix } " ,
66
+ "dim" : lora_dict [suffix ]
67
+ })
68
+ return lora_patterns
69
+
70
+ def forward (self , base_output , lora_outputs , name ):
71
+ return self .model_dict [name .replace ("." , "___" )](base_output , lora_outputs )
0 commit comments