Skip to content

Commit aa56fe5

Browse files
GhassenJedward-bot
authored andcommitted
Add a clipping option to KL divergence inputs for numerical stability.
PiperOrigin-RevId: 317889780
1 parent 990e3e7 commit aa56fe5

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

edward2/tensorflow/metrics.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,13 @@ def logit_kl_divergence(logits_1, logits_2):
171171
return tf.reduce_mean(vals)
172172

173173

174-
def kl_divergence(p, q):
174+
def kl_divergence(p, q, clip=True):
175175
"""Generalized KL divergence [1] for unnormalized distributions.
176176
177177
Args:
178178
p: tf.Tensor.
179-
q: tf.Tensor
179+
q: tf.Tensor.
180+
clip: bool.
180181
181182
Returns:
182183
tf.Tensor of the Kullback-Leibler divergences per example.
@@ -187,7 +188,10 @@ def kl_divergence(p, q):
187188
matrix factorization." Advances in neural information processing systems.
188189
2001.
189190
"""
190-
return tf.reduce_sum(p * tf.math.log(p / q) - p + q, axis=-1)
191+
if clip:
192+
p = tf.clip_by_value(p, tf.keras.backend.epsilon(), 1)
193+
q = tf.clip_by_value(q, tf.keras.backend.epsilon(), 1)
194+
return tf.reduce_sum(p * tf.math.log(p / q), axis=-1)
191195

192196

193197
def lp_distance(x, y, p=1):
@@ -229,7 +233,7 @@ def average_pairwise_diversity(probs, num_models, error=None):
229233
# TODO(ghassen): we could also return max and min pairwise metrics.
230234
average_disagreement = tf.reduce_mean(tf.stack(pairwise_disagreement))
231235
if error is not None:
232-
average_disagreement /= (error + tf.keras.backend.epsilon())
236+
average_disagreement /= (1 - error + tf.keras.backend.epsilon())
233237
average_kl_divergence = tf.reduce_mean(tf.stack(pairwise_kl_divergence))
234238
average_cosine_distance = tf.reduce_mean(tf.stack(pairwise_cosine_distance))
235239

0 commit comments

Comments
 (0)