Skip to content

Commit 4c4571f

Browse files
committed
Skeleton ShieldGemma class
1 parent 5a7ecb6 commit 4c4571f

File tree

3 files changed

+80
-3
lines changed

3 files changed

+80
-3
lines changed

keras_hub/src/models/gemma/gemma_presets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@
206206
"metadata": {
207207
"description": "2 billion parameter, 26-layer, ShieldGemma model.",
208208
"params": 2614341888,
209-
"official_name": "Gemma",
209+
"official_name": "ShieldGemma",
210210
"path": "gemma",
211211
"model_card": "https://www.kaggle.com/models/google/shieldgemma",
212212
},
@@ -216,7 +216,7 @@
216216
"metadata": {
217217
"description": "9 billion parameter, 42-layer, ShieldGemma model.",
218218
"params": 9241705984,
219-
"official_name": "Gemma",
219+
"official_name": "ShieldGemma",
220220
"path": "gemma",
221221
"model_card": "https://www.kaggle.com/models/google/shieldgemma",
222222
},
@@ -226,7 +226,7 @@
226226
"metadata": {
227227
"description": "27 billion parameter, 42-layer, ShieldGemma model.",
228228
"params": 27227128320,
229-
"official_name": "Gemma",
229+
"official_name": "ShieldGemma",
230230
"path": "gemma",
231231
"model_card": "https://www.kaggle.com/models/google/shieldgemma",
232232
},
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.gemma import gemma_causal_lm
5+
from keras_hub.src.models.task import Task
6+
7+
8+
class ShieldGemmaViolationProbaility(keras.layers.Layer):
9+
"""Relative probabilities for the 'Yes' (violating) and 'No' tokens."""
10+
11+
def __init__(self, yes_token_idx, no_token_idx, **kw):
12+
super().__init__(**kw)
13+
self.yes_token_idx = yes_token_idx
14+
self.no_token_idx = no_token_idx
15+
16+
def call(self, logits, padding_mask):
17+
last_prompt_index = keras.ops.cast(
18+
keras.ops.sum(padding_mask, axis=1) - 1, "int32"
19+
)
20+
last_logits = keras.ops.take(logits, last_prompt_index, axis=1)[:, 0]
21+
yes_logits = last_logits[:, self.yes_token_idx]
22+
no_logits = last_logits[:, self.no_token_idx]
23+
yes_no_logits = keras.ops.stack((yes_logits, no_logits), axis=1)
24+
return keras.ops.softmax(yes_no_logits, axis=1)
25+
26+
27+
@keras_hub_export("keras_hub.models.ShieldGemma")
28+
class ShieldGemma(Task):
29+
"""A ShieldGemma model for safety content moderation, built on Gemma 2.
30+
31+
ShieldGemma is a Gemma 2 variant fine-tuned to detect and predict violations
32+
of four harm types—Harrassment, Hate Speech, Dangerous Content, and
33+
Sexual Content—in text content from a user or model. Architecturally,
34+
the weights are the same as any other Gemma 2 class, but the prediction is
35+
augmented with a final layer that takes returns the probability that the
36+
provided content violates the harm type specified in the prompt.
37+
38+
Links:
39+
40+
* https://arxiv.org/abs/2407.21772
41+
* https://ai.google.dev/gemma/docs/shieldgemma/model_card
42+
* https://ai.google.dev/responsible/docs/safeguards/shieldgemma
43+
* https://www.kaggle.com/models/google/shieldgemma
44+
45+
Args:
46+
gemma: A `keras_hub.models.GemmaCausalLM` initialized with ShieldGemma
47+
weights.
48+
49+
Examples:
50+
51+
Coming soon.
52+
"""
53+
54+
backbone_cls = gemma_causal_lm.GemmaCausalLM.backbone_cls
55+
preprocessor_cls = gemma_causal_lm.GemmaCausalLM.preprocessor_cls
56+
57+
def __init__(self, gemma: gemma_causal_lm.GemmaCausalLM, **kwargs):
58+
# === Layers ===
59+
self.gemma = gemma
60+
self.yes_no_layer = ShieldGemmaViolationProbaility(
61+
yes_token_idx=self.gemma.preprocessor.tokenizer.token_to_id("Yes"),
62+
no_token_idx=self.gemma.preprocessor.tokenizer.token_to_id("No"),
63+
)
64+
self.backbone = self.gemma.backbone
65+
self.preprocessor = self.gemma.preprocessor
66+
67+
# === Functional Model ===
68+
inputs = self.gemma.input
69+
hidden_states = self.gemma(inputs)
70+
outputs = self.yes_no_layer(hidden_states, inputs["padding_mask"])
71+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
72+
73+
@classmethod
74+
def from_preset(cls, **kwargs):
75+
"""Instantiate a `keras_hub.models.ShieldGemma` from a model preset."""
76+
gemma = gemma_causal_lm.GemmaCausalLM.from_preset(**kwargs)
77+
return cls(gemma)

keras_hub/src/models/gemma/shieldgemma_test.py

Whitespace-only changes.

0 commit comments

Comments
 (0)