-
Notifications
You must be signed in to change notification settings - Fork 1
Description
Problem
The leaf optimization implements the following pseudocode for a given leaf (with a few extra things like exponential smoothing and softmaxes):
leaf_probabilities[j] = sum[over i] ( p_arrival[i] * leaf_probabilities[j] * y_true[i, j] / y_predicted[i, j] ),
where i
indexes the B
datapoints in the minibatch and j
indexes the C
classes. In log probabilities (which is what is currently in the code) this is
leaf_log_probabilities[j] = logsumexp[over i] ( log_p_arrival[i] + leaf_logits[j] + y_true_logits[i, j] - y_predicted_logits[i, j] ).
Note that y_true
is a one hot encoding, and so for most j
there will only be a small number of i
s for which y_true[i, j] == 1 != 0
and hence anything is contributed to the sum[over i]
. Similarly, for the log probabilities, most values of y_true_logits[i, j]
will be non-contributing -Inf
s, and only a few 0
s will contribute to the logsumexp[over i]
.
The current code does not take advantage of this structure, y_true_logits
is a dense float64 (B, C)
tensor where most of the entries are -Inf
, and the overall result is computed by broadcasting everything up to (B, C)
. This is not efficient; y_true_logits
has only B
0
entries, and y_predicted_logits
is already calculated and shared between all the leaves (we only need to access [i, j]
where y_true_logits == 0
), so we should only need O(B+C)
calculations. These redundant calculations and tensors full of -Inf
s also make the code a bit confusing.
Possible Solutions
There are a few ways we could write the efficient/cleaner algorithm.
- With Python for loops: This code would be slow (both due to pure Python and no GPU acceleration), which seems to defeat the point of trying to optimize this.
- With PyTorch (or other library) sparse masks: Unfortunately, PyTorch (and other libraries) don't seem to have sufficient support for sparse tensors or dense masks, let alone sparse masks; most of the calculations we want to do are very convoluted/infeasible right now. Even a PyTorch implementation with dense masks runs into problems due to no
logsumexp
implementation (and so we have to useexp
andsum
and risk floating point inaccuracies), and due to the lower efficiency of the masked dense tensors it's roughly 5x slower than the current non-masked dense version. An example of what we're trying to achieve is in Use sparse masks for derivative free leaf optimization #14. We could of course try adding the desired functionality to PyTorch, but this seems like it would require a huge amount of work. Sparse and masked tensors are still in beta and prototype status in PyTorch, respectively, so it's possible that this option will become much easier in the future as the PyTorch developers add more functionality. - With custom C++/Rust/CUDA code for our specific algorithm: This seems feasible, but could still require a fair bit of work. We'd need to come up with a fairly efficient implementation for both CPU and GPU devices.
Overall, that the current code is still moderately readable, and that the leaf optimization is only 1-2% of training time, it doesn't seem worthwhile to spend time trying to optimize this code now given all of the issues with these approaches.
Out of scope
Note: Any further optimization or vectorization between leaves is not considered here, i.e. each leaf has its own separate tensor.