Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions include/stdexec/__detail/__when_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,42 @@ namespace stdexec {

struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ { };

template <class _ErrorsVariant, class _ValuesTuple, class _StopToken, bool _SendsStopped>
template <class _State, class _Receiver>
struct __forward_stop_request {
void operator()() const noexcept {
// Temporarily increment the count to avoid concurrent/recursive arrivals to
// pull the rug under our feet. Relaxed memory order is fine here.
__state_->__count_.fetch_add(1, __std::memory_order_relaxed);

__state_t __expected = __started;
// Transition to the "stopped" state if and only if we're in the
// "started" state. (If this fails, it's because we're in an
// error state, which trumps cancellation.)
if (__state_->__state_.compare_exchange_strong(__expected, __stopped)) {
__state_->__stop_source_.request_stop();
}

// Arrive in order to decrement the count again and complete if needed.
__state_->__arrive(*__rcvr_);
}

_State* __state_;
_Receiver* __rcvr_;
};

template <class _ErrorsVariant, class _ValuesTuple, class _Receiver, bool _SendsStopped>
struct __when_all_state {
using __stop_callback_t = stop_callback_for_t<_StopToken, __forward_stop_request>;
using __stop_callback_t = stop_callback_for_t<
stop_token_of_t<env_of_t<_Receiver>>,
__forward_stop_request<__when_all_state, _Receiver>
>;

template <class _Receiver>
void __arrive(_Receiver& __rcvr) noexcept {
if (1 == __count_.fetch_sub(1)) {
if (1 == __count_.fetch_sub(1, __std::memory_order_acq_rel)) {
__complete(__rcvr);
}
}

template <class _Receiver>
void __complete(_Receiver& __rcvr) noexcept {
// Stop callback is no longer needed. Destroy it.
__on_stop_.reset();
Expand Down Expand Up @@ -283,24 +307,25 @@ namespace stdexec {
}
};

template <class _Env>
static auto __mk_state_fn(const _Env&) noexcept {
return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore, _Child&&...) {
using _Traits = __traits<_Env, _Child...>;
template <class _Receiver>
static auto __mk_state_fn(const _Receiver&) noexcept {
using __env_of_t = env_of_t<_Receiver>;
return []<__max1_sender<__env_t<__env_of_t>>... _Child>(__ignore, __ignore, _Child&&...) {
using _Traits = __traits<__env_of_t, _Child...>;
using _ErrorsVariant = _Traits::__errors_variant;
using _ValuesTuple = _Traits::__values_tuple;
using _State = __when_all_state<
_ErrorsVariant,
_ValuesTuple,
stop_token_of_t<_Env>,
(sends_stopped<_Child, _Env> || ...)
_Receiver,
(sends_stopped<_Child, __env_of_t> || ...)
>;
return _State{sizeof...(_Child)};
};
}

template <class _Env>
using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>()));
template <class _Receiver>
using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Receiver>()));

struct when_all_t {
template <sender... _Senders>
Expand Down Expand Up @@ -340,9 +365,9 @@ namespace stdexec {

static constexpr auto get_state =
[]<class _Self, class _Receiver>(_Self&& __self, _Receiver& __rcvr)
-> __sexpr_apply_result_t<_Self, __mk_state_fn_t<env_of_t<_Receiver>>> {
-> __sexpr_apply_result_t<_Self, __mk_state_fn_t<_Receiver>> {
return __sexpr_apply(
static_cast<_Self&&>(__self), __when_all::__mk_state_fn(stdexec::get_env(__rcvr)));
static_cast<_Self&&>(__self), __when_all::__mk_state_fn(__rcvr));
};

static constexpr auto start = []<class _State, class _Receiver, class... _Operations>(
Expand All @@ -351,7 +376,7 @@ namespace stdexec {
_Operations&... __child_ops) noexcept -> void {
// register stop callback:
__state.__on_stop_.emplace(
get_stop_token(stdexec::get_env(__rcvr)), __forward_stop_request{__state.__stop_source_});
get_stop_token(stdexec::get_env(__rcvr)), __forward_stop_request<_State, _Receiver>{&__state, &__rcvr});
(stdexec::start(__child_ops), ...);
if constexpr (sizeof...(__child_ops) == 0) {
__state.__complete(__rcvr);
Expand Down
13 changes: 13 additions & 0 deletions test/stdexec/algos/adaptors/test_when_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#include <catch2/catch.hpp>
#include <stdexec/execution.hpp>
#include <exec/env.hpp>
#include <exec/async_scope.hpp>
#include <test_common/schedulers.hpp>
#include <test_common/receivers.hpp>
#include <test_common/senders.hpp>
#include <test_common/type_helpers.hpp>

namespace ex = stdexec;
Expand Down Expand Up @@ -367,4 +369,15 @@ namespace {
auto op = ex::connect(snd, expect_void_receiver{});
ex::start(op);
}



TEST_CASE("when_all handles stop requests from the environment correctly", "[adaptors][when_all") {
auto snd = ex::when_all(completes_if(false), completes_if(false));

exec::async_scope scope;
scope.spawn(snd);
scope.request_stop();
ex::sync_wait(scope.on_empty());
}
} // namespace
Loading