diff --git a/include/stdexec/__detail/__let.hpp b/include/stdexec/__detail/__let.hpp index 7ffd3e890..3a7d36701 100644 --- a/include/stdexec/__detail/__let.hpp +++ b/include/stdexec/__detail/__let.hpp @@ -33,6 +33,8 @@ #include "__variant.hpp" #include +#include +#include namespace stdexec { ////////////////////////////////////////////////////////////////////////////// @@ -282,32 +284,90 @@ namespace stdexec { //! The core of the operation state for `let_*`. //! This gets bundled up into a larger operation state (`__detail::__op_state<...>`). - template + template struct __let_state { - using __fun_t = _Fun; - using __env2_t = _Env2; - using __env_t = __join_env_t<_Env2, env_of_t<_Receiver>>; - using __rcvr_t = __receiver_with_env_t<_Receiver, _Env2>; + using __env2_t = __let::__env2_t<_SetTag, env_of_t, env_of_t>; + using __second_rcvr_t = __receiver_with_env_t<_Receiver, __env2_t>; + struct __first_rcvr_t { + using receiver_concept = ::stdexec::receiver_t; + __let_state& __state; + _Receiver& __rcvr; + template + constexpr void __impl(_Tag __tag, _Args&&... __args) noexcept { + if constexpr (std::is_same_v<_SetTag, _Tag>) { + constexpr bool __nothrow_store = (__nothrow_decay_copyable<_Args> && ...); + constexpr bool __nothrow_invoke = __nothrow_callable<_Fun, __decay_t<_Args>&...>; + using __sender_t = __call_result_t<_Fun, __decay_t<_Args>&...>; + using __submit_t = __submit_result<__sender_t, __env2_t, _Receiver>; + constexpr bool __nothrow_submit = noexcept( + __state.__storage_.template emplace<__submit_t>( + std::declval<__sender_t>(), std::declval<__second_rcvr_t>())); + constexpr bool __nothrow = __nothrow_store && __nothrow_invoke && __nothrow_submit; + const auto __impl = [&]() noexcept(__nothrow) { + auto& __tuple = + __state.__args_.emplace_from(__tup::__mktuple, static_cast<_Args&&>(__args)...); + auto&& __sender = __tuple.apply(static_cast<_Fun&&>(__state.__fun_), __tuple); + __state.__storage_.template emplace<__monostate>(); + __second_rcvr_t __r{__rcvr, static_cast<__env2_t&&>(__state.__env2_)}; + auto& __op = __state.__storage_.template emplace<__submit_t>( + static_cast<__sender_t&&>(__sender), static_cast<__second_rcvr_t&&>(__r)); + __op.submit(static_cast<__sender_t&&>(__sender), static_cast<__second_rcvr_t&&>(__r)); + }; + if constexpr (__nothrow) { + __impl(); + } else { + STDEXEC_TRY { + __impl(); + } + STDEXEC_CATCH_ALL { + ::stdexec::set_error(static_cast<_Receiver&&>(__rcvr), std::current_exception()); + } + } + } else { + __tag(static_cast<_Receiver&&>(__rcvr), static_cast<_Args&&>(__args)...); + } + } + template + constexpr void set_value(_Args&&... __args) noexcept { + __impl(::stdexec::set_value, static_cast<_Args&&>(__args)...); + } + template + constexpr void set_error(_Args&&... __args) noexcept { + __impl(::stdexec::set_error, static_cast<_Args&&>(__args)...); + } + template + constexpr void set_stopped(_Args&&... __args) noexcept { + __impl(::stdexec::set_stopped, static_cast<_Args&&>(__args)...); + } + constexpr decltype(auto) get_env() const noexcept { + return ::stdexec::get_env(__rcvr); + } + }; + using __result_variant = __variant_for<__monostate, _Tuples...>; - using __submit_variant = __variant_for< + using __op_state_variant = __variant_for< __monostate, - __mapply<__submit_datum_for<_Receiver, _Fun, _SetTag, _Env2>, _Tuples>... - >; - - template - auto __get_result_receiver(const _ResultSender&, _OpState& __op_state) -> decltype(auto) { - return __rcvr_t{__op_state.__rcvr_, __env2_}; + ::stdexec::connect_result_t<_Sender, __first_rcvr_t>, + __mapply<__submit_datum_for<_Receiver, _Fun, _SetTag, __env2_t>, _Tuples>...>; + + constexpr explicit __let_state(_Sender&& __sender, _Fun __fun, _Receiver& __r) noexcept( + __nothrow_connectable<_Sender, __first_rcvr_t> + && std::is_nothrow_move_constructible_v<_Fun>) + : __fun_(static_cast<_Fun&&>(__fun)) + , __env2_( + __let::__mk_env2<_SetTag>(::stdexec::get_env(__sender), ::stdexec::get_env(__r))) { + __storage_.emplace_from( + ::stdexec::connect, static_cast<_Sender&&>(__sender), __first_rcvr_t{*this, __r}); } STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS _Fun __fun_; STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS - _Env2 __env2_; + __env2_t __env2_; //! Variant to hold the results passed from upstream before passing them to the function: __result_variant __args_{}; - //! Variant type for holding the operation state from connecting - //! the function result to the downstream receiver: - __submit_variant __storage_{}; + //! Variant type for holding the operation state of the currently in flight operation + __op_state_variant __storage_{}; }; // The set_value completions of: @@ -504,6 +564,17 @@ namespace stdexec { } }; + template + struct __data_t { + _Sender __sndr; + _Fun __fun; + }; + + template + using __sender_of = decltype((std::declval<__data_of<_Sender>>().__sndr)); + template + using __fun_of = decltype((std::declval<__data_of<_Sender>>().__fun)); + //! Implementation of the `let_*_t` types, where `_SetTag` is, e.g., `set_value_t` for `let_value`. template struct __let_t { @@ -512,7 +583,7 @@ namespace stdexec { template auto operator()(_Sender&& __sndr, _Fun __fun) const -> __well_formed_sender auto { return __make_sexpr<__let_t<_SetTag>>( - static_cast<_Fun&&>(__fun), static_cast<_Sender&&>(__sndr)); + __data_t{static_cast<_Sender&&>(__sndr), static_cast<_Fun&&>(__fun)}); } template @@ -524,21 +595,22 @@ namespace stdexec { template struct __let_impl : __sexpr_defaults { - static constexpr auto get_attrs = []( - const _Fun&, - [[maybe_unused]] - const _Child& __child) noexcept { - // BUGBUG: - return stdexec::get_env(__child); - //return __attrs<__let_t<_SetTag>, _Child, _Fun>{}; - }; + static constexpr auto get_attrs = + [](const __data_t<_Child, _Fun>& __data) noexcept { + // BUGBUG: + return stdexec::get_env(__data.__sndr); + //return __attrs<__let_t<_SetTag>, _Child, _Fun>{}; + }; static constexpr auto get_completion_signatures = [](_Self&&, _Env&&...) noexcept { static_assert(sender_expr_for<_Self, __let_t<_SetTag>>); if constexpr (__decay_copyable<_Self>) { - using __result_t = - __completions_t<__let_t<_SetTag>, __data_of<_Self>, __child_of<_Self>, _Env>; + using __result_t = __completions_t< + __let_t<_SetTag>, + std::remove_cvref_t<__fun_of<_Self>>, + __sender_of<_Self>, + _Env>; return __result_t{}; } else { return __mexception<_SENDER_TYPE_IS_NOT_COPYABLE_, _WITH_SENDER_<_Self>>{}; @@ -546,92 +618,27 @@ namespace stdexec { }; static constexpr auto get_state = - [](_Sender&& __sndr, const _Receiver& __rcvr) - requires sender_in<__child_of<_Sender>, env_of_t<_Receiver>> + [](_Sender&& __sndr, _Receiver& __rcvr) + requires sender_in<__sender_of<_Sender>, env_of_t<_Receiver>> { - static_assert(sender_expr_for<_Sender, __let_t<_SetTag>>); - using _Fun = __decay_t<__data_of<_Sender>>; - using _Child = __child_of<_Sender>; - using _Env2 = __env2_t<_SetTag, env_of_t, env_of_t>; - using __mk_let_state = __mbind_front_q<__let_state, _Receiver, _Fun, _SetTag, _Env2>; - + using _Child = __sender_of<_Sender>; + using _Fun = __decay_t<__fun_of<_Sender>>; + using __mk_let_state = __mbind_front_q<__let_state, _SetTag, _Child, _Fun, _Receiver>; using __let_state_t = __gather_completions_of< _SetTag, _Child, env_of_t<_Receiver>, __q<__decayed_tuple>, - __mk_let_state - >; - - return __sndr.apply( - static_cast<_Sender&&>(__sndr), - [&](__ignore, _Fn&& __fn, _Child&& __child) { - // TODO(ericniebler): this needs a fallback - _Env2 __env2 = - __let::__mk_env2<_SetTag>(stdexec::get_env(__child), stdexec::get_env(__rcvr)); - return __let_state_t{static_cast<_Fn&&>(__fn), static_cast<_Env2&&>(__env2)}; - }); + __mk_let_state>; + auto&& [__tag, __data] = static_cast<_Sender&&>(__sndr); + return __let_state_t( + __forward_like<_Sender>(__data).__sndr, __forward_like<_Sender>(__data).__fun, __rcvr); }; - //! Helper function to actually invoke the function to produce `let_*`'s sender, - //! connect it to the downstream receiver, and start it. This is the heart of - //! `let_*`. - template - static void __bind_(_State& __state, _OpState& __op_state, _As&&... __as) { - // Store the passed-in (received) args: - auto& __args = __state.__args_.emplace_from(__tup::__mktuple, static_cast<_As&&>(__as)...); - // Apply the function to the args to get the sender: - auto __sndr2 = __args.apply(std::move(__state.__fun_), __args); - // Create a receiver based on the state, the computed sender, and the operation state: - auto __rcvr2 = __state.__get_result_receiver(__sndr2, __op_state); - // Connect the sender to the receiver and start it: - using __result_t = decltype(submit_result{std::move(__sndr2), std::move(__rcvr2)}); - auto& __op = __state.__storage_ - .template emplace<__result_t>(std::move(__sndr2), std::move(__rcvr2)); - __op.submit(std::move(__sndr2), std::move(__rcvr2)); - } - - template - static void __bind(_OpState& __op_state, _As&&... __as) noexcept { - using _State = decltype(__op_state.__state_); - using _Receiver = decltype(__op_state.__rcvr_); - using _Fun = _State::__fun_t; - using _Env2 = _State::__env2_t; - using _JoinEnv2 = __join_env_t<_Env2, env_of_t<_Receiver>>; - using _ResultSender = __mcall<__result_sender_fn<_SetTag, _Fun, _JoinEnv2>, _As...>; - - _State& __state = __op_state.__state_; - _Receiver& __rcvr = __op_state.__rcvr_; - - if constexpr ( - (__nothrow_decay_copyable<_As> && ...) && __nothrow_callable<_Fun, __decay_t<_As>&...> - && __nothrow_connectable<_ResultSender, __result_receiver_t<_Receiver, _Env2>>) { - __bind_(__state, __op_state, static_cast<_As&&>(__as)...); - } else { - STDEXEC_TRY { - __bind_(__state, __op_state, static_cast<_As&&>(__as)...); - } - STDEXEC_CATCH_ALL { - using _Receiver = decltype(__op_state.__rcvr_); - stdexec::set_error(static_cast<_Receiver&&>(__rcvr), std::current_exception()); - } - } - } - - static constexpr auto complete = []( - __ignore, - _OpState& __op_state, - _Tag, - _As&&... __as) noexcept -> void { - if constexpr (__same_as<_Tag, _SetTag>) { - // Intercept the channel of interest to compute the sender and connect it: - __bind(__op_state, static_cast<_As&&>(__as)...); - } else { - // Forward the other channels downstream: - using _Receiver = decltype(__op_state.__rcvr_); - _Tag()(static_cast<_Receiver&&>(__op_state.__rcvr_), static_cast<_As&&>(__as)...); - } - }; + static constexpr auto start = + [](_State& __state, _Receiver&) noexcept { + ::stdexec::start(__state.__storage_.template get<1>()); + }; }; } // namespace __let diff --git a/test/stdexec/algos/adaptors/test_let_value.cpp b/test/stdexec/algos/adaptors/test_let_value.cpp index 37d9be995..31cb21371 100644 --- a/test/stdexec/algos/adaptors/test_let_value.cpp +++ b/test/stdexec/algos/adaptors/test_let_value.cpp @@ -396,4 +396,52 @@ namespace { ex::start(op); CHECK(*ptr == 5); } + + TEST_CASE( + "let_value destroys the first operation state before invoking the sender factory", + "[adaptors][let_value]") { + const auto ptr = std::make_shared(5); + CHECK(ptr.use_count() == 1); + auto first = ex::just() | ex::then([ptr = ptr]() { }); + CHECK(ptr.use_count() == 2); + auto sender = ex::let_value(std::move(first), [&]() { + CHECK(ptr.use_count() == 2); + return ex::just(); + }); + CHECK(ptr.use_count() == 2); + auto op = ex::connect(std::move(sender), expect_void_receiver{}); + CHECK(ptr.use_count() == 2); + ex::start(op); + CHECK(ptr.use_count() == 1); + } + + struct immovable_sender { + using sender_concept = ::stdexec::sender_t; + template + consteval auto get_completion_signatures(const Args&...) const & noexcept { + return ::stdexec::completion_signatures_of_t{}; + } + template + auto connect(Receiver r) const & noexcept { + return ::stdexec::connect(::stdexec::just(), std::move(r)); + } + immovable_sender() = default; + immovable_sender(const immovable_sender&) { + throw std::logic_error("Unexpected copy"); + } + }; + static_assert(::stdexec::sender); + static_assert(::stdexec::sender); + static_assert(::stdexec::sender_in>); + static_assert(::stdexec::sender_in>); + + TEST_CASE( + "If the sender factory returns a reference to a sender that reference is passed to connect", + "[adaptors][let_value]") { + const immovable_sender s; + auto just = ex::just(); + auto sender = ex::let_value(just, [&]() -> decltype(auto) { return (s); }); + auto op = ex::connect(sender, expect_void_receiver{}); + ex::start(op); + } } // namespace