Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions base/client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ function load_julia_startup()
end

const repl_hooks = []
const _repl_hooks_lock = ReentrantLock()

"""
atreplinit(f)
Expand All @@ -382,10 +383,11 @@ interactive sessions; this is useful to customize the interface. The argument of
REPL object. This function should be called from within the `.julia/config/startup.jl`
initialization file.
"""
atreplinit(f::Function) = (pushfirst!(repl_hooks, f); nothing)
atreplinit(f::Function) = @lock _repl_hooks_lock (pushfirst!(repl_hooks, f); nothing)

function __atreplinit(repl)
for f in repl_hooks
hooks = @lock _repl_hooks_lock copy(repl_hooks)
for f in hooks
try
f(repl)
catch err
Expand Down Expand Up @@ -557,7 +559,7 @@ function _start()
empty!(ARGS)
append!(ARGS, Core.ARGS)
# clear any postoutput hooks that were saved in the sysimage
empty!(Base.postoutput_hooks)
@lock Base._postoutput_hooks_lock empty!(Base.postoutput_hooks)
local ret = 0
try
repl_was_requested = exec_options(JLOptions())
Expand Down
4 changes: 3 additions & 1 deletion base/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ end
# Vector{Any} storing tuples (sf::StackFrame, nrepetitions::Int), and the updater should
# replace `sf` as needed.
const update_stackframes_callback = Ref{Function}(identity)
const _update_stackframes_callback_lock = ReentrantLock()

const STACKTRACE_MODULECOLORS = Iterators.Stateful(Iterators.cycle([:magenta, :cyan, :green, :yellow]))
const STACKTRACE_FIXEDCOLORS = IdDict(Base => :light_black, Core => :light_black)
Expand Down Expand Up @@ -938,7 +939,8 @@ function show_backtrace(io::IO, t::Vector; prefix = nothing)
end

# Allow external code to edit information in the frames (e.g. line numbers with Revise)
try invokelatest(update_stackframes_callback[], filtered) catch end
callback = @lock _update_stackframes_callback_lock update_stackframes_callback[]
try invokelatest(callback, filtered) catch end

show_processed_backtrace(IOContext(io, :backtrace => true), filtered, nframes, repeated_cycles, max_nested_cycles; print_linebreaks = stacktrace_linebreaks(), prefix)
nothing
Expand Down
15 changes: 11 additions & 4 deletions base/initdefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ const ACTIVE_PROJECT = Ref{Union{String,Nothing}}(nothing) # Modify this only vi
# Each should be a thunk, i.e., `f()`. To determine the current active project,
# the thunk can query `Base.active_project()`.
const active_project_callbacks = []
const _active_project_callbacks_lock = ReentrantLock()

function current_project(dir::AbstractString)
# look for project file in current dir and parents
Expand Down Expand Up @@ -362,7 +363,8 @@ Set the active `Project.toml` file to `projfile`. See also [`Base.active_project
"""
function set_active_project(projfile::Union{AbstractString,Nothing})
ACTIVE_PROJECT[] = projfile
for f in active_project_callbacks
callbacks = @lock _active_project_callbacks_lock copy(active_project_callbacks)
for f in callbacks
try
Base.invokelatest(f)
catch
Expand Down Expand Up @@ -475,12 +477,17 @@ end
## like atexit but runs after any requested output.
## any hooks saved in the sysimage are cleared in Base._start
const postoutput_hooks = Callable[]
const _postoutput_hooks_lock = ReentrantLock()

postoutput(f::Function) = (pushfirst!(postoutput_hooks, f); nothing)
postoutput(f::Function) = @lock _postoutput_hooks_lock (pushfirst!(postoutput_hooks, f); nothing)

function _postoutput()
while !isempty(postoutput_hooks)
f = popfirst!(postoutput_hooks)
while true
local f
@lock _postoutput_hooks_lock begin
isempty(postoutput_hooks) && return
f = popfirst!(postoutput_hooks)
end
try
f()
catch ex
Expand Down
8 changes: 6 additions & 2 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1433,9 +1433,10 @@ end
function run_package_callbacks(modkey::PkgId)
run_extension_callbacks(modkey)
assert_havelock(require_lock)
callbacks = @lock _package_callbacks_lock copy(package_callbacks)
unlock(require_lock)
try
for callback in package_callbacks
for callback in callbacks
invokelatest(callback, modkey)
end
catch
Expand Down Expand Up @@ -2228,10 +2229,12 @@ end
# Callbacks take the form (mod::Base.PkgId) -> nothing.
# WARNING: This is an experimental feature and might change later, without deprecation.
const package_callbacks = Any[]
const _package_callbacks_lock = ReentrantLock()
# to notify downstream consumers that a file has been included into a particular module
# Callbacks take the form (mod::Module, filename::String) -> nothing
# WARNING: This is an experimental feature and might change later, without deprecation.
const include_callbacks = Any[]
const _include_callbacks_lock = ReentrantLock()

# used to optionally track dependencies when requiring a module:
const _concrete_dependencies = Pair{PkgId,UInt128}[] # these dependency versions are "set in stone", because they are explicitly loaded, and the process should try to avoid invalidating them
Expand Down Expand Up @@ -2915,7 +2918,8 @@ Base.include # defined in Base.jl
function _include(mapexpr::Function, mod::Module, _path::AbstractString)
@noinline # Workaround for module availability in _simplify_include_frames
path, prev = _include_dependency(mod, _path)
for callback in include_callbacks # to preserve order, must come before eval in include_string
callbacks = @lock _include_callbacks_lock copy(include_callbacks)
for callback in callbacks # to preserve order, must come before eval in include_string
invokelatest(callback, mod, path)
end
code = read(path, String)
Expand Down
6 changes: 4 additions & 2 deletions base/methodshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ end
# Any function `f` stored here must be consistent with the signature
# f(m::Method)::Tuple{Union{Symbol,String}, Union{Int32,Int64}}
const methodloc_callback = Ref{Union{Function, Nothing}}(nothing)
const _methodloc_callback_lock = ReentrantLock()

function fixup_stdlib_path(path::String)
# The file defining Base.Sys gets included after this file is included so make sure
Expand All @@ -150,9 +151,10 @@ end
# This function does the method location updating
function updated_methodloc(m::Method)::Tuple{String, Int32}
file, line = m.file, m.line
if methodloc_callback[] !== nothing
callback = @lock _methodloc_callback_lock methodloc_callback[]
if callback !== nothing
try
file, line = invokelatest(methodloc_callback[], m)::Tuple{Union{Symbol,String}, Union{Int32,Int64}}
file, line = invokelatest(callback, m)::Tuple{Union{Symbol,String}, Union{Int32,Int64}}
catch
end
end
Expand Down
248 changes: 248 additions & 0 deletions test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,251 @@
@test_throws ErrorException("expected") once()
@test_throws ErrorException("expected") once()
end

# Test thread safety of callback hooks
# NOTE: This test suite asserts thread-safety guarantees for callback registration
# and execution. Base.postoutput and Base.atreplinit are public APIs, but
# Base._postoutput and Base.__atreplinit are internal. We're testing that concurrent
# operations don't corrupt internal data structures (safety property), not specific
# visibility or ordering guarantees (correctness properties).
if Threads.nthreads() == 1
@testset "Thread-safe callback hooks" skip=true begin end
else
@testset "Thread-safe callback hooks" begin

Check warning on line 599 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
# Test concurrent postoutput_hooks operations
@test begin
# Snapshot current state to restore later
saved_hooks = copy(Base.postoutput_hooks)
errors = Threads.Atomic{Int}(0)
n_tasks = min(100, Threads.nthreads() * 20)
success = true

Check warning on line 607 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
try
# Run multiple iterations to increase confidence
for iteration in 1:3
gate = Threads.Event()

Check warning on line 612 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
tasks = [Threads.@spawn begin
wait(gate) # Wait for synchronized start

Check warning on line 615 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
# Race to add and execute callbacks
for op in 1:20
try
if op % 2 == 0
Base.postoutput(() -> nothing)
else
Base._postoutput() # Internal API
end
catch e
if isa(e, InterruptException)
rethrow() # Don't hide Ctrl-C
else
# Any exception indicates a problem
Threads.atomic_add!(errors, 1)
break
end
end
end
end for i in 1:n_tasks]

Check warning on line 635 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
notify(gate) # Release all tasks simultaneously
# Re-throw task exceptions so they count as failures
for t in tasks
try
fetch(t)
catch
Threads.atomic_add!(errors, 1)
end
end

Check warning on line 645 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
if errors[] > 0
success = false
break
end
end
finally
# Restore original hooks
empty!(Base.postoutput_hooks)
append!(Base.postoutput_hooks, saved_hooks)
end

Check warning on line 656 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
# Contract: no corruption errors with proper locking
success && errors[] == 0
end

Check warning on line 660 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
# Test concurrent repl_hooks operations
@test begin
# Snapshot current state
saved_hooks = copy(Base.repl_hooks)
errors = Threads.Atomic{Int}(0)
n_tasks = min(100, Threads.nthreads() * 20)
success = true

Check warning on line 668 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
try
for iteration in 1:3
gate = Threads.Event()

Check warning on line 672 in test/threads.jl

View workflow job for this annotation

GitHub Actions / Check whitespace

Whitespace check

trailing whitespace
tasks = [Threads.@spawn begin
wait(gate) # Synchronized start

# Race between adding and executing callbacks
for op in 1:20
try
if op % 2 == 0
Base.atreplinit(x -> nothing)
else
Base.__atreplinit(nothing) # Internal API
end
catch e
if isa(e, InterruptException)
rethrow()
else
Threads.atomic_add!(errors, 1)
break
end
end
end
end for i in 1:n_tasks]

notify(gate)
for t in tasks
try
fetch(t)
catch
Threads.atomic_add!(errors, 1)
end
end

if errors[] > 0
success = false
break
end
end
finally
# Restore original hooks
empty!(Base.repl_hooks)
append!(Base.repl_hooks, saved_hooks)
end

success && errors[] == 0
end

# Test mixed concurrent operations with correctness check
@test begin
# Snapshot state
saved_post = copy(Base.postoutput_hooks)
saved_repl = copy(Base.repl_hooks)
errors = Threads.Atomic{Int}(0)
callback_calls = Threads.Atomic{Int}(0)
n_tasks = min(100, Threads.nthreads() * 20)

try
# Add a counting callback to verify execution
Base.postoutput(() -> Threads.atomic_add!(callback_calls, 1))

gate = Threads.Event()

tasks = [Threads.@spawn begin
wait(gate)

# Mix different callback operations
for op in 1:30
try
if op % 4 == 0
Base.postoutput(() -> nothing)
elseif op % 4 == 1
Base._postoutput() # Internal API
elseif op % 4 == 2
Base.atreplinit(x -> nothing)
else
Base.__atreplinit(nothing) # Internal API
end
catch e
if isa(e, InterruptException)
rethrow()
else
Threads.atomic_add!(errors, 1)
break
end
end
end
end for i in 1:n_tasks]

notify(gate)
for t in tasks
try
fetch(t)
catch
Threads.atomic_add!(errors, 1)
end
end

# Execute remaining callbacks to finalize count (bounded)
local max_runs = 5 * (length(Base.postoutput_hooks) + 1)
for _ in 1:max_runs
isempty(Base.postoutput_hooks) && break
Base._postoutput()
end
finally
# Restore original state
empty!(Base.postoutput_hooks)
append!(Base.postoutput_hooks, saved_post)
empty!(Base.repl_hooks)
append!(Base.repl_hooks, saved_repl)
end

# Contract: no corruption, and callback was executed at least once
errors[] == 0 && callback_calls[] > 0
end

# Test aggressive concurrent modifications
@test begin
saved_hooks = copy(Base.postoutput_hooks)
errors = Threads.Atomic{Int}(0)
n_tasks = min(100, Threads.nthreads() * 20)

try
gate = Threads.Event()

tasks = [Threads.@spawn begin
wait(gate)

# Rapid operations without yielding to maximize contention
for op in 1:20
try
if op % 2 == 0
Base.postoutput(() -> nothing)
else
Base._postoutput()
end
catch e
if isa(e, InterruptException)
rethrow()
else
Threads.atomic_add!(errors, 1)
break
end
end
end
end for i in 1:n_tasks]

notify(gate)
for t in tasks
try
fetch(t)
catch
Threads.atomic_add!(errors, 1)
end
end
finally
# Restore original hooks
empty!(Base.postoutput_hooks)
append!(Base.postoutput_hooks, saved_hooks)
end

# Should have no corruption errors with proper locking
errors[] == 0
end
end
end
Loading