diff --git a/include/stdexec/__detail/__when_all.hpp b/include/stdexec/__detail/__when_all.hpp index ab3b92fb5..845fd9e65 100644 --- a/include/stdexec/__detail/__when_all.hpp +++ b/include/stdexec/__detail/__when_all.hpp @@ -183,18 +183,42 @@ namespace stdexec { struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ { }; - template + template + struct __forward_stop_request { + void operator()() const noexcept { + // Temporarily increment the count to avoid concurrent/recursive arrivals to + // pull the rug under our feet. Relaxed memory order is fine here. + __state_->__count_.fetch_add(1, __std::memory_order_relaxed); + + __state_t __expected = __started; + // Transition to the "stopped" state if and only if we're in the + // "started" state. (If this fails, it's because we're in an + // error state, which trumps cancellation.) + if (__state_->__state_.compare_exchange_strong(__expected, __stopped)) { + __state_->__stop_source_.request_stop(); + } + + // Arrive in order to decrement the count again and complete if needed. + __state_->__arrive(*__rcvr_); + } + + _State* __state_; + _Receiver* __rcvr_; + }; + + template struct __when_all_state { - using __stop_callback_t = stop_callback_for_t<_StopToken, __forward_stop_request>; + using __stop_callback_t = stop_callback_for_t< + stop_token_of_t>, + __forward_stop_request<__when_all_state, _Receiver> + >; - template void __arrive(_Receiver& __rcvr) noexcept { - if (1 == __count_.fetch_sub(1)) { + if (1 == __count_.fetch_sub(1, __std::memory_order_acq_rel)) { __complete(__rcvr); } } - template void __complete(_Receiver& __rcvr) noexcept { // Stop callback is no longer needed. Destroy it. __on_stop_.reset(); @@ -283,24 +307,25 @@ namespace stdexec { } }; - template - static auto __mk_state_fn(const _Env&) noexcept { - return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore, _Child&&...) { - using _Traits = __traits<_Env, _Child...>; + template + static auto __mk_state_fn(const _Receiver&) noexcept { + using __env_of_t = env_of_t<_Receiver>; + return []<__max1_sender<__env_t<__env_of_t>>... _Child>(__ignore, __ignore, _Child&&...) { + using _Traits = __traits<__env_of_t, _Child...>; using _ErrorsVariant = _Traits::__errors_variant; using _ValuesTuple = _Traits::__values_tuple; using _State = __when_all_state< _ErrorsVariant, _ValuesTuple, - stop_token_of_t<_Env>, - (sends_stopped<_Child, _Env> || ...) + _Receiver, + (sends_stopped<_Child, __env_of_t> || ...) >; return _State{sizeof...(_Child)}; }; } - template - using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>())); + template + using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Receiver>())); struct when_all_t { template @@ -340,9 +365,9 @@ namespace stdexec { static constexpr auto get_state = [](_Self&& __self, _Receiver& __rcvr) - -> __sexpr_apply_result_t<_Self, __mk_state_fn_t>> { + -> __sexpr_apply_result_t<_Self, __mk_state_fn_t<_Receiver>> { return __sexpr_apply( - static_cast<_Self&&>(__self), __when_all::__mk_state_fn(stdexec::get_env(__rcvr))); + static_cast<_Self&&>(__self), __when_all::__mk_state_fn(__rcvr)); }; static constexpr auto start = []( @@ -351,7 +376,7 @@ namespace stdexec { _Operations&... __child_ops) noexcept -> void { // register stop callback: __state.__on_stop_.emplace( - get_stop_token(stdexec::get_env(__rcvr)), __forward_stop_request{__state.__stop_source_}); + get_stop_token(stdexec::get_env(__rcvr)), __forward_stop_request<_State, _Receiver>{&__state, &__rcvr}); (stdexec::start(__child_ops), ...); if constexpr (sizeof...(__child_ops) == 0) { __state.__complete(__rcvr); diff --git a/test/stdexec/algos/adaptors/test_when_all.cpp b/test/stdexec/algos/adaptors/test_when_all.cpp index 8d40d54f5..d2a46a078 100644 --- a/test/stdexec/algos/adaptors/test_when_all.cpp +++ b/test/stdexec/algos/adaptors/test_when_all.cpp @@ -17,8 +17,10 @@ #include #include #include +#include #include #include +#include #include namespace ex = stdexec; @@ -367,4 +369,15 @@ namespace { auto op = ex::connect(snd, expect_void_receiver{}); ex::start(op); } + + + + TEST_CASE("when_all handles stop requests from the environment correctly", "[adaptors][when_all") { + auto snd = ex::when_all(completes_if(false), completes_if(false)); + + exec::async_scope scope; + scope.spawn(snd); + scope.request_stop(); + ex::sync_wait(scope.on_empty()); + } } // namespace