-
Notifications
You must be signed in to change notification settings - Fork 118
Add state to logits processing #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add state to logits processing #425
Conversation
jonatanklosko
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @xhr15 and @joelpaulkoch, thanks for the PR!
I dropped a few comments, but the main one is about the API. I know it's a bit more involved, but probably worth it. Let me know what you think, and if you have any concerns!
| context = | ||
| put_in( | ||
| context, | ||
| [:logits_processor_state, :next_suppressed_token_id], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the current API, the state is always initialized to %{} and then first invocation of the processor adds a key, here %{next_suppressed_token_id: %Nx.Tensor{...}}.
This can be problematic in defn while loop, which requires the accumulation sate to always have the same shape. In other words, the initial state should already include :next_suppressed_token_id with the default tensor. It is possible that this didn't come up during your tests, because depending on the model/input, we do the first generation step outside of the while loop, and the first call would initialize the state. However, if we are going to support stateful, I would rather do it in a more robust way.
Given the above, a stateless logits processor would involve two steps (functions):
- Building an initial state.
- Performing logits processing, which receives logits and state, and returns update logits and state.
This way we can call (1) when initializing the generation context, and for the actual processing we call (2).
The behaviour can be similar to Bumblebee.Scheduler. Something like this:
defmodule Bumblebee.LogitsProcessor do
@moduledoc """
An interface for configuring and using logits processors.
Logits processors are used during autoregressive generation to modify
predicted scores at each generation step. This allows for applying
certain rules to the model output to control which tokens are picked
at each generation step, and which are not.
Every module implementing this behaviour is expected to also define
a configuration struct.
"""
@type t :: Bumblebee.Configurable.t()
@type state :: Nx.Container.t()
@doc """
Initializes state for a new logits processor.
Returns `state`, which is an opaque `Nx.Container`, and it is then
passed to and returned from `process/2`.
Oftentimes logits processors are stateless, in which case this
function can return an empty continer, such as `{}`.
"""
@callback init(t(), context) :: state()
when context: %{
prng_key: Nx.Tensor.t()
}
@doc """
Processes logits, applying specific rules.
"""
@callback process(
t(),
state(),
logits :: Nx.Tensor.t(),
context :: context
) :: {state :: map(), logits :: Nx.Tensor.t()}
when context: %{
sequence: Nx.Tensor.t(),
length: Nx.Tensor.t(),
input_length: Nx.Tensor.t()
}
endTechnically, the :logits_processors options is public API, but we can make it backward-compatible. For example, we can define %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun}, where the state is always empty and process just invokes the fun. I would even use that for the built-in processors, so that we don't need to define a bunch of new modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonatanklosko Thank you very much for your comments! I think esp. the two step call makes sense. We'll move in that direction :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonatanklosko
as an afterthought:
What is the use case for context here:
@callback init(t(), context) :: state()
when context: %{
prng_key: Nx.Tensor.t()
}
Later in the loop, context holds:
context = %{
sequences: sequences,
input_length: length,
length: length,
}
I am wondering how those would influence the initialisation of the logits processors?
Or are you planning of using additional keys? E.g. from the state as returned by init squence:
%{
sequences: sequences,
input_length: length,
length: length,
finished_length: finished_length,
ignored: Nx.broadcast(0, {batch_size})
}
If that was the case, we should probably rename the parameter to state or initial_state.
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the use case for context here:
I picked "context" in both functions as a generic name for state/metadata that may be relevant to the logits processor. You can see that in my snippet the context type is different for init and process. Technically all of the context fields could be separate arguments, but keeping it as a map makes the signature more manageable, and more importantly allows us to add more fields in the future without breaking compatibility.
Does that make sense?
|
@jonatanklosko Before we add more test and do further refactorings: Do you think this goes in the right direction? Please let me know if you have concerns or anything could be improved. |
|
We might not want to vectorize all the state of the logits processors e.g. when we want to read from a shared state tensor while processing the vectorized logits we would otherwise have to duplicate the shared state tensor across the vectorized axis, right? That's basically the reason for 2ba5e0a, I'm not entirely sure if this is alright or if it has negative implications for |
lib/bumblebee/text/generation.ex
Outdated
| Enum.reduce(processors, %{}, fn processor, state_acc -> | ||
| state = Bumblebee.logits_processor_init(processor, context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently each processor needs to be careful to namespace its state and avoid conflicts. Ideally the processor would not need to care about it, so in your example:
def init(logits_processor, _context) do
initial_enforced_token_ids =
Enum.map(logits_processor.initial_enforced_token_ids, &List.wrap(&1))
initial_enforced_batch_token_id =
Nx.tensor(initial_enforced_token_ids)
- %{
- sfp_state: %{
- next_enforced_token_id: initial_enforced_batch_token_id
- }
- }
+ %{
+ next_enforced_token_id: initial_enforced_batch_token_id
+ }
endTo do this, instead of having a single map and Map.merge into it, we can instead have a list of processor states. We init them separately, and we zip processors with their states for updates. Something like this:
init_fun = fn context ->
processors
|> Enum.map(fn processor ->
Bumblebee.logits_processor_init(processor, context)
end)
|> List.to_tuple()
end
process_fun = fn logits, context, processor_states ->
{processor_states, logits} =
processors
|> Enum.zip(Tuple.to_list(processor_states))
|> Enum.map_reduce(logits, fn {processor, processor_state}, logits ->
Bumblebee.logits_processor_process(processor, processor_state, logits, context)
end)
{List.to_tuple(processor_states), logits}
endNote that we want to keep the states as a tuple instead of list, so that it is a valid Nx container and can be passed to while and around defn calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @jonatanklosko, we were wondering if having this kind of interface to reach into other processors would be a good or bad thing :)
We'll take it out!
| Bumblebee.Text.Generation.build_generate(model, spec, generation_config, | ||
| logits_processors: [ | ||
| Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, | ||
| initial_enforced_token_ids: [78, 20] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an arbitrary example, but I don't think we would ever have different value for each entry in the batch, because the batch entries should generally be interchangable. So it should rather be initial_enforced_token_id: 78, and then you check that it is enforced in the same way for both batch entries.
Correct, Bumblebee should not call vectorize on the logits processor state. Ideally we want vectorization to happen automatically. For example, schedulers have a similar init, here's one of them: bumblebee/lib/bumblebee/diffusion/pndm_scheduler.ex Lines 97 to 108 in bc1b452
For this to work automatically though, we need something to derive state of off (like |
|
Sorry for the late reply, I was off last week :) |
Let's just pass |
```
** (RuntimeError) unexpected vectorized axes in evaluator for operation :add: #Nx.Tensor<
vectorized[batch: 1]
s32[1]
Nx.Defn.Expr
tensor a s32[1]
b = reshape a s32[1][1]
```
Co-authored-by: Jonatan Kłosko <[email protected]>
604b60e to
572b748
Compare
572b748 to
ce92584
Compare
…er all batches now
…ictly related to statefull processing
|
@jonatanklosko thank you for the late night review. Please let me know what you think. I added two livebooks about logits processing in the last commit. They are not strictly related to state, but I found them useful to explain logits processing in talks. I could open up a separate PR for them if you like, it was just too tempting to include them :) |
Logits processing is a powerful tool, particularly for using smaller language models for tasks such as named entity recognition. @seanmor5 started work in this area with #354.
Whatever the approach, it will require some kind of state.
This pull request is a proposal to allow logits processors to be stateful.
This would enable the use of deterministic finite automata (DFAs) or pushdown automata (PDAs) for processing constrained grammars in logits processing. bitcrowd#6 shows how this would be used. We will follow up on this PR if this approach is favoured.