Skip to content

Conversation

@xhr15
Copy link

@xhr15 xhr15 commented Oct 21, 2025

Logits processing is a powerful tool, particularly for using smaller language models for tasks such as named entity recognition. @seanmor5 started work in this area with #354.

Whatever the approach, it will require some kind of state.

This pull request is a proposal to allow logits processors to be stateful.

This would enable the use of deterministic finite automata (DFAs) or pushdown automata (PDAs) for processing constrained grammars in logits processing. bitcrowd#6 shows how this would be used. We will follow up on this PR if this approach is favoured.

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @xhr15 and @joelpaulkoch, thanks for the PR!

I dropped a few comments, but the main one is about the API. I know it's a bit more involved, but probably worth it. Let me know what you think, and if you have any concerns!

context =
put_in(
context,
[:logits_processor_state, :next_suppressed_token_id],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current API, the state is always initialized to %{} and then first invocation of the processor adds a key, here %{next_suppressed_token_id: %Nx.Tensor{...}}.

This can be problematic in defn while loop, which requires the accumulation sate to always have the same shape. In other words, the initial state should already include :next_suppressed_token_id with the default tensor. It is possible that this didn't come up during your tests, because depending on the model/input, we do the first generation step outside of the while loop, and the first call would initialize the state. However, if we are going to support stateful, I would rather do it in a more robust way.

Given the above, a stateless logits processor would involve two steps (functions):

  1. Building an initial state.
  2. Performing logits processing, which receives logits and state, and returns update logits and state.

This way we can call (1) when initializing the generation context, and for the actual processing we call (2).

The behaviour can be similar to Bumblebee.Scheduler. Something like this:

defmodule Bumblebee.LogitsProcessor do
  @moduledoc """
  An interface for configuring and using logits processors.

  Logits processors are used during autoregressive generation to modify
  predicted scores at each generation step. This allows for applying
  certain rules to the model output to control which tokens are picked
  at each generation step, and which are not.

  Every module implementing this behaviour is expected to also define
  a configuration struct.
  """

  @type t :: Bumblebee.Configurable.t()

  @type state :: Nx.Container.t()

  @doc """
  Initializes state for a new logits processor.

  Returns `state`, which is an opaque `Nx.Container`, and it is then
  passed to and returned from `process/2`.

  Oftentimes logits processors are stateless, in which case this
  function can return an empty continer, such as `{}`.
  """
  @callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

  @doc """
  Processes logits, applying specific rules.
  """
  @callback process(
              t(),
              state(),
              logits :: Nx.Tensor.t(),
              context :: context
            ) :: {state :: map(), logits :: Nx.Tensor.t()}
            when context: %{
                   sequence: Nx.Tensor.t(),
                   length: Nx.Tensor.t(),
                   input_length: Nx.Tensor.t()
                 }
end

Technically, the :logits_processors options is public API, but we can make it backward-compatible. For example, we can define %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun}, where the state is always empty and process just invokes the fun. I would even use that for the built-in processors, so that we don't need to define a bunch of new modules.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko Thank you very much for your comments! I think esp. the two step call makes sense. We'll move in that direction :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko
as an afterthought:

What is the use case for context here:

@callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

Later in the loop, context holds:

context = %{
      sequences: sequences,
      input_length: length,
      length: length,
    }

I am wondering how those would influence the initialisation of the logits processors?

Or are you planning of using additional keys? E.g. from the state as returned by init squence:

%{
      sequences: sequences,
      input_length: length,
      length: length,
      finished_length: finished_length,
      ignored: Nx.broadcast(0, {batch_size})
    }

If that was the case, we should probably rename the parameter to state or initial_state.

Wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the use case for context here:

I picked "context" in both functions as a generic name for state/metadata that may be relevant to the logits processor. You can see that in my snippet the context type is different for init and process. Technically all of the context fields could be separate arguments, but keeping it as a map makes the signature more manageable, and more importantly allows us to add more fields in the future without breaking compatibility.

Does that make sense?

@xhr15
Copy link
Author

xhr15 commented Oct 24, 2025

@jonatanklosko Before we add more test and do further refactorings: Do you think this goes in the right direction? Please let me know if you have concerns or anything could be improved.

@joelpaulkoch
Copy link
Contributor

We might not want to vectorize all the state of the logits processors e.g. when we want to read from a shared state tensor while processing the vectorized logits we would otherwise have to duplicate the shared state tensor across the vectorized axis, right?
We can instead vectorize only the state that needs vectorization inside the logits processor.

That's basically the reason for 2ba5e0a, I'm not entirely sure if this is alright or if it has negative implications for defn.

@xhr15 xhr15 requested a review from jonatanklosko November 3, 2025 22:18
Comment on lines 403 to 404
Enum.reduce(processors, %{}, fn processor, state_acc ->
state = Bumblebee.logits_processor_init(processor, context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently each processor needs to be careful to namespace its state and avoid conflicts. Ideally the processor would not need to care about it, so in your example:

def init(logits_processor, _context) do
  initial_enforced_token_ids =
    Enum.map(logits_processor.initial_enforced_token_ids, &List.wrap(&1))

  initial_enforced_batch_token_id =
    Nx.tensor(initial_enforced_token_ids)

-  %{
-    sfp_state: %{
-      next_enforced_token_id: initial_enforced_batch_token_id
-    }
-  }
+  %{
+    next_enforced_token_id: initial_enforced_batch_token_id
+  }
end

To do this, instead of having a single map and Map.merge into it, we can instead have a list of processor states. We init them separately, and we zip processors with their states for updates. Something like this:

init_fun = fn context ->
  processors
  |> Enum.map(fn processor ->
    Bumblebee.logits_processor_init(processor, context)
  end)
  |> List.to_tuple()
end

process_fun = fn logits, context, processor_states ->
  {processor_states, logits} =
    processors
    |> Enum.zip(Tuple.to_list(processor_states))
    |> Enum.map_reduce(logits, fn {processor, processor_state}, logits ->
      Bumblebee.logits_processor_process(processor, processor_state, logits, context)
    end)
    
  {List.to_tuple(processor_states), logits}
end

Note that we want to keep the states as a tuple instead of list, so that it is a valid Nx container and can be passed to while and around defn calls.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @jonatanklosko, we were wondering if having this kind of interface to reach into other processors would be a good or bad thing :)

We'll take it out!

Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
logits_processors: [
Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing,
initial_enforced_token_ids: [78, 20]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an arbitrary example, but I don't think we would ever have different value for each entry in the batch, because the batch entries should generally be interchangable. So it should rather be initial_enforced_token_id: 78, and then you check that it is enforced in the same way for both batch entries.

@jonatanklosko
Copy link
Member

We might not want to vectorize all the state of the logits processors e.g. when we want to read from a shared state tensor while processing the vectorized logits we would otherwise have to duplicate the shared state tensor across the vectorized axis, right? We can instead vectorize only the state that needs vectorization inside the logits processor.

That's basically the reason for 2ba5e0a, I'm not entirely sure if this is alright or if it has negative implications for defn.

Correct, Bumblebee should not call vectorize on the logits processor state. Ideally we want vectorization to happen automatically.

For example, schedulers have a similar init, here's one of them:

def init(scheduler, num_steps, sample_template, _prng_key) do
timesteps =
timesteps(
scheduler.num_train_steps,
num_steps,
scheduler.timesteps_offset,
scheduler.reduce_warmup
)
alpha_bars = init_parameters(scheduler: scheduler)
empty = Nx.fill(sample_template, 0)

alpha_bars is generated as a flat tensor and it is shared state (not duplicated across batch). On the other hand, the caller (Bumblebee) can pass sample_template with vectorized axis and then empty = Nx.fill(sample_template, 0) would be vectorized state. What's nice is that the scheduler is not aware about the vectorization, and a non-vectoriezd input works just fine.

For this to work automatically though, we need something to derive state of off (like sample_template), so that it gets automatically vectorized. I'm not yet sure how it would look for the processor, I need to think more about this.

@jonatanklosko
Copy link
Member

Sorry for the late reply, I was off last week :)

@jonatanklosko
Copy link
Member

For this to work automatically though, we need something to derive state of off (like sample_template), so that it gets automatically vectorized. I'm not yet sure how it would look for the processor, I need to think more about this.

Let's just pass sequence: Nx.vectorize(state.sequences, :batch) in the init context too. Depending on what per-sequence state the user creates, they may need to take special care to make it vectorization friendly (e.g. Nx.iota({2, 2}, vectorized_axes: sequence.vectorized_axes), or using Nx.broadcast_vectors), but I think it's fine.

xhr15 and others added 4 commits November 14, 2025 23:26
```
** (RuntimeError) unexpected vectorized axes in evaluator for operation :add: #Nx.Tensor<
       vectorized[batch: 1]
       s32[1]

       Nx.Defn.Expr
       tensor a        s32[1]
       b = reshape a   s32[1][1]
```
@xhr15 xhr15 force-pushed the task/sample-6-add-state-to-logits-processing branch from 604b60e to 572b748 Compare November 14, 2025 22:54
@xhr15 xhr15 force-pushed the task/sample-6-add-state-to-logits-processing branch from 572b748 to ce92584 Compare November 14, 2025 22:57
@xhr15
Copy link
Author

xhr15 commented Nov 14, 2025

@jonatanklosko thank you for the late night review. Please let me know what you think. I added two livebooks about logits processing in the last commit. They are not strictly related to state, but I found them useful to explain logits processing in talks. I could open up a separate PR for them if you like, it was just too tempting to include them :)

@xhr15 xhr15 requested a review from jonatanklosko November 14, 2025 23:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants