diff --git a/tests/test_dynamic_last_turn_loss.py b/tests/test_dynamic_last_turn_loss.py
new file mode 100644
index 0000000..4767f42
--- /dev/null
+++ b/tests/test_dynamic_last_turn_loss.py
@@ -0,0 +1,259 @@
+"""Tests for dynamic per-sample last_turn_loss_only based on thinking content detection."""
+
+import torch
+
+from torchspec.data.parse import has_thinking_content
+from torchspec.data.utils import DataCollatorWithPadding, resolve_loss_mask
+
+# ── has_thinking_content detection ────────────────────────────────────
+
+
+class TestHasThinkingContent:
+ def test_think_tag_with_content(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "reasoning hereHello!"},
+ ]
+ assert has_thinking_content(conv) is True
+
+ def test_empty_think_tag(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!"},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_think_tag_whitespace_only(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": " \n\t Hello!"},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_think_tag_with_single_char(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": " xanswer"},
+ ]
+ assert has_thinking_content(conv) is True
+
+ def test_no_think_tag(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello there!"},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_thinking_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "thinking": "some reasoning"},
+ ]
+ assert has_thinking_content(conv) is True
+
+ def test_thinking_content_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "thinking_content": "reasoning"},
+ ]
+ assert has_thinking_content(conv) is True
+
+ def test_empty_thinking_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "thinking": ""},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_none_thinking_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "thinking": None},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_user_message_with_think_tag_ignored(self):
+ conv = [
+ {"role": "user", "content": "user put this hereHi"},
+ {"role": "assistant", "content": "Hello!"},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_multi_turn_thinking_in_earlier_turn(self):
+ conv = [
+ {"role": "user", "content": "Q1"},
+ {"role": "assistant", "content": "thoughtA1"},
+ {"role": "user", "content": "Q2"},
+ {"role": "assistant", "content": "A2"},
+ ]
+ assert has_thinking_content(conv) is True
+
+ def test_multi_turn_no_thinking(self):
+ conv = [
+ {"role": "user", "content": "Q1"},
+ {"role": "assistant", "content": "A1"},
+ {"role": "user", "content": "Q2"},
+ {"role": "assistant", "content": "A2"},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_empty_conversation(self):
+ assert has_thinking_content([]) is False
+
+ def test_non_dict_messages(self):
+ assert has_thinking_content(["not a dict"]) is False
+
+ def test_multimodal_content_no_thinking(self):
+ conv = [
+ {"role": "user", "content": [{"type": "text", "text": "Describe"}]},
+ {"role": "assistant", "content": "A cat."},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_system_message_ignored(self):
+ conv = [
+ {"role": "system", "content": "system thinking"},
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!"},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_reasoning_content_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "reasoning_content": "Let me think..."},
+ ]
+ assert has_thinking_content(conv) is True
+
+ def test_reasoning_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "reasoning": "step by step"},
+ ]
+ assert has_thinking_content(conv) is True
+
+ def test_empty_reasoning_content_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "reasoning_content": ""},
+ ]
+ assert has_thinking_content(conv) is False
+
+ def test_none_reasoning_content_field(self):
+ conv = [
+ {"role": "user", "content": "Hi"},
+ {"role": "assistant", "content": "Hello!", "reasoning_content": None},
+ ]
+ assert has_thinking_content(conv) is False
+
+
+# ── Collator reads pre-computed loss_mask from MooncakeDataset ────────
+
+
+class TestCollatorPrecomputedMask:
+ def _make_collator(self):
+ return DataCollatorWithPadding()
+
+ def test_precomputed_loss_mask_used(self):
+ collator = self._make_collator()
+ mask_tensor = torch.tensor([[0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]], dtype=torch.long)
+ item = {
+ "input_ids": torch.zeros(1, 12, dtype=torch.long),
+ "loss_mask": mask_tensor,
+ }
+ result = collator._get_loss_mask(item)
+ assert torch.equal(result, mask_tensor)
+
+ def test_missing_loss_mask_raises(self):
+ import pytest
+
+ collator = self._make_collator()
+ item = {"input_ids": torch.zeros(1, 4, dtype=torch.long)}
+ with pytest.raises(KeyError):
+ collator._get_loss_mask(item)
+
+
+# ── resolve_loss_mask (single source of truth in data/utils.py) ───────
+
+_HEADER = [10, 20]
+_END = [30, 40]
+
+
+class TestResolveLossMask:
+ """Test that resolve_loss_mask computes, stores, and skips correctly."""
+
+ def test_packed_loss_mask_nonzero(self):
+ data = {"packed_loss_mask": "2,3,2"}
+ mask = resolve_loss_mask(data)
+ assert mask is not None
+ assert mask.tolist() == [0, 0, 1, 1, 1, 0, 0]
+ assert "loss_mask" in data
+
+ def test_packed_loss_mask_zero(self):
+ data = {"packed_loss_mask": "10"}
+ assert resolve_loss_mask(data) is None
+
+ def test_dynamic_mask_nonzero(self):
+ data = {"input_ids": torch.tensor([10, 20, 1, 2, 30, 40], dtype=torch.long)}
+ mask = resolve_loss_mask(
+ data,
+ dynamic_loss_mask=True,
+ assistant_header_ids=_HEADER,
+ end_token_ids=_END,
+ )
+ assert mask is not None
+ assert mask.tolist() == [0, 0, 1, 1, 0, 0]
+ assert "loss_mask" in data
+
+ def test_dynamic_mask_zero(self):
+ data = {"input_ids": torch.tensor([5, 6, 7, 8], dtype=torch.long)}
+ assert (
+ resolve_loss_mask(
+ data,
+ dynamic_loss_mask=True,
+ assistant_header_ids=_HEADER,
+ end_token_ids=_END,
+ )
+ is None
+ )
+
+ def test_dynamic_mask_per_sample_last_turn_only(self):
+ ids = [10, 20, 1, 30, 40, 10, 20, 2, 30, 40]
+ data = {
+ "input_ids": torch.tensor(ids, dtype=torch.long),
+ "last_turn_loss_only": True,
+ }
+ mask = resolve_loss_mask(
+ data,
+ dynamic_loss_mask=True,
+ assistant_header_ids=_HEADER,
+ end_token_ids=_END,
+ last_turn_loss_only=False,
+ )
+ assert mask is not None
+ assert mask.tolist() == [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]
+ assert "loss_mask" in data
+
+ def test_dynamic_mask_per_sample_all_turns(self):
+ ids = [10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40]
+ data = {
+ "input_ids": torch.tensor(ids, dtype=torch.long),
+ "last_turn_loss_only": False,
+ }
+ mask = resolve_loss_mask(
+ data,
+ dynamic_loss_mask=True,
+ assistant_header_ids=_HEADER,
+ end_token_ids=_END,
+ last_turn_loss_only=True,
+ )
+ assert mask is not None
+ assert mask.tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]
+
+ def test_no_params_defaults_nonzero(self):
+ data = {"input_ids": torch.tensor([1, 2, 3], dtype=torch.long)}
+ assert resolve_loss_mask(data) is not None
+
+ def test_empty_packed_loss_mask(self):
+ data = {"packed_loss_mask": ""}
+ assert resolve_loss_mask(data) is None
diff --git a/tests/test_kimi_k25_integration.py b/tests/test_kimi_k25_integration.py
index 144574f..19ef396 100644
--- a/tests/test_kimi_k25_integration.py
+++ b/tests/test_kimi_k25_integration.py
@@ -280,7 +280,7 @@ def test_preprocess_to_collator(self, mock_tokenizer, kimi_template, sample_conv
{
"input_ids": result["input_ids"][0],
"attention_mask": result["attention_mask"][0],
- "packed_loss_mask": result["packed_loss_mask"][0],
+ "loss_mask": unpack_loss_mask(result["packed_loss_mask"][0])[None, :],
}
]
diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py
index 01460dc..1192c3f 100644
--- a/tests/test_loss_mask.py
+++ b/tests/test_loss_mask.py
@@ -217,3 +217,49 @@ def test_200k_trailing_open_turn(self):
result = compute_assistant_loss_mask(ids_open, self.H, self.E)
assert torch.equal(result, ref)
assert result[-100:].sum() == 100
+
+
+# ── last_turn_only ───────────────────────────────────────────────────
+
+
+class TestLastTurnOnly:
+ H = [10, 20]
+ E = [30, 40]
+
+ def test_single_turn_unchanged(self):
+ ids = torch.tensor([10, 20, 1, 2, 3, 30, 40], dtype=torch.long)
+ result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True)
+ assert result.tolist() == [0, 0, 1, 1, 1, 0, 0]
+
+ def test_two_turns_keeps_last(self):
+ ids = torch.tensor([10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40], dtype=torch.long)
+ result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True)
+ assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]
+
+ def test_three_turns_keeps_last(self):
+ ids = torch.tensor(
+ [10, 20, 1, 30, 40, 10, 20, 2, 30, 40, 10, 20, 3, 4, 5, 30, 40], dtype=torch.long
+ )
+ result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True)
+ assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0]
+
+ def test_truncated_last_turn(self):
+ """Sequence cut mid-response — last turn has no end token."""
+ ids = torch.tensor([10, 20, 1, 30, 40, 10, 20, 2, 3], dtype=torch.long)
+ result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True)
+ assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 1, 1]
+
+ def test_no_assistant_turns(self):
+ ids = torch.tensor([5, 6, 7, 8], dtype=torch.long)
+ result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True)
+ assert result.tolist() == [0, 0, 0, 0]
+
+ def test_empty_input(self):
+ ids = torch.tensor([], dtype=torch.long)
+ result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True)
+ assert result.tolist() == []
+
+ def test_false_flag_returns_all_turns(self):
+ ids = torch.tensor([10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40], dtype=torch.long)
+ result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=False)
+ assert result.tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]
diff --git a/tests/test_loss_mask_cross_validation.py b/tests/test_loss_mask_cross_validation.py
new file mode 100644
index 0000000..65a37af
--- /dev/null
+++ b/tests/test_loss_mask_cross_validation.py
@@ -0,0 +1,185 @@
+"""Cross-validation: compute_assistant_loss_mask must match parser.parse for every template.
+
+CI gate — auto-parametrizes over all registered templates so adding a new
+template without verifying loss mask correctness will fail the test.
+
+Requires tokenizer downloads (~10MB each). Templates whose tokenizer is
+unavailable are skipped.
+"""
+
+import pytest
+import torch
+
+from torchspec.data.parse import create_parser
+from torchspec.data.template import TEMPLATE_REGISTRY
+from torchspec.models.ops.loss_mask import compute_assistant_loss_mask
+
+_END_TOKEN_BPE_XFAIL: set[str] = set()
+
+MESSAGES_MULTI_TURN = [
+ {"role": "user", "content": "What is 2+2?"},
+ {"role": "assistant", "content": "The answer is 4."},
+ {"role": "user", "content": "And 3+3?"},
+ {"role": "assistant", "content": "The answer is 6."},
+]
+
+MESSAGES_LEADING_NEWLINES = [
+ {"role": "user", "content": "What is 2+2?"},
+ {"role": "assistant", "content": "\n\nThe answer is 4."},
+]
+
+_tokenizer_cache: dict = {}
+
+
+def _testable_templates():
+ """Templates that support dynamic loss mask (need both header and end token)."""
+ for name in TEMPLATE_REGISTRY.get_all_template_names():
+ template = TEMPLATE_REGISTRY.get(name)
+ if template.assistant_header and template.end_of_turn_token:
+ yield name
+
+
+def _get_tokenizer(template_name):
+ template = TEMPLATE_REGISTRY.get(template_name)
+ model_path = template.reference_model
+ if model_path is None:
+ pytest.skip(f"No reference_model for template '{template_name}'")
+
+ if model_path in _tokenizer_cache:
+ return _tokenizer_cache[model_path]
+
+ from transformers import AutoTokenizer
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ except Exception as e:
+ pytest.skip(f"Tokenizer unavailable for {model_path}: {e}")
+
+ _tokenizer_cache[model_path] = tokenizer
+ return tokenizer
+
+
+def _get_header_ids_and_skip(tokenizer, template):
+ """Replicate get_assistant_token_ids() logic for the test."""
+ full_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False)
+ stripped = template.assistant_header.rstrip("\n")
+ stripped_ids = tokenizer.encode(stripped, add_special_tokens=False)
+ skip_after = len(full_ids) - len(stripped_ids)
+ end_ids = tokenizer.encode(template.end_of_turn_token, add_special_tokens=False)
+ return stripped_ids, end_ids, skip_after
+
+
+def _first_diff(a, b):
+ for i, (x, y) in enumerate(zip(a, b)):
+ if x != y:
+ return i
+ if len(a) != len(b):
+ return min(len(a), len(b))
+ return None
+
+
+def _strip_end_tokens_from_mask(mask_list, ids_list, end_ids):
+ """Zero out end-of-turn token positions in a mask.
+
+ The parser regex captures content *including* the end_of_turn_token,
+ but compute_assistant_loss_mask marks content *excluding* it.
+ This normalizes the parser mask to match the dynamic mask convention.
+ """
+ result = list(mask_list)
+ end_len = len(end_ids)
+ i = 0
+ while i <= len(ids_list) - end_len:
+ if ids_list[i : i + end_len] == end_ids:
+ for k in range(end_len):
+ result[i + k] = 0
+ i += end_len
+ else:
+ i += 1
+ return result
+
+
+def _cross_validate(template_name, messages, last_turn_only=False):
+ """Core cross-validation: parser ground truth vs dynamic mask."""
+ template = TEMPLATE_REGISTRY.get(template_name)
+ tokenizer = _get_tokenizer(template_name)
+ parser = create_parser(tokenizer, template)
+ header_ids, end_ids, skip_after = _get_header_ids_and_skip(tokenizer, template)
+
+ formatted = parser.format(messages, add_generation_prompt=False, expand_media_tokens=False)
+
+ gt_ids, gt_mask = parser.parse(
+ formatted, max_length=200000, preformatted=True, last_turn_only=last_turn_only
+ )
+ gt_ids_list = gt_ids.squeeze().tolist()
+ gt_mask_list = gt_mask.squeeze().tolist()
+
+ if sum(gt_mask_list) == 0:
+ pytest.skip(
+ f"Parser produced all-zero mask for '{template_name}' — "
+ f"template header may not match tokenizer's chat template"
+ )
+
+ engine_ids = tokenizer.encode(formatted, add_special_tokens=False)
+
+ assert gt_ids_list == engine_ids, (
+ f"[{template_name}] Tokenization mismatch: "
+ f"parser produced {len(gt_ids_list)} tokens, "
+ f"tokenizer.encode produced {len(engine_ids)} tokens"
+ )
+
+ # Parser includes end_of_turn_token in mask; dynamic mask does not.
+ normalized_gt = _strip_end_tokens_from_mask(gt_mask_list, engine_ids, end_ids)
+
+ assert sum(normalized_gt) > 0, (
+ f"[{template_name}] After stripping end tokens, parser mask is all zeros — "
+ f"this likely means the parser only matched end tokens as content"
+ )
+
+ dyn_mask = compute_assistant_loss_mask(
+ torch.tensor(engine_ids),
+ header_ids,
+ end_ids,
+ last_turn_only=last_turn_only,
+ skip_after_header=skip_after,
+ )
+
+ diff_idx = _first_diff(normalized_gt, dyn_mask.tolist())
+ assert normalized_gt == dyn_mask.tolist(), (
+ f"[{template_name}] Loss mask mismatch "
+ f"(last_turn_only={last_turn_only}):\n"
+ f" parser mask 1s (normalized): {sum(normalized_gt)}\n"
+ f" dynamic mask 1s: {dyn_mask.sum().item()}\n"
+ f" first diff at token {diff_idx}"
+ )
+
+
+# ── Parametrized tests ───────────────────────────────────────────────
+
+
+def _maybe_xfail(template_name):
+ if template_name in _END_TOKEN_BPE_XFAIL:
+ pytest.xfail(
+ f"'{template_name}' has a known end-token BPE merge issue — "
+ f"use precomputed masks (defer_tokenization=False) for this template"
+ )
+
+
+@pytest.mark.parametrize("template_name", list(_testable_templates()))
+def test_multi_turn(template_name):
+ """Dynamic mask matches parser on a standard multi-turn conversation."""
+ _maybe_xfail(template_name)
+ _cross_validate(template_name, MESSAGES_MULTI_TURN)
+
+
+@pytest.mark.parametrize("template_name", list(_testable_templates()))
+def test_multi_turn_last_turn_only(template_name):
+ """Dynamic mask matches parser with last_turn_only=True."""
+ _maybe_xfail(template_name)
+ _cross_validate(template_name, MESSAGES_MULTI_TURN, last_turn_only=True)
+
+
+@pytest.mark.parametrize("template_name", list(_testable_templates()))
+def test_leading_newlines(template_name):
+ """Dynamic mask matches parser when content starts with newlines (BPE merge edge case)."""
+ _maybe_xfail(template_name)
+ _cross_validate(template_name, MESSAGES_LEADING_NEWLINES)
diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py
index 20da08b..b84df92 100644
--- a/torchspec/config/train_config.py
+++ b/torchspec/config/train_config.py
@@ -39,7 +39,7 @@ class DatasetConfig:
eval_interval: int = 50
eval_micro_batch_size: Optional[int] = None
eval_prompt_key: Optional[str] = None
- last_turn_loss_only: bool = False
+ last_turn_loss_only: Any = "auto" # bool or "auto"
prompt_key: str = "conversations"
train_data_path: str = ""
diff --git a/torchspec/controller/inference_manager.py b/torchspec/controller/inference_manager.py
index 4e032dd..5178e7b 100644
--- a/torchspec/controller/inference_manager.py
+++ b/torchspec/controller/inference_manager.py
@@ -481,6 +481,7 @@ def _parse_engine_output(self, entry: InferenceInput, output: dict) -> Inference
tensor_shapes=output.get("tensor_shapes", {}),
tensor_dtypes=output.get("tensor_dtypes", {}),
packed_loss_mask=output.get("packed_loss_mask", entry.packed_loss_mask),
+ metadata=entry.metadata,
)
async def _forward_results(self, results: list[tuple[InferenceInput, Any | Exception]]) -> int:
diff --git a/torchspec/controller/training_controller.py b/torchspec/controller/training_controller.py
index c25f959..d3e1e09 100644
--- a/torchspec/controller/training_controller.py
+++ b/torchspec/controller/training_controller.py
@@ -446,12 +446,15 @@ def _dispatch_to_queues(
partitioned = self._partition_results(batch_results)
for dp_rank, results in enumerate(partitioned):
for result in results:
+ metadata = getattr(result, "metadata", {}) or {}
+ last_turn_loss_only = metadata.get("has_thinking")
queues[dp_rank].put(
TrainSample(
mooncake_key=result.mooncake_key,
tensor_shapes=result.tensor_shapes,
tensor_dtypes=result.tensor_dtypes,
packed_loss_mask=result.packed_loss_mask,
+ last_turn_loss_only=last_turn_loss_only,
)
)
diff --git a/torchspec/data/dataset.py b/torchspec/data/dataset.py
index 8ee6a07..264fbc0 100644
--- a/torchspec/data/dataset.py
+++ b/torchspec/data/dataset.py
@@ -26,7 +26,8 @@
import torch
from tqdm import tqdm
-from torchspec.data.preprocessing import _normalize_conversation
+from torchspec.data.parse import create_parser, has_thinking_content
+from torchspec.data.preprocessing import _normalize_conversation, preprocess_conversations
from torchspec.data.template import TEMPLATE_REGISTRY
from torchspec.data.utils import (
estimate_row_count,
@@ -35,6 +36,7 @@
load_hf_dataset,
)
from torchspec.utils.logging import logger
+from torchspec.utils.processing import load_tokenizer
_logging.getLogger("transformers_modules").setLevel(_logging.ERROR)
@@ -45,9 +47,6 @@ def _init_tokenize_worker(
tokenizer_path, trust_remote_code, chat_template_name, last_turn_loss_only=False
):
"""Initializer for each worker process — loads tokenizer once."""
- from torchspec.data.preprocessing import preprocess_conversations
- from torchspec.utils.processing import load_tokenizer
-
_logging.getLogger("transformers_modules").setLevel(_logging.ERROR)
_worker_state["tokenizer"] = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code)
_worker_state["template"] = TEMPLATE_REGISTRY.get(chat_template_name)
@@ -55,6 +54,13 @@ def _init_tokenize_worker(
_worker_state["last_turn_loss_only"] = last_turn_loss_only
+def _resolve_last_turn_loss_only(messages):
+ ltlo = _worker_state.get("last_turn_loss_only", False)
+ if ltlo == "auto":
+ return has_thinking_content(messages)
+ return bool(ltlo)
+
+
def _tokenize_single(args):
"""Worker function — tokenize one sample."""
messages, max_length, train_with_decode = args
@@ -68,7 +74,7 @@ def _tokenize_single(args):
use_packed_loss_mask=True,
add_generation_prompt=train_with_decode,
return_formatted_text=True,
- last_turn_loss_only=_worker_state.get("last_turn_loss_only", False),
+ last_turn_loss_only=_resolve_last_turn_loss_only(messages),
)
if not processed["input_ids"]:
return None
@@ -82,14 +88,14 @@ def _tokenize_single(args):
}
-def _init_format_worker(tokenizer_path, trust_remote_code, chat_template_name):
- from torchspec.data.parse import create_parser
- from torchspec.utils.processing import load_tokenizer
-
+def _init_format_worker(
+ tokenizer_path, trust_remote_code, chat_template_name, last_turn_loss_only=False
+):
_logging.getLogger("transformers_modules").setLevel(_logging.ERROR)
tokenizer = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code)
_worker_state["template"] = TEMPLATE_REGISTRY.get(chat_template_name)
_worker_state["parser"] = create_parser(tokenizer, _worker_state["template"])
+ _worker_state["last_turn_loss_only"] = last_turn_loss_only
def _format_single(args):
@@ -98,11 +104,20 @@ def _format_single(args):
"""
messages, _, train_with_decode = args
messages = _normalize_conversation(messages)
+
+ result = {}
+ ltlo = _worker_state.get("last_turn_loss_only", False)
+ if ltlo == "auto":
+ result["has_thinking"] = has_thinking_content(messages)
+
parser = _worker_state["parser"]
- formatted = parser.format(messages, add_generation_prompt=train_with_decode)
+ formatted = parser.format(
+ messages, add_generation_prompt=train_with_decode, expand_media_tokens=False
+ )
if not formatted:
return None
- return {"formatted_prompt": formatted}
+ result["formatted_prompt"] = formatted
+ return result
def load_conversation_dataset(args):
@@ -185,16 +200,17 @@ def load_conversation_dataset(args):
# Pass 2: process in parallel
work_items = [(messages, max_length, train_with_decode) for _, messages, _ in raw_samples]
+ last_turn_loss_only = getattr(args, "last_turn_loss_only", False)
if defer_tokenization:
worker_init = _init_format_worker
- worker_initargs = (args.target_model_path, True, chat_template_name)
+ worker_initargs = (args.target_model_path, True, chat_template_name, last_turn_loss_only)
worker_fn = _format_single
desc = "Formatting dataset"
else:
- last_turn_loss_only = getattr(args, "last_turn_loss_only", False)
if last_turn_loss_only:
logger.info(
- "last_turn_loss_only=True: loss mask will only cover the last assistant turn"
+ f"last_turn_loss_only={last_turn_loss_only}: "
+ "loss mask will only cover the last assistant turn"
)
worker_init = _init_tokenize_worker
worker_initargs = (args.target_model_path, True, chat_template_name, last_turn_loss_only)
@@ -221,9 +237,13 @@ def load_conversation_dataset(args):
if result is None:
skipped += 1
continue
+ metadata = {}
+ if "has_thinking" in result:
+ metadata["has_thinking"] = result["has_thinking"]
+
entry = {
"data_id": data_id,
- "metadata": {},
+ "metadata": metadata,
"multimodal_inputs": multimodal_inputs,
"formatted_prompt": result["formatted_prompt"],
}
diff --git a/torchspec/data/parse.py b/torchspec/data/parse.py
index 0e13875..f91e0d0 100644
--- a/torchspec/data/parse.py
+++ b/torchspec/data/parse.py
@@ -33,7 +33,40 @@
Conversation = List[Dict[str, Any]]
-__all__ = ["GeneralParser", "HarmonyParser", "KimiK25Parser", "create_parser"]
+__all__ = [
+ "GeneralParser",
+ "HarmonyParser",
+ "KimiK25Parser",
+ "create_parser",
+ "has_thinking_content",
+]
+
+_HAS_THINKING_RE = re.compile(r"(?!\s*)")
+
+
+def has_thinking_content(conversation: list) -> bool:
+ """Detect whether any assistant message contains real thinking content.
+
+ Checks for non-empty blocks in message content and for
+ separate thinking/thinking_content or reasoning/reasoning_content
+ fields on the message dict. Must be called on the raw conversation
+ BEFORE formatting, since formatters (e.g. KimiK25Parser) inject
+ empty tags.
+ """
+ for msg in conversation:
+ if not isinstance(msg, dict) or msg.get("role") != "assistant":
+ continue
+ content = msg.get("content", "")
+ if isinstance(content, str) and _HAS_THINKING_RE.search(content):
+ return True
+ if (
+ msg.get("thinking")
+ or msg.get("thinking_content")
+ or msg.get("reasoning")
+ or msg.get("reasoning_content")
+ ):
+ return True
+ return False
class Parser(ABC):
diff --git a/torchspec/data/preprocessing.py b/torchspec/data/preprocessing.py
index b0f8781..460e66f 100644
--- a/torchspec/data/preprocessing.py
+++ b/torchspec/data/preprocessing.py
@@ -76,7 +76,12 @@ def _normalize_conversation(conversation: Conversation) -> Conversation:
normalized = []
for msg in conversation:
role = ROLE_MAPPING.get(msg["from"], msg["from"])
- normalized.append({"role": role, "content": msg["value"]})
+ entry = {"role": role, "content": msg["value"]}
+ for field in ("thinking", "thinking_content", "reasoning_content", "reasoning"):
+ if msg.get(field):
+ entry["reasoning_content"] = msg[field]
+ break
+ normalized.append(entry)
return normalized
return conversation
diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py
index 4822f6f..f7a220b 100644
--- a/torchspec/data/utils.py
+++ b/torchspec/data/utils.py
@@ -47,16 +47,8 @@ def is_local_data_path(path: str, base_dir: str | None = None) -> bool:
class DataCollatorWithPadding:
- def __init__(
- self,
- assistant_header_ids: Optional[List[int]] = None,
- end_token_ids: Optional[List[int]] = None,
- dynamic_loss_mask: bool = False,
- ):
+ def __init__(self):
self.sp_degree = 1
- self.assistant_header_ids = assistant_header_ids
- self.end_token_ids = end_token_ids
- self.dynamic_loss_mask = dynamic_loss_mask
def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor:
B, n, S = intensors.shape
@@ -71,33 +63,14 @@ def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor:
return outtensors
def _get_loss_mask(self, item: Dict[str, Any]) -> torch.Tensor:
- """Derive loss_mask for a single sample.
+ """Read the materialized loss_mask tensor from the item.
- Priority:
- 1. dynamic_loss_mask flag → compute from input_ids token boundaries
- 2. packed_loss_mask string → unpack RLE-encoded mask
- 3. fallback → all-ones mask
+ Callers (e.g. MooncakeDataset) are responsible for computing and
+ attaching loss_mask before items reach the collator.
"""
- if self.dynamic_loss_mask:
- if self.assistant_header_ids is None or self.end_token_ids is None:
- raise ValueError(
- "dynamic_loss_mask requires assistant_header_ids and "
- "end_token_ids to be set on the collator"
- )
- # TCP's input_ids are on CPU, so we can use input_ids_cpu directly.
- input_ids = item.get("input_ids_cpu", item["input_ids"])
- if input_ids.dim() == 2:
- input_ids = input_ids.squeeze(0)
- mask = compute_assistant_loss_mask(
- input_ids, self.assistant_header_ids, self.end_token_ids
- )
- # Copy back to GPU.
- return mask[None, :].to(item["input_ids"].device)
-
- if "packed_loss_mask" not in item or item["packed_loss_mask"] is None:
- seq_len = item["input_ids"].shape[-1]
- return torch.ones(1, seq_len, dtype=torch.long, device=item["input_ids"].device)
- return unpack_loss_mask(item["packed_loss_mask"])[None, :]
+ if "loss_mask" in item and isinstance(item["loss_mask"], torch.Tensor):
+ return item["loss_mask"]
+ raise KeyError(f"loss_mask not found in item: {item}")
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
max_length = max(item["input_ids"].shape[1] for item in features)
@@ -219,6 +192,52 @@ def unpack_loss_mask(packed: Union[List[int], str]) -> torch.Tensor:
return loss_mask
+def resolve_loss_mask(
+ data: Dict[str, Any],
+ *,
+ dynamic_loss_mask: bool = False,
+ assistant_header_ids: Optional[List[int]] = None,
+ end_token_ids: Optional[List[int]] = None,
+ last_turn_loss_only: bool = False,
+ skip_after_header: int = 0,
+) -> torch.Tensor | None:
+ """
+ Two strategies, tried in order:
+ 1. ``packed_loss_mask`` key present → unpack it.
+ 2. ``dynamic_loss_mask`` enabled with valid header/end ids → compute from
+ ``input_ids`` via :func:`compute_assistant_loss_mask`.
+ """
+ packed = data.get("packed_loss_mask")
+ if packed is not None:
+ mask = unpack_loss_mask(packed)
+ if not mask.any():
+ return None
+ data["loss_mask"] = mask
+ return mask
+
+ if dynamic_loss_mask and assistant_header_ids and end_token_ids:
+ input_ids = data.get("input_ids")
+ if input_ids is None:
+ return None
+ if input_ids.dim() == 2:
+ input_ids = input_ids.squeeze(0)
+ per_sample = data.get("last_turn_loss_only")
+ last_turn_only = per_sample if per_sample is not None else last_turn_loss_only
+ mask = compute_assistant_loss_mask(
+ input_ids,
+ assistant_header_ids,
+ end_token_ids,
+ last_turn_only=last_turn_only,
+ skip_after_header=skip_after_header,
+ )
+ if not mask.any():
+ return None
+ data["loss_mask"] = mask
+ return mask
+
+ return torch.ones(1)
+
+
def serialize_packed_loss_mask(packed: List[int]) -> str:
"""
Serialize packed loss_mask to a comma-separated string.
diff --git a/torchspec/models/ops/loss_mask.py b/torchspec/models/ops/loss_mask.py
index 7f5e78f..4da0abf 100644
--- a/torchspec/models/ops/loss_mask.py
+++ b/torchspec/models/ops/loss_mask.py
@@ -24,7 +24,7 @@
@numba.njit(cache=True)
-def _numba_loss_mask(ids, header, header_len, end, end_len, out):
+def _numba_loss_mask(ids, header, header_len, end, end_len, out, skip_after):
n = len(ids)
i = 0
while i <= n - header_len:
@@ -36,7 +36,7 @@ def _numba_loss_mask(ids, header, header_len, end, end_len, out):
if not match:
i += 1
continue
- j = i + header_len
+ j = i + header_len + skip_after
found_end = False
while j <= n - end_len:
end_match = True
@@ -69,6 +69,8 @@ def compute_assistant_loss_mask(
input_ids: torch.Tensor,
assistant_header_ids: list[int],
end_token_ids: list[int],
+ last_turn_only: bool = False,
+ skip_after_header: int = 0,
) -> torch.Tensor:
"""Compute loss mask where 1s mark assistant content tokens only.
@@ -80,6 +82,10 @@ def compute_assistant_loss_mask(
input_ids: 1-D tensor of token IDs (CPU or CUDA).
assistant_header_ids: Token ID sequence marking the start of assistant content.
end_token_ids: Token ID sequence marking the end of assistant content.
+ last_turn_only: If True, only the last assistant turn is marked.
+ skip_after_header: Number of tokens to skip after header match before
+ marking content. Use 1 when the header excludes a trailing newline
+ that BPE may merge with subsequent content.
Returns:
1-D long tensor on the same device as input_ids, with 1s for assistant
@@ -94,6 +100,14 @@ def compute_assistant_loss_mask(
header_np = np.array(assistant_header_ids, dtype=np.int64)
end_np = np.array(end_token_ids, dtype=np.int64)
out = np.zeros(len(ids_np), dtype=np.int64)
- _numba_loss_mask(ids_np, header_np, len(header_np), end_np, len(end_np), out)
+ _numba_loss_mask(ids_np, header_np, len(header_np), end_np, len(end_np), out, skip_after_header)
+
+ if last_turn_only and out.any():
+ last_one = len(out) - 1 - np.argmax(out[::-1])
+ first_of_last = last_one
+ while first_of_last > 0 and out[first_of_last - 1] == 1:
+ first_of_last -= 1
+ out[:first_of_last] = 0
+
result = torch.from_numpy(out)
return result.to(device) if device.type != "cpu" else result
diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py
index 18d3f5e..76fa62c 100644
--- a/torchspec/training/data_fetcher.py
+++ b/torchspec/training/data_fetcher.py
@@ -32,6 +32,7 @@
from ray.util.queue import Queue as RayQueue
from torch.utils.data import DataLoader, IterableDataset
+from torchspec.data.utils import resolve_loss_mask
from torchspec.utils.logging import logger
@@ -41,6 +42,7 @@ class TrainSample:
tensor_shapes: Dict[str, Tuple[int, ...]]
tensor_dtypes: Optional[Dict[str, torch.dtype]] = None
packed_loss_mask: Optional[str] = None
+ last_turn_loss_only: Optional[bool] = None
class MooncakeDataset(IterableDataset):
@@ -57,20 +59,22 @@ def __init__(
device: torch.device,
prefetch_factor: int = 2,
timeout: Optional[float] = None,
+ assistant_header_ids: Optional[List[int]] = None,
+ end_token_ids: Optional[List[int]] = None,
+ dynamic_loss_mask: bool = False,
+ last_turn_loss_only: bool = False,
+ skip_after_header: int = 0,
):
- """
- Args:
- ray_queue: Ray Queue to receive TrainSample from controller.
- mooncake_store: Mooncake store client for loading tensors.
- device: Target device for tensors.
- prefetch_factor: Number of samples to prefetch in background thread.
- timeout: Timeout in seconds for waiting on queue. None means wait forever.
- """
self.ray_queue = ray_queue
self.mooncake_store = mooncake_store
self.device = device
self.prefetch_factor = prefetch_factor
self.timeout = timeout
+ self.assistant_header_ids = assistant_header_ids
+ self.end_token_ids = end_token_ids
+ self.dynamic_loss_mask = dynamic_loss_mask
+ self.last_turn_loss_only = last_turn_loss_only
+ self.skip_after_header = skip_after_header
def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]:
"""Load tensors from mooncake key into device memory."""
@@ -104,6 +108,8 @@ def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]:
result = tensors.to_tensor_dict()
if sample.packed_loss_mask is not None:
result["packed_loss_mask"] = sample.packed_loss_mask
+ if sample.last_turn_loss_only is not None:
+ result["last_turn_loss_only"] = sample.last_turn_loss_only
return result
def _cleanup_mooncake_data(self, sample: TrainSample) -> None:
@@ -118,12 +124,24 @@ def _cleanup_mooncake_data(self, sample: TrainSample) -> None:
has_target=has_target,
)
+ def _compute_loss_mask(self, data: Dict[str, Any]) -> torch.Tensor | None:
+ return resolve_loss_mask(
+ data,
+ dynamic_loss_mask=self.dynamic_loss_mask,
+ assistant_header_ids=self.assistant_header_ids,
+ end_token_ids=self.end_token_ids,
+ last_turn_loss_only=self.last_turn_loss_only,
+ skip_after_header=self.skip_after_header,
+ )
+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Iterate over samples synchronously.
Blocks waiting for each item from the queue and loads from mooncake.
+ Skips samples whose loss mask is all zeros to avoid wasted compute.
"""
yield_count = 0
+ skip_count = 0
while True:
logger.debug(f"__iter__: waiting for item from ray_queue (yield_count={yield_count})")
try:
@@ -138,6 +156,15 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
logger.debug(f"__iter__: got item, mooncake_key={item.mooncake_key}")
data = self._load_from_mooncake(item)
+
+ if self._compute_loss_mask(data) is None:
+ skip_count += 1
+ logger.warning(
+ f"Skipping sample with all-zero loss mask "
+ f"(mooncake_key={item.mooncake_key}, total_skipped={skip_count})"
+ )
+ continue
+
# Note: target is computed in the collator from last_hidden_states for sglang mode
# Add batch dimension if missing (sglang stores without batch dim)
@@ -174,6 +201,11 @@ def create_mooncake_dataloader(
batch_size: int = 1,
prefetch_factor: int = 2,
timeout: Optional[float] = None,
+ assistant_header_ids: Optional[List[int]] = None,
+ end_token_ids: Optional[List[int]] = None,
+ dynamic_loss_mask: bool = False,
+ last_turn_loss_only: bool = False,
+ skip_after_header: int = 0,
) -> DataLoader:
"""Create a DataLoader that fetches from mooncake via queue.
@@ -193,11 +225,26 @@ def create_mooncake_dataloader(
batch_size: Number of samples per batch (= per_dp_rank_batch_size).
prefetch_factor: Unused, kept for API compatibility.
timeout: Timeout in seconds for waiting on queue. None means wait forever.
+ assistant_header_ids: Token IDs for assistant header (for loss mask skip check).
+ end_token_ids: Token IDs for end of turn (for loss mask skip check).
+ dynamic_loss_mask: Whether loss mask is computed dynamically from input_ids.
+ last_turn_loss_only: Global fallback for last-turn-only loss masking.
Returns:
DataLoader instance.
"""
- dataset = MooncakeDataset(ray_queue, mooncake_store, device, prefetch_factor, timeout)
+ dataset = MooncakeDataset(
+ ray_queue,
+ mooncake_store,
+ device,
+ prefetch_factor,
+ timeout,
+ assistant_header_ids=assistant_header_ids,
+ end_token_ids=end_token_ids,
+ dynamic_loss_mask=dynamic_loss_mask,
+ last_turn_loss_only=last_turn_loss_only,
+ skip_after_header=skip_after_header,
+ )
return DataLoader(
dataset,
@@ -232,6 +279,11 @@ def __init__(
batch_size: int = 1,
prefetch_factor: int = 2,
timeout: Optional[float] = None,
+ assistant_header_ids: Optional[List[int]] = None,
+ end_token_ids: Optional[List[int]] = None,
+ dynamic_loss_mask: bool = False,
+ last_turn_loss_only: bool = False,
+ skip_after_header: int = 0,
):
self.batch_size = batch_size
self._dataloader = create_mooncake_dataloader(
@@ -242,6 +294,11 @@ def __init__(
batch_size=batch_size,
prefetch_factor=prefetch_factor,
timeout=timeout,
+ assistant_header_ids=assistant_header_ids,
+ end_token_ids=end_token_ids,
+ dynamic_loss_mask=dynamic_loss_mask,
+ last_turn_loss_only=last_turn_loss_only,
+ skip_after_header=skip_after_header,
)
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
diff --git a/torchspec/training/trainer.py b/torchspec/training/trainer.py
index 4358680..bd4aa71 100644
--- a/torchspec/training/trainer.py
+++ b/torchspec/training/trainer.py
@@ -78,7 +78,10 @@ def __init__(self, args: Namespace):
self.prof = TrainProfiler(args)
self.dynamic_loss_mask = getattr(args, "dynamic_loss_mask", False)
- self.assistant_header_ids, self.end_token_ids = get_assistant_token_ids(self.args)
+ self.last_turn_loss_only = getattr(args, "last_turn_loss_only", False)
+ self.assistant_header_ids, self.end_token_ids, self.skip_after_header = (
+ get_assistant_token_ids(self.args)
+ )
self.save_debug_train_data = getattr(args, "save_debug_train_data", None)
self.max_dump_steps = getattr(args, "max_dump_steps", 5)
@@ -155,11 +158,7 @@ def set_train_queue(
if mooncake_config is not None and self.mooncake_store is None:
self.init_mooncake_store(mooncake_config)
- collator = DataCollatorWithPadding(
- assistant_header_ids=self.assistant_header_ids,
- end_token_ids=self.end_token_ids,
- dynamic_loss_mask=self.dynamic_loss_mask,
- )
+ collator = DataCollatorWithPadding()
self.data_fetcher = MooncakeDataFetcher(
queue=self.train_queue,
@@ -167,6 +166,11 @@ def set_train_queue(
collator=collator,
device=torch.cuda.current_device(),
batch_size=per_dp_rank_batch_size,
+ assistant_header_ids=self.assistant_header_ids,
+ end_token_ids=self.end_token_ids,
+ dynamic_loss_mask=self.dynamic_loss_mask,
+ last_turn_loss_only=self.last_turn_loss_only,
+ skip_after_header=self.skip_after_header,
)
logger.info(
@@ -186,11 +190,7 @@ def set_eval_queue(
if mooncake_config is not None and self.mooncake_store is None:
self.init_mooncake_store(mooncake_config)
- collator = DataCollatorWithPadding(
- assistant_header_ids=self.assistant_header_ids,
- end_token_ids=self.end_token_ids,
- dynamic_loss_mask=self.dynamic_loss_mask,
- )
+ collator = DataCollatorWithPadding()
self._eval_data_fetcher = MooncakeDataFetcher(
queue=queue,
@@ -198,6 +198,11 @@ def set_eval_queue(
collator=collator,
device=torch.cuda.current_device(),
batch_size=per_dp_rank_batch_size,
+ assistant_header_ids=self.assistant_header_ids,
+ end_token_ids=self.end_token_ids,
+ dynamic_loss_mask=self.dynamic_loss_mask,
+ last_turn_loss_only=self.last_turn_loss_only,
+ skip_after_header=self.skip_after_header,
)
self._eval_collator = collator
self._eval_cache: list[dict] = []
diff --git a/torchspec/utils/processing.py b/torchspec/utils/processing.py
index a06e315..974b63c 100644
--- a/torchspec/utils/processing.py
+++ b/torchspec/utils/processing.py
@@ -28,22 +28,41 @@ def load_tokenizer(name_or_path: str, **kwargs):
return AutoTokenizer.from_pretrained(name_or_path, **kwargs)
-def get_assistant_token_ids(args) -> tuple[list[int] | None, list[int] | None]:
- """Derive assistant_header_ids and end_token_ids from chat_template config."""
+def get_assistant_token_ids(
+ args,
+) -> tuple[list[int] | None, list[int] | None, int]:
+ """Derive assistant_header_ids, end_token_ids, and skip_after_header.
+
+ Returns:
+ (header_ids, end_ids, skip_after_header) where skip_after_header is the
+ number of trailing-newline tokens stripped from the header to avoid BPE
+ merge issues. Pass this to ``compute_assistant_loss_mask`` so it skips
+ the formatting newline after the role name.
+ """
from torchspec.data.template import TEMPLATE_REGISTRY
chat_template_name = getattr(args, "chat_template", None)
if not chat_template_name:
- return None, None
+ return None, None, 0
template = TEMPLATE_REGISTRY.get(chat_template_name)
if not template.assistant_header or not template.end_of_turn_token:
- return None, None
+ return None, None, 0
tokenizer = load_tokenizer(args.target_model_path, trust_remote_code=True)
- assistant_header_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False)
+
+ full_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False)
+ header_text = template.assistant_header.rstrip("\n")
+ stripped_ids = tokenizer.encode(header_text, add_special_tokens=False)
+ # BPE tokenizers may merge the header's trailing \n with content-leading
+ # newlines into a single multi-char token, breaking subsequence matching.
+ # We tokenize without the trailing \n and tell the mask function to skip
+ # the newline token(s) that follow.
+ skip_after_header = len(full_ids) - len(stripped_ids)
+
end_token_ids = tokenizer.encode(template.end_of_turn_token, add_special_tokens=False)
logger.info(
- f"Assistant loss mask token IDs: header={assistant_header_ids}, end={end_token_ids}"
+ f"Assistant loss mask token IDs: header={stripped_ids}, end={end_token_ids}, "
+ f"skip_after_header={skip_after_header}"
)
- return assistant_header_ids, end_token_ids
+ return stripped_ids, end_token_ids, skip_after_header
diff --git a/torchspec/utils/types.py b/torchspec/utils/types.py
index ee4b329..9d9fa89 100644
--- a/torchspec/utils/types.py
+++ b/torchspec/utils/types.py
@@ -49,3 +49,4 @@ class InferenceOutput:
tensor_shapes: dict[str, tuple[int, ...]]
tensor_dtypes: dict[str, torch.dtype] | None = None
packed_loss_mask: str | None = None
+ metadata: dict = field(default_factory=dict)