Skip to content

Commit 2f1123f

Browse files
Merge pull request #766 from AayushSabharwal/as/fix-threading
fix: ensure equivalent expressions on different tasks are compared properly
2 parents 4621a4c + 127accf commit 2f1123f

File tree

4 files changed

+29
-18
lines changed

4 files changed

+29
-18
lines changed

src/cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The key stored in the cache for a particular value. Returns a `SymbolicKey` for
2727
# can't dispatch because `BasicSymbolic` isn't defined here
2828
function get_cache_key(x)
2929
if x isa BasicSymbolic
30-
id = x.id
30+
id = x.id[2]
3131
if id === nothing
3232
return CacheSentinel()
3333
end

src/types.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ const ROArgsT = ReadOnlyVector{Any, ArgsT}
3535
const ACDict{K, V} = Dict{K, V}
3636
const ShapeVecT = SmallV{UnitRange{Int}}
3737
const ShapeT = Union{Unknown, ShapeVecT}
38-
const IdentT = Union{IDType, Nothing}
38+
const IdentT = Union{Tuple{UInt, IDType}, Tuple{Nothing, Nothing}}
3939

4040
"""
4141
Enum used to differentiate between variants of `BasicSymbolicImpl.ACTerm`.
@@ -171,22 +171,22 @@ override_properties(obj::BSImpl.Type) = override_properties(MData.variant_type(o
171171

172172
function override_properties(obj::Type{<:BSImpl.Variant})
173173
@match obj begin
174-
::Type{<:BSImpl.Sym} => (; id = nothing, hash2 = 0)
175-
::Type{<:BSImpl.Term} => (; id = nothing, hash = 0, hash2 = 0)
176-
::Type{<:BSImpl.AddOrMul} => (; id = nothing, hash = 0, hash2 = 0)
177-
::Type{<:BSImpl.Div} => (; id = nothing, hash2 = 0)
178-
::Type{<:BSImpl.Pow} => (; id = nothing, hash2 = 0)
174+
::Type{<:BSImpl.Sym} => (; id = (nothing, nothing), hash2 = 0)
175+
::Type{<:BSImpl.Term} => (; id = (nothing, nothing), hash = 0, hash2 = 0)
176+
::Type{<:BSImpl.AddOrMul} => (; id = (nothing, nothing), hash = 0, hash2 = 0)
177+
::Type{<:BSImpl.Div} => (; id = (nothing, nothing), hash2 = 0)
178+
::Type{<:BSImpl.Pow} => (; id = (nothing, nothing), hash2 = 0)
179179
_ => throw(UnimplementedForVariantError(override_properties, obj))
180180
end
181181
end
182182

183183
function ordered_override_properties(obj::Type{<:BSImpl.Variant})
184184
@match obj begin
185-
::Type{<:BSImpl.Sym} => (0, nothing)
186-
::Type{<:BSImpl.Term} => (0, 0, nothing)
187-
::Type{<:BSImpl.AddOrMul} => (ArgsT(), 0, 0, nothing)
188-
::Type{<:BSImpl.Div} => (0, nothing)
189-
::Type{<:BSImpl.Pow} => (0, nothing)
185+
::Type{<:BSImpl.Sym} => (0, (nothing, nothing))
186+
::Type{<:BSImpl.Term} => (0, 0, (nothing, nothing))
187+
::Type{<:BSImpl.AddOrMul} => (ArgsT(), 0, 0, (nothing, nothing))
188+
::Type{<:BSImpl.Div} => (0, (nothing, nothing))
189+
::Type{<:BSImpl.Pow} => (0, (nothing, nothing))
190190
_ => throw(UnimplementedForVariantError(override_properties, obj))
191191
end
192192
end
@@ -410,8 +410,8 @@ end
410410

411411
function isequal_bsimpl(a::BSImpl.Type, b::BSImpl.Type, full)
412412
a === b && return true
413-
ida = a.id
414-
idb = b.id
413+
taskida, ida = a.id
414+
taskidb, idb = b.id
415415
ida === idb && ida !== nothing && return true
416416
typeof(a) === typeof(b) || return false
417417

@@ -420,7 +420,7 @@ function isequal_bsimpl(a::BSImpl.Type, b::BSImpl.Type, full)
420420
Ta === Tb || return false
421421

422422

423-
if full && ida !== idb && ida !== nothing && idb !== nothing
423+
if full && ida !== idb && ida !== nothing && idb !== nothing && taskida == taskidb
424424
return false
425425
end
426426

@@ -596,9 +596,10 @@ const ENABLE_HASHCONSING = Ref(true)
596596
const WKD = TaskLocalValue{WeakKeyDict{BSImpl.Type, Nothing}}(WeakKeyDict{BSImpl.Type, Nothing})
597597
const WVD = TaskLocalValue{WeakValueDict{UInt, BSImpl.Type}}(WeakValueDict{UInt, BSImpl.Type})
598598
const WCS = TaskLocalValue{WeakCacheSet{BSImpl.Type}}(WeakCacheSet{BSImpl.Type})
599+
const TASK_ID = TaskLocalValue{UInt}(() -> rand(UInt))
599600

600601
function generate_id()
601-
return IDType()
602+
return (TASK_ID[], IDType())
602603
end
603604

604605
const TOTAL = TaskLocalValue{Int}(Returns(0))
@@ -651,7 +652,7 @@ function hashcons(s::BSImpl.Type{T})::BSImpl.Type{T} where {T}
651652
# cache[h] = s
652653
# k = s
653654
# end
654-
if k.id === nothing
655+
if k.id === (nothing, nothing)
655656
k.id = generate_id()
656657
end
657658
return k

test/basics.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,13 @@ end
448448
@test isequal(res, x)
449449
end
450450
end
451+
452+
@testset "Equivalent expressions across tasks are equal" begin
453+
@syms a
454+
task = Threads.@spawn @syms a
455+
a2 = only(fetch(task))
456+
@test isequal(a, a2)
457+
@test SymbolicUtils.@manually_scope SymbolicUtils.COMPARE_FULL => true begin
458+
isequal(a, a2)
459+
end
460+
end

test/hash_consing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ end
129129
@syms a b
130130
x1 = a + b
131131
x2 = a + b
132-
@test x1.id === nothing === x2.id
132+
@test x1.id === (nothing, nothing) === x2.id
133133
SymbolicUtils.ENABLE_HASHCONSING[] = true
134134
end
135135

0 commit comments

Comments
 (0)