Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,12 @@ defmodule Bumblebee.Text.Generation do
end ++ logits_processors

fn logits, context ->
for processor <- processors, processor, reduce: logits do
logits -> processor.(logits, context)
for processor <- processors, processor, reduce: {logits, context} do
{logits, context} ->
case processor.(logits, context) do
{logits, new_context} -> {logits, new_context}
logits -> {logits, context}
end
end
end
end
Expand Down Expand Up @@ -551,7 +555,8 @@ defmodule Bumblebee.Text.Generation do
length: length,
finished_length: finished_length,
# The ignored return value that we attach all hooks to
ignored: Nx.broadcast(0, {batch_size})
ignored: Nx.broadcast(0, {batch_size}),
logits_processor_states: %{}
}
end

Expand All @@ -574,7 +579,7 @@ defmodule Bumblebee.Text.Generation do
outputs = predict_fun.(params, inputs)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
token_id = Nx.argmax(logits, axis: -1)

state = update_sequences(state, token_id, pad_token_id, eos_token_id)
Expand Down Expand Up @@ -632,14 +637,24 @@ defmodule Bumblebee.Text.Generation do
end

defnp batch_process_logits(logits_processor_fun, logits, state) do
logits
|> Nx.vectorize(:batch)
|> logits_processor_fun.(%{
sequence: Nx.vectorize(state.sequences, :batch),
length: state.length,
input_length: state.input_length
})
|> Nx.devectorize(keep_names: false)
logits = Nx.vectorize(logits, :batch)

{logits, new_context} =
logits_processor_fun.(logits, %{
sequence: Nx.vectorize(state.sequences, :batch),
length: state.length,
input_length: state.input_length,
logits_processor_state: Nx.vectorize(state.logits_processor_states, :batch)
})

logits = Nx.devectorize(logits, keep_names: false)

logits_processor_states =
Nx.devectorize(new_context.logits_processor_state, keep_names: false)

sequences = Nx.devectorize(new_context.sequence, keep_names: false)

{logits, %{state | sequences: sequences, logits_processor_states: logits_processor_states}}
end

# Contrastive search
Expand Down Expand Up @@ -684,7 +699,7 @@ defmodule Bumblebee.Text.Generation do
joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)

Expand Down Expand Up @@ -727,7 +742,7 @@ defmodule Bumblebee.Text.Generation do

logits = outputs.logits[[.., -1]]
logits = Utils.Nx.chunked_take(logits, top_k, selected_idx)
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)

scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)
Expand Down Expand Up @@ -888,7 +903,7 @@ defmodule Bumblebee.Text.Generation do
outputs = predict_fun.(params, inputs)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
scores = Axon.Activations.softmax(logits)
token_id = batched_choice(key, scores)

Expand Down
50 changes: 49 additions & 1 deletion test/bumblebee/text/generation/logits_processing_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,53 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do

alias Bumblebee.Text.Generation.LogitsProcessing

describe "stateful logits processors" do
defmodule StatefulLogitsProcessing do
import Nx.Defn

deftransform stateful_processor(logits, context, opts) do
initial_suppressed_token_index = Nx.tensor([opts[:initial_suppressed_token_index]])

suppressed_index =
context.logits_processor_state[:next_suppressed_token_index] || initial_suppressed_token_index

values =
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index))

logits = Nx.indexed_put(logits, suppressed_index, values)

next_suppressed_token_index = Nx.add(suppressed_index, Nx.tensor([1]))

context =
put_in(
context,
[:logits_processor_state, :next_suppressed_token_index],
next_suppressed_token_index
)

{logits, context}
end
end

test "can register and modify state" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])

context = context([1, 0, 0, 0])

{logits, context} =
StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0)

assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0]))
assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([1]))

{logits, context} =
StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0)

assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0]))
assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([2]))
end
end

describe "suppressed_tokens_processor/3" do
test "ignores the given tokens" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])
Expand Down Expand Up @@ -382,7 +429,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
%{
sequence: Nx.tensor(sequence),
length: Enum.count(sequence, &(&1 != 0)),
input_length: 1
input_length: 1,
logits_processor_state: %{}
}
end
end
145 changes: 145 additions & 0 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,149 @@ defmodule Bumblebee.Text.GenerationTest do

assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]]))
end


test "with stateful logits processor with batch size of 1" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec

input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]])
attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
seed = Nx.tensor([0])

inputs = %{
"input_ids" => input_ids,
"attention_mask" => attention_mask,
"seed" => seed
}

# We demonstrate the use of the state with the following example of a
# stateful processor (see below). On the first iteration, it suppresses the
# given initial ID, then increments the token ID to be suppressed on the
# following iterations. The ID of the token to be suppressed is passed on
# between iterations using the logits_processor_state.
#
# So invoked with the initial ID of 79, it suppresses 79, 80, 81, ... in
# the subsequent iterations, demonstrating the use of the state in a
# logits processor.

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2)

generate =
Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
logits_processors: [
&Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2,
initial_suppressed_token_id: [79]
)
]
)

# The result without the logits processor would be, as with the first
# decoder test above: 80, 80, 80.
#
# Now, with the processor below, we expect no change (suppressed token ID is
# 79), then a change to another random token ID (176) as the suppressed
# token ID is incremented from 79 to 80, disallowing the previous most
# likely token ID (80) from being selected.

%{token_ids: token_ids} = generate.(params, inputs)


# first token_id still 80 as we suppress token_id 79
assert_equal(token_ids[[0,0]], 80)
# in the next step we increment from 79 to 80 and suppress token_id 80, the
#result is 176 as that is the next likelihood in the logits.

assert_equal(token_ids[[0,1]], 176)
end

test "with stateful logits processor with batch size of 2" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec

input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]])
attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
seed = Nx.tensor([0])

inputs = %{
"input_ids" => Nx.Batch.concatenate([input_ids, input_ids]),
"attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]),
"seed" => Nx.Batch.concatenate([seed, seed])
}

# this is the same example as above, but with a batch size of 2.


generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)

generate =
Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
logits_processors: [
&Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2,
initial_suppressed_token_id: [78, 79]
)
]
)

%{token_ids: token_ids} = generate.(params, inputs)

# result without logit processor: 80, 80, 80

# first entry in batch
# first token_id still 80 as we suppress token_id 78
assert_equal(token_ids[[0, 0]], 80)
# second token_id still 80 as we suppress token_id 79
assert_equal(token_ids[[0, 1]], 80)
# in the next step we increment from 79 to 80 and suppress token_id 80
assert_equal(token_ids[[0, 2]], 1016)

# second entry in batch
# first token_id still 80 as we suppress token_id 79
assert_equal(token_ids[[1, 0]], 80)
# in the next step we increment from 79 to 80 and suppress token_id 80
assert_equal(token_ids[[1, 1]], 176)
end

defmodule StatefulLogitsProcessing do
import Nx.Defn

deftransform stateful_processor(logits, context, opts \\ []) do
initial_suppressed_token_ids = Enum.map(opts[:initial_suppressed_token_id], &List.wrap(&1))
initial_suppressed_token_id = Nx.tensor(initial_suppressed_token_ids) |> Nx.vectorize(:batch)

suppressed_id =
context.logits_processor_state[:next_suppressed_token_id] || initial_suppressed_token_id

logits = suppress_id(logits, suppressed_id)

next_suppressed_token_id = Nx.add(suppressed_id, 1)

context =
put_in(
context,
[:logits_processor_state, :next_suppressed_token_id],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current API, the state is always initialized to %{} and then first invocation of the processor adds a key, here %{next_suppressed_token_id: %Nx.Tensor{...}}.

This can be problematic in defn while loop, which requires the accumulation sate to always have the same shape. In other words, the initial state should already include :next_suppressed_token_id with the default tensor. It is possible that this didn't come up during your tests, because depending on the model/input, we do the first generation step outside of the while loop, and the first call would initialize the state. However, if we are going to support stateful, I would rather do it in a more robust way.

Given the above, a stateless logits processor would involve two steps (functions):

  1. Building an initial state.
  2. Performing logits processing, which receives logits and state, and returns update logits and state.

This way we can call (1) when initializing the generation context, and for the actual processing we call (2).

The behaviour can be similar to Bumblebee.Scheduler. Something like this:

defmodule Bumblebee.LogitsProcessor do
  @moduledoc """
  An interface for configuring and using logits processors.

  Logits processors are used during autoregressive generation to modify
  predicted scores at each generation step. This allows for applying
  certain rules to the model output to control which tokens are picked
  at each generation step, and which are not.

  Every module implementing this behaviour is expected to also define
  a configuration struct.
  """

  @type t :: Bumblebee.Configurable.t()

  @type state :: Nx.Container.t()

  @doc """
  Initializes state for a new logits processor.

  Returns `state`, which is an opaque `Nx.Container`, and it is then
  passed to and returned from `process/2`.

  Oftentimes logits processors are stateless, in which case this
  function can return an empty continer, such as `{}`.
  """
  @callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

  @doc """
  Processes logits, applying specific rules.
  """
  @callback process(
              t(),
              state(),
              logits :: Nx.Tensor.t(),
              context :: context
            ) :: {state :: map(), logits :: Nx.Tensor.t()}
            when context: %{
                   sequence: Nx.Tensor.t(),
                   length: Nx.Tensor.t(),
                   input_length: Nx.Tensor.t()
                 }
end

Technically, the :logits_processors options is public API, but we can make it backward-compatible. For example, we can define %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun}, where the state is always empty and process just invokes the fun. I would even use that for the built-in processors, so that we don't need to define a bunch of new modules.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko Thank you very much for your comments! I think esp. the two step call makes sense. We'll move in that direction :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko
as an afterthought:

What is the use case for context here:

@callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

Later in the loop, context holds:

context = %{
      sequences: sequences,
      input_length: length,
      length: length,
    }

I am wondering how those would influence the initialisation of the logits processors?

Or are you planning of using additional keys? E.g. from the state as returned by init squence:

%{
      sequences: sequences,
      input_length: length,
      length: length,
      finished_length: finished_length,
      ignored: Nx.broadcast(0, {batch_size})
    }

If that was the case, we should probably rename the parameter to state or initial_state.

Wdyt?

next_suppressed_token_id
)

{logits, context}
end

defnp suppress_id(logits, id) do
Nx.indexed_put(
logits,
id,
Nx.Constants.neg_infinity(Nx.type(logits))
)
end
end
end