diff --git a/langml/activations.py b/langml/activations.py index 8db02ed..5119058 100644 --- a/langml/activations.py +++ b/langml/activations.py @@ -5,7 +5,7 @@ import math -from langml import keras, K +from langml import keras, K, L from langml.tensor_typing import Tensors @@ -15,13 +15,22 @@ def gelu(x: Tensors) -> Tensors: $GELU(x) = 0.5x(1 + tanh[\sqrt(2 / \Pi) (x + 0.044715x^3)])$ """ - return 0.5 * x * (1.0 + K.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x**3))) def relu2(x: Tensors) -> Tensors: + """ ReLU Square + """ return K.pow(K.relu(x), 2) -custom_objects = {'gelu': gelu, 'relu2': relu2} +def swish(x: Tensors, beta: float = 1.0) -> Tensors: + return (x * K.sigmoid(beta * x)) + + +custom_objects = {} +custom_objects.update({'gelu': L.Activation(gelu)}) +custom_objects.update({'relu2': L.Activation(relu2)}) +custom_objects.update({'swish': L.Activation(swish)}) + keras.utils.get_custom_objects().update(custom_objects)