@@ -84,8 +84,8 @@ def update_leaf_distributions(
84
84
# y_true_range = torch.arange(0, batch_size)
85
85
# y_true_indices = torch.stack((y_true_range, y_true))
86
86
# y_true_one_hot = torch.sparse_coo_tensor(y_true_indices,
87
- # torch.ones_like(y_true, dtype=torch.bool), logits.shape) # Or other more suitable sparse format,
88
- # or even better,
87
+ # torch.ones_like(y_true, dtype=torch.bool), logits.shape) # Might be better to use CSR or CSC
88
+ # or better still ,
89
89
# y_true_one_hot = F.sparse_one_hot(y_true, num_classes=num_classes, dtype=torch.bool),
90
90
# but PyTorch doesn't yet have sufficient sparse mask support for the logic in update_leaf to work.
91
91
y_true_one_hot = F .one_hot (y_true , num_classes = num_classes ).to (dtype = torch .bool )
@@ -113,6 +113,12 @@ def update_leaf(
113
113
log_p_arrival = node_to_prob [leaf ].log_p_arrival .unsqueeze (1 )
114
114
# shape (num_classes). Not the same as logits, which has (batch_size, num_classes)
115
115
leaf_logits = leaf .logits ()
116
+
117
+ # TODO If PyTorch had more support for sparse masks we might be able to do something like
118
+ # masked_logits = logits.sparse_mask(y_true_one_hot),
119
+ # and perhaps if necessary combine it with
120
+ # masked_log_p_arrival = y_true_one_hot * log_p_arrival # sparse_mask can't broadcast
121
+ # masked_leaf_logits = y_true_one_hot * leaf_logits # sparse_mask can't broadcast.
116
122
masked_logits = masked_tensor (logits , y_true_one_hot )
117
123
118
124
masked_log_combined = log_p_arrival + (leaf_logits - masked_logits )
0 commit comments