@@ -83,25 +83,30 @@ def update_leaf_distributions(
83
83
"""
84
84
batch_size , num_classes = logits .shape
85
85
86
- y_true_one_hot = F .one_hot (y_true , num_classes = num_classes )
87
- y_true_logits = torch .log (y_true_one_hot )
86
+ # Other sparse formats may be better than COO.
87
+ # TODO: This is a bit convoluted. Why is there no sparse version of torch.nn.functional.one_hot ?
88
+ y_true_range = torch .arange (0 , batch_size )
89
+ y_true_indices = torch .stack ((y_true_range , y_true ))
90
+ y_true_one_hot = torch .sparse_coo_tensor (
91
+ y_true_indices , torch .ones_like (y_true , dtype = torch .bool ), logits .shape
92
+ )
88
93
89
94
for leaf in root .leaves :
90
- update_leaf (leaf , node_to_prob , logits , y_true_logits , smoothing_factor )
95
+ update_leaf (leaf , node_to_prob , logits , y_true_one_hot , smoothing_factor )
91
96
92
97
93
98
def update_leaf (
94
99
leaf : Leaf ,
95
100
node_to_prob : dict [Node , NodeProbabilities ],
96
101
logits : torch .Tensor ,
97
- y_true_logits : torch .Tensor ,
102
+ y_true_one_hot : torch .Tensor ,
98
103
smoothing_factor : float ,
99
104
):
100
105
"""
101
106
:param leaf:
102
107
:param node_to_prob:
103
108
:param logits: of shape (batch_size, num_classes)
104
- :param y_true_logits: of shape (batch_size, num_classes)
109
+ :param y_true_one_hot: boolean tensor of shape (batch_size, num_classes)
105
110
:param smoothing_factor:
106
111
:return:
107
112
"""
@@ -110,15 +115,15 @@ def update_leaf(
110
115
# shape (num_classes). Not the same as logits, which has (batch_size, num_classes)
111
116
leaf_logits = leaf .y_logits ()
112
117
113
- # TODO: y_true_logits is mostly -Inf terms (the rest being 0s) that won't contribute to the total, and we are also
114
- # summing together tensors of different shapes. We should be able to express this more clearly and efficiently by
115
- # taking advantage of this sparsity.
116
- log_dist_update = torch . logsumexp (
117
- log_p_arrival + leaf_logits + y_true_logits - logits ,
118
- dim = 0 ,
119
- )
118
+ masked_logits = logits . sparse_mask ( y_true_one_hot )
119
+ masked_log_p_arrival = y_true_one_hot * log_p_arrival
120
+ masked_leaf_logits = y_true_one_hot * leaf_logits
121
+ masked_log_combined = masked_log_p_arrival + masked_leaf_logits - masked_logits
122
+
123
+ # TODO: Can't use logsumexp because masked tensors don't support it.
124
+ masked_dist_update = torch . logsumexp ( masked_log_combined , dim = 0 )
120
125
121
- dist_update = torch . exp ( log_dist_update )
126
+ dist_update = masked_dist_update . to_tensor ( 0.0 )
122
127
123
128
# This scaling (subtraction of `-1/n_batches * c` in the ProtoTree paper) seems to be a form of exponentially
124
129
# weighted moving average, designed to ensure stability of the leaf class probability distributions (
0 commit comments