Skip to content

Commit 4450b61

Browse files
Initial example
1 parent f93d950 commit 4450b61

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

src/prototree/train.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,25 +83,30 @@ def update_leaf_distributions(
8383
"""
8484
batch_size, num_classes = logits.shape
8585

86-
y_true_one_hot = F.one_hot(y_true, num_classes=num_classes)
87-
y_true_logits = torch.log(y_true_one_hot)
86+
# Other sparse formats may be better than COO.
87+
# TODO: This is a bit convoluted. Why is there no sparse version of torch.nn.functional.one_hot ?
88+
y_true_range = torch.arange(0, batch_size)
89+
y_true_indices = torch.stack((y_true_range, y_true))
90+
y_true_one_hot = torch.sparse_coo_tensor(
91+
y_true_indices, torch.ones_like(y_true, dtype=torch.bool), logits.shape
92+
)
8893

8994
for leaf in root.leaves:
90-
update_leaf(leaf, node_to_prob, logits, y_true_logits, smoothing_factor)
95+
update_leaf(leaf, node_to_prob, logits, y_true_one_hot, smoothing_factor)
9196

9297

9398
def update_leaf(
9499
leaf: Leaf,
95100
node_to_prob: dict[Node, NodeProbabilities],
96101
logits: torch.Tensor,
97-
y_true_logits: torch.Tensor,
102+
y_true_one_hot: torch.Tensor,
98103
smoothing_factor: float,
99104
):
100105
"""
101106
:param leaf:
102107
:param node_to_prob:
103108
:param logits: of shape (batch_size, num_classes)
104-
:param y_true_logits: of shape (batch_size, num_classes)
109+
:param y_true_one_hot: boolean tensor of shape (batch_size, num_classes)
105110
:param smoothing_factor:
106111
:return:
107112
"""
@@ -110,15 +115,15 @@ def update_leaf(
110115
# shape (num_classes). Not the same as logits, which has (batch_size, num_classes)
111116
leaf_logits = leaf.y_logits()
112117

113-
# TODO: y_true_logits is mostly -Inf terms (the rest being 0s) that won't contribute to the total, and we are also
114-
# summing together tensors of different shapes. We should be able to express this more clearly and efficiently by
115-
# taking advantage of this sparsity.
116-
log_dist_update = torch.logsumexp(
117-
log_p_arrival + leaf_logits + y_true_logits - logits,
118-
dim=0,
119-
)
118+
masked_logits = logits.sparse_mask(y_true_one_hot)
119+
masked_log_p_arrival = y_true_one_hot * log_p_arrival
120+
masked_leaf_logits = y_true_one_hot * leaf_logits
121+
masked_log_combined = masked_log_p_arrival + masked_leaf_logits - masked_logits
122+
123+
# TODO: Can't use logsumexp because masked tensors don't support it.
124+
masked_dist_update = torch.logsumexp(masked_log_combined, dim=0)
120125

121-
dist_update = torch.exp(log_dist_update)
126+
dist_update = masked_dist_update.to_tensor(0.0)
122127

123128
# This scaling (subtraction of `-1/n_batches * c` in the ProtoTree paper) seems to be a form of exponentially
124129
# weighted moving average, designed to ensure stability of the leaf class probability distributions (

0 commit comments

Comments
 (0)