From 8f1e4cac16df187004bd6d9fb5c8ff969d9f47ae Mon Sep 17 00:00:00 2001 From: Howard Beard-Marlowe Date: Thu, 8 Oct 2020 20:49:22 +0300 Subject: [PATCH] Adds full support for handle_continue/2 to gen_stage * {:continue, _term} instructions can now be returned as one would expect from gen_server. * :hibernate is now supported on init similar to gen_server. --- lib/gen_stage.ex | 191 ++++++++++++++++++++++++++++++++----- test/gen_stage_test.exs | 203 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 372 insertions(+), 22 deletions(-) diff --git a/lib/gen_stage.ex b/lib/gen_stage.ex index 6cd25d8..c8890b8 100644 --- a/lib/gen_stage.ex +++ b/lib/gen_stage.ex @@ -885,11 +885,18 @@ defmodule GenStage do @callback init(args :: term) :: {:producer, state} + | {:producer, state, {:continue, term} | :hibernate} | {:producer, state, [producer_option]} + | {:producer, state, {:continue, term} | :hibernate, [producer_option]} | {:producer_consumer, state} + | {:producer_consumer, state, {:continue, term} | :hibernate} | {:producer_consumer, state, [producer_consumer_option]} + | {:producer_consumer, state, {:continue, term} | :hibernate, + [producer_consumer_option]} | {:consumer, state} + | {:consumer, state, {:continue, term} | :hibernate} | {:consumer, state, [consumer_option]} + | {:consumer, state, {:continue, term} | :hibernate, [consumer_option]} | :ignore | {:stop, reason :: any} when state: any @@ -925,6 +932,7 @@ defmodule GenStage do @callback handle_demand(demand :: pos_integer, state :: term) :: {:noreply, [event], new_state} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, new_state} when new_state: term, reason: term, event: term @@ -1004,6 +1012,7 @@ defmodule GenStage do ) :: {:noreply, [event], new_state} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, new_state} when event: term, new_state: term, reason: term @@ -1017,6 +1026,7 @@ defmodule GenStage do @callback handle_events(events :: [event], from, state :: term) :: {:noreply, [event], new_state} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, new_state} when new_state: term, reason: term, event: term @@ -1056,8 +1066,10 @@ defmodule GenStage do @callback handle_call(request :: term, from :: GenServer.from(), state :: term) :: {:reply, reply, [event], new_state} | {:reply, reply, [event], new_state, :hibernate} + | {:reply, reply, [event], new_state, {:continue, term}} | {:noreply, [event], new_state} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason, reply, new_state} | {:stop, reason, new_state} when reply: term, new_state: term, reason: term, event: term @@ -1086,6 +1098,7 @@ defmodule GenStage do @callback handle_cast(request :: term, state :: term) :: {:noreply, [event], new_state} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason :: term, new_state} when new_state: term, event: term @@ -1103,6 +1116,27 @@ defmodule GenStage do @callback handle_info(message :: term, state :: term) :: {:noreply, [event], new_state} | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} + | {:stop, reason :: term, new_state} + when new_state: term, event: term + + @doc """ + Invoked to handle `continue` instructions. + + It is useful for performing work after initialization or for splitting the work + in a callback in multiple steps, updating the process state along the way. + + Return values are the same as `c:handle_cast/2`. + + This callback is optional. If one is not implemented, the server will fail + if a continue instruction is used. + + This callback is only supported on Erlang/OTP 21+. + """ + @callback handle_continue(continue :: term, state :: term) :: + {:noreply, [event], new_state} + | {:noreply, [event], new_state, :hibernate} + | {:noreply, [event], new_state, {:continue, term}} | {:stop, reason :: term, new_state} when new_state: term, event: term @@ -1139,6 +1173,7 @@ defmodule GenStage do format_status: 2, handle_call: 3, handle_cast: 2, + handle_continue: 2, handle_info: 2, terminate: 2 ] @@ -1722,22 +1757,58 @@ defmodule GenStage do def init({mod, args}) do case mod.init(args) do {:producer, state} -> - init_producer(mod, [], state) + init_producer(mod, [], state, nil) + + {:producer, state, {:continue, _term} = continue} -> + init_producer(mod, [], state, continue) + + {:producer, state, :hibernate} -> + init_producer(mod, [], state, :hibernate) {:producer, state, opts} when is_list(opts) -> - init_producer(mod, opts, state) + init_producer(mod, opts, state, nil) + + {:producer, state, {:continue, _term} = continue, opts} when is_list(opts) -> + init_producer(mod, opts, state, continue) + + {:producer, state, :hibernate, opts} when is_list(opts) -> + init_producer(mod, opts, state, :hibernate) {:producer_consumer, state} -> - init_producer_consumer(mod, [], state) + init_producer_consumer(mod, [], state, nil) + + {:producer_consumer, state, {:continue, _term} = continue} -> + init_producer_consumer(mod, [], state, continue) + + {:producer_consumer, state, :hibernate} -> + init_producer_consumer(mod, [], state, :hibernate) {:producer_consumer, state, opts} when is_list(opts) -> - init_producer_consumer(mod, opts, state) + init_producer_consumer(mod, opts, state, nil) + + {:producer_consumer, state, {:continue, _term} = continue, opts} when is_list(opts) -> + init_producer_consumer(mod, opts, state, continue) + + {:producer_consumer, state, :hibernate, opts} when is_list(opts) -> + init_producer_consumer(mod, opts, state, :hibernate) {:consumer, state} -> - init_consumer(mod, [], state) + init_consumer(mod, [], state, nil) + + {:consumer, state, {:continue, _term} = continue} -> + init_consumer(mod, [], state, continue) + + {:consumer, state, :hibernate} -> + init_consumer(mod, [], state, :hibernate) {:consumer, state, opts} when is_list(opts) -> - init_consumer(mod, opts, state) + init_consumer(mod, opts, state, nil) + + {:consumer, state, {:continue, _term} = continue, opts} when is_list(opts) -> + init_consumer(mod, opts, state, continue) + + {:consumer, state, :hibernate, opts} when is_list(opts) -> + init_consumer(mod, opts, state, :hibernate) {:stop, _} = stop -> stop @@ -1750,7 +1821,7 @@ defmodule GenStage do end end - defp init_producer(mod, opts, state) do + defp init_producer(mod, opts, state, continue_or_hibernate) do with {:ok, dispatcher_mod, dispatcher_state, opts} <- init_dispatcher(opts), {:ok, buffer_size, opts} <- Utils.validate_integer(opts, :buffer_size, 10000, 0, :infinity, true), @@ -1770,7 +1841,7 @@ defmodule GenStage do dispatcher_state: dispatcher_state } - {:ok, stage} + if continue_or_hibernate, do: {:ok, stage, continue_or_hibernate}, else: {:ok, stage} else {:error, message} -> {:stop, {:bad_opts, message}} end @@ -1792,7 +1863,7 @@ defmodule GenStage do end end - defp init_producer_consumer(mod, opts, state) do + defp init_producer_consumer(mod, opts, state, continue_or_hibernate) do with {:ok, dispatcher_mod, dispatcher_state, opts} <- init_dispatcher(opts), {:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []), {:ok, buffer_size, opts} <- @@ -1811,22 +1882,68 @@ defmodule GenStage do dispatcher_state: dispatcher_state } - consumer_init_subscribe(subscribe_to, stage) + case handle_gen_server_init_args(continue_or_hibernate, stage) do + {:ok, stage} -> + consumer_init_subscribe(subscribe_to, stage) + + {:ok, stage, args} -> + {:ok, stage} = consumer_init_subscribe(subscribe_to, stage) + {:ok, stage, args} + + {:stop, _, _} = error -> + error + end else {:error, message} -> {:stop, {:bad_opts, message}} end end - defp init_consumer(mod, opts, state) do + defp init_consumer(mod, opts, state, continue_or_hibernate) do with {:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []), :ok <- Utils.validate_no_opts(opts) do stage = %GenStage{mod: mod, state: state, type: :consumer} - consumer_init_subscribe(subscribe_to, stage) + + case handle_gen_server_init_args(continue_or_hibernate, stage) do + {:ok, stage} -> + consumer_init_subscribe(subscribe_to, stage) + + {:ok, stage, args} -> + {:ok, stage} = consumer_init_subscribe(subscribe_to, stage) + {:ok, stage, args} + + {:stop, _, _} = error -> + error + end else {:error, message} -> {:stop, {:bad_opts, message}} end end + defp handle_gen_server_init_args({:continue, _term} = continue, stage) do + case handle_continue(continue, stage) do + {:noreply, stage} -> + {:ok, stage} + + {:noreply, stage, :hibernate} -> + {:ok, stage, :hibernate} + + {:noreply, stage, {:continue, _term} = continue} -> + {:ok, stage, continue} + + {:stop, reason, stage} -> + {:stop, reason, stage} + end + end + + defp handle_gen_server_init_args(:hibernate, stage), do: {:ok, stage, :hibernate} + defp handle_gen_server_init_args(nil, stage), do: {:ok, stage} + + @doc false + + def handle_continue(continue, %{state: state} = stage) do + noreply_callback(:handle_continue, [continue, state], stage) + end + @doc false def handle_call({:"$info", msg}, _from, stage) do @@ -1855,6 +1972,10 @@ defmodule GenStage do stage = dispatch_events(events, length(events), %{stage | state: state}) {:reply, reply, stage, :hibernate} + {:reply, reply, events, state, {:continue, _term} = continue} -> + stage = dispatch_events(events, length(events), %{stage | state: state}) + {:reply, reply, stage, continue} + {:stop, reason, reply, state} -> {:stop, reason, reply, %{stage | state: state}} @@ -1995,7 +2116,7 @@ defmodule GenStage do case producers do %{^ref => entry} -> {batches, stage} = consumer_receive(from, entry, events, stage) - consumer_dispatch(batches, from, mod, state, stage, false) + consumer_dispatch(batches, from, mod, state, stage, nil) _ -> msg = {:"$gen_producer", {self(), ref}, {:cancel, :unknown_subscription}} @@ -2122,6 +2243,14 @@ defmodule GenStage do end end + defp noreply_callback(:handle_continue, [continue, state], %{mod: mod} = stage) do + if function_exported?(mod, :handle_continue, 2) do + handle_noreply_callback(mod.handle_continue(continue, state), stage) + else + :error_handler.raise_undef_exception(mod, :handle_continue, [continue, state]) + end + end + defp noreply_callback(callback, args, %{mod: mod} = stage) do handle_noreply_callback(apply(mod, callback, args), stage) end @@ -2136,6 +2265,10 @@ defmodule GenStage do stage = dispatch_events(events, length(events), %{stage | state: state}) {:noreply, stage, :hibernate} + {:noreply, events, state, {:continue, _term} = continue} when is_list(events) -> + stage = dispatch_events(events, length(events), %{stage | state: state}) + {:noreply, stage, continue} + {:stop, reason, state} -> {:stop, reason, %{stage | state: state}} @@ -2259,6 +2392,9 @@ defmodule GenStage do # main module must know the consumer is no longer subscribed. dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage) + {:noreply, %{dispatcher_state: dispatcher_state} = stage, _hibernate_or_continue} -> + dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage) + {:stop, _, _} = stop -> stop end @@ -2459,17 +2595,22 @@ defmodule GenStage do {[{events, 0}], stage} end - defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _hibernate?) do + defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _gen_opts) do case mod.handle_events(batch, from, state) do {:noreply, events, state} when is_list(events) -> stage = dispatch_events(events, length(events), stage) ask(from, ask, [:noconnect]) - consumer_dispatch(batches, from, mod, state, stage, false) + consumer_dispatch(batches, from, mod, state, stage, nil) - {:noreply, events, state, :hibernate} when is_list(events) -> + {:noreply, events, state, :hibernate} -> stage = dispatch_events(events, length(events), stage) ask(from, ask, [:noconnect]) - consumer_dispatch(batches, from, mod, state, stage, true) + consumer_dispatch(batches, from, mod, state, stage, :hibernate) + + {:noreply, events, state, {:continue, _} = continue} -> + stage = dispatch_events(events, length(events), stage) + ask(from, ask, [:noconnect]) + consumer_dispatch(batches, from, mod, state, stage, continue) {:stop, reason, state} -> {:stop, reason, %{stage | state: state}} @@ -2479,12 +2620,12 @@ defmodule GenStage do end end - defp consumer_dispatch([], _from, _mod, state, stage, false) do + defp consumer_dispatch([], _from, _mod, state, stage, nil) do {:noreply, %{stage | state: state}} end - defp consumer_dispatch([], _from, _mod, state, stage, true) do - {:noreply, %{stage | state: state}, :hibernate} + defp consumer_dispatch([], _from, _mod, state, stage, gen_opts) do + {:noreply, %{stage | state: state}, gen_opts} end defp consumer_subscribe({to, opts}, stage) when is_list(opts), @@ -2613,11 +2754,11 @@ defmodule GenStage do {producer_id, _, _} = entry from = {producer_id, ref} {batches, stage} = consumer_receive(from, entry, events, stage) - consumer_dispatch(batches, from, mod, state, stage, false) + consumer_dispatch(batches, from, mod, state, stage, nil) %{} -> # We queued but producer was removed - consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, false) + consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, nil) end end @@ -2634,6 +2775,9 @@ defmodule GenStage do {:noreply, stage, :hibernate} -> take_pc_events(queue, counter, stage) + {:noreply, stage, {:continue, _term}} -> + take_pc_events(queue, counter, stage) + {:stop, _, _} = stop -> stop end @@ -2646,6 +2790,9 @@ defmodule GenStage do {:noreply, %{events: {queue, counter}} = stage, :hibernate} -> take_pc_events(queue, counter, stage) + {:noreply, %{events: {queue, counter}} = stage, {:continue, _term}} -> + take_pc_events(queue, counter, stage) + {:stop, _, _} = stop -> stop end diff --git a/test/gen_stage_test.exs b/test/gen_stage_test.exs index c05c62b..94fef2e 100644 --- a/test/gen_stage_test.exs +++ b/test/gen_stage_test.exs @@ -81,6 +81,109 @@ defmodule GenStageTest do events = Enum.to_list(counter..(counter + demand - 1)) {:noreply, events, counter + demand} end + + # Use continue instructions to modify the counter, this + # can be reached from any gen_server callback by supplying + # a continue instruction with an integer term. + def handle_continue(new_counter, _counter) when is_integer(new_counter) do + {:noreply, [], new_counter} + end + end + + defmodule CounterNestedContinue do + @moduledoc """ + A producer that works as a counter in batches. + It also supports events to be queued via sync + and async calls. A negative counter disables + the counting behaviour. + + This counter uses a nested handle_continue on init. + """ + + use GenStage + + def start_link(init, opts \\ []) do + GenStage.start_link(__MODULE__, init, opts) + end + + def sync_queue(stage, events) do + GenStage.call(stage, {:queue, events}) + end + + def async_queue(stage, events) do + GenStage.cast(stage, {:queue, events}) + end + + def stop(stage) do + GenStage.call(stage, :stop) + end + + ## Callbacks + + def init(init) do + init + end + + def handle_call(:stop, _from, state) do + {:stop, :shutdown, :ok, state} + end + + def handle_call({:early_reply_queue, events}, from, state) do + GenStage.reply(from, state) + {:noreply, events, state} + end + + def handle_call({:queue, events}, _from, state) do + {:reply, state, events, state} + end + + def handle_cast({:queue, events}, state) do + {:noreply, events, state} + end + + def handle_info({:queue, events}, state) do + {:noreply, events, state} + end + + def handle_info(other, state) do + is_pid(state) && send(state, other) + {:noreply, [], state} + end + + def handle_subscribe(:consumer, opts, from, state) do + is_pid(state) && send(state, {:producer_subscribed, from}) + {Keyword.get(opts, :producer_demand, :automatic), state} + end + + def handle_cancel(reason, from, state) do + is_pid(state) && send(state, {:producer_cancelled, from, reason}) + {:noreply, [], state} + end + + def handle_demand(demand, pid) when is_pid(pid) and demand > 0 do + {:noreply, [], pid} + end + + def handle_demand(demand, counter) when demand > 0 do + # If the counter is 3 and we ask for 2 items, we will + # emit the items 3 and 4, and set the state to 5. + events = Enum.to_list(counter..(counter + demand - 1)) + {:noreply, events, counter + demand} + end + + # Use continue instructions to modify the counter, this + # can be reached from any gen_server callback by supplying + # a continue instruction with an integer term. + # + # This particular handle_continue returns another continue instruction + # testing that we handle nested continues properly. + def handle_continue(500, _counter) do + {:noreply, [], 500, {:continue, 2000}} + end + + def handle_continue(2000, _counter) do + {:noreply, [], 2000} + end end defmodule DemandProducer do @@ -139,6 +242,11 @@ defmodule GenStageTest do is_pid(state) && send(state, other) {:noreply, [], state} end + + def handle_continue({:continue, term}, recipient) do + send(recipient, term) + {:noreply, [], recipient} + end end defmodule Postponer do @@ -258,6 +366,11 @@ defmodule GenStageTest do {:noreply, [], recipient} end + def handle_continue({:continue, term}, recipient) do + send(recipient, term) + {:noreply, [], recipient} + end + def terminate(reason, state) do send(state, {:terminated, reason}) end @@ -324,6 +437,96 @@ defmodule GenStageTest do } end + {otp_version, ""} = :otp_release |> :erlang.system_info() |> to_string() |> Integer.parse() + + if otp_version >= 21 do + describe "handle_continue tests" do + test "producing_init with continue instruction setting counter start position" do + {:ok, producer} = Counter.start_link({:producer, 0, {:continue, 500}}) + {:ok, _} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]}) + + batch = Enum.to_list(0..499) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(500..999) + assert_receive {:consumed, ^batch} + batch = Enum.to_list(1000..1499) + assert_receive {:consumed, ^batch} + end + + test "producer_init with nested continue instruction setting counter start position" do + {:ok, producer} = CounterNestedContinue.start_link({:producer, 0, {:continue, 500}}) + {:ok, _} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]}) + + # The nested continue sets the counter to 2000 + batch = Enum.to_list(0..499) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(500..999) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(1000..1499) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(1500..1999) + refute_receive {:consumed, ^batch} + batch = Enum.to_list(2000..2499) + assert_receive {:consumed, ^batch} + end + + test "consumer_init with continue instruction" do + {:ok, producer} = Counter.start_link({:producer, 0, {:continue, 500}}) + + {:ok, _} = + Forwarder.start_link( + {:consumer, self(), {:continue, :continue_reached}, subscribe_to: [producer]} + ) + + assert_receive :continue_reached + end + + test "producer_consumer with continue instruction" do + {:ok, producer} = Counter.start_link({:producer, 0}) + + {:ok, _doubler} = + Doubler.start_link( + {:producer_consumer, self(), {:continue, :continue_reached}, + subscribe_to: [{producer, max_demand: 100, min_demand: 80}]} + ) + + assert_receive :continue_reached + end + end + end + + describe "hibernate tests" do + test "producer_init with hibernate instruction" do + {:ok, producer} = Counter.start_link({:producer, 0, :hibernate}) + + assert :erlang.process_info(producer, :current_function) == + {:current_function, {:erlang, :hibernate, 3}} + end + + test "consumer_init with hibernate instruction" do + {:ok, producer} = Counter.start_link({:producer, 0}) + + {:ok, consumer} = + Forwarder.start_link({:consumer, self(), :hibernate, subscribe_to: [producer]}) + + assert :erlang.process_info(consumer, :current_function) == + {:current_function, {:erlang, :hibernate, 3}} + end + + test "producer_consumer with hibernate instruction" do + {:ok, producer} = Counter.start_link({:producer, 0}) + + {:ok, doubler} = + Doubler.start_link( + {:producer_consumer, self(), :hibernate, + subscribe_to: [{producer, max_demand: 100, min_demand: 80}]} + ) + + assert :erlang.process_info(doubler, :current_function) == + {:current_function, {:erlang, :hibernate, 3}} + end + end + describe "producer-to-consumer demand" do test "with default max and min demand" do {:ok, producer} = Counter.start_link({:producer, 0})