@@ -171,12 +171,13 @@ def logit_kl_divergence(logits_1, logits_2):
171171 return tf .reduce_mean (vals )
172172
173173
174- def kl_divergence (p , q ):
174+ def kl_divergence (p , q , clip = True ):
175175 """Generalized KL divergence [1] for unnormalized distributions.
176176
177177 Args:
178178 p: tf.Tensor.
179- q: tf.Tensor
179+ q: tf.Tensor.
180+ clip: bool.
180181
181182 Returns:
182183 tf.Tensor of the Kullback-Leibler divergences per example.
@@ -187,7 +188,10 @@ def kl_divergence(p, q):
187188 matrix factorization." Advances in neural information processing systems.
188189 2001.
189190 """
190- return tf .reduce_sum (p * tf .math .log (p / q ) - p + q , axis = - 1 )
191+ if clip :
192+ p = tf .clip_by_value (p , tf .keras .backend .epsilon (), 1 )
193+ q = tf .clip_by_value (q , tf .keras .backend .epsilon (), 1 )
194+ return tf .reduce_sum (p * tf .math .log (p / q ), axis = - 1 )
191195
192196
193197def lp_distance (x , y , p = 1 ):
@@ -229,7 +233,7 @@ def average_pairwise_diversity(probs, num_models, error=None):
229233 # TODO(ghassen): we could also return max and min pairwise metrics.
230234 average_disagreement = tf .reduce_mean (tf .stack (pairwise_disagreement ))
231235 if error is not None :
232- average_disagreement /= (error + tf .keras .backend .epsilon ())
236+ average_disagreement /= (1 - error + tf .keras .backend .epsilon ())
233237 average_kl_divergence = tf .reduce_mean (tf .stack (pairwise_kl_divergence ))
234238 average_cosine_distance = tf .reduce_mean (tf .stack (pairwise_cosine_distance ))
235239
0 commit comments