-
Notifications
You must be signed in to change notification settings - Fork 79
feat: add a stacked option to onehot #2421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2421 +/- ##
==========================================
+ Coverage 67.50% 75.18% +7.67%
==========================================
Files 31 56 +25
Lines 12668 16971 +4303
==========================================
+ Hits 8552 12759 +4207
- Misses 4116 4212 +96 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
1c09896 to
6cbf2c0
Compare
|
cc @wsmoses for a review |
| return res | ||
| end | ||
| stacked isa Val{false} && return ret | ||
| return stack(ret) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shold use a different implementation here. allocate a zero array of the whole size then just selectively set the one's in the identity (alternatively, is there already an eye or similar function we could use)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like mapslices(Base.Fix2(copyto!, I), x; dims=(1, 2)) (with inpalce and stuff) and then reshape. I am just unsure what that whole onehot_internal function above does.
If we make it a proper arg I can leave these ones unchanged
| @allowscalar @inbounds res[i + start - 1] = 1 | ||
| return res | ||
| end | ||
| stacked isa Val{false} && return ret |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment
|
|
||
| @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L} | ||
| ntuple(Val(L)) do i | ||
| @inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}; stacked::Union{Val{true}, Val{false}} = Val(false)) where {S, T, N, L} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also debating if we should make this a proper non kw argument (essentially because while the outer function is marked inline, I don't know a away to make the inner kwfunc inline [and we really do want that here]
No description provided.