|
| 1 | +import keras |
| 2 | + |
| 3 | +from keras_hub.src.api_export import keras_hub_export |
| 4 | +from keras_hub.src.models.gemma import gemma_causal_lm |
| 5 | +from keras_hub.src.models.task import Task |
| 6 | + |
| 7 | + |
| 8 | +class ShieldGemmaViolationProbaility(keras.layers.Layer): |
| 9 | + """Relative probabilities for the 'Yes' (violating) and 'No' tokens.""" |
| 10 | + |
| 11 | + def __init__(self, yes_token_idx, no_token_idx, **kw): |
| 12 | + super().__init__(**kw) |
| 13 | + self.yes_token_idx = yes_token_idx |
| 14 | + self.no_token_idx = no_token_idx |
| 15 | + |
| 16 | + def call(self, logits, padding_mask): |
| 17 | + last_prompt_index = keras.ops.cast( |
| 18 | + keras.ops.sum(padding_mask, axis=1) - 1, "int32" |
| 19 | + ) |
| 20 | + last_logits = keras.ops.take(logits, last_prompt_index, axis=1)[:, 0] |
| 21 | + yes_logits = last_logits[:, self.yes_token_idx] |
| 22 | + no_logits = last_logits[:, self.no_token_idx] |
| 23 | + yes_no_logits = keras.ops.stack((yes_logits, no_logits), axis=1) |
| 24 | + return keras.ops.softmax(yes_no_logits, axis=1) |
| 25 | + |
| 26 | + |
| 27 | +@keras_hub_export("keras_hub.models.ShieldGemma") |
| 28 | +class ShieldGemma(Task): |
| 29 | + """A ShieldGemma model for safety content moderation, built on Gemma 2. |
| 30 | +
|
| 31 | + ShieldGemma is a Gemma 2 variant fine-tuned to detect and predict violations |
| 32 | + of four harm types—Harrassment, Hate Speech, Dangerous Content, and |
| 33 | + Sexual Content—in text content from a user or model. Architecturally, |
| 34 | + the weights are the same as any other Gemma 2 class, but the prediction is |
| 35 | + augmented with a final layer that takes returns the probability that the |
| 36 | + provided content violates the harm type specified in the prompt. |
| 37 | +
|
| 38 | + Links: |
| 39 | +
|
| 40 | + * https://arxiv.org/abs/2407.21772 |
| 41 | + * https://ai.google.dev/gemma/docs/shieldgemma/model_card |
| 42 | + * https://ai.google.dev/responsible/docs/safeguards/shieldgemma |
| 43 | + * https://www.kaggle.com/models/google/shieldgemma |
| 44 | +
|
| 45 | + Args: |
| 46 | + gemma: A `keras_hub.models.GemmaCausalLM` initialized with ShieldGemma |
| 47 | + weights. |
| 48 | +
|
| 49 | + Examples: |
| 50 | +
|
| 51 | + Coming soon. |
| 52 | + """ |
| 53 | + |
| 54 | + backbone_cls = gemma_causal_lm.GemmaCausalLM.backbone_cls |
| 55 | + preprocessor_cls = gemma_causal_lm.GemmaCausalLM.preprocessor_cls |
| 56 | + |
| 57 | + def __init__(self, gemma: gemma_causal_lm.GemmaCausalLM, **kwargs): |
| 58 | + # === Layers === |
| 59 | + self.gemma = gemma |
| 60 | + self.yes_no_layer = ShieldGemmaViolationProbaility( |
| 61 | + yes_token_idx=self.gemma.preprocessor.tokenizer.token_to_id("Yes"), |
| 62 | + no_token_idx=self.gemma.preprocessor.tokenizer.token_to_id("No"), |
| 63 | + ) |
| 64 | + self.backbone = self.gemma.backbone |
| 65 | + self.preprocessor = self.gemma.preprocessor |
| 66 | + |
| 67 | + # === Functional Model === |
| 68 | + inputs = self.gemma.input |
| 69 | + hidden_states = self.gemma(inputs) |
| 70 | + outputs = self.yes_no_layer(hidden_states, inputs["padding_mask"]) |
| 71 | + super().__init__(inputs=inputs, outputs=outputs, **kwargs) |
| 72 | + |
| 73 | + @classmethod |
| 74 | + def from_preset(cls, **kwargs): |
| 75 | + """Instantiate a `keras_hub.models.ShieldGemma` from a model preset.""" |
| 76 | + gemma = gemma_causal_lm.GemmaCausalLM.from_preset(**kwargs) |
| 77 | + return cls(gemma) |
0 commit comments