Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fc0825a
[#SAMPLE-6] Add state to logits processing
joelpaulkoch Oct 17, 2025
01ab3af
stateful logits processors
joelpaulkoch Oct 16, 2025
5413662
adding another test
joelpaulkoch Oct 16, 2025
9d4ef39
fix test so compilation works
joelpaulkoch Oct 20, 2025
4ce01cc
demonstrate stateful logits processor through test assertions
joelpaulkoch Oct 20, 2025
2161b77
independent state for batch entries
joelpaulkoch Oct 20, 2025
fefc9fd
renamed initial_suppressed_token_index for clarity
xhr15 Oct 21, 2025
6e8612a
renamend next_suppressed_index -> :next_suppressed_token_index
xhr15 Oct 21, 2025
e43254a
logits_processor_states -> logits_processor_state in batch tests
xhr15 Oct 21, 2025
a2f0015
added a test with batch size 1 for clarity
xhr15 Oct 21, 2025
0cdc0ad
renaming suppressed_id -> suppressed_token_id
xhr15 Oct 21, 2025
cc6d6e3
more comments
xhr15 Oct 21, 2025
3816e7c
changed to to demonstrate stack functionality
xhr15 Oct 23, 2025
fe58712
merged tests
xhr15 Oct 23, 2025
c97890a
removed test for processor only used in test
xhr15 Oct 23, 2025
fbf5ef3
introduces LogitsProcessor module
xhr15 Oct 24, 2025
dfa223c
clean up
joelpaulkoch Oct 27, 2025
9098bda
mix format
joelpaulkoch Oct 27, 2025
544d80f
vectorized sequences are called sequence in context
joelpaulkoch Oct 27, 2025
2ba5e0a
don't vectorize all the logits processor state
joelpaulkoch Oct 27, 2025
196c8f0
swap {logits, state} to {state, logits}
joelpaulkoch Nov 5, 2025
ee2a01e
rename logits_processor_state to logits_processor_states
joelpaulkoch Nov 5, 2025
3563ff0
states as tuples
joelpaulkoch Nov 5, 2025
6db771e
update test
joelpaulkoch Nov 5, 2025
c8442e0
single initial state for all batch entries
joelpaulkoch Nov 5, 2025
41dd2ad
vectorize sequence for init, derive vectorized state
joelpaulkoch Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,36 @@ defmodule Bumblebee do
end
end

@doc """
Initializes state for a new logits processor.

Returns `state`, which is an opaque `Nx.Container`, and it is then
passed to and returned from `process/4`.
"""
@doc type: :logits_processor
@spec logits_processor_init(
Bumblebee.LogitsProcessor.t(),
context :: term()
) :: Bumblebee.LogitsProcessor.state()
def logits_processor_init(%module{} = logits_processor, context) do
module.init(logits_processor, context)
end

@doc """
Processes logits, applying specific rules. Receives context, state and
logits, and returns updated logits and state.
"""
@doc type: :logits_processor
@spec logits_processor_process(
Bumblebee.LogitsProcessor.t(),
Bumblebee.LogitsProcessor.state(),
logits :: Nx.Tensor.t(),
context :: term()
) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()}
def logits_processor_process(%module{} = logits_processor, state, logits, context) do
module.process(logits_processor, state, logits, context)
end

@doc """
Initializes state for a new scheduler loop.

Expand Down
38 changes: 38 additions & 0 deletions lib/bumblebee/logits_processor.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
defmodule Bumblebee.LogitsProcessor do
@moduledoc """
An interface for configuring and using logits processors.

Logits processors are used during autoregressive generation to modify
predicted scores at each generation step. This allows for applying
certain rules to the model output to control which tokens are picked
at each generation step, and which are not.

Every module implementing this behaviour is expected to also define
a configuration struct.
"""

@type t :: Bumblebee.Configurable.t()

@type state :: Nx.Container.t()

@doc """
Initializes state for a new logits processor.

Returns `state`, which is an opaque `Nx.Container`, and it is then
passed to and returned from `process/2`.

Oftentimes logits processors are stateless, in which case this
function can return an empty container, such as `{}`.
"""
@callback init(t(), any()) :: state()

@doc """
Processes logits, applying specific rules.
"""
@callback process(
t(),
state(),
logits :: Nx.Tensor.t(),
context :: term()
) :: {state :: map(), logits :: Nx.Tensor.t()}
end
Loading