Skip to content

Commit 9e90311

Browse files
WindQAQseanpmorgan
authored andcommitted
unify activations and tests (#551)
* clean up activation/test * test general properties for activations
1 parent 8c94e2f commit 9e90311

File tree

17 files changed

+113
-154
lines changed

17 files changed

+113
-154
lines changed

tensorflow_addons/activations/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@ py_library(
1919
srcs_version = "PY2AND3",
2020
)
2121

22+
py_test(
23+
name = "activations_test",
24+
size = "small",
25+
srcs = [
26+
"activations_test.py",
27+
],
28+
main = "activations_test.py",
29+
srcs_version = "PY2AND3",
30+
deps = [
31+
":activations",
32+
],
33+
)
34+
2235
py_test(
2336
name = "sparsemax_test",
2437
size = "small",

tensorflow_addons/activations/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ must:
3535
or `run_all_in_graph_and_eager_modes` (for TestCase subclass)
3636
decorator.
3737
* Add a `py_test` to this sub-package's BUILD file.
38+
* Add activation name to [activations_test.py](https://github.com/tensorflow/addons/tree/master/tensorflow_addons/activations/activations_test.py) to test serialization.
3839

3940
#### Documentation Requirements
4041
* Update the table of contents in this sub-package's README.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
from tensorflow_addons import activations
22+
from tensorflow_addons.utils import test_utils
23+
24+
25+
@test_utils.run_all_in_graph_and_eager_modes
26+
class ActivationsTest(tf.test.TestCase):
27+
28+
ALL_ACTIVATIONS = [
29+
"gelu", "hardshrink", "lisht", "sparsemax", "tanhshrink"
30+
]
31+
32+
def test_serialization(self):
33+
for name in self.ALL_ACTIVATIONS:
34+
fn = tf.keras.activations.get(name)
35+
ref_fn = getattr(activations, name)
36+
self.assertEqual(fn, ref_fn)
37+
config = tf.keras.activations.serialize(fn)
38+
fn = tf.keras.activations.deserialize(config)
39+
self.assertEqual(fn, ref_fn)
40+
41+
def test_serialization_with_layers(self):
42+
for name in self.ALL_ACTIVATIONS:
43+
layer = tf.keras.layers.Dense(
44+
3, activation=getattr(activations, name))
45+
config = tf.keras.layers.serialize(layer)
46+
deserialized_layer = tf.keras.layers.deserialize(config)
47+
self.assertEqual(deserialized_layer.__class__.__name__,
48+
layer.__class__.__name__)
49+
self.assertEqual(deserialized_layer.activation.__name__, name)

tensorflow_addons/activations/gelu_test.py

Lines changed: 8 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,58 +19,33 @@
1919

2020
from absl.testing import parameterized
2121

22-
import math
23-
2422
import numpy as np
2523
import tensorflow as tf
2624
from tensorflow_addons.activations import gelu
2725
from tensorflow_addons.utils import test_utils
2826

2927

30-
def _ref_gelu(x, approximate=True):
31-
x = tf.convert_to_tensor(x)
32-
if approximate:
33-
pi = tf.cast(math.pi, x.dtype)
34-
coeff = tf.cast(0.044715, x.dtype)
35-
return 0.5 * x * (
36-
1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
37-
else:
38-
return 0.5 * x * (
39-
1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
40-
41-
4228
@test_utils.run_all_in_graph_and_eager_modes
4329
class GeluTest(tf.test.TestCase, parameterized.TestCase):
4430
@parameterized.named_parameters(("float16", np.float16),
4531
("float32", np.float32),
4632
("float64", np.float64))
4733
def test_gelu(self, dtype):
48-
x = np.random.rand(2, 3, 4).astype(dtype)
49-
self.assertAllCloseAccordingToType(gelu(x), _ref_gelu(x))
50-
self.assertAllCloseAccordingToType(gelu(x, False), _ref_gelu(x, False))
34+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
35+
expected_result = tf.constant(
36+
[-0.04540229, -0.158808, 0.0, 0.841192, 1.9545977], dtype=dtype)
37+
self.assertAllCloseAccordingToType(gelu(x), expected_result)
5138

52-
@parameterized.named_parameters(("float16", np.float16),
53-
("float32", np.float32),
54-
("float64", np.float64))
55-
def test_gradients(self, dtype):
56-
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
57-
58-
for approximate in [True, False]:
59-
with self.subTest(approximate=approximate):
60-
with tf.GradientTape(persistent=True) as tape:
61-
tape.watch(x)
62-
y_ref = _ref_gelu(x, approximate)
63-
y = gelu(x, approximate)
64-
grad_ref = tape.gradient(y_ref, x)
65-
grad = tape.gradient(y, x)
66-
self.assertAllCloseAccordingToType(grad, grad_ref)
39+
expected_result = tf.constant(
40+
[-0.04550028, -0.15865526, 0.0, 0.8413447, 1.9544997], dtype=dtype)
41+
self.assertAllCloseAccordingToType(gelu(x, False), expected_result)
6742

6843
@parameterized.named_parameters(("float32", np.float32),
6944
("float64", np.float64))
7045
def test_theoretical_gradients(self, dtype):
7146
# Only test theoretical gradients for float32 and float64
7247
# because of the instability of float16 while computing jacobian
73-
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
48+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
7449

7550
for approximate in [True, False]:
7651
with self.subTest(approximate=approximate):
@@ -87,20 +62,6 @@ def test_unknown_shape(self):
8762
x = tf.ones(shape=shape, dtype=tf.float32)
8863
self.assertAllClose(fn(x), gelu(x))
8964

90-
def test_serialization(self):
91-
ref_fn = gelu
92-
config = tf.keras.activations.serialize(ref_fn)
93-
fn = tf.keras.activations.deserialize(config)
94-
self.assertEqual(fn, ref_fn)
95-
96-
def test_serialization_with_layers(self):
97-
layer = tf.keras.layers.Dense(3, activation=gelu)
98-
config = tf.keras.layers.serialize(layer)
99-
deserialized_layer = tf.keras.layers.deserialize(config)
100-
self.assertEqual(deserialized_layer.__class__.__name__,
101-
layer.__class__.__name__)
102-
self.assertEqual(deserialized_layer.activation.__name__, "gelu")
103-
10465

10566
if __name__ == "__main__":
10667
tf.test.main()

tensorflow_addons/activations/hardshrink_test.py

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@
2525
from tensorflow_addons.utils import test_utils
2626

2727

28-
def _ref_hardshrink(x, lower=-1.0, upper=1.0):
29-
x = tf.convert_to_tensor(x)
30-
return tf.where(tf.math.logical_or(x < lower, x > upper), x, 0.0)
31-
32-
3328
@test_utils.run_all_in_graph_and_eager_modes
3429
class HardshrinkTest(tf.test.TestCase, parameterized.TestCase):
3530
def test_invalid(self):
@@ -42,34 +37,25 @@ def test_invalid(self):
4237
("float32", np.float32),
4338
("float64", np.float64))
4439
def test_hardshrink(self, dtype):
45-
x = (np.random.rand(2, 3, 4) * 2.0 - 1.0).astype(dtype)
46-
self.assertAllCloseAccordingToType(hardshrink(x), _ref_hardshrink(x))
47-
self.assertAllCloseAccordingToType(
48-
hardshrink(x, -2.0, 2.0), _ref_hardshrink(x, -2.0, 2.0))
40+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
41+
expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype)
42+
self.assertAllCloseAccordingToType(hardshrink(x), expected_result)
4943

50-
@parameterized.named_parameters(("float16", np.float16),
51-
("float32", np.float32),
52-
("float64", np.float64))
53-
def test_gradients(self, dtype):
54-
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)
55-
56-
with tf.GradientTape(persistent=True) as tape:
57-
tape.watch(x)
58-
y_ref = _ref_hardshrink(x)
59-
y = hardshrink(x)
60-
grad_ref = tape.gradient(y_ref, x)
61-
grad = tape.gradient(y, x)
62-
self.assertAllCloseAccordingToType(grad, grad_ref)
44+
expected_result = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
45+
self.assertAllCloseAccordingToType(
46+
hardshrink(x, lower=-0.5, upper=0.5), expected_result)
6347

6448
@parameterized.named_parameters(("float32", np.float32),
6549
("float64", np.float64))
6650
def test_theoretical_gradients(self, dtype):
6751
# Only test theoretical gradients for float32 and float64
6852
# because of the instability of float16 while computing jacobian
69-
x = tf.constant([-1.5, -0.5, 0.5, 1.5], dtype=dtype)
7053

71-
theoretical, numerical = tf.test.compute_gradient(
72-
lambda x: hardshrink(x), [x])
54+
# Hardshrink is not continuous at `lower` and `upper`.
55+
# Avoid these two points to make gradients smooth.
56+
x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype)
57+
58+
theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
7359
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
7460

7561
def test_unknown_shape(self):
@@ -80,20 +66,6 @@ def test_unknown_shape(self):
8066
x = tf.ones(shape=shape, dtype=tf.float32)
8167
self.assertAllClose(fn(x), hardshrink(x))
8268

83-
def test_serialization(self):
84-
ref_fn = hardshrink
85-
config = tf.keras.activations.serialize(ref_fn)
86-
fn = tf.keras.activations.deserialize(config)
87-
self.assertEqual(fn, ref_fn)
88-
89-
def test_serialization_with_layers(self):
90-
layer = tf.keras.layers.Dense(3, activation=hardshrink)
91-
config = tf.keras.layers.serialize(layer)
92-
deserialized_layer = tf.keras.layers.deserialize(config)
93-
self.assertEqual(deserialized_layer.__class__.__name__,
94-
layer.__class__.__name__)
95-
self.assertEqual(deserialized_layer.activation.__name__, "hardshrink")
96-
9769

9870
if __name__ == "__main__":
9971
tf.test.main()

tensorflow_addons/activations/lisht_test.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,6 @@ def test_unknown_shape(self):
5555
x = tf.ones(shape=shape, dtype=tf.float32)
5656
self.assertAllClose(fn(x), lisht(x))
5757

58-
def test_serialization(self):
59-
config = tf.keras.activations.serialize(lisht)
60-
fn = tf.keras.activations.deserialize(config)
61-
self.assertEqual(fn, lisht)
62-
63-
def test_serialization_with_layers(self):
64-
layer = tf.keras.layers.Dense(3, activation=lisht)
65-
config = tf.keras.layers.serialize(layer)
66-
deserialized_layer = tf.keras.layers.deserialize(config)
67-
self.assertEqual(deserialized_layer.__class__.__name__,
68-
layer.__class__.__name__)
69-
self.assertEqual(deserialized_layer.activation.__name__, "lisht")
70-
7158

7259
if __name__ == "__main__":
7360
tf.test.main()

tensorflow_addons/activations/sparsemax.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
@keras_utils.register_keras_custom_object
2626
@tf.function
27-
def sparsemax(logits, axis=-1, name=None):
27+
def sparsemax(logits, axis=-1):
2828
"""Sparsemax activation function [1].
2929
3030
For each batch `i` and class `j` we have
@@ -35,7 +35,6 @@ def sparsemax(logits, axis=-1, name=None):
3535
Args:
3636
logits: Input tensor.
3737
axis: Integer, axis along which the sparsemax operation is applied.
38-
name: A name for the operation (optional).
3938
Returns:
4039
Tensor, output of sparsemax transformation. Has the same type and
4140
shape as `logits`.
@@ -50,7 +49,7 @@ def sparsemax(logits, axis=-1, name=None):
5049
is_last_axis = (axis == -1) or (axis == rank - 1)
5150

5251
if is_last_axis:
53-
output = _compute_2d_sparsemax(logits, name=name)
52+
output = _compute_2d_sparsemax(logits)
5453
output.set_shape(shape)
5554
return output
5655

@@ -64,8 +63,7 @@ def sparsemax(logits, axis=-1, name=None):
6463

6564
# Do the actual softmax on its last dimension.
6665
output = _compute_2d_sparsemax(logits)
67-
output = _swap_axis(
68-
output, axis_norm, tf.math.subtract(rank_op, 1), name=name)
66+
output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1))
6967

7068
# Make shape inference work since transpose may erase its static shape.
7169
output.set_shape(shape)
@@ -82,7 +80,7 @@ def _swap_axis(logits, dim_index, last_index, **kwargs):
8280

8381

8482
@tf.function
85-
def _compute_2d_sparsemax(logits, name=None):
83+
def _compute_2d_sparsemax(logits):
8684
"""Performs the sparsemax operation when axis=-1."""
8785
shape_op = tf.shape(logits)
8886
obs = tf.math.reduce_prod(shape_op[:-1])
@@ -134,5 +132,5 @@ def _compute_2d_sparsemax(logits, name=None):
134132
logits.dtype)), p)
135133

136134
# Reshape back to original size
137-
p_safe = tf.reshape(p_safe, shape_op, name=name)
135+
p_safe = tf.reshape(p_safe, shape_op)
138136
return p_safe

tensorflow_addons/activations/sparsemax_test.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -274,20 +274,6 @@ def test_gradient_against_estimate(self, dtype=None):
274274
lambda logits: sparsemax(logits), [z], delta=1e-6)
275275
self.assertAllCloseAccordingToType(jacob_sym, jacob_num)
276276

277-
def test_serialization(self, dtype=None):
278-
ref_fn = sparsemax
279-
config = tf.keras.activations.serialize(ref_fn)
280-
fn = tf.keras.activations.deserialize(config)
281-
self.assertEqual(fn, ref_fn)
282-
283-
def test_serialization_with_layers(self, dtype=None):
284-
layer = tf.keras.layers.Dense(3, activation=sparsemax)
285-
config = tf.keras.layers.serialize(layer)
286-
deserialized_layer = tf.keras.layers.deserialize(config)
287-
self.assertEqual(deserialized_layer.__class__.__name__,
288-
layer.__class__.__name__)
289-
self.assertEqual(deserialized_layer.activation.__name__, "sparsemax")
290-
291277

292278
if __name__ == '__main__':
293279
tf.test.main()

tensorflow_addons/activations/tanhshrink_test.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,27 @@
2525
from tensorflow_addons.utils import test_utils
2626

2727

28-
def _ref_tanhshrink(x):
29-
return x - tf.tanh(x)
30-
31-
3228
@test_utils.run_all_in_graph_and_eager_modes
3329
class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase):
3430
@parameterized.named_parameters(("float16", np.float16),
3531
("float32", np.float32),
3632
("float64", np.float64))
3733
def test_tanhshrink(self, dtype):
38-
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
39-
self.assertAllCloseAccordingToType(tanhshrink(x), _ref_tanhshrink(x))
34+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
35+
expected_result = tf.constant(
36+
[-1.0359724, -0.23840582, 0.0, 0.23840582, 1.0359724], dtype=dtype)
4037

41-
@parameterized.named_parameters(("float16", np.float16),
42-
("float32", np.float32),
38+
self.assertAllCloseAccordingToType(tanhshrink(x), expected_result)
39+
40+
@parameterized.named_parameters(("float32", np.float32),
4341
("float64", np.float64))
44-
def test_gradients(self, dtype):
45-
x = tf.constant([1.0, 2.0, 3.0], dtype=dtype)
46-
with tf.GradientTape(persistent=True) as tape:
47-
tape.watch(x)
48-
y_ref = _ref_tanhshrink(x)
49-
y = tanhshrink(x)
50-
grad_ref = tape.gradient(y_ref, x)
51-
grad = tape.gradient(y, x)
52-
self.assertAllCloseAccordingToType(grad, grad_ref)
53-
54-
def test_serialization(self):
55-
ref_fn = tanhshrink
56-
config = tf.keras.activations.serialize(ref_fn)
57-
fn = tf.keras.activations.deserialize(config)
58-
self.assertEqual(fn, ref_fn)
42+
def test_theoretical_gradients(self, dtype):
43+
# Only test theoretical gradients for float32 and float64
44+
# because of the instability of float16 while computing jacobian
45+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
46+
47+
theoretical, numerical = tf.test.compute_gradient(tanhshrink, [x])
48+
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
5949

6050

6151
if __name__ == "__main__":

tensorflow_addons/custom_ops/activations/cc/kernels/gelu_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,5 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GELU_GPU_KERNELS);
7575

7676
#endif // GOOGLE_CUDA
7777

78-
} // end namespace addons
79-
} // namespace tensorflow
78+
} // namespace addons
79+
} // namespace tensorflow

0 commit comments

Comments
 (0)