Skip to content

Commit f148e62

Browse files
Improve comments
1 parent c4c46cf commit f148e62

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

prototree/train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def update_leaf_distributions(
8484
# y_true_range = torch.arange(0, batch_size)
8585
# y_true_indices = torch.stack((y_true_range, y_true))
8686
# y_true_one_hot = torch.sparse_coo_tensor(y_true_indices,
87-
# torch.ones_like(y_true, dtype=torch.bool), logits.shape) # Or other more suitable sparse format,
88-
# or even better,
87+
# torch.ones_like(y_true, dtype=torch.bool), logits.shape) # Might be better to use CSR or CSC
88+
# or better still,
8989
# y_true_one_hot = F.sparse_one_hot(y_true, num_classes=num_classes, dtype=torch.bool),
9090
# but PyTorch doesn't yet have sufficient sparse mask support for the logic in update_leaf to work.
9191
y_true_one_hot = F.one_hot(y_true, num_classes=num_classes).to(dtype=torch.bool)
@@ -113,6 +113,12 @@ def update_leaf(
113113
log_p_arrival = node_to_prob[leaf].log_p_arrival.unsqueeze(1)
114114
# shape (num_classes). Not the same as logits, which has (batch_size, num_classes)
115115
leaf_logits = leaf.logits()
116+
117+
# TODO If PyTorch had more support for sparse masks we might be able to do something like
118+
# masked_logits = logits.sparse_mask(y_true_one_hot),
119+
# and perhaps if necessary combine it with
120+
# masked_log_p_arrival = y_true_one_hot * log_p_arrival # sparse_mask can't broadcast
121+
# masked_leaf_logits = y_true_one_hot * leaf_logits # sparse_mask can't broadcast.
116122
masked_logits = masked_tensor(logits, y_true_one_hot)
117123

118124
masked_log_combined = log_p_arrival + (leaf_logits - masked_logits)

0 commit comments

Comments
 (0)