[Bug Fix] Support last_turn_loss_only in dynamic_loss_mask path#37
[Bug Fix] Support last_turn_loss_only in dynamic_loss_mask path#37yubofredwang merged 7 commits intomainfrom
Conversation
The last_turn_loss_only flag was silently ignored when defer_tokenization was enabled (dynamic_loss_mask path). Thread the flag through compute_assistant_loss_mask → DataCollatorWithPadding → Trainer so it takes effect for deferred-tokenization workloads.
c9d86dd to
ce508aa
Compare
The assistant header's trailing newline can merge with content tokens under BPE, breaking subsequence matching. Strip the newline from the header pattern and skip those token(s) when scanning for assistant content boundaries. Also fixes FSDP sync for batches with no valid loss-mask tokens in Eagle3 by using a dummy index through the compiled graph instead of an early-return path.
ce508aa to
0a83740
Compare
There was a problem hiding this comment.
Pull request overview
This PR fixes last_turn_loss_only being ignored when defer_tokenization=True (dynamic loss mask path) by threading the flag and related metadata through the inference → controller → mooncake dataset → collator pipeline, and extends loss-mask computation to handle template header newline/BPE edge cases.
Changes:
- Thread
last_turn_loss_onlythrough Mooncake fetching so dynamic loss masks can respect it (including per-sample overrides via metadata). - Add
skip_after_headersupport to avoid BPE merge issues around assistant-header trailing newlines. - Introduce
"auto"mode (based on thinking-content detection) and add/extend tests for mask correctness.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| torchspec/utils/types.py | Adds metadata to InferenceOutput so dataset-derived metadata can propagate through inference outputs. |
| torchspec/utils/processing.py | Enhances assistant header token derivation and returns skip_after_header for dynamic mask stability. |
| torchspec/training/trainer.py | Threads dynamic-mask parameters (incl. last_turn_loss_only, skip_after_header) into the mooncake fetcher. |
| torchspec/training/data_fetcher.py | Materializes/derives loss masks in the dataset (packed or dynamic) and supports per-sample overrides. |
| torchspec/models/ops/loss_mask.py | Adds last_turn_only and skip_after_header support to the dynamic loss mask implementation. |
| torchspec/data/utils.py | Collator now expects precomputed loss_mask; adds image URL cache resolution helpers. |
| torchspec/data/parse.py | Adds has_thinking_content() utility and exports it for auto mode. |
| torchspec/data/dataset.py | Implements "auto" last-turn masking logic and emits has_thinking metadata in deferred mode. |
| torchspec/controller/training_controller.py | Converts metadata.has_thinking into a per-sample last_turn_loss_only flag for training. |
| torchspec/controller/inference_manager.py | Preserves InferenceInput.metadata into InferenceOutput for downstream use. |
| torchspec/config/train_config.py | Changes last_turn_loss_only to allow "auto". |
| tests/test_loss_mask_cross_validation.py | Cross-validates dynamic loss mask vs parser parse across templates (with skips on missing tokenizers). |
| tests/test_loss_mask.py | Adds unit tests for last_turn_only behavior in compute_assistant_loss_mask. |
| tests/test_kimi_k25_integration.py | Updates integration test to pass loss_mask (since collator no longer unpacks packed_loss_mask). |
| tests/test_dynamic_last_turn_loss.py | Adds tests for thinking detection + collator precomputed-mask contract + dataset mask computation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…filter, normalize reasoning fields - Move loss-mask resolution from MooncakeDataset into reusable resolve_loss_mask() in data/utils.py - Add pre-tokenization filtering to drop samples exceeding estimated token budget (image-aware) - Preserve reasoning_content/reasoning/thinking fields during conversation normalization - Move process-level imports to module top-level in dataset.py - Expand has_thinking_content tests for reasoning_content and reasoning fields
87b2122 to
09aa71b
Compare
There was a problem hiding this comment.
Pull request overview
Fixes deferred-tokenization (“dynamic_loss_mask”) training so last_turn_loss_only is honored by threading per-sample state from dataset → inference/controller → mooncake fetcher, and extends the dynamic loss-mask implementation to better match parser behavior across templates.
Changes:
- Propagate per-sample
last_turn_loss_only(including"auto"viametadata["has_thinking"]) through controller dispatch into the Mooncake data pipeline. - Refactor loss-mask handling so the collator consumes a pre-materialized
loss_mask, with centralized computation inresolve_loss_mask()and updated numba masking support (last_turn_only,skip_after_header). - Add/extend tests validating dynamic masking vs parser ground truth, newline/BPE edge cases, and per-sample
"auto"behavior.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| torchspec/utils/types.py | Adds metadata to inference outputs so per-sample flags can be forwarded. |
| torchspec/utils/processing.py | Enhances assistant header token derivation with newline-stripping + skip_after_header. |
| torchspec/training/trainer.py | Threads dynamic masking parameters (including last-turn config) into Mooncake fetching. |
| torchspec/training/data_fetcher.py | Adds per-sample last-turn flag transport and resolves loss masks before collation; skips all-zero masks. |
| torchspec/models/ops/loss_mask.py | Extends dynamic mask computation with last_turn_only and newline/BPE skip support. |
| torchspec/data/utils.py | Moves loss-mask materialization to resolve_loss_mask(); collator now requires loss_mask. |
| torchspec/data/preprocessing.py | Normalizes thinking/reasoning fields into reasoning_content. |
| torchspec/data/parse.py | Adds has_thinking_content() helper for "auto" last-turn behavior. |
| torchspec/data/dataset.py | Implements "auto" resolution and threads has_thinking metadata; adds extra multimodal filtering + cache key params. |
| torchspec/controller/training_controller.py | Maps metadata["has_thinking"] into TrainSample.last_turn_loss_only. |
| torchspec/controller/inference_manager.py | Forwards input metadata into inference outputs. |
| torchspec/config/train_config.py | Changes last_turn_loss_only config to support "auto". |
| tests/test_loss_mask_cross_validation.py | Cross-validates dynamic loss masks vs parser masks across templates. |
| tests/test_loss_mask.py | Adds unit tests for last_turn_only behavior. |
| tests/test_kimi_k25_integration.py | Updates integration test to supply loss_mask directly to collator. |
| tests/test_dynamic_last_turn_loss.py | Adds tests for has_thinking_content, collator expectations, and resolve_loss_mask() behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR extends TorchSpec’s training data pipeline to support per-sample loss masking decisions (including an “auto” mode driven by detecting thinking/reasoning content) and improves dynamic assistant-span masking robustness around chat-template header newlines.
Changes:
- Add “auto” mode for
last_turn_loss_only, with per-sample detection viahas_thinking_content()and propagation through metadata into training samples. - Refactor loss-mask materialization so the collator consumes a precomputed
loss_mask, and add a dynamic mask skip mechanism for header-newline BPE merge edge cases. - Add/expand tests for last-turn-only masking, cross-validation against parser masks across templates, and thinking-content detection.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| torchspec/utils/types.py | Extend InferenceOutput to carry metadata through the pipeline. |
| torchspec/utils/processing.py | Return (header_ids, end_ids, skip_after_header) to mitigate header newline BPE merge issues. |
| torchspec/training/trainer.py | Thread dynamic mask parameters (incl. last-turn-only + skip count) into MooncakeDataFetcher. |
| torchspec/training/data_fetcher.py | Compute/attach loss_mask inside MooncakeDataset and optionally skip all-zero-mask samples. |
| torchspec/models/ops/loss_mask.py | Add last_turn_only and skip_after_header support to dynamic assistant-span masking. |
| torchspec/data/utils.py | Make collator require a pre-materialized loss_mask; add resolve_loss_mask() as a single source of truth. |
| torchspec/data/preprocessing.py | Normalize additional reasoning/thinking fields into reasoning_content. |
| torchspec/data/parse.py | Add has_thinking_content() helper for detecting non-empty thinking/reasoning. |
| torchspec/data/dataset.py | Implement “auto” per-sample last-turn-only decision and propagate has_thinking metadata; add multimodal token-budget filtering. |
| torchspec/controller/training_controller.py | Forward per-sample last-turn-only decision into TrainSample. |
| torchspec/controller/inference_manager.py | Preserve input metadata when building InferenceOutput. |
| torchspec/config/train_config.py | Change last_turn_loss_only default to "auto" (bool or "auto"). |
| tests/test_loss_mask.py | Add unit tests for last_turn_only behavior. |
| tests/test_loss_mask_cross_validation.py | Add template-wide cross-validation of dynamic mask vs parser mask. |
| tests/test_kimi_k25_integration.py | Update integration test to pass loss_mask (collator no longer unpacks). |
| tests/test_dynamic_last_turn_loss.py | Add tests for has_thinking_content() and resolve_loss_mask() behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The pre-tokenization filtering that estimated token budgets based on image count and text length was a rough heuristic. Remove it along with the related cache params.
The last_turn_loss_only flag was silently ignored when defer_tokenization was enabled (dynamic_loss_mask path). Thread the flag through compute_assistant_loss_mask → DataCollatorWithPadding → Trainer so it takes effect for deferred-tokenization workloads.