diff --git a/base/client.jl b/base/client.jl index 8ff97328a74cb..59376fa43497c 100644 --- a/base/client.jl +++ b/base/client.jl @@ -373,6 +373,7 @@ function load_julia_startup() end const repl_hooks = [] +const _repl_hooks_lock = ReentrantLock() """ atreplinit(f) @@ -382,10 +383,11 @@ interactive sessions; this is useful to customize the interface. The argument of REPL object. This function should be called from within the `.julia/config/startup.jl` initialization file. """ -atreplinit(f::Function) = (pushfirst!(repl_hooks, f); nothing) +atreplinit(f::Function) = @lock _repl_hooks_lock (pushfirst!(repl_hooks, f); nothing) function __atreplinit(repl) - for f in repl_hooks + hooks = @lock _repl_hooks_lock copy(repl_hooks) + for f in hooks try f(repl) catch err @@ -557,7 +559,7 @@ function _start() empty!(ARGS) append!(ARGS, Core.ARGS) # clear any postoutput hooks that were saved in the sysimage - empty!(Base.postoutput_hooks) + @lock Base._postoutput_hooks_lock empty!(Base.postoutput_hooks) local ret = 0 try repl_was_requested = exec_options(JLOptions()) diff --git a/base/errorshow.jl b/base/errorshow.jl index 1ae98378ff542..f9ba1fae7bcb7 100644 --- a/base/errorshow.jl +++ b/base/errorshow.jl @@ -654,6 +654,7 @@ end # Vector{Any} storing tuples (sf::StackFrame, nrepetitions::Int), and the updater should # replace `sf` as needed. const update_stackframes_callback = Ref{Function}(identity) +const _update_stackframes_callback_lock = ReentrantLock() const STACKTRACE_MODULECOLORS = Iterators.Stateful(Iterators.cycle([:magenta, :cyan, :green, :yellow])) const STACKTRACE_FIXEDCOLORS = IdDict(Base => :light_black, Core => :light_black) @@ -938,7 +939,8 @@ function show_backtrace(io::IO, t::Vector; prefix = nothing) end # Allow external code to edit information in the frames (e.g. line numbers with Revise) - try invokelatest(update_stackframes_callback[], filtered) catch end + callback = @lock _update_stackframes_callback_lock update_stackframes_callback[] + try invokelatest(callback, filtered) catch end show_processed_backtrace(IOContext(io, :backtrace => true), filtered, nframes, repeated_cycles, max_nested_cycles; print_linebreaks = stacktrace_linebreaks(), prefix) nothing diff --git a/base/initdefs.jl b/base/initdefs.jl index 14b3d5d921083..dc93293337622 100644 --- a/base/initdefs.jl +++ b/base/initdefs.jl @@ -199,6 +199,7 @@ const ACTIVE_PROJECT = Ref{Union{String,Nothing}}(nothing) # Modify this only vi # Each should be a thunk, i.e., `f()`. To determine the current active project, # the thunk can query `Base.active_project()`. const active_project_callbacks = [] +const _active_project_callbacks_lock = ReentrantLock() function current_project(dir::AbstractString) # look for project file in current dir and parents @@ -362,7 +363,8 @@ Set the active `Project.toml` file to `projfile`. See also [`Base.active_project """ function set_active_project(projfile::Union{AbstractString,Nothing}) ACTIVE_PROJECT[] = projfile - for f in active_project_callbacks + callbacks = @lock _active_project_callbacks_lock copy(active_project_callbacks) + for f in callbacks try Base.invokelatest(f) catch @@ -475,12 +477,17 @@ end ## like atexit but runs after any requested output. ## any hooks saved in the sysimage are cleared in Base._start const postoutput_hooks = Callable[] +const _postoutput_hooks_lock = ReentrantLock() -postoutput(f::Function) = (pushfirst!(postoutput_hooks, f); nothing) +postoutput(f::Function) = @lock _postoutput_hooks_lock (pushfirst!(postoutput_hooks, f); nothing) function _postoutput() - while !isempty(postoutput_hooks) - f = popfirst!(postoutput_hooks) + while true + local f + @lock _postoutput_hooks_lock begin + isempty(postoutput_hooks) && return + f = popfirst!(postoutput_hooks) + end try f() catch ex diff --git a/base/loading.jl b/base/loading.jl index 2da72b7828ea0..6392968cdbfa6 100644 --- a/base/loading.jl +++ b/base/loading.jl @@ -1433,9 +1433,10 @@ end function run_package_callbacks(modkey::PkgId) run_extension_callbacks(modkey) assert_havelock(require_lock) + callbacks = @lock _package_callbacks_lock copy(package_callbacks) unlock(require_lock) try - for callback in package_callbacks + for callback in callbacks invokelatest(callback, modkey) end catch @@ -2228,10 +2229,12 @@ end # Callbacks take the form (mod::Base.PkgId) -> nothing. # WARNING: This is an experimental feature and might change later, without deprecation. const package_callbacks = Any[] +const _package_callbacks_lock = ReentrantLock() # to notify downstream consumers that a file has been included into a particular module # Callbacks take the form (mod::Module, filename::String) -> nothing # WARNING: This is an experimental feature and might change later, without deprecation. const include_callbacks = Any[] +const _include_callbacks_lock = ReentrantLock() # used to optionally track dependencies when requiring a module: const _concrete_dependencies = Pair{PkgId,UInt128}[] # these dependency versions are "set in stone", because they are explicitly loaded, and the process should try to avoid invalidating them @@ -2915,7 +2918,8 @@ Base.include # defined in Base.jl function _include(mapexpr::Function, mod::Module, _path::AbstractString) @noinline # Workaround for module availability in _simplify_include_frames path, prev = _include_dependency(mod, _path) - for callback in include_callbacks # to preserve order, must come before eval in include_string + callbacks = @lock _include_callbacks_lock copy(include_callbacks) + for callback in callbacks # to preserve order, must come before eval in include_string invokelatest(callback, mod, path) end code = read(path, String) diff --git a/base/methodshow.jl b/base/methodshow.jl index 1470303a01bbc..0627bc71eb341 100644 --- a/base/methodshow.jl +++ b/base/methodshow.jl @@ -126,6 +126,7 @@ end # Any function `f` stored here must be consistent with the signature # f(m::Method)::Tuple{Union{Symbol,String}, Union{Int32,Int64}} const methodloc_callback = Ref{Union{Function, Nothing}}(nothing) +const _methodloc_callback_lock = ReentrantLock() function fixup_stdlib_path(path::String) # The file defining Base.Sys gets included after this file is included so make sure @@ -150,9 +151,10 @@ end # This function does the method location updating function updated_methodloc(m::Method)::Tuple{String, Int32} file, line = m.file, m.line - if methodloc_callback[] !== nothing + callback = @lock _methodloc_callback_lock methodloc_callback[] + if callback !== nothing try - file, line = invokelatest(methodloc_callback[], m)::Tuple{Union{Symbol,String}, Union{Int32,Int64}} + file, line = invokelatest(callback, m)::Tuple{Union{Symbol,String}, Union{Int32,Int64}} catch end end diff --git a/test/threads.jl b/test/threads.jl index fa0b33a6352f3..7c03a3a0341f6 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -585,3 +585,251 @@ let once = OncePerTask{Int}(() -> error("expected")) @test_throws ErrorException("expected") once() @test_throws ErrorException("expected") once() end + +# Test thread safety of callback hooks +# NOTE: This test suite asserts thread-safety guarantees for callback registration +# and execution. Base.postoutput and Base.atreplinit are public APIs, but +# Base._postoutput and Base.__atreplinit are internal. We're testing that concurrent +# operations don't corrupt internal data structures (safety property), not specific +# visibility or ordering guarantees (correctness properties). +if Threads.nthreads() == 1 + @testset "Thread-safe callback hooks" skip=true begin end +else +@testset "Thread-safe callback hooks" begin + + # Test concurrent postoutput_hooks operations + @test begin + # Snapshot current state to restore later + saved_hooks = copy(Base.postoutput_hooks) + errors = Threads.Atomic{Int}(0) + n_tasks = min(100, Threads.nthreads() * 20) + success = true + + try + # Run multiple iterations to increase confidence + for iteration in 1:3 + gate = Threads.Event() + + tasks = [Threads.@spawn begin + wait(gate) # Wait for synchronized start + + # Race to add and execute callbacks + for op in 1:20 + try + if op % 2 == 0 + Base.postoutput(() -> nothing) + else + Base._postoutput() # Internal API + end + catch e + if isa(e, InterruptException) + rethrow() # Don't hide Ctrl-C + else + # Any exception indicates a problem + Threads.atomic_add!(errors, 1) + break + end + end + end + end for i in 1:n_tasks] + + notify(gate) # Release all tasks simultaneously + # Re-throw task exceptions so they count as failures + for t in tasks + try + fetch(t) + catch + Threads.atomic_add!(errors, 1) + end + end + + if errors[] > 0 + success = false + break + end + end + finally + # Restore original hooks + empty!(Base.postoutput_hooks) + append!(Base.postoutput_hooks, saved_hooks) + end + + # Contract: no corruption errors with proper locking + success && errors[] == 0 + end + + # Test concurrent repl_hooks operations + @test begin + # Snapshot current state + saved_hooks = copy(Base.repl_hooks) + errors = Threads.Atomic{Int}(0) + n_tasks = min(100, Threads.nthreads() * 20) + success = true + + try + for iteration in 1:3 + gate = Threads.Event() + + tasks = [Threads.@spawn begin + wait(gate) # Synchronized start + + # Race between adding and executing callbacks + for op in 1:20 + try + if op % 2 == 0 + Base.atreplinit(x -> nothing) + else + Base.__atreplinit(nothing) # Internal API + end + catch e + if isa(e, InterruptException) + rethrow() + else + Threads.atomic_add!(errors, 1) + break + end + end + end + end for i in 1:n_tasks] + + notify(gate) + for t in tasks + try + fetch(t) + catch + Threads.atomic_add!(errors, 1) + end + end + + if errors[] > 0 + success = false + break + end + end + finally + # Restore original hooks + empty!(Base.repl_hooks) + append!(Base.repl_hooks, saved_hooks) + end + + success && errors[] == 0 + end + + # Test mixed concurrent operations with correctness check + @test begin + # Snapshot state + saved_post = copy(Base.postoutput_hooks) + saved_repl = copy(Base.repl_hooks) + errors = Threads.Atomic{Int}(0) + callback_calls = Threads.Atomic{Int}(0) + n_tasks = min(100, Threads.nthreads() * 20) + + try + # Add a counting callback to verify execution + Base.postoutput(() -> Threads.atomic_add!(callback_calls, 1)) + + gate = Threads.Event() + + tasks = [Threads.@spawn begin + wait(gate) + + # Mix different callback operations + for op in 1:30 + try + if op % 4 == 0 + Base.postoutput(() -> nothing) + elseif op % 4 == 1 + Base._postoutput() # Internal API + elseif op % 4 == 2 + Base.atreplinit(x -> nothing) + else + Base.__atreplinit(nothing) # Internal API + end + catch e + if isa(e, InterruptException) + rethrow() + else + Threads.atomic_add!(errors, 1) + break + end + end + end + end for i in 1:n_tasks] + + notify(gate) + for t in tasks + try + fetch(t) + catch + Threads.atomic_add!(errors, 1) + end + end + + # Execute remaining callbacks to finalize count (bounded) + local max_runs = 5 * (length(Base.postoutput_hooks) + 1) + for _ in 1:max_runs + isempty(Base.postoutput_hooks) && break + Base._postoutput() + end + finally + # Restore original state + empty!(Base.postoutput_hooks) + append!(Base.postoutput_hooks, saved_post) + empty!(Base.repl_hooks) + append!(Base.repl_hooks, saved_repl) + end + + # Contract: no corruption, and callback was executed at least once + errors[] == 0 && callback_calls[] > 0 + end + + # Test aggressive concurrent modifications + @test begin + saved_hooks = copy(Base.postoutput_hooks) + errors = Threads.Atomic{Int}(0) + n_tasks = min(100, Threads.nthreads() * 20) + + try + gate = Threads.Event() + + tasks = [Threads.@spawn begin + wait(gate) + + # Rapid operations without yielding to maximize contention + for op in 1:20 + try + if op % 2 == 0 + Base.postoutput(() -> nothing) + else + Base._postoutput() + end + catch e + if isa(e, InterruptException) + rethrow() + else + Threads.atomic_add!(errors, 1) + break + end + end + end + end for i in 1:n_tasks] + + notify(gate) + for t in tasks + try + fetch(t) + catch + Threads.atomic_add!(errors, 1) + end + end + finally + # Restore original hooks + empty!(Base.postoutput_hooks) + append!(Base.postoutput_hooks, saved_hooks) + end + + # Should have no corruption errors with proper locking + errors[] == 0 + end +end +end