Skip to content

Commit e24bbc9

Browse files
Copilotlmangani
andcommitted
fix: proper SentencePiece unigram tokenizer (Viterbi + greedy longest-match)
Co-authored-by: lmangani <1423657+lmangani@users.noreply.github.com>
1 parent 0f38fb7 commit e24bbc9

File tree

4 files changed

+200
-79
lines changed

4 files changed

+200
-79
lines changed

DEV.md

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -613,18 +613,33 @@ from the VAE GGUF (`vae.temporal_scale`, `vae.spatial_scale`).
613613
614614
### Tokenizer
615615
616-
The T5 tokenizer uses a SentencePiece vocabulary embedded in the GGUF as a
617-
string array (`tokenizer.ggml.tokens`). Tokenisation proceeds:
616+
The T5 tokenizer implements the **SentencePiece unigram** algorithm in pure
617+
C++ with no external library dependency. The vocabulary and optional
618+
log-probability scores are loaded from the GGUF metadata at model-load time:
618619
619-
1. Split input on whitespace.
620-
2. For each word, prepend the `▁` (U+2581) sentinel.
621-
3. Look up the full token in the vocabulary map.
622-
4. On miss, fall back to individual character lookup; unknown chars map to
623-
`unk_id=2`.
624-
5. Append EOS (`eos_id=1`), pad to `max_len=512`.
620+
| GGUF key | Type | Description |
621+
|----------|------|-------------|
622+
| `tokenizer.ggml.tokens` | string[] | id → piece (UTF-8, ▁-prefixed) |
623+
| `tokenizer.ggml.scores` | float32[] | id → unigram log-probability (optional) |
625624
626-
This is intentionally simplified. For best prompt fidelity, replace the
627-
`T5Tokenizer::encode()` body with a proper BPE or unigram segmenter.
625+
**Preprocessing** (`T5Tokenizer::preprocess`):
626+
1. Collapse runs of whitespace to a single space; strip leading/trailing.
627+
2. Prepend `▁` (U+2581) to the beginning; replace each remaining space with `▁`.
628+
629+
**Segmentation** — two modes depending on whether scores are in the GGUF:
630+
631+
| Mode | Condition | Algorithm |
632+
|------|-----------|-----------|
633+
| Viterbi | `tokenizer.ggml.scores` present | DP over byte positions; maximises sum of log-probs; O(n × max_piece_len) |
634+
| Greedy | scores absent | Longest-match scan from left; O(n × max_piece_len) |
635+
636+
In both modes an **unk fallback** advances one full UTF-8 character (not one
637+
byte) when no vocabulary piece covers the current position, preventing split
638+
multi-byte sequences from producing garbage tokens.
639+
640+
Scores are written by `convert.py --tokenizer` (via
641+
`tok.sp_model.GetScore(i)`) and preserved through quantization by
642+
`ltx-quantize` (via `gguf_set_kv`).
628643
629644
---
630645
@@ -661,8 +676,8 @@ and where contributions are most welcome.
661676
| 2 | **VAE encoder** | Only the first `conv_in` layer is used; pseudo-encoding fallback | Implement full encoder stack for accurate I2V latent inversion |
662677
| 3 | **AdaLN-single** | Timestep embedding is computed but per-block scale/shift is not fully applied | Apply `ada_params` chunks as scale/shift in each block's norms |
663678
| 4 | **3-D RoPE** | Positional embeddings are not yet applied | Add rotary embeddings along (t, h, w) axes to Q and K tensors |
664-
| 5 | **T5 tokenizer** | Whitespace-split + per-char fallback | Replace with a proper SentencePiece unigram/BPE tokenizer |
665-
| 6 | **`ltx-quantize` metadata** | String arrays (tokenizer vocab) are skipped during quantization | Copy `GGUF_TYPE_ARRAY` entries in the KV copy loop |
679+
| 5 | **T5 tokenizer** | ~~Whitespace-split + per-char fallback~~ **Fixed**: full SentencePiece unigram Viterbi DP (when scores in GGUF) or greedy longest-match | |
680+
| 6 | **`ltx-quantize` metadata** | ~~String arrays (tokenizer vocab) are skipped during quantization~~ **Fixed**: `gguf_set_kv` copies all KV pairs including arrays | |
666681
| 7 | **Persistent scratch** | DiT allocates 1 GB of ggml scratch per forward call | Pre-allocate a single scratch context and reset between calls |
667682
| 8 | **Batch size > 1** | Only batch=1 is implemented | Add batch dimension to enable parallel generation |
668683
| 9 | **CFG single-pass** | CFG requires two full forward passes | Implement single-pass CFG by duplicating the batch |

convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,13 @@ def convert_t5(tensors: Dict[str, np.ndarray], output: str, tokenizer_path: Opti
209209
tok = HFT5Tok.from_pretrained(tokenizer_path)
210210
vocab = [tok.convert_ids_to_tokens(i) for i in range(tok.vocab_size)]
211211
w.writer.add_array("tokenizer.ggml.tokens", vocab)
212-
print(f" embedded tokenizer ({len(vocab)} tokens)")
212+
# Write SentencePiece unigram log-probability scores.
213+
# Presence of this key enables Viterbi-optimal segmentation in the
214+
# C++ tokenizer (t5_encoder.hpp); without it, greedy longest-match
215+
# is used, which is already a strong fallback.
216+
scores = [tok.sp_model.GetScore(i) for i in range(len(vocab))]
217+
w.writer.add_token_scores(scores)
218+
print(f" embedded tokenizer ({len(vocab)} tokens + unigram scores)")
213219
except Exception as e:
214220
print(f" warning: could not embed tokenizer: {e}")
215221

src/ltx-quantize.cpp

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,31 +59,8 @@ int main(int argc, char ** argv) {
5959
// Build output GGUF.
6060
struct gguf_context * out_ctx = gguf_init_empty();
6161

62-
// Copy all KV metadata.
63-
int64_t n_kv = gguf_get_n_kv(src.gguf_ctx);
64-
for (int64_t i = 0; i < n_kv; ++i) {
65-
const char * key = gguf_get_key(src.gguf_ctx, i);
66-
enum gguf_type vt = gguf_get_kv_type(src.gguf_ctx, i);
67-
switch (vt) {
68-
case GGUF_TYPE_STRING:
69-
gguf_set_val_str(out_ctx, key, gguf_get_val_str(src.gguf_ctx, i));
70-
break;
71-
case GGUF_TYPE_UINT32:
72-
gguf_set_val_u32(out_ctx, key, gguf_get_val_u32(src.gguf_ctx, i));
73-
break;
74-
case GGUF_TYPE_INT32:
75-
gguf_set_val_i32(out_ctx, key, gguf_get_val_i32(src.gguf_ctx, i));
76-
break;
77-
case GGUF_TYPE_FLOAT32:
78-
gguf_set_val_f32(out_ctx, key, gguf_get_val_f32(src.gguf_ctx, i));
79-
break;
80-
case GGUF_TYPE_BOOL:
81-
gguf_set_val_bool(out_ctx, key, gguf_get_val_bool(src.gguf_ctx, i));
82-
break;
83-
default:
84-
break; // arrays etc. – skip for now
85-
}
86-
}
62+
// Copy all KV metadata (scalars and arrays alike).
63+
gguf_set_kv(out_ctx, src.gguf_ctx);
8764

8865
// Add quantized tensors.
8966
for (int ti = 0; ti < n_tensors; ++ti) {

src/t5_encoder.hpp

Lines changed: 164 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
#include "ltx_common.hpp"
1717

18+
#include <algorithm>
19+
#include <unordered_map>
20+
1821
struct T5Config {
1922
int d_model = 4096; // hidden size
2023
int d_kv = 64; // key/value dim per head
@@ -26,66 +29,186 @@ struct T5Config {
2629
float eps = 1e-6f; // layer norm eps
2730
};
2831

29-
// ── Simple BPE tokenizer (SentencePiece-compatible, loaded from GGUF) ────────
32+
// ── SentencePiece unigram tokenizer ──────────────────────────────────────────
33+
//
34+
// Implements the SentencePiece unigram algorithm used by T5:
35+
// - Text preprocessing: whitespace normalisation + ▁ (U+2581) insertion.
36+
// - Segmentation: Viterbi DP when unigram log-probability scores are present
37+
// in the GGUF (key "tokenizer.ggml.scores"); greedy longest-match otherwise.
38+
// - Fallback: characters with no vocabulary piece are emitted as unk_id,
39+
// advancing one full UTF-8 character to avoid splitting multi-byte sequences.
40+
//
41+
// Vocabulary and optional scores are read from GGUF metadata:
42+
// "tokenizer.ggml.tokens" – string array: id → piece (UTF-8, ▁-prefixed)
43+
// "tokenizer.ggml.scores" – float32 array: id → unigram log-probability
3044

3145
struct T5Tokenizer {
32-
std::vector<std::string> vocab; // id → token string
33-
std::map<std::string, int> tok2id;
34-
int unk_id = 2, pad_id = 0, eos_id = 1;
35-
36-
// Load vocabulary from GGUF KV array "tokenizer.ggml.tokens".
46+
std::vector<std::string> vocab; // id → piece
47+
std::vector<float> scores; // id → log-prob (empty → greedy mode)
48+
std::unordered_map<std::string, int> tok2id; // piece → id (O(1) lookup)
49+
int unk_id = 2;
50+
int pad_id = 0;
51+
int eos_id = 1;
52+
int max_piece_len = 0; // max byte-length of any vocabulary piece
53+
54+
// Load vocabulary and (optional) scores from GGUF metadata.
3755
bool load_from_gguf(struct gguf_context * gc) {
38-
int64_t kid = gguf_find_key(gc, "tokenizer.ggml.tokens");
39-
if (kid < 0) {
56+
int64_t tokens_kid = gguf_find_key(gc, "tokenizer.ggml.tokens");
57+
if (tokens_kid < 0) {
4058
LTX_ERR("T5 tokenizer: 'tokenizer.ggml.tokens' not found in GGUF");
4159
return false;
4260
}
43-
int64_t n = gguf_get_arr_n(gc, kid);
61+
size_t n = gguf_get_arr_n(gc, tokens_kid);
4462
vocab.resize(n);
45-
for (int64_t i = 0; i < n; ++i) {
46-
vocab[i] = gguf_get_arr_str(gc, kid, i);
63+
for (size_t i = 0; i < n; ++i) {
64+
vocab[i] = gguf_get_arr_str(gc, tokens_kid, i);
4765
tok2id[vocab[i]] = static_cast<int>(i);
66+
int len = static_cast<int>(vocab[i].size());
67+
if (len > max_piece_len) max_piece_len = len;
4868
}
49-
LTX_LOG("T5 tokenizer: loaded %lld tokens", (long long)n);
69+
70+
// Optional: unigram log-probability scores → enables Viterbi mode.
71+
int64_t scores_kid = gguf_find_key(gc, "tokenizer.ggml.scores");
72+
if (scores_kid >= 0 &&
73+
gguf_get_arr_type(gc, scores_kid) == GGUF_TYPE_FLOAT32) {
74+
size_t ns = gguf_get_arr_n(gc, scores_kid);
75+
const float * raw = reinterpret_cast<const float *>(
76+
gguf_get_arr_data(gc, scores_kid));
77+
if (raw) scores.assign(raw, raw + ns);
78+
}
79+
80+
LTX_LOG("T5 tokenizer: loaded %zu tokens, max_piece=%d bytes, mode=%s",
81+
n, max_piece_len, scores.empty() ? "greedy" : "Viterbi");
5082
return true;
5183
}
5284

53-
// Naïve whitespace + subword tokenisation (SentencePiece ▁ prefix).
54-
// For production use, replace with a proper SentencePiece unigram tokenizer.
55-
std::vector<int> encode(const std::string & text, int max_len) const {
56-
std::vector<int> ids;
57-
58-
// Split on whitespace and look up each word (incl. subword fall-back).
59-
std::string cur;
60-
auto flush = [&]() {
61-
if (cur.empty()) return;
62-
// prepend ▁ (U+2581 = \xe2\x96\x81)
63-
std::string tok = "\xe2\x96\x81" + cur;
64-
auto it = tok2id.find(tok);
65-
if (it != tok2id.end()) {
66-
ids.push_back(it->second);
85+
// SentencePiece text normalisation:
86+
// 1. Collapse runs of whitespace to a single space; strip leading/trailing.
87+
// 2. Prepend ▁ and replace each remaining space with ▁.
88+
static std::string preprocess(const std::string & text) {
89+
// Step 1: collapse and strip.
90+
std::string stripped;
91+
stripped.reserve(text.size());
92+
bool prev_ws = true; // treat start as whitespace to drop leading ws
93+
for (unsigned char c : text) {
94+
bool is_ws = (c == ' ' || c == '\t' || c == '\n' || c == '\r');
95+
if (is_ws) {
96+
if (!prev_ws) stripped += ' ';
6797
} else {
68-
// fall back character by character
69-
for (char c : cur) {
70-
std::string ct(1, c);
71-
auto it2 = tok2id.find(ct);
72-
ids.push_back(it2 != tok2id.end() ? it2->second : unk_id);
98+
stripped += static_cast<char>(c);
99+
}
100+
prev_ws = is_ws;
101+
}
102+
while (!stripped.empty() && stripped.back() == ' ') stripped.pop_back();
103+
104+
// Step 2: insert ▁ (U+2581 = \xe2\x96\x81, 3 bytes).
105+
static const char SPIECE[4] = "\xe2\x96\x81";
106+
std::string out;
107+
out.reserve(stripped.size() * 2);
108+
out.append(SPIECE, 3); // always prepend ▁
109+
for (char c : stripped) {
110+
if (c == ' ') out.append(SPIECE, 3);
111+
else out += c;
112+
}
113+
return out;
114+
}
115+
116+
// Return the byte-length of the UTF-8 character whose first byte is `b`.
117+
static int utf8_char_len(unsigned char b) {
118+
if (b < 0x80) return 1; // 0xxxxxxx – ASCII
119+
if ((b & 0xE0) == 0xC0) return 2; // 110xxxxx – 2-byte
120+
if ((b & 0xF0) == 0xE0) return 3; // 1110xxxx – 3-byte (e.g. ▁)
121+
if ((b & 0xF8) == 0xF0) return 4; // 11110xxx – 4-byte
122+
return 1; // invalid continuation byte: skip
123+
}
124+
125+
// Viterbi optimal segmentation maximising the sum of unigram log-probs.
126+
std::vector<int> viterbi(const std::string & text) const {
127+
int n = static_cast<int>(text.size());
128+
if (n == 0) return {};
129+
130+
constexpr float NEG_INF = -1e38f;
131+
// best[i]: best total score for text[0..i)
132+
std::vector<float> best(n + 1, NEG_INF);
133+
// from[i]: {prev_position, token_id} that achieves best[i]
134+
std::vector<std::pair<int,int>> from(n + 1, {-1, -1});
135+
best[0] = 0.0f;
136+
137+
for (int i = 0; i < n; ++i) {
138+
if (best[i] <= NEG_INF / 2.0f) continue;
139+
int max_len = std::min(max_piece_len, n - i);
140+
bool any_match = false;
141+
for (int len = 1; len <= max_len; ++len) {
142+
auto it = tok2id.find(text.substr(i, len));
143+
if (it == tok2id.end()) continue;
144+
int tok = it->second;
145+
float sc = (tok < static_cast<int>(scores.size()))
146+
? scores[tok] : 0.0f;
147+
float new_best = best[i] + sc;
148+
if (new_best > best[i + len]) {
149+
best[i + len] = new_best;
150+
from[i + len] = {i, tok};
151+
}
152+
any_match = true;
153+
}
154+
// No vocabulary piece covers position i: skip one UTF-8 char as unk.
155+
if (!any_match) {
156+
int skip = std::min(utf8_char_len(
157+
static_cast<unsigned char>(text[i])), n - i);
158+
constexpr float UNK_PENALTY = -10.0f;
159+
if (best[i] + UNK_PENALTY > best[i + skip]) {
160+
best[i + skip] = best[i] + UNK_PENALTY;
161+
from[i + skip] = {i, unk_id};
73162
}
74163
}
75-
cur.clear();
76-
};
164+
}
77165

78-
for (char c : text) {
79-
if (c == ' ') { flush(); }
80-
else { cur += c; }
166+
// Backtrack from position n.
167+
std::vector<int> ids;
168+
for (int pos = n; pos > 0;) {
169+
auto [prev, tok] = from[pos];
170+
if (prev < 0) { ids.push_back(unk_id); break; }
171+
ids.push_back(tok);
172+
pos = prev;
81173
}
82-
flush();
174+
std::reverse(ids.begin(), ids.end());
175+
return ids;
176+
}
83177

84-
// Append EOS, pad / truncate to max_len.
85-
ids.push_back(eos_id);
86-
while ((int)ids.size() < max_len) ids.push_back(pad_id);
87-
if ((int)ids.size() > max_len) ids.resize(max_len);
178+
// Greedy longest-match segmentation (fallback when scores are absent).
179+
std::vector<int> greedy(const std::string & text) const {
180+
std::vector<int> ids;
181+
int n = static_cast<int>(text.size());
182+
int pos = 0;
183+
while (pos < n) {
184+
int max_len = std::min(max_piece_len, n - pos);
185+
bool found = false;
186+
for (int len = max_len; len >= 1; --len) {
187+
auto it = tok2id.find(text.substr(pos, len));
188+
if (it != tok2id.end()) {
189+
ids.push_back(it->second);
190+
pos += len;
191+
found = true;
192+
break;
193+
}
194+
}
195+
if (!found) {
196+
ids.push_back(unk_id);
197+
pos += std::min(utf8_char_len(
198+
static_cast<unsigned char>(text[pos])), n - pos);
199+
}
200+
}
201+
return ids;
202+
}
88203

204+
// Tokenise text; pad or truncate to max_len (EOS is appended before padding).
205+
std::vector<int> encode(const std::string & text, int max_len) const {
206+
std::string processed = preprocess(text);
207+
std::vector<int> ids = scores.empty() ? greedy(processed)
208+
: viterbi(processed);
209+
ids.push_back(eos_id);
210+
while (static_cast<int>(ids.size()) < max_len) ids.push_back(pad_id);
211+
if (static_cast<int>(ids.size()) > max_len) ids.resize(max_len);
89212
return ids;
90213
}
91214
};

0 commit comments

Comments
 (0)