diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 51f2330f..62010053 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -1083,6 +1083,36 @@ defmodule Bumblebee do end end + @doc """ + Initializes state for a new logits processor. + + Returns `state`, which is an opaque `Nx.Container`, and it is then + passed to and returned from `process/4`. + """ + @doc type: :logits_processor + @spec logits_processor_init( + Bumblebee.LogitsProcessor.t(), + context :: Bumblebee.LogitsProcessor.init_context() + ) :: Bumblebee.LogitsProcessor.state() + def logits_processor_init(%module{} = logits_processor, context) do + module.init(logits_processor, context) + end + + @doc """ + Processes logits, applying specific rules. Receives context, state and + logits, and returns updated logits and state. + """ + @doc type: :logits_processor + @spec logits_processor_process( + Bumblebee.LogitsProcessor.t(), + Bumblebee.LogitsProcessor.state(), + logits :: Nx.Tensor.t(), + context :: Bumblebee.LogitsProcessor.process_context() + ) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()} + def logits_processor_process(%module{} = logits_processor, state, logits, context) do + module.process(logits_processor, state, logits, context) + end + @doc """ Initializes state for a new scheduler loop. diff --git a/lib/bumblebee/logits_processor.ex b/lib/bumblebee/logits_processor.ex new file mode 100644 index 00000000..40c74daf --- /dev/null +++ b/lib/bumblebee/logits_processor.ex @@ -0,0 +1,46 @@ +defmodule Bumblebee.LogitsProcessor do + @moduledoc """ + An interface for configuring and using logits processors. + + Logits processors are used during autoregressive generation to modify + predicted scores at each generation step. This allows for applying + certain rules to the model output to control which tokens are picked + at each generation step, and which are not. + + Every module implementing this behaviour is expected to also define + a configuration struct. + """ + + @type t :: Bumblebee.Configurable.t() + + @type state :: Nx.Container.t() + + @type process_context :: %{ + sequence: Nx.Tensor.t(), + length: Nx.Tensor.t(), + input_length: Nx.Tensor.t() + } + + @type init_context :: %{} + + @doc """ + Initializes state for a new logits processor. + + Returns `state`, which is an opaque `Nx.Container`, and it is then + passed to and returned from `process/2`. + + Oftentimes logits processors are stateless, in which case this + function can return an empty container, such as `{}`. + """ + @callback init(t(), init_context()) :: state() + + @doc """ + Processes logits, applying specific rules. + """ + @callback process( + t(), + state(), + logits :: Nx.Tensor.t(), + context :: process_context() + ) :: {state :: map(), logits :: Nx.Tensor.t()} +end diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 935c4921..669c1b7e 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -164,13 +164,15 @@ defmodule Bumblebee.Text.Generation do {_init_fun, predict_fun} = Axon.build(model, global_layer_options: global_layer_options) - logits_processor_fun = get_logits_processor(min_length_fun, config, opts[:logits_processors]) + {logits_processor_init_fun, logits_processor_process_fun} = + get_logits_processor(min_length_fun, config, opts[:logits_processors]) &generate_impl( &2, predict_fun, &1, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, prepare_inputs_fun, update_inputs_fun, traverse_cache_fun, @@ -386,18 +388,45 @@ defmodule Bumblebee.Text.Generation do [] end ++ logits_processors - fn logits, context -> - for processor <- processors, processor, reduce: logits do - logits -> processor.(logits, context) - end + processors = + processors + |> Enum.filter(fn processor -> processor != nil end) + |> Enum.map(fn processor -> + if is_function(processor, 2) do + %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: processor} + else + processor + end + end) + + init_fun = fn context -> + processors + |> Enum.map(fn processor -> + Bumblebee.logits_processor_init(processor, context) + end) + |> List.to_tuple() end + + process_fun = fn logits, context, processor_states -> + {processor_states, logits} = + processors + |> Enum.zip(Tuple.to_list(processor_states)) + |> Enum.map_reduce(logits, fn {processor, processor_state}, logits -> + Bumblebee.logits_processor_process(processor, processor_state, logits, context) + end) + + {List.to_tuple(processor_states), logits} + end + + {init_fun, process_fun} end defnp generate_impl( inputs, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, prepare_inputs_fun, update_inputs_fun, traverse_cache_fun, @@ -427,7 +456,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, merge_options([max_length: max_length], opts) ) @@ -439,7 +469,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, traverse_cache_fun, merge_options( @@ -456,7 +487,8 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, seed, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, merge_options([max_length: max_length], opts) ) @@ -485,7 +517,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -493,7 +526,14 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) # The loop works with inputs of length 1, so if the initial input # is longer, we make the initial pass outside @@ -504,7 +544,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -521,7 +561,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -533,7 +573,13 @@ defmodule Bumblebee.Text.Generation do state end - defnp init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) do + defnp init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) do {batch_size, length} = Nx.shape(decoder_input_ids) sequences = Nx.broadcast(pad_token_id, {batch_size, max_length}) @@ -545,13 +591,20 @@ defmodule Bumblebee.Text.Generation do # they could produce arbitrary tokens until we reach max length. finished_length = Nx.select(padded_batch_item?, 1, 0) + context = %{ + sequence: Nx.vectorize(sequences, :batch), + input_length: length, + length: length + } + %{ sequences: sequences, input_length: length, length: length, finished_length: finished_length, # The ignored return value that we attach all hooks to - ignored: Nx.broadcast(0, {batch_size}) + ignored: Nx.broadcast(0, {batch_size}), + logits_processor_states: logits_processor_init_fun.(context) } end @@ -564,7 +617,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, opts ) do @@ -574,7 +627,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) token_id = Nx.argmax(logits, axis: -1) state = update_sequences(state, token_id, pad_token_id, eos_token_id) @@ -631,15 +684,25 @@ defmodule Bumblebee.Text.Generation do end end - defnp batch_process_logits(logits_processor_fun, logits, state) do - logits - |> Nx.vectorize(:batch) - |> logits_processor_fun.(%{ + defnp batch_process_logits(logits_processor_process_fun, logits, state) do + logits = Nx.vectorize(logits, :batch) + + context = %{ sequence: Nx.vectorize(state.sequences, :batch), length: state.length, input_length: state.input_length - }) - |> Nx.devectorize(keep_names: false) + } + + {logits_processor_states, logits} = + logits_processor_process_fun.( + logits, + context, + state.logits_processor_states + ) + + logits = Nx.devectorize(logits, keep_names: false) + + {logits, %{state | logits_processor_states: logits_processor_states}} end # Contrastive search @@ -650,7 +713,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, traverse_cache_fun, opts \\ [] @@ -661,7 +725,14 @@ defmodule Bumblebee.Text.Generation do top_k = opts[:top_k] penalty_alpha = opts[:penalty_alpha] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) # Step (1) # Initial pass to obtain hidden state and expand inputs to top-k @@ -684,7 +755,7 @@ defmodule Bumblebee.Text.Generation do joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -727,7 +798,7 @@ defmodule Bumblebee.Text.Generation do logits = outputs.logits[[.., -1]] logits = Utils.Nx.chunked_take(logits, top_k, selected_idx) - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -817,7 +888,8 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, seed, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -825,7 +897,14 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key() @@ -839,7 +918,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -857,7 +936,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -875,7 +954,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -888,7 +967,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits) token_id = batched_choice(key, scores) diff --git a/lib/bumblebee/text/generation/stateless_logits_processor.ex b/lib/bumblebee/text/generation/stateless_logits_processor.ex new file mode 100644 index 00000000..8e84d6fd --- /dev/null +++ b/lib/bumblebee/text/generation/stateless_logits_processor.ex @@ -0,0 +1,30 @@ +defmodule Bumblebee.Text.Generation.StatelessLogitsProcessor do + @moduledoc false + + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + fun: [ + default: nil, + doc: "a state-less function that is applied to the logits" + ] + ] + + defstruct Bumblebee.Shared.option_defaults(options) + + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(_logits_processor, _init_context) do + %{} + end + + @impl Bumblebee.LogitsProcessor + def process(logits_processor, state, logits, process_context) do + {state, logits_processor.fun.(logits, process_context)} + end +end diff --git a/notebooks/debug_print.livemd b/notebooks/debug_print.livemd new file mode 100644 index 00000000..227ec7e3 --- /dev/null +++ b/notebooks/debug_print.livemd @@ -0,0 +1,203 @@ +# Printing top logits and token ids + +```elixir +Mix.install([ + {:bumblebee, "~> 0.6"}, + {:nx, "~> 0.10.0", override: true}, + {:exla, "~> 0.10.0"}, + {:emlx, github: "elixir-nx/emlx"} +]) + +# backend = {EMLX.Backend, device: :gpu} +# compiler = EMLX +backend = {EXLA.Backend, client: :host} +compiler = EXLA + +Nx.global_default_backend(backend) +``` + +## A print logits processor + +```elixir +defmodule PrintLogitsProcessor do + import Nx.Defn + + deftransform debug_processor(logits, context, opts \\ []) do + k = opts[:debug_limit] + + print_top_k_logits_and_token_ids(logits, k, context.sequence) + end + + defnp print_top_k_logits_and_token_ids(logits, k, sequence) do + token = create_token() + + {top_values, top_indices} = Nx.top_k(logits, k: k) + + {token, _sequence} = + hook_token(token, sequence, :sequence, &IO.inspect({:sequence, &1}, limit: :infinity)) + + {token, _top_values} = + hook_token(token, top_values, :top_values, &IO.inspect({:logits, &1})) + + {token, _top_indices} = + hook_token(token, top_indices, :top_indices, &IO.inspect({:token_ids, &1})) + + attach_token(token, logits) + end +end +``` + +## Building the generate function + +```elixir +repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} + +sequence_length = 512 + +max_new_tokens = 32 + +prompt = """ +<|im_start|>system +You are a helpful AI assistant named SmolLM. You tell phantastic poems about airships.<|im_end|> +<|im_start|>user +Tell about airships.<|im_end|> +<|im_start|>assistant +""" + +{:ok, model_info} = Bumblebee.load_model(repo, backend: backend) + +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: max_new_tokens, + strategy: %{type: :multinomial_sampling, top_k: 3 } + ) + +%{model: model, params: params, spec: spec} = model_info + +generate_fun = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: [ &PrintLogitsProcessor.debug_processor(&1, &2, [debug_limit: 2])] + ) +``` + +## Setting up the serving + +This is taken from `lib/bumblebee/text/text_generation.ex`. It's basically the `Bumblebee.Text.generation` function that you usually use to create text generation servings (with some minor modifications to simplify it). +We must use the lower level API here to be able to include `PrintLogitsProcessor` in `generate_fun`. + +```elixir +alias Bumblebee.Shared + +batch_keys = Shared.sequence_batch_keys(sequence_length) +batch_size = 1 +defn_options = [compiler: compiler] + +preallocate_params = false + +tokenizer = + Bumblebee.configure(tokenizer, + length: sequence_length, + pad_direction: :left, + return_token_type_ids: false, + return_length: true + ) + +validate_input = fn text -> {:ok, %{text: text, seed: :erlang.system_time()}} end + +serving = + Nx.Serving.new( + fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + + scope = {:generate, batch_key} + + generate_fun = + Shared.compile_or_jit(generate_fun, scope, defn_options, true, fn -> + {:sequence_length, sequence_length} = batch_key + + inputs = %{ + "input_ids" => Nx.template({batch_size, sequence_length}, :u32), + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), + "seed" => Nx.template({batch_size}, :s64) + } + + [params, inputs] + end) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + generate_fun.(params, inputs) |> Shared.serving_post_computation() + end + end, + defn_options + ) + |> Nx.Serving.batch_size(batch_size) + |> Nx.Serving.process_options(batch_keys: batch_keys) + |> Nx.Serving.client_preprocessing(fn input -> + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input.(&1)) + + texts = Enum.map(inputs, & &1.text) + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(type: :s64, backend: Nx.BinaryBackend) + + inputs = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts) + end) + + {input_length, inputs} = Map.pop!(inputs, "length") + input_padded_length = Nx.axis_size(inputs["input_ids"], 1) + + inputs = Map.put(inputs, "seed", seed) + + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, {multi?, input_length, input_padded_length}} + end) + |> Nx.Serving.client_postprocessing(fn {%{token_ids: token_ids, length: length}, _metadata}, + {multi?, input_length, input_padded_length} -> + decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) + output_length = Nx.to_flat_list(length) + input_length = Nx.to_flat_list(input_length) + + Enum.zip_with( + [decoded, output_length, input_length], + fn [decoded, output_length, input_length] -> + token_summary = + %{ + input: input_length, + output: output_length, + padding: input_padded_length - input_length + } + + %{results: [%{text: decoded, token_summary: token_summary}]} + end + ) + |> Shared.normalize_output(multi?) + end) +``` + +## Run the serving + +```elixir +prompt = """ +Tell me about airships. +""" +``` + +### Note: + +In the following cell, the **content of :sequence is padded** ([2, 2, ...] scroll to the right to see the content emerge): + +``` + [2, 2, 2, .. (a lot of 2's later) ... 31530, 549, 563, 1512, 27322, 30, 198, ...] +``` + +Have a look at the [tokenizer.json file on hugging face](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct/blob/main/tokenizer.json) to see the meaning on the tokens. + +```elixir +Nx.Serving.run(serving, prompt) +``` diff --git a/notebooks/suppressing_e.livemd b/notebooks/suppressing_e.livemd new file mode 100644 index 00000000..36500212 --- /dev/null +++ b/notebooks/suppressing_e.livemd @@ -0,0 +1,160 @@ +# Suppressing e + +```elixir +Mix.install([ + {:bumblebee, "~> 0.6.0"}, + {:nx, "~> 0.10.0"}, + {:exla, "~> 0.10.0"}, + {:kino, "~> 0.17.0"}, + {:emlx, "~> 0.2.0"} +]) + +# EMLX is fast but seems to work only with greedy strategy +# backend = {EMLX.Backend, device: :gpu} +# compiler = EMLX + +backend = {EXLA.Backend, client: :host} +compiler = EXLA + +Nx.global_default_backend(backend) +``` + +## Introduction + +In this notebook we outline the general setup for running a Large Language Model (LLM). + +## SmolLM2 + +In this section we look at running the [SmolLM2](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) model from huggingface as it is a small and open source LLM. + + + +Let's load the model and create a serving for text generation: + +```elixir +repo = {:hf, "HuggingFaceTB/SmolLM2-1.7B-Instruct"} + +{:ok, model_info} = Bumblebee.load_model(repo, type: :bf16, backend: backend) +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +:ok +``` + +```elixir +generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: 60, + # note that multinomial sampling might still pick one of the suppressed tokens + # depending on top_p or top_k + # strategy: %{type: :greedy_search} + strategy: %{type: :multinomial_sampling, top_k: 10} + # strategy: %{type: :multinomial_sampling, top_p: 0.7} + ) + +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 256], + stream: false, + defn_options: [compiler: compiler] + ) +``` + +```elixir +prompt = """ +<|im_start|>system +You are an AI Shakespeare writing poems. You are not allowed to use the letter e. +Kindoms will fall if you do. +Do NOT use the letter e. +If you use the letter e it will have catastrophic consequences!<|im_end|> +<|im_start|>user +Write a poem praising the functional programming concept.<|im_end|> +<|im_start|>assistant +""" + +Kino.Text.new(prompt) +``` + +```elixir +%{results: [%{text: out}]} = Nx.Serving.run(serving, prompt) + +Kino.Text.new(out) +``` + +```elixir +String.graphemes(out) |> Enum.count(&(&1 == "e" or &1 == "E")) +``` + +## Constrained Sampling + +First, we find all tokens in the vocabulary of our tokenizer which contain the letter `e`. + +```elixir +alias Bumblebee.Tokenizer + +last_token_id = model_info.spec.vocab_size - 1 + +special_tokens_ids = + Tokenizer.all_special_tokens(tokenizer) |> Enum.map(&Tokenizer.token_to_id(tokenizer, &1)) + +allowed_tokens_ids = + special_tokens_ids ++ Enum.map([""], &Tokenizer.token_to_id(tokenizer, &1)) + +token_ids_with_e = + for token_id <- 17..last_token_id, + token_id not in allowed_tokens_ids, + token = Tokenizer.id_to_token(tokenizer, token_id), + String.contains?(token, "e") or String.contains?(token, "E") do + token_id + end + +Enum.map(token_ids_with_e, fn id -> {id, Tokenizer.id_to_token(tokenizer, id)} end) +``` + +Then, we suppress all token ids that correspond to a token containing the letter `e` during generation. + +This is the logits processor used when we pass the config as below. + + + +```elixir + deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do + opts = Keyword.validate!(opts, [:suppressed_token_ids]) + + indices = opts[:suppressed_token_ids] |> Nx.tensor() |> Nx.new_axis(-1) + values = Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), {Nx.size(indices)}) + Nx.indexed_put(logits, indices, values) + end +``` + +```elixir +generation_config = + Bumblebee.configure(generation_config, + suppressed_token_ids: token_ids_with_e + ) + +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 1024], + stream: false, + defn_options: [compiler: compiler] + ) + +%{results: [%{text: out}]} = Nx.Serving.run(serving, prompt) + +Kino.Text.new(out) +``` + +```elixir +String.contains?(out, "e") or String.contains?(out, "E") +``` + +```elixir +%{"input_ids" => out_token_ids} = Bumblebee.apply_tokenizer(tokenizer, out) + +out_token_ids = Nx.to_flat_list(out_token_ids) + +ids = Enum.filter(out_token_ids, &(&1 in token_ids_with_e)) + +Enum.map(ids, &Tokenizer.id_to_token(tokenizer, &1)) +``` diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 5bc5a44f..190e97c4 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -382,7 +382,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do %{ sequence: Nx.tensor(sequence), length: Enum.count(sequence, &(&1 != 0)), - input_length: 1 + input_length: 1, + logits_processor_state: %{} } end end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index ff9854a1..10014b4b 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -106,4 +106,145 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) end + + test "with stateful logits processor with different batch sizes" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec + + input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) + attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) + seed = Nx.tensor([0]) + + ######################################################### + # batch size of 1 + + inputs = %{ + "input_ids" => input_ids, + "attention_mask" => attention_mask, + "seed" => seed + } + + # We demonstrate the use of the state with the following example of a + # stateful processor (see below). On the first iteration, it enforces the + # given initial ID, then increments the token ID to be enforced on the + # following iterations. The ID of the token to be enforced is passed on + # between iterations using the logits_processor_state. + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + # ToDo Bumblee.configure() + logits_processors: [ + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, + initial_enforced_token_id: 79 + ) + ] + ) + + # The result without the logits processor would be, as with the first + # decoder test above, [80, 80, 80]. + # + # Now, with the processor below, we expect the sequence of [79, 80, 81 ..], + # demonstrating the use of the state in a logits processor. + + %{token_ids: token_ids} = + Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) + + assert_equal(token_ids[[0, 0]], 79) + assert_equal(token_ids[[0, 1]], 80) + + ######################################################### + # batch size of 2 + + inputs = %{ + "input_ids" => Nx.Batch.concatenate([input_ids, input_ids]), + "attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]), + "seed" => Nx.Batch.concatenate([seed, seed]) + } + + # this is the same example as above, but with a batch size of 2. + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: [ + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, + initial_enforced_token_id: 78 + ) + ] + ) + + %{token_ids: token_ids} = + Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) + + # result without logit processor: 80, 80, 80 + + # first entry in batch + assert_equal(token_ids[[0, 0]], 78) + assert_equal(token_ids[[0, 1]], 79) + assert_equal(token_ids[[0, 2]], 80) + + # second entry in batch + assert_equal(token_ids[[1, 0]], 78) + assert_equal(token_ids[[1, 1]], 79) + assert_equal(token_ids[[1, 2]], 80) + end + + defmodule StatefulLogitsProcessing do + @moduledoc false + + import Nx.Defn + + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + initial_enforced_token_id: [ + default: [], + doc: "A token id to enforce on the first iteration" + ] + ] + + defstruct Bumblebee.Shared.option_defaults(options) + + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(logits_processor, _init_context) do + initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id]) + + %{ + next_enforced_token_id: initial_enforced_token_id + } + end + + @impl Bumblebee.LogitsProcessor + def process(_logits_processor, state, logits, _process_context) do + next_enforced_token_id = state.next_enforced_token_id + + logits = enforce_token(logits, next_enforced_token_id) + + next_enforced_token_id = Nx.add(next_enforced_token_id, 1) + + state = put_in(state.next_enforced_token_id, next_enforced_token_id) + + {state, logits} + end + + defnp enforce_token(logits, token_id) do + logits + |> Nx.fill(Nx.Constants.neg_infinity(), type: Nx.type(logits)) + |> Nx.indexed_put(token_id, Nx.tensor(0, type: Nx.type(logits))) + end + end end