Skip to content

Commit 49446f4

Browse files
YingboMambauman
authored andcommitted
Inline the whole broadcast expression to avoid allocation (#30986)
* Inline the whole broadcast expression to avoid allocation `map` has an inline limit of 16. To make sure that the whole broadcast tree gets inlined properly, I added the `_inlined_map` function. I am not sure if it is a good idea, but worth trying. This PR solves the issue which I have mentioned in 2693778#issuecomment-461248258 ```julia julia> @allocated foo(tmp, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, k17, k18, k19, k20, k21, k22, k23, k24, k25, k26, k27, k28, k29, k30, k31, k32, k33, k34) 0 ``` * Fix CI failure * Stricter test (vectorization & no allocation for a 9-array bc) * rm `_inlined_map`
1 parent 4261ac8 commit 49446f4

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

base/broadcast.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ julia> Broadcast.combine_axes(1, 1, 1)
464464
()
465465
```
466466
"""
467-
@inline combine_axes(A, B...) = broadcast_shape(axes(A), combine_axes(B...))
467+
@inline combine_axes(A, B, C...) = broadcast_shape(axes(A), combine_axes(B, C...))
468+
@inline combine_axes(A, B) = broadcast_shape(axes(A), axes(B))
468469
combine_axes(A) = axes(A)
469470

470471
# shape (i.e., tuple-of-indices) inputs
@@ -502,7 +503,7 @@ function check_broadcast_shape(shp, Ashp::Tuple)
502503
_bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination"))
503504
check_broadcast_shape(tail(shp), tail(Ashp))
504505
end
505-
check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A))
506+
@inline check_broadcast_axes(shp, A) = check_broadcast_shape(shp, axes(A))
506507
# comparing many inputs
507508
@inline function check_broadcast_axes(shp, A, As...)
508509
check_broadcast_axes(shp, A)
@@ -911,7 +912,7 @@ _is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray{<:Any,1}) = axe
911912
_is_static_broadcast_28126(dest::AbstractArray, x::AbstractArray) = axes(dest) == axes(x) # This can be better with other missing dimensions
912913

913914
@inline _is_static_broadcast_28126_args(dest, args::Tuple) = _is_static_broadcast_28126(dest, args[1]) && _is_static_broadcast_28126_args(dest, tail(args))
914-
_is_static_broadcast_28126_args(dest, args::Tuple{Any}) = _is_static_broadcast_28126(dest, args[1])
915+
@inline _is_static_broadcast_28126_args(dest, args::Tuple{Any}) = _is_static_broadcast_28126(dest, args[1])
915916
_is_static_broadcast_28126_args(dest, args::Tuple{}) = true
916917

917918
struct _NonExtruded28126{T}

test/broadcast.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -791,14 +791,15 @@ let
791791
end
792792

793793
@testset "large fusions vectorize and don't allocate (#28126)" begin
794-
u, k1, k2, k3, k4, k5, k6, k7 = (ones(1000) for i in 1:8)
795-
function goo(u, k1, k2, k3, k4, k5, k6, k7)
796-
@. u = 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7)
794+
using InteractiveUtils: code_llvm
795+
u, uprev, k1, k2, k3, k4, k5, k6, k7 = (ones(1000) for i in 1:9)
796+
function goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)
797+
@. u = uprev + 0.1*(0.1*k1 + 0.2*k2 + 0.3*k3 + 0.4*k4 + 0.5*k5 + 0.6*k6 + 0.7*k7)
797798
nothing
798799
end
799-
@allocated goo(u, k1, k2, k3, k4, k5, k6, k7)
800-
@test @allocated(goo(u, k1, k2, k3, k4, k5, k6, k7)) == 0
801-
@test occursin("vector.body", sprint(code_llvm, goo, NTuple{8, Vector{Float32}}))
800+
@allocated goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)
801+
@test @allocated(goo(u, uprev, k1, k2, k3, k4, k5, k6, k7)) == 0
802+
@test occursin("vector.body", sprint(code_llvm, goo, NTuple{9, Vector{Float32}}))
802803
end
803804

804805
# Broadcasted iterable/indexable APIs

0 commit comments

Comments
 (0)