Skip to content

Commit e07fa9d

Browse files
authored
[https://nvbugs/5496960][fix] Fix Gemma model forward. (#7509)
Signed-off-by: Yukun He <[email protected]>
1 parent cabda24 commit e07fa9d

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tensorrt_llm/models/gemma/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,13 @@ def forward(self,
157157
if default_net().plugin_config.reduce_fusion else
158158
AllReduceFusionOp.NONE,
159159
residual=residual,
160-
norm_weight=self.pre_feedforward_layernorm.weight.value,
161-
norm_pre_residual_weight=self.post_layernorm.weight.value
160+
norm_weight=self.pre_feedforward_layernorm.weight.value
162161
if self.config.inter_layernorms else None,
163-
eps=self.pre_feedforward_layernorm.eps))
162+
norm_pre_residual_weight=self.post_layernorm.weight.value,
163+
eps=self.pre_feedforward_layernorm.eps
164+
if self.config.inter_layernorms else 1e-06,
165+
),
166+
)
164167

165168
if use_cache:
166169
attention_output, presents = attention_output

0 commit comments

Comments
 (0)