diff --git a/tests/test_dynamic_last_turn_loss.py b/tests/test_dynamic_last_turn_loss.py new file mode 100644 index 0000000..4767f42 --- /dev/null +++ b/tests/test_dynamic_last_turn_loss.py @@ -0,0 +1,259 @@ +"""Tests for dynamic per-sample last_turn_loss_only based on thinking content detection.""" + +import torch + +from torchspec.data.parse import has_thinking_content +from torchspec.data.utils import DataCollatorWithPadding, resolve_loss_mask + +# ── has_thinking_content detection ──────────────────────────────────── + + +class TestHasThinkingContent: + def test_think_tag_with_content(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "reasoning hereHello!"}, + ] + assert has_thinking_content(conv) is True + + def test_empty_think_tag(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + assert has_thinking_content(conv) is False + + def test_think_tag_whitespace_only(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": " \n\t Hello!"}, + ] + assert has_thinking_content(conv) is False + + def test_think_tag_with_single_char(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": " xanswer"}, + ] + assert has_thinking_content(conv) is True + + def test_no_think_tag(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello there!"}, + ] + assert has_thinking_content(conv) is False + + def test_thinking_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking": "some reasoning"}, + ] + assert has_thinking_content(conv) is True + + def test_thinking_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking_content": "reasoning"}, + ] + assert has_thinking_content(conv) is True + + def test_empty_thinking_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking": ""}, + ] + assert has_thinking_content(conv) is False + + def test_none_thinking_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking": None}, + ] + assert has_thinking_content(conv) is False + + def test_user_message_with_think_tag_ignored(self): + conv = [ + {"role": "user", "content": "user put this hereHi"}, + {"role": "assistant", "content": "Hello!"}, + ] + assert has_thinking_content(conv) is False + + def test_multi_turn_thinking_in_earlier_turn(self): + conv = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "thoughtA1"}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + assert has_thinking_content(conv) is True + + def test_multi_turn_no_thinking(self): + conv = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + assert has_thinking_content(conv) is False + + def test_empty_conversation(self): + assert has_thinking_content([]) is False + + def test_non_dict_messages(self): + assert has_thinking_content(["not a dict"]) is False + + def test_multimodal_content_no_thinking(self): + conv = [ + {"role": "user", "content": [{"type": "text", "text": "Describe"}]}, + {"role": "assistant", "content": "A cat."}, + ] + assert has_thinking_content(conv) is False + + def test_system_message_ignored(self): + conv = [ + {"role": "system", "content": "system thinking"}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + assert has_thinking_content(conv) is False + + def test_reasoning_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning_content": "Let me think..."}, + ] + assert has_thinking_content(conv) is True + + def test_reasoning_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning": "step by step"}, + ] + assert has_thinking_content(conv) is True + + def test_empty_reasoning_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning_content": ""}, + ] + assert has_thinking_content(conv) is False + + def test_none_reasoning_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning_content": None}, + ] + assert has_thinking_content(conv) is False + + +# ── Collator reads pre-computed loss_mask from MooncakeDataset ──────── + + +class TestCollatorPrecomputedMask: + def _make_collator(self): + return DataCollatorWithPadding() + + def test_precomputed_loss_mask_used(self): + collator = self._make_collator() + mask_tensor = torch.tensor([[0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]], dtype=torch.long) + item = { + "input_ids": torch.zeros(1, 12, dtype=torch.long), + "loss_mask": mask_tensor, + } + result = collator._get_loss_mask(item) + assert torch.equal(result, mask_tensor) + + def test_missing_loss_mask_raises(self): + import pytest + + collator = self._make_collator() + item = {"input_ids": torch.zeros(1, 4, dtype=torch.long)} + with pytest.raises(KeyError): + collator._get_loss_mask(item) + + +# ── resolve_loss_mask (single source of truth in data/utils.py) ─────── + +_HEADER = [10, 20] +_END = [30, 40] + + +class TestResolveLossMask: + """Test that resolve_loss_mask computes, stores, and skips correctly.""" + + def test_packed_loss_mask_nonzero(self): + data = {"packed_loss_mask": "2,3,2"} + mask = resolve_loss_mask(data) + assert mask is not None + assert mask.tolist() == [0, 0, 1, 1, 1, 0, 0] + assert "loss_mask" in data + + def test_packed_loss_mask_zero(self): + data = {"packed_loss_mask": "10"} + assert resolve_loss_mask(data) is None + + def test_dynamic_mask_nonzero(self): + data = {"input_ids": torch.tensor([10, 20, 1, 2, 30, 40], dtype=torch.long)} + mask = resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + ) + assert mask is not None + assert mask.tolist() == [0, 0, 1, 1, 0, 0] + assert "loss_mask" in data + + def test_dynamic_mask_zero(self): + data = {"input_ids": torch.tensor([5, 6, 7, 8], dtype=torch.long)} + assert ( + resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + ) + is None + ) + + def test_dynamic_mask_per_sample_last_turn_only(self): + ids = [10, 20, 1, 30, 40, 10, 20, 2, 30, 40] + data = { + "input_ids": torch.tensor(ids, dtype=torch.long), + "last_turn_loss_only": True, + } + mask = resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + last_turn_loss_only=False, + ) + assert mask is not None + assert mask.tolist() == [0, 0, 0, 0, 0, 0, 0, 1, 0, 0] + assert "loss_mask" in data + + def test_dynamic_mask_per_sample_all_turns(self): + ids = [10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40] + data = { + "input_ids": torch.tensor(ids, dtype=torch.long), + "last_turn_loss_only": False, + } + mask = resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + last_turn_loss_only=True, + ) + assert mask is not None + assert mask.tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0] + + def test_no_params_defaults_nonzero(self): + data = {"input_ids": torch.tensor([1, 2, 3], dtype=torch.long)} + assert resolve_loss_mask(data) is not None + + def test_empty_packed_loss_mask(self): + data = {"packed_loss_mask": ""} + assert resolve_loss_mask(data) is None diff --git a/tests/test_kimi_k25_integration.py b/tests/test_kimi_k25_integration.py index 144574f..19ef396 100644 --- a/tests/test_kimi_k25_integration.py +++ b/tests/test_kimi_k25_integration.py @@ -280,7 +280,7 @@ def test_preprocess_to_collator(self, mock_tokenizer, kimi_template, sample_conv { "input_ids": result["input_ids"][0], "attention_mask": result["attention_mask"][0], - "packed_loss_mask": result["packed_loss_mask"][0], + "loss_mask": unpack_loss_mask(result["packed_loss_mask"][0])[None, :], } ] diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py index 01460dc..1192c3f 100644 --- a/tests/test_loss_mask.py +++ b/tests/test_loss_mask.py @@ -217,3 +217,49 @@ def test_200k_trailing_open_turn(self): result = compute_assistant_loss_mask(ids_open, self.H, self.E) assert torch.equal(result, ref) assert result[-100:].sum() == 100 + + +# ── last_turn_only ─────────────────────────────────────────────────── + + +class TestLastTurnOnly: + H = [10, 20] + E = [30, 40] + + def test_single_turn_unchanged(self): + ids = torch.tensor([10, 20, 1, 2, 3, 30, 40], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 1, 1, 1, 0, 0] + + def test_two_turns_keeps_last(self): + ids = torch.tensor([10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] + + def test_three_turns_keeps_last(self): + ids = torch.tensor( + [10, 20, 1, 30, 40, 10, 20, 2, 30, 40, 10, 20, 3, 4, 5, 30, 40], dtype=torch.long + ) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0] + + def test_truncated_last_turn(self): + """Sequence cut mid-response — last turn has no end token.""" + ids = torch.tensor([10, 20, 1, 30, 40, 10, 20, 2, 3], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 1, 1] + + def test_no_assistant_turns(self): + ids = torch.tensor([5, 6, 7, 8], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0] + + def test_empty_input(self): + ids = torch.tensor([], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [] + + def test_false_flag_returns_all_turns(self): + ids = torch.tensor([10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=False) + assert result.tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0] diff --git a/tests/test_loss_mask_cross_validation.py b/tests/test_loss_mask_cross_validation.py new file mode 100644 index 0000000..65a37af --- /dev/null +++ b/tests/test_loss_mask_cross_validation.py @@ -0,0 +1,185 @@ +"""Cross-validation: compute_assistant_loss_mask must match parser.parse for every template. + +CI gate — auto-parametrizes over all registered templates so adding a new +template without verifying loss mask correctness will fail the test. + +Requires tokenizer downloads (~10MB each). Templates whose tokenizer is +unavailable are skipped. +""" + +import pytest +import torch + +from torchspec.data.parse import create_parser +from torchspec.data.template import TEMPLATE_REGISTRY +from torchspec.models.ops.loss_mask import compute_assistant_loss_mask + +_END_TOKEN_BPE_XFAIL: set[str] = set() + +MESSAGES_MULTI_TURN = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "The answer is 4."}, + {"role": "user", "content": "And 3+3?"}, + {"role": "assistant", "content": "The answer is 6."}, +] + +MESSAGES_LEADING_NEWLINES = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "\n\nThe answer is 4."}, +] + +_tokenizer_cache: dict = {} + + +def _testable_templates(): + """Templates that support dynamic loss mask (need both header and end token).""" + for name in TEMPLATE_REGISTRY.get_all_template_names(): + template = TEMPLATE_REGISTRY.get(name) + if template.assistant_header and template.end_of_turn_token: + yield name + + +def _get_tokenizer(template_name): + template = TEMPLATE_REGISTRY.get(template_name) + model_path = template.reference_model + if model_path is None: + pytest.skip(f"No reference_model for template '{template_name}'") + + if model_path in _tokenizer_cache: + return _tokenizer_cache[model_path] + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + except Exception as e: + pytest.skip(f"Tokenizer unavailable for {model_path}: {e}") + + _tokenizer_cache[model_path] = tokenizer + return tokenizer + + +def _get_header_ids_and_skip(tokenizer, template): + """Replicate get_assistant_token_ids() logic for the test.""" + full_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False) + stripped = template.assistant_header.rstrip("\n") + stripped_ids = tokenizer.encode(stripped, add_special_tokens=False) + skip_after = len(full_ids) - len(stripped_ids) + end_ids = tokenizer.encode(template.end_of_turn_token, add_special_tokens=False) + return stripped_ids, end_ids, skip_after + + +def _first_diff(a, b): + for i, (x, y) in enumerate(zip(a, b)): + if x != y: + return i + if len(a) != len(b): + return min(len(a), len(b)) + return None + + +def _strip_end_tokens_from_mask(mask_list, ids_list, end_ids): + """Zero out end-of-turn token positions in a mask. + + The parser regex captures content *including* the end_of_turn_token, + but compute_assistant_loss_mask marks content *excluding* it. + This normalizes the parser mask to match the dynamic mask convention. + """ + result = list(mask_list) + end_len = len(end_ids) + i = 0 + while i <= len(ids_list) - end_len: + if ids_list[i : i + end_len] == end_ids: + for k in range(end_len): + result[i + k] = 0 + i += end_len + else: + i += 1 + return result + + +def _cross_validate(template_name, messages, last_turn_only=False): + """Core cross-validation: parser ground truth vs dynamic mask.""" + template = TEMPLATE_REGISTRY.get(template_name) + tokenizer = _get_tokenizer(template_name) + parser = create_parser(tokenizer, template) + header_ids, end_ids, skip_after = _get_header_ids_and_skip(tokenizer, template) + + formatted = parser.format(messages, add_generation_prompt=False, expand_media_tokens=False) + + gt_ids, gt_mask = parser.parse( + formatted, max_length=200000, preformatted=True, last_turn_only=last_turn_only + ) + gt_ids_list = gt_ids.squeeze().tolist() + gt_mask_list = gt_mask.squeeze().tolist() + + if sum(gt_mask_list) == 0: + pytest.skip( + f"Parser produced all-zero mask for '{template_name}' — " + f"template header may not match tokenizer's chat template" + ) + + engine_ids = tokenizer.encode(formatted, add_special_tokens=False) + + assert gt_ids_list == engine_ids, ( + f"[{template_name}] Tokenization mismatch: " + f"parser produced {len(gt_ids_list)} tokens, " + f"tokenizer.encode produced {len(engine_ids)} tokens" + ) + + # Parser includes end_of_turn_token in mask; dynamic mask does not. + normalized_gt = _strip_end_tokens_from_mask(gt_mask_list, engine_ids, end_ids) + + assert sum(normalized_gt) > 0, ( + f"[{template_name}] After stripping end tokens, parser mask is all zeros — " + f"this likely means the parser only matched end tokens as content" + ) + + dyn_mask = compute_assistant_loss_mask( + torch.tensor(engine_ids), + header_ids, + end_ids, + last_turn_only=last_turn_only, + skip_after_header=skip_after, + ) + + diff_idx = _first_diff(normalized_gt, dyn_mask.tolist()) + assert normalized_gt == dyn_mask.tolist(), ( + f"[{template_name}] Loss mask mismatch " + f"(last_turn_only={last_turn_only}):\n" + f" parser mask 1s (normalized): {sum(normalized_gt)}\n" + f" dynamic mask 1s: {dyn_mask.sum().item()}\n" + f" first diff at token {diff_idx}" + ) + + +# ── Parametrized tests ─────────────────────────────────────────────── + + +def _maybe_xfail(template_name): + if template_name in _END_TOKEN_BPE_XFAIL: + pytest.xfail( + f"'{template_name}' has a known end-token BPE merge issue — " + f"use precomputed masks (defer_tokenization=False) for this template" + ) + + +@pytest.mark.parametrize("template_name", list(_testable_templates())) +def test_multi_turn(template_name): + """Dynamic mask matches parser on a standard multi-turn conversation.""" + _maybe_xfail(template_name) + _cross_validate(template_name, MESSAGES_MULTI_TURN) + + +@pytest.mark.parametrize("template_name", list(_testable_templates())) +def test_multi_turn_last_turn_only(template_name): + """Dynamic mask matches parser with last_turn_only=True.""" + _maybe_xfail(template_name) + _cross_validate(template_name, MESSAGES_MULTI_TURN, last_turn_only=True) + + +@pytest.mark.parametrize("template_name", list(_testable_templates())) +def test_leading_newlines(template_name): + """Dynamic mask matches parser when content starts with newlines (BPE merge edge case).""" + _maybe_xfail(template_name) + _cross_validate(template_name, MESSAGES_LEADING_NEWLINES) diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 20da08b..b84df92 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -39,7 +39,7 @@ class DatasetConfig: eval_interval: int = 50 eval_micro_batch_size: Optional[int] = None eval_prompt_key: Optional[str] = None - last_turn_loss_only: bool = False + last_turn_loss_only: Any = "auto" # bool or "auto" prompt_key: str = "conversations" train_data_path: str = "" diff --git a/torchspec/controller/inference_manager.py b/torchspec/controller/inference_manager.py index 4e032dd..5178e7b 100644 --- a/torchspec/controller/inference_manager.py +++ b/torchspec/controller/inference_manager.py @@ -481,6 +481,7 @@ def _parse_engine_output(self, entry: InferenceInput, output: dict) -> Inference tensor_shapes=output.get("tensor_shapes", {}), tensor_dtypes=output.get("tensor_dtypes", {}), packed_loss_mask=output.get("packed_loss_mask", entry.packed_loss_mask), + metadata=entry.metadata, ) async def _forward_results(self, results: list[tuple[InferenceInput, Any | Exception]]) -> int: diff --git a/torchspec/controller/training_controller.py b/torchspec/controller/training_controller.py index c25f959..d3e1e09 100644 --- a/torchspec/controller/training_controller.py +++ b/torchspec/controller/training_controller.py @@ -446,12 +446,15 @@ def _dispatch_to_queues( partitioned = self._partition_results(batch_results) for dp_rank, results in enumerate(partitioned): for result in results: + metadata = getattr(result, "metadata", {}) or {} + last_turn_loss_only = metadata.get("has_thinking") queues[dp_rank].put( TrainSample( mooncake_key=result.mooncake_key, tensor_shapes=result.tensor_shapes, tensor_dtypes=result.tensor_dtypes, packed_loss_mask=result.packed_loss_mask, + last_turn_loss_only=last_turn_loss_only, ) ) diff --git a/torchspec/data/dataset.py b/torchspec/data/dataset.py index 8ee6a07..264fbc0 100644 --- a/torchspec/data/dataset.py +++ b/torchspec/data/dataset.py @@ -26,7 +26,8 @@ import torch from tqdm import tqdm -from torchspec.data.preprocessing import _normalize_conversation +from torchspec.data.parse import create_parser, has_thinking_content +from torchspec.data.preprocessing import _normalize_conversation, preprocess_conversations from torchspec.data.template import TEMPLATE_REGISTRY from torchspec.data.utils import ( estimate_row_count, @@ -35,6 +36,7 @@ load_hf_dataset, ) from torchspec.utils.logging import logger +from torchspec.utils.processing import load_tokenizer _logging.getLogger("transformers_modules").setLevel(_logging.ERROR) @@ -45,9 +47,6 @@ def _init_tokenize_worker( tokenizer_path, trust_remote_code, chat_template_name, last_turn_loss_only=False ): """Initializer for each worker process — loads tokenizer once.""" - from torchspec.data.preprocessing import preprocess_conversations - from torchspec.utils.processing import load_tokenizer - _logging.getLogger("transformers_modules").setLevel(_logging.ERROR) _worker_state["tokenizer"] = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code) _worker_state["template"] = TEMPLATE_REGISTRY.get(chat_template_name) @@ -55,6 +54,13 @@ def _init_tokenize_worker( _worker_state["last_turn_loss_only"] = last_turn_loss_only +def _resolve_last_turn_loss_only(messages): + ltlo = _worker_state.get("last_turn_loss_only", False) + if ltlo == "auto": + return has_thinking_content(messages) + return bool(ltlo) + + def _tokenize_single(args): """Worker function — tokenize one sample.""" messages, max_length, train_with_decode = args @@ -68,7 +74,7 @@ def _tokenize_single(args): use_packed_loss_mask=True, add_generation_prompt=train_with_decode, return_formatted_text=True, - last_turn_loss_only=_worker_state.get("last_turn_loss_only", False), + last_turn_loss_only=_resolve_last_turn_loss_only(messages), ) if not processed["input_ids"]: return None @@ -82,14 +88,14 @@ def _tokenize_single(args): } -def _init_format_worker(tokenizer_path, trust_remote_code, chat_template_name): - from torchspec.data.parse import create_parser - from torchspec.utils.processing import load_tokenizer - +def _init_format_worker( + tokenizer_path, trust_remote_code, chat_template_name, last_turn_loss_only=False +): _logging.getLogger("transformers_modules").setLevel(_logging.ERROR) tokenizer = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code) _worker_state["template"] = TEMPLATE_REGISTRY.get(chat_template_name) _worker_state["parser"] = create_parser(tokenizer, _worker_state["template"]) + _worker_state["last_turn_loss_only"] = last_turn_loss_only def _format_single(args): @@ -98,11 +104,20 @@ def _format_single(args): """ messages, _, train_with_decode = args messages = _normalize_conversation(messages) + + result = {} + ltlo = _worker_state.get("last_turn_loss_only", False) + if ltlo == "auto": + result["has_thinking"] = has_thinking_content(messages) + parser = _worker_state["parser"] - formatted = parser.format(messages, add_generation_prompt=train_with_decode) + formatted = parser.format( + messages, add_generation_prompt=train_with_decode, expand_media_tokens=False + ) if not formatted: return None - return {"formatted_prompt": formatted} + result["formatted_prompt"] = formatted + return result def load_conversation_dataset(args): @@ -185,16 +200,17 @@ def load_conversation_dataset(args): # Pass 2: process in parallel work_items = [(messages, max_length, train_with_decode) for _, messages, _ in raw_samples] + last_turn_loss_only = getattr(args, "last_turn_loss_only", False) if defer_tokenization: worker_init = _init_format_worker - worker_initargs = (args.target_model_path, True, chat_template_name) + worker_initargs = (args.target_model_path, True, chat_template_name, last_turn_loss_only) worker_fn = _format_single desc = "Formatting dataset" else: - last_turn_loss_only = getattr(args, "last_turn_loss_only", False) if last_turn_loss_only: logger.info( - "last_turn_loss_only=True: loss mask will only cover the last assistant turn" + f"last_turn_loss_only={last_turn_loss_only}: " + "loss mask will only cover the last assistant turn" ) worker_init = _init_tokenize_worker worker_initargs = (args.target_model_path, True, chat_template_name, last_turn_loss_only) @@ -221,9 +237,13 @@ def load_conversation_dataset(args): if result is None: skipped += 1 continue + metadata = {} + if "has_thinking" in result: + metadata["has_thinking"] = result["has_thinking"] + entry = { "data_id": data_id, - "metadata": {}, + "metadata": metadata, "multimodal_inputs": multimodal_inputs, "formatted_prompt": result["formatted_prompt"], } diff --git a/torchspec/data/parse.py b/torchspec/data/parse.py index 0e13875..f91e0d0 100644 --- a/torchspec/data/parse.py +++ b/torchspec/data/parse.py @@ -33,7 +33,40 @@ Conversation = List[Dict[str, Any]] -__all__ = ["GeneralParser", "HarmonyParser", "KimiK25Parser", "create_parser"] +__all__ = [ + "GeneralParser", + "HarmonyParser", + "KimiK25Parser", + "create_parser", + "has_thinking_content", +] + +_HAS_THINKING_RE = re.compile(r"(?!\s*)") + + +def has_thinking_content(conversation: list) -> bool: + """Detect whether any assistant message contains real thinking content. + + Checks for non-empty blocks in message content and for + separate thinking/thinking_content or reasoning/reasoning_content + fields on the message dict. Must be called on the raw conversation + BEFORE formatting, since formatters (e.g. KimiK25Parser) inject + empty tags. + """ + for msg in conversation: + if not isinstance(msg, dict) or msg.get("role") != "assistant": + continue + content = msg.get("content", "") + if isinstance(content, str) and _HAS_THINKING_RE.search(content): + return True + if ( + msg.get("thinking") + or msg.get("thinking_content") + or msg.get("reasoning") + or msg.get("reasoning_content") + ): + return True + return False class Parser(ABC): diff --git a/torchspec/data/preprocessing.py b/torchspec/data/preprocessing.py index b0f8781..460e66f 100644 --- a/torchspec/data/preprocessing.py +++ b/torchspec/data/preprocessing.py @@ -76,7 +76,12 @@ def _normalize_conversation(conversation: Conversation) -> Conversation: normalized = [] for msg in conversation: role = ROLE_MAPPING.get(msg["from"], msg["from"]) - normalized.append({"role": role, "content": msg["value"]}) + entry = {"role": role, "content": msg["value"]} + for field in ("thinking", "thinking_content", "reasoning_content", "reasoning"): + if msg.get(field): + entry["reasoning_content"] = msg[field] + break + normalized.append(entry) return normalized return conversation diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py index 4822f6f..f7a220b 100644 --- a/torchspec/data/utils.py +++ b/torchspec/data/utils.py @@ -47,16 +47,8 @@ def is_local_data_path(path: str, base_dir: str | None = None) -> bool: class DataCollatorWithPadding: - def __init__( - self, - assistant_header_ids: Optional[List[int]] = None, - end_token_ids: Optional[List[int]] = None, - dynamic_loss_mask: bool = False, - ): + def __init__(self): self.sp_degree = 1 - self.assistant_header_ids = assistant_header_ids - self.end_token_ids = end_token_ids - self.dynamic_loss_mask = dynamic_loss_mask def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: B, n, S = intensors.shape @@ -71,33 +63,14 @@ def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor: return outtensors def _get_loss_mask(self, item: Dict[str, Any]) -> torch.Tensor: - """Derive loss_mask for a single sample. + """Read the materialized loss_mask tensor from the item. - Priority: - 1. dynamic_loss_mask flag → compute from input_ids token boundaries - 2. packed_loss_mask string → unpack RLE-encoded mask - 3. fallback → all-ones mask + Callers (e.g. MooncakeDataset) are responsible for computing and + attaching loss_mask before items reach the collator. """ - if self.dynamic_loss_mask: - if self.assistant_header_ids is None or self.end_token_ids is None: - raise ValueError( - "dynamic_loss_mask requires assistant_header_ids and " - "end_token_ids to be set on the collator" - ) - # TCP's input_ids are on CPU, so we can use input_ids_cpu directly. - input_ids = item.get("input_ids_cpu", item["input_ids"]) - if input_ids.dim() == 2: - input_ids = input_ids.squeeze(0) - mask = compute_assistant_loss_mask( - input_ids, self.assistant_header_ids, self.end_token_ids - ) - # Copy back to GPU. - return mask[None, :].to(item["input_ids"].device) - - if "packed_loss_mask" not in item or item["packed_loss_mask"] is None: - seq_len = item["input_ids"].shape[-1] - return torch.ones(1, seq_len, dtype=torch.long, device=item["input_ids"].device) - return unpack_loss_mask(item["packed_loss_mask"])[None, :] + if "loss_mask" in item and isinstance(item["loss_mask"], torch.Tensor): + return item["loss_mask"] + raise KeyError(f"loss_mask not found in item: {item}") def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: max_length = max(item["input_ids"].shape[1] for item in features) @@ -219,6 +192,52 @@ def unpack_loss_mask(packed: Union[List[int], str]) -> torch.Tensor: return loss_mask +def resolve_loss_mask( + data: Dict[str, Any], + *, + dynamic_loss_mask: bool = False, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + last_turn_loss_only: bool = False, + skip_after_header: int = 0, +) -> torch.Tensor | None: + """ + Two strategies, tried in order: + 1. ``packed_loss_mask`` key present → unpack it. + 2. ``dynamic_loss_mask`` enabled with valid header/end ids → compute from + ``input_ids`` via :func:`compute_assistant_loss_mask`. + """ + packed = data.get("packed_loss_mask") + if packed is not None: + mask = unpack_loss_mask(packed) + if not mask.any(): + return None + data["loss_mask"] = mask + return mask + + if dynamic_loss_mask and assistant_header_ids and end_token_ids: + input_ids = data.get("input_ids") + if input_ids is None: + return None + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + per_sample = data.get("last_turn_loss_only") + last_turn_only = per_sample if per_sample is not None else last_turn_loss_only + mask = compute_assistant_loss_mask( + input_ids, + assistant_header_ids, + end_token_ids, + last_turn_only=last_turn_only, + skip_after_header=skip_after_header, + ) + if not mask.any(): + return None + data["loss_mask"] = mask + return mask + + return torch.ones(1) + + def serialize_packed_loss_mask(packed: List[int]) -> str: """ Serialize packed loss_mask to a comma-separated string. diff --git a/torchspec/models/ops/loss_mask.py b/torchspec/models/ops/loss_mask.py index 7f5e78f..4da0abf 100644 --- a/torchspec/models/ops/loss_mask.py +++ b/torchspec/models/ops/loss_mask.py @@ -24,7 +24,7 @@ @numba.njit(cache=True) -def _numba_loss_mask(ids, header, header_len, end, end_len, out): +def _numba_loss_mask(ids, header, header_len, end, end_len, out, skip_after): n = len(ids) i = 0 while i <= n - header_len: @@ -36,7 +36,7 @@ def _numba_loss_mask(ids, header, header_len, end, end_len, out): if not match: i += 1 continue - j = i + header_len + j = i + header_len + skip_after found_end = False while j <= n - end_len: end_match = True @@ -69,6 +69,8 @@ def compute_assistant_loss_mask( input_ids: torch.Tensor, assistant_header_ids: list[int], end_token_ids: list[int], + last_turn_only: bool = False, + skip_after_header: int = 0, ) -> torch.Tensor: """Compute loss mask where 1s mark assistant content tokens only. @@ -80,6 +82,10 @@ def compute_assistant_loss_mask( input_ids: 1-D tensor of token IDs (CPU or CUDA). assistant_header_ids: Token ID sequence marking the start of assistant content. end_token_ids: Token ID sequence marking the end of assistant content. + last_turn_only: If True, only the last assistant turn is marked. + skip_after_header: Number of tokens to skip after header match before + marking content. Use 1 when the header excludes a trailing newline + that BPE may merge with subsequent content. Returns: 1-D long tensor on the same device as input_ids, with 1s for assistant @@ -94,6 +100,14 @@ def compute_assistant_loss_mask( header_np = np.array(assistant_header_ids, dtype=np.int64) end_np = np.array(end_token_ids, dtype=np.int64) out = np.zeros(len(ids_np), dtype=np.int64) - _numba_loss_mask(ids_np, header_np, len(header_np), end_np, len(end_np), out) + _numba_loss_mask(ids_np, header_np, len(header_np), end_np, len(end_np), out, skip_after_header) + + if last_turn_only and out.any(): + last_one = len(out) - 1 - np.argmax(out[::-1]) + first_of_last = last_one + while first_of_last > 0 and out[first_of_last - 1] == 1: + first_of_last -= 1 + out[:first_of_last] = 0 + result = torch.from_numpy(out) return result.to(device) if device.type != "cpu" else result diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index 18d3f5e..76fa62c 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -32,6 +32,7 @@ from ray.util.queue import Queue as RayQueue from torch.utils.data import DataLoader, IterableDataset +from torchspec.data.utils import resolve_loss_mask from torchspec.utils.logging import logger @@ -41,6 +42,7 @@ class TrainSample: tensor_shapes: Dict[str, Tuple[int, ...]] tensor_dtypes: Optional[Dict[str, torch.dtype]] = None packed_loss_mask: Optional[str] = None + last_turn_loss_only: Optional[bool] = None class MooncakeDataset(IterableDataset): @@ -57,20 +59,22 @@ def __init__( device: torch.device, prefetch_factor: int = 2, timeout: Optional[float] = None, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + dynamic_loss_mask: bool = False, + last_turn_loss_only: bool = False, + skip_after_header: int = 0, ): - """ - Args: - ray_queue: Ray Queue to receive TrainSample from controller. - mooncake_store: Mooncake store client for loading tensors. - device: Target device for tensors. - prefetch_factor: Number of samples to prefetch in background thread. - timeout: Timeout in seconds for waiting on queue. None means wait forever. - """ self.ray_queue = ray_queue self.mooncake_store = mooncake_store self.device = device self.prefetch_factor = prefetch_factor self.timeout = timeout + self.assistant_header_ids = assistant_header_ids + self.end_token_ids = end_token_ids + self.dynamic_loss_mask = dynamic_loss_mask + self.last_turn_loss_only = last_turn_loss_only + self.skip_after_header = skip_after_header def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: """Load tensors from mooncake key into device memory.""" @@ -104,6 +108,8 @@ def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: result = tensors.to_tensor_dict() if sample.packed_loss_mask is not None: result["packed_loss_mask"] = sample.packed_loss_mask + if sample.last_turn_loss_only is not None: + result["last_turn_loss_only"] = sample.last_turn_loss_only return result def _cleanup_mooncake_data(self, sample: TrainSample) -> None: @@ -118,12 +124,24 @@ def _cleanup_mooncake_data(self, sample: TrainSample) -> None: has_target=has_target, ) + def _compute_loss_mask(self, data: Dict[str, Any]) -> torch.Tensor | None: + return resolve_loss_mask( + data, + dynamic_loss_mask=self.dynamic_loss_mask, + assistant_header_ids=self.assistant_header_ids, + end_token_ids=self.end_token_ids, + last_turn_loss_only=self.last_turn_loss_only, + skip_after_header=self.skip_after_header, + ) + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Iterate over samples synchronously. Blocks waiting for each item from the queue and loads from mooncake. + Skips samples whose loss mask is all zeros to avoid wasted compute. """ yield_count = 0 + skip_count = 0 while True: logger.debug(f"__iter__: waiting for item from ray_queue (yield_count={yield_count})") try: @@ -138,6 +156,15 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: logger.debug(f"__iter__: got item, mooncake_key={item.mooncake_key}") data = self._load_from_mooncake(item) + + if self._compute_loss_mask(data) is None: + skip_count += 1 + logger.warning( + f"Skipping sample with all-zero loss mask " + f"(mooncake_key={item.mooncake_key}, total_skipped={skip_count})" + ) + continue + # Note: target is computed in the collator from last_hidden_states for sglang mode # Add batch dimension if missing (sglang stores without batch dim) @@ -174,6 +201,11 @@ def create_mooncake_dataloader( batch_size: int = 1, prefetch_factor: int = 2, timeout: Optional[float] = None, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + dynamic_loss_mask: bool = False, + last_turn_loss_only: bool = False, + skip_after_header: int = 0, ) -> DataLoader: """Create a DataLoader that fetches from mooncake via queue. @@ -193,11 +225,26 @@ def create_mooncake_dataloader( batch_size: Number of samples per batch (= per_dp_rank_batch_size). prefetch_factor: Unused, kept for API compatibility. timeout: Timeout in seconds for waiting on queue. None means wait forever. + assistant_header_ids: Token IDs for assistant header (for loss mask skip check). + end_token_ids: Token IDs for end of turn (for loss mask skip check). + dynamic_loss_mask: Whether loss mask is computed dynamically from input_ids. + last_turn_loss_only: Global fallback for last-turn-only loss masking. Returns: DataLoader instance. """ - dataset = MooncakeDataset(ray_queue, mooncake_store, device, prefetch_factor, timeout) + dataset = MooncakeDataset( + ray_queue, + mooncake_store, + device, + prefetch_factor, + timeout, + assistant_header_ids=assistant_header_ids, + end_token_ids=end_token_ids, + dynamic_loss_mask=dynamic_loss_mask, + last_turn_loss_only=last_turn_loss_only, + skip_after_header=skip_after_header, + ) return DataLoader( dataset, @@ -232,6 +279,11 @@ def __init__( batch_size: int = 1, prefetch_factor: int = 2, timeout: Optional[float] = None, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + dynamic_loss_mask: bool = False, + last_turn_loss_only: bool = False, + skip_after_header: int = 0, ): self.batch_size = batch_size self._dataloader = create_mooncake_dataloader( @@ -242,6 +294,11 @@ def __init__( batch_size=batch_size, prefetch_factor=prefetch_factor, timeout=timeout, + assistant_header_ids=assistant_header_ids, + end_token_ids=end_token_ids, + dynamic_loss_mask=dynamic_loss_mask, + last_turn_loss_only=last_turn_loss_only, + skip_after_header=skip_after_header, ) def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: diff --git a/torchspec/training/trainer.py b/torchspec/training/trainer.py index 4358680..bd4aa71 100644 --- a/torchspec/training/trainer.py +++ b/torchspec/training/trainer.py @@ -78,7 +78,10 @@ def __init__(self, args: Namespace): self.prof = TrainProfiler(args) self.dynamic_loss_mask = getattr(args, "dynamic_loss_mask", False) - self.assistant_header_ids, self.end_token_ids = get_assistant_token_ids(self.args) + self.last_turn_loss_only = getattr(args, "last_turn_loss_only", False) + self.assistant_header_ids, self.end_token_ids, self.skip_after_header = ( + get_assistant_token_ids(self.args) + ) self.save_debug_train_data = getattr(args, "save_debug_train_data", None) self.max_dump_steps = getattr(args, "max_dump_steps", 5) @@ -155,11 +158,7 @@ def set_train_queue( if mooncake_config is not None and self.mooncake_store is None: self.init_mooncake_store(mooncake_config) - collator = DataCollatorWithPadding( - assistant_header_ids=self.assistant_header_ids, - end_token_ids=self.end_token_ids, - dynamic_loss_mask=self.dynamic_loss_mask, - ) + collator = DataCollatorWithPadding() self.data_fetcher = MooncakeDataFetcher( queue=self.train_queue, @@ -167,6 +166,11 @@ def set_train_queue( collator=collator, device=torch.cuda.current_device(), batch_size=per_dp_rank_batch_size, + assistant_header_ids=self.assistant_header_ids, + end_token_ids=self.end_token_ids, + dynamic_loss_mask=self.dynamic_loss_mask, + last_turn_loss_only=self.last_turn_loss_only, + skip_after_header=self.skip_after_header, ) logger.info( @@ -186,11 +190,7 @@ def set_eval_queue( if mooncake_config is not None and self.mooncake_store is None: self.init_mooncake_store(mooncake_config) - collator = DataCollatorWithPadding( - assistant_header_ids=self.assistant_header_ids, - end_token_ids=self.end_token_ids, - dynamic_loss_mask=self.dynamic_loss_mask, - ) + collator = DataCollatorWithPadding() self._eval_data_fetcher = MooncakeDataFetcher( queue=queue, @@ -198,6 +198,11 @@ def set_eval_queue( collator=collator, device=torch.cuda.current_device(), batch_size=per_dp_rank_batch_size, + assistant_header_ids=self.assistant_header_ids, + end_token_ids=self.end_token_ids, + dynamic_loss_mask=self.dynamic_loss_mask, + last_turn_loss_only=self.last_turn_loss_only, + skip_after_header=self.skip_after_header, ) self._eval_collator = collator self._eval_cache: list[dict] = [] diff --git a/torchspec/utils/processing.py b/torchspec/utils/processing.py index a06e315..974b63c 100644 --- a/torchspec/utils/processing.py +++ b/torchspec/utils/processing.py @@ -28,22 +28,41 @@ def load_tokenizer(name_or_path: str, **kwargs): return AutoTokenizer.from_pretrained(name_or_path, **kwargs) -def get_assistant_token_ids(args) -> tuple[list[int] | None, list[int] | None]: - """Derive assistant_header_ids and end_token_ids from chat_template config.""" +def get_assistant_token_ids( + args, +) -> tuple[list[int] | None, list[int] | None, int]: + """Derive assistant_header_ids, end_token_ids, and skip_after_header. + + Returns: + (header_ids, end_ids, skip_after_header) where skip_after_header is the + number of trailing-newline tokens stripped from the header to avoid BPE + merge issues. Pass this to ``compute_assistant_loss_mask`` so it skips + the formatting newline after the role name. + """ from torchspec.data.template import TEMPLATE_REGISTRY chat_template_name = getattr(args, "chat_template", None) if not chat_template_name: - return None, None + return None, None, 0 template = TEMPLATE_REGISTRY.get(chat_template_name) if not template.assistant_header or not template.end_of_turn_token: - return None, None + return None, None, 0 tokenizer = load_tokenizer(args.target_model_path, trust_remote_code=True) - assistant_header_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False) + + full_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False) + header_text = template.assistant_header.rstrip("\n") + stripped_ids = tokenizer.encode(header_text, add_special_tokens=False) + # BPE tokenizers may merge the header's trailing \n with content-leading + # newlines into a single multi-char token, breaking subsequence matching. + # We tokenize without the trailing \n and tell the mask function to skip + # the newline token(s) that follow. + skip_after_header = len(full_ids) - len(stripped_ids) + end_token_ids = tokenizer.encode(template.end_of_turn_token, add_special_tokens=False) logger.info( - f"Assistant loss mask token IDs: header={assistant_header_ids}, end={end_token_ids}" + f"Assistant loss mask token IDs: header={stripped_ids}, end={end_token_ids}, " + f"skip_after_header={skip_after_header}" ) - return assistant_header_ids, end_token_ids + return stripped_ids, end_token_ids, skip_after_header diff --git a/torchspec/utils/types.py b/torchspec/utils/types.py index ee4b329..9d9fa89 100644 --- a/torchspec/utils/types.py +++ b/torchspec/utils/types.py @@ -49,3 +49,4 @@ class InferenceOutput: tensor_shapes: dict[str, tuple[int, ...]] tensor_dtypes: dict[str, torch.dtype] | None = None packed_loss_mask: str | None = None + metadata: dict = field(default_factory=dict)