From 6c3a528e4150e3d320809b3392c05d4136067fb5 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 9 Mar 2026 06:31:58 +0000 Subject: [PATCH 1/7] [Bug Fix] Support last_turn_loss_only in dynamic_loss_mask path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The last_turn_loss_only flag was silently ignored when defer_tokenization was enabled (dynamic_loss_mask path). Thread the flag through compute_assistant_loss_mask → DataCollatorWithPadding → Trainer so it takes effect for deferred-tokenization workloads. --- tests/test_loss_mask.py | 46 +++++++++++++++++++++++++++++++ torchspec/data/utils.py | 27 +++++++++++++++++- torchspec/models/ops/loss_mask.py | 10 +++++++ torchspec/training/trainer.py | 3 ++ 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py index 01460dc..1192c3f 100644 --- a/tests/test_loss_mask.py +++ b/tests/test_loss_mask.py @@ -217,3 +217,49 @@ def test_200k_trailing_open_turn(self): result = compute_assistant_loss_mask(ids_open, self.H, self.E) assert torch.equal(result, ref) assert result[-100:].sum() == 100 + + +# ── last_turn_only ─────────────────────────────────────────────────── + + +class TestLastTurnOnly: + H = [10, 20] + E = [30, 40] + + def test_single_turn_unchanged(self): + ids = torch.tensor([10, 20, 1, 2, 3, 30, 40], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 1, 1, 1, 0, 0] + + def test_two_turns_keeps_last(self): + ids = torch.tensor([10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] + + def test_three_turns_keeps_last(self): + ids = torch.tensor( + [10, 20, 1, 30, 40, 10, 20, 2, 30, 40, 10, 20, 3, 4, 5, 30, 40], dtype=torch.long + ) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0] + + def test_truncated_last_turn(self): + """Sequence cut mid-response — last turn has no end token.""" + ids = torch.tensor([10, 20, 1, 30, 40, 10, 20, 2, 3], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0, 0, 0, 0, 1, 1] + + def test_no_assistant_turns(self): + ids = torch.tensor([5, 6, 7, 8], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [0, 0, 0, 0] + + def test_empty_input(self): + ids = torch.tensor([], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=True) + assert result.tolist() == [] + + def test_false_flag_returns_all_turns(self): + ids = torch.tensor([10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40], dtype=torch.long) + result = compute_assistant_loss_mask(ids, self.H, self.E, last_turn_only=False) + assert result.tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0] diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py index 4822f6f..5e34167 100644 --- a/torchspec/data/utils.py +++ b/torchspec/data/utils.py @@ -22,6 +22,7 @@ import os from pathlib import Path from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse import torch from datasets import IterableDataset, load_dataset @@ -29,6 +30,25 @@ from torchspec.models.ops.loss_mask import compute_assistant_loss_mask +_IMAGE_CACHE_DIR = os.environ.get("TORCHSPEC_IMAGE_CACHE", "/data/ywang/image_cache") + + +def resolve_image_url(url: str, cache_dir: str = _IMAGE_CACHE_DIR) -> str: + """Return local file path if a cached copy exists, otherwise the original URL.""" + if not url or not url.startswith("http"): + return url + parsed = urlparse(url) + local_path = os.path.join(cache_dir, parsed.netloc, parsed.path.lstrip("/")) + if os.path.isfile(local_path): + return local_path + return url + + +def resolve_image_urls(urls: list[str], cache_dir: str = _IMAGE_CACHE_DIR) -> list[str]: + """Resolve a list of image URLs to local paths where cached copies exist.""" + return [resolve_image_url(u, cache_dir) for u in urls] + + _LOCAL_DATA_EXTS = frozenset({".json", ".jsonl", ".parquet", ".arrow", ".csv", ".tsv", ".txt"}) @@ -52,11 +72,13 @@ def __init__( assistant_header_ids: Optional[List[int]] = None, end_token_ids: Optional[List[int]] = None, dynamic_loss_mask: bool = False, + last_turn_loss_only: bool = False, ): self.sp_degree = 1 self.assistant_header_ids = assistant_header_ids self.end_token_ids = end_token_ids self.dynamic_loss_mask = dynamic_loss_mask + self.last_turn_loss_only = last_turn_loss_only def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: B, n, S = intensors.shape @@ -89,7 +111,10 @@ def _get_loss_mask(self, item: Dict[str, Any]) -> torch.Tensor: if input_ids.dim() == 2: input_ids = input_ids.squeeze(0) mask = compute_assistant_loss_mask( - input_ids, self.assistant_header_ids, self.end_token_ids + input_ids, + self.assistant_header_ids, + self.end_token_ids, + last_turn_only=self.last_turn_loss_only, ) # Copy back to GPU. return mask[None, :].to(item["input_ids"].device) diff --git a/torchspec/models/ops/loss_mask.py b/torchspec/models/ops/loss_mask.py index 7f5e78f..dc555b2 100644 --- a/torchspec/models/ops/loss_mask.py +++ b/torchspec/models/ops/loss_mask.py @@ -69,6 +69,7 @@ def compute_assistant_loss_mask( input_ids: torch.Tensor, assistant_header_ids: list[int], end_token_ids: list[int], + last_turn_only: bool = False, ) -> torch.Tensor: """Compute loss mask where 1s mark assistant content tokens only. @@ -80,6 +81,7 @@ def compute_assistant_loss_mask( input_ids: 1-D tensor of token IDs (CPU or CUDA). assistant_header_ids: Token ID sequence marking the start of assistant content. end_token_ids: Token ID sequence marking the end of assistant content. + last_turn_only: If True, only the last assistant turn is marked. Returns: 1-D long tensor on the same device as input_ids, with 1s for assistant @@ -95,5 +97,13 @@ def compute_assistant_loss_mask( end_np = np.array(end_token_ids, dtype=np.int64) out = np.zeros(len(ids_np), dtype=np.int64) _numba_loss_mask(ids_np, header_np, len(header_np), end_np, len(end_np), out) + + if last_turn_only and out.any(): + last_one = len(out) - 1 - np.argmax(out[::-1]) + first_of_last = last_one + while first_of_last > 0 and out[first_of_last - 1] == 1: + first_of_last -= 1 + out[:first_of_last] = 0 + result = torch.from_numpy(out) return result.to(device) if device.type != "cpu" else result diff --git a/torchspec/training/trainer.py b/torchspec/training/trainer.py index 4358680..dfe8bfe 100644 --- a/torchspec/training/trainer.py +++ b/torchspec/training/trainer.py @@ -78,6 +78,7 @@ def __init__(self, args: Namespace): self.prof = TrainProfiler(args) self.dynamic_loss_mask = getattr(args, "dynamic_loss_mask", False) + self.last_turn_loss_only = getattr(args, "last_turn_loss_only", False) self.assistant_header_ids, self.end_token_ids = get_assistant_token_ids(self.args) self.save_debug_train_data = getattr(args, "save_debug_train_data", None) @@ -159,6 +160,7 @@ def set_train_queue( assistant_header_ids=self.assistant_header_ids, end_token_ids=self.end_token_ids, dynamic_loss_mask=self.dynamic_loss_mask, + last_turn_loss_only=self.last_turn_loss_only, ) self.data_fetcher = MooncakeDataFetcher( @@ -190,6 +192,7 @@ def set_eval_queue( assistant_header_ids=self.assistant_header_ids, end_token_ids=self.end_token_ids, dynamic_loss_mask=self.dynamic_loss_mask, + last_turn_loss_only=self.last_turn_loss_only, ) self._eval_data_fetcher = MooncakeDataFetcher( From 6cb125cc6205855f8a99070f5c233da46436c1bc Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 9 Mar 2026 17:03:28 -0700 Subject: [PATCH 2/7] loss calculation refactoring --- tests/test_dynamic_last_turn_loss.py | 222 ++++++++++++++++++++ tests/test_kimi_k25_integration.py | 2 +- torchspec/config/train_config.py | 2 +- torchspec/controller/inference_manager.py | 1 + torchspec/controller/training_controller.py | 3 + torchspec/data/dataset.py | 40 +++- torchspec/data/parse.py | 23 +- torchspec/data/utils.py | 47 +---- torchspec/training/data_fetcher.py | 93 +++++++- torchspec/training/trainer.py | 22 +- torchspec/utils/types.py | 1 + 11 files changed, 393 insertions(+), 63 deletions(-) create mode 100644 tests/test_dynamic_last_turn_loss.py diff --git a/tests/test_dynamic_last_turn_loss.py b/tests/test_dynamic_last_turn_loss.py new file mode 100644 index 0000000..7900e30 --- /dev/null +++ b/tests/test_dynamic_last_turn_loss.py @@ -0,0 +1,222 @@ +"""Tests for dynamic per-sample last_turn_loss_only based on thinking content detection.""" + +import torch + +from torchspec.data.parse import has_thinking_content +from torchspec.data.utils import DataCollatorWithPadding + + +# ── has_thinking_content detection ──────────────────────────────────── + + +class TestHasThinkingContent: + def test_think_tag_with_content(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "reasoning hereHello!"}, + ] + assert has_thinking_content(conv) is True + + def test_empty_think_tag(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + assert has_thinking_content(conv) is False + + def test_think_tag_whitespace_only(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": " \n\t Hello!"}, + ] + assert has_thinking_content(conv) is False + + def test_think_tag_with_single_char(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": " xanswer"}, + ] + assert has_thinking_content(conv) is True + + def test_no_think_tag(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello there!"}, + ] + assert has_thinking_content(conv) is False + + def test_thinking_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking": "some reasoning"}, + ] + assert has_thinking_content(conv) is True + + def test_thinking_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking_content": "reasoning"}, + ] + assert has_thinking_content(conv) is True + + def test_empty_thinking_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking": ""}, + ] + assert has_thinking_content(conv) is False + + def test_none_thinking_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "thinking": None}, + ] + assert has_thinking_content(conv) is False + + def test_user_message_with_think_tag_ignored(self): + conv = [ + {"role": "user", "content": "user put this hereHi"}, + {"role": "assistant", "content": "Hello!"}, + ] + assert has_thinking_content(conv) is False + + def test_multi_turn_thinking_in_earlier_turn(self): + conv = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "thoughtA1"}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + assert has_thinking_content(conv) is True + + def test_multi_turn_no_thinking(self): + conv = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + assert has_thinking_content(conv) is False + + def test_empty_conversation(self): + assert has_thinking_content([]) is False + + def test_non_dict_messages(self): + assert has_thinking_content(["not a dict"]) is False + + def test_multimodal_content_no_thinking(self): + conv = [ + {"role": "user", "content": [{"type": "text", "text": "Describe"}]}, + {"role": "assistant", "content": "A cat."}, + ] + assert has_thinking_content(conv) is False + + def test_system_message_ignored(self): + conv = [ + {"role": "system", "content": "system thinking"}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + assert has_thinking_content(conv) is False + + +# ── Collator reads pre-computed loss_mask from MooncakeDataset ──────── + + +class TestCollatorPrecomputedMask: + def _make_collator(self): + return DataCollatorWithPadding() + + def test_precomputed_loss_mask_used(self): + collator = self._make_collator() + mask_tensor = torch.tensor([[0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]], dtype=torch.long) + item = { + "input_ids": torch.zeros(1, 12, dtype=torch.long), + "loss_mask": mask_tensor, + } + result = collator._get_loss_mask(item) + assert torch.equal(result, mask_tensor) + + def test_missing_loss_mask_raises(self): + import pytest + + collator = self._make_collator() + item = {"input_ids": torch.zeros(1, 4, dtype=torch.long)} + with pytest.raises(KeyError): + collator._get_loss_mask(item) + + +# ── MooncakeDataset._compute_loss_mask (single source of truth) ────── + + +class TestComputeLossMask: + """Test that _compute_loss_mask computes, stores, and skips correctly.""" + + def _make_dataset(self, dynamic=True, last_turn_loss_only=False): + from torchspec.training.data_fetcher import MooncakeDataset + + ds = MooncakeDataset.__new__(MooncakeDataset) + ds.assistant_header_ids = [10, 20] + ds.end_token_ids = [30, 40] + ds.dynamic_loss_mask = dynamic + ds.last_turn_loss_only = last_turn_loss_only + return ds + + def test_packed_loss_mask_nonzero(self): + ds = self._make_dataset(dynamic=False) + data = {"packed_loss_mask": "2,3,2"} + mask = ds._compute_loss_mask(data) + assert mask is not None + assert mask.tolist() == [0, 0, 1, 1, 1, 0, 0] + assert "loss_mask" in data + + def test_packed_loss_mask_zero(self): + ds = self._make_dataset(dynamic=False) + data = {"packed_loss_mask": "10"} + assert ds._compute_loss_mask(data) is None + + def test_dynamic_mask_nonzero(self): + ds = self._make_dataset(dynamic=True) + data = {"input_ids": torch.tensor([10, 20, 1, 2, 30, 40], dtype=torch.long)} + mask = ds._compute_loss_mask(data) + assert mask is not None + assert mask.tolist() == [0, 0, 1, 1, 0, 0] + assert "loss_mask" in data + + def test_dynamic_mask_zero(self): + ds = self._make_dataset(dynamic=True) + data = {"input_ids": torch.tensor([5, 6, 7, 8], dtype=torch.long)} + assert ds._compute_loss_mask(data) is None + + def test_dynamic_mask_per_sample_last_turn_only(self): + ds = self._make_dataset(dynamic=True, last_turn_loss_only=False) + ids = [10, 20, 1, 30, 40, 10, 20, 2, 30, 40] + data = { + "input_ids": torch.tensor(ids, dtype=torch.long), + "last_turn_loss_only": True, + } + mask = ds._compute_loss_mask(data) + assert mask is not None + assert mask.tolist() == [0, 0, 0, 0, 0, 0, 0, 1, 0, 0] + assert "loss_mask" in data + + def test_dynamic_mask_per_sample_all_turns(self): + ds = self._make_dataset(dynamic=True, last_turn_loss_only=True) + ids = [10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40] + data = { + "input_ids": torch.tensor(ids, dtype=torch.long), + "last_turn_loss_only": False, + } + mask = ds._compute_loss_mask(data) + assert mask is not None + assert mask.tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0] + + def test_no_params_defaults_nonzero(self): + ds = self._make_dataset(dynamic=False) + data = {"input_ids": torch.tensor([1, 2, 3], dtype=torch.long)} + assert ds._compute_loss_mask(data) is not None + + def test_empty_packed_loss_mask(self): + ds = self._make_dataset(dynamic=False) + data = {"packed_loss_mask": ""} + assert ds._compute_loss_mask(data) is None diff --git a/tests/test_kimi_k25_integration.py b/tests/test_kimi_k25_integration.py index 144574f..19ef396 100644 --- a/tests/test_kimi_k25_integration.py +++ b/tests/test_kimi_k25_integration.py @@ -280,7 +280,7 @@ def test_preprocess_to_collator(self, mock_tokenizer, kimi_template, sample_conv { "input_ids": result["input_ids"][0], "attention_mask": result["attention_mask"][0], - "packed_loss_mask": result["packed_loss_mask"][0], + "loss_mask": unpack_loss_mask(result["packed_loss_mask"][0])[None, :], } ] diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index 20da08b..b84df92 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -39,7 +39,7 @@ class DatasetConfig: eval_interval: int = 50 eval_micro_batch_size: Optional[int] = None eval_prompt_key: Optional[str] = None - last_turn_loss_only: bool = False + last_turn_loss_only: Any = "auto" # bool or "auto" prompt_key: str = "conversations" train_data_path: str = "" diff --git a/torchspec/controller/inference_manager.py b/torchspec/controller/inference_manager.py index 4e032dd..5178e7b 100644 --- a/torchspec/controller/inference_manager.py +++ b/torchspec/controller/inference_manager.py @@ -481,6 +481,7 @@ def _parse_engine_output(self, entry: InferenceInput, output: dict) -> Inference tensor_shapes=output.get("tensor_shapes", {}), tensor_dtypes=output.get("tensor_dtypes", {}), packed_loss_mask=output.get("packed_loss_mask", entry.packed_loss_mask), + metadata=entry.metadata, ) async def _forward_results(self, results: list[tuple[InferenceInput, Any | Exception]]) -> int: diff --git a/torchspec/controller/training_controller.py b/torchspec/controller/training_controller.py index c25f959..d3e1e09 100644 --- a/torchspec/controller/training_controller.py +++ b/torchspec/controller/training_controller.py @@ -446,12 +446,15 @@ def _dispatch_to_queues( partitioned = self._partition_results(batch_results) for dp_rank, results in enumerate(partitioned): for result in results: + metadata = getattr(result, "metadata", {}) or {} + last_turn_loss_only = metadata.get("has_thinking") queues[dp_rank].put( TrainSample( mooncake_key=result.mooncake_key, tensor_shapes=result.tensor_shapes, tensor_dtypes=result.tensor_dtypes, packed_loss_mask=result.packed_loss_mask, + last_turn_loss_only=last_turn_loss_only, ) ) diff --git a/torchspec/data/dataset.py b/torchspec/data/dataset.py index 8ee6a07..f54f610 100644 --- a/torchspec/data/dataset.py +++ b/torchspec/data/dataset.py @@ -55,6 +55,15 @@ def _init_tokenize_worker( _worker_state["last_turn_loss_only"] = last_turn_loss_only +def _resolve_last_turn_loss_only(messages): + ltlo = _worker_state.get("last_turn_loss_only", False) + if ltlo == "auto": + from torchspec.data.parse import has_thinking_content + + return has_thinking_content(messages) + return bool(ltlo) + + def _tokenize_single(args): """Worker function — tokenize one sample.""" messages, max_length, train_with_decode = args @@ -68,7 +77,7 @@ def _tokenize_single(args): use_packed_loss_mask=True, add_generation_prompt=train_with_decode, return_formatted_text=True, - last_turn_loss_only=_worker_state.get("last_turn_loss_only", False), + last_turn_loss_only=_resolve_last_turn_loss_only(messages), ) if not processed["input_ids"]: return None @@ -82,7 +91,9 @@ def _tokenize_single(args): } -def _init_format_worker(tokenizer_path, trust_remote_code, chat_template_name): +def _init_format_worker( + tokenizer_path, trust_remote_code, chat_template_name, last_turn_loss_only=False +): from torchspec.data.parse import create_parser from torchspec.utils.processing import load_tokenizer @@ -90,6 +101,7 @@ def _init_format_worker(tokenizer_path, trust_remote_code, chat_template_name): tokenizer = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code) _worker_state["template"] = TEMPLATE_REGISTRY.get(chat_template_name) _worker_state["parser"] = create_parser(tokenizer, _worker_state["template"]) + _worker_state["last_turn_loss_only"] = last_turn_loss_only def _format_single(args): @@ -98,11 +110,20 @@ def _format_single(args): """ messages, _, train_with_decode = args messages = _normalize_conversation(messages) + + result = {} + ltlo = _worker_state.get("last_turn_loss_only", False) + if ltlo == "auto": + from torchspec.data.parse import has_thinking_content + + result["has_thinking"] = has_thinking_content(messages) + parser = _worker_state["parser"] formatted = parser.format(messages, add_generation_prompt=train_with_decode) if not formatted: return None - return {"formatted_prompt": formatted} + result["formatted_prompt"] = formatted + return result def load_conversation_dataset(args): @@ -185,16 +206,17 @@ def load_conversation_dataset(args): # Pass 2: process in parallel work_items = [(messages, max_length, train_with_decode) for _, messages, _ in raw_samples] + last_turn_loss_only = getattr(args, "last_turn_loss_only", False) if defer_tokenization: worker_init = _init_format_worker - worker_initargs = (args.target_model_path, True, chat_template_name) + worker_initargs = (args.target_model_path, True, chat_template_name, last_turn_loss_only) worker_fn = _format_single desc = "Formatting dataset" else: - last_turn_loss_only = getattr(args, "last_turn_loss_only", False) if last_turn_loss_only: logger.info( - "last_turn_loss_only=True: loss mask will only cover the last assistant turn" + f"last_turn_loss_only={last_turn_loss_only}: " + "loss mask will only cover the last assistant turn" ) worker_init = _init_tokenize_worker worker_initargs = (args.target_model_path, True, chat_template_name, last_turn_loss_only) @@ -221,9 +243,13 @@ def load_conversation_dataset(args): if result is None: skipped += 1 continue + metadata = {} + if "has_thinking" in result: + metadata["has_thinking"] = result["has_thinking"] + entry = { "data_id": data_id, - "metadata": {}, + "metadata": metadata, "multimodal_inputs": multimodal_inputs, "formatted_prompt": result["formatted_prompt"], } diff --git a/torchspec/data/parse.py b/torchspec/data/parse.py index 0e13875..e3c636b 100644 --- a/torchspec/data/parse.py +++ b/torchspec/data/parse.py @@ -33,7 +33,28 @@ Conversation = List[Dict[str, Any]] -__all__ = ["GeneralParser", "HarmonyParser", "KimiK25Parser", "create_parser"] +__all__ = ["GeneralParser", "HarmonyParser", "KimiK25Parser", "create_parser", "has_thinking_content"] + +_HAS_THINKING_RE = re.compile(r"(?!\s*)") + + +def has_thinking_content(conversation: list) -> bool: + """Detect whether any assistant message contains real thinking content. + + Checks for non-empty blocks in message content and for + separate thinking/thinking_content fields on the message dict. + Must be called on the raw conversation BEFORE formatting, since + formatters (e.g. KimiK25Parser) inject empty tags. + """ + for msg in conversation: + if not isinstance(msg, dict) or msg.get("role") != "assistant": + continue + content = msg.get("content", "") + if isinstance(content, str) and _HAS_THINKING_RE.search(content): + return True + if msg.get("thinking") or msg.get("thinking_content"): + return True + return False class Parser(ABC): diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py index 5e34167..292069c 100644 --- a/torchspec/data/utils.py +++ b/torchspec/data/utils.py @@ -28,7 +28,6 @@ from datasets import IterableDataset, load_dataset from huggingface_hub import hf_hub_download, list_repo_files -from torchspec.models.ops.loss_mask import compute_assistant_loss_mask _IMAGE_CACHE_DIR = os.environ.get("TORCHSPEC_IMAGE_CACHE", "/data/ywang/image_cache") @@ -67,18 +66,8 @@ def is_local_data_path(path: str, base_dir: str | None = None) -> bool: class DataCollatorWithPadding: - def __init__( - self, - assistant_header_ids: Optional[List[int]] = None, - end_token_ids: Optional[List[int]] = None, - dynamic_loss_mask: bool = False, - last_turn_loss_only: bool = False, - ): + def __init__(self): self.sp_degree = 1 - self.assistant_header_ids = assistant_header_ids - self.end_token_ids = end_token_ids - self.dynamic_loss_mask = dynamic_loss_mask - self.last_turn_loss_only = last_turn_loss_only def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: B, n, S = intensors.shape @@ -93,36 +82,14 @@ def paddingtensor2D(self, intensors: torch.Tensor, N: int) -> torch.Tensor: return outtensors def _get_loss_mask(self, item: Dict[str, Any]) -> torch.Tensor: - """Derive loss_mask for a single sample. + """Read the materialized loss_mask tensor from the item. - Priority: - 1. dynamic_loss_mask flag → compute from input_ids token boundaries - 2. packed_loss_mask string → unpack RLE-encoded mask - 3. fallback → all-ones mask + Callers (e.g. MooncakeDataset) are responsible for computing and + attaching loss_mask before items reach the collator. """ - if self.dynamic_loss_mask: - if self.assistant_header_ids is None or self.end_token_ids is None: - raise ValueError( - "dynamic_loss_mask requires assistant_header_ids and " - "end_token_ids to be set on the collator" - ) - # TCP's input_ids are on CPU, so we can use input_ids_cpu directly. - input_ids = item.get("input_ids_cpu", item["input_ids"]) - if input_ids.dim() == 2: - input_ids = input_ids.squeeze(0) - mask = compute_assistant_loss_mask( - input_ids, - self.assistant_header_ids, - self.end_token_ids, - last_turn_only=self.last_turn_loss_only, - ) - # Copy back to GPU. - return mask[None, :].to(item["input_ids"].device) - - if "packed_loss_mask" not in item or item["packed_loss_mask"] is None: - seq_len = item["input_ids"].shape[-1] - return torch.ones(1, seq_len, dtype=torch.long, device=item["input_ids"].device) - return unpack_loss_mask(item["packed_loss_mask"])[None, :] + if "loss_mask" in item and isinstance(item["loss_mask"], torch.Tensor): + return item["loss_mask"] + raise KeyError(f"loss_mask not found in item: {item}") def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: max_length = max(item["input_ids"].shape[1] for item in features) diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index 18d3f5e..565f8b2 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -32,6 +32,8 @@ from ray.util.queue import Queue as RayQueue from torch.utils.data import DataLoader, IterableDataset +from torchspec.data.utils import unpack_loss_mask +from torchspec.models.ops.loss_mask import compute_assistant_loss_mask from torchspec.utils.logging import logger @@ -41,6 +43,7 @@ class TrainSample: tensor_shapes: Dict[str, Tuple[int, ...]] tensor_dtypes: Optional[Dict[str, torch.dtype]] = None packed_loss_mask: Optional[str] = None + last_turn_loss_only: Optional[bool] = None class MooncakeDataset(IterableDataset): @@ -57,6 +60,10 @@ def __init__( device: torch.device, prefetch_factor: int = 2, timeout: Optional[float] = None, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + dynamic_loss_mask: bool = False, + last_turn_loss_only: bool = False, ): """ Args: @@ -65,12 +72,20 @@ def __init__( device: Target device for tensors. prefetch_factor: Number of samples to prefetch in background thread. timeout: Timeout in seconds for waiting on queue. None means wait forever. + assistant_header_ids: Token IDs for assistant header (for loss mask skip check). + end_token_ids: Token IDs for end of turn (for loss mask skip check). + dynamic_loss_mask: Whether loss mask is computed dynamically from input_ids. + last_turn_loss_only: Global fallback for last-turn-only loss masking. """ self.ray_queue = ray_queue self.mooncake_store = mooncake_store self.device = device self.prefetch_factor = prefetch_factor self.timeout = timeout + self.assistant_header_ids = assistant_header_ids + self.end_token_ids = end_token_ids + self.dynamic_loss_mask = dynamic_loss_mask + self.last_turn_loss_only = last_turn_loss_only def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: """Load tensors from mooncake key into device memory.""" @@ -104,6 +119,8 @@ def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: result = tensors.to_tensor_dict() if sample.packed_loss_mask is not None: result["packed_loss_mask"] = sample.packed_loss_mask + if sample.last_turn_loss_only is not None: + result["last_turn_loss_only"] = sample.last_turn_loss_only return result def _cleanup_mooncake_data(self, sample: TrainSample) -> None: @@ -118,12 +135,51 @@ def _cleanup_mooncake_data(self, sample: TrainSample) -> None: has_target=has_target, ) + def _compute_loss_mask(self, data: Dict[str, Any]) -> torch.Tensor | None: + """Compute the loss mask for a sample and store it on the data dict. + + This is the single place where loss masks are resolved for mooncake + samples, so the collator never needs to recompute. + + Returns the 1-D mask tensor, or None if the mask is all zeros. + """ + packed = data.get("packed_loss_mask") + if packed is not None: + mask = unpack_loss_mask(packed) + if not mask.any(): + return None + data["loss_mask"] = mask + return mask + + if self.dynamic_loss_mask and self.assistant_header_ids and self.end_token_ids: + input_ids = data.get("input_ids") + if input_ids is None: + return None + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + per_sample = data.get("last_turn_loss_only") + last_turn_only = per_sample if per_sample is not None else self.last_turn_loss_only + mask = compute_assistant_loss_mask( + input_ids, + self.assistant_header_ids, + self.end_token_ids, + last_turn_only=last_turn_only, + ) + if not mask.any(): + return None + data["loss_mask"] = mask + return mask + + return torch.ones(1) # non-None signals "don't skip" + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Iterate over samples synchronously. Blocks waiting for each item from the queue and loads from mooncake. + Skips samples whose loss mask is all zeros to avoid wasted compute. """ yield_count = 0 + skip_count = 0 while True: logger.debug(f"__iter__: waiting for item from ray_queue (yield_count={yield_count})") try: @@ -138,6 +194,15 @@ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: logger.debug(f"__iter__: got item, mooncake_key={item.mooncake_key}") data = self._load_from_mooncake(item) + + if self._compute_loss_mask(data) is None: + skip_count += 1 + logger.warning( + f"Skipping sample with all-zero loss mask " + f"(mooncake_key={item.mooncake_key}, total_skipped={skip_count})" + ) + continue + # Note: target is computed in the collator from last_hidden_states for sglang mode # Add batch dimension if missing (sglang stores without batch dim) @@ -174,6 +239,10 @@ def create_mooncake_dataloader( batch_size: int = 1, prefetch_factor: int = 2, timeout: Optional[float] = None, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + dynamic_loss_mask: bool = False, + last_turn_loss_only: bool = False, ) -> DataLoader: """Create a DataLoader that fetches from mooncake via queue. @@ -193,11 +262,25 @@ def create_mooncake_dataloader( batch_size: Number of samples per batch (= per_dp_rank_batch_size). prefetch_factor: Unused, kept for API compatibility. timeout: Timeout in seconds for waiting on queue. None means wait forever. + assistant_header_ids: Token IDs for assistant header (for loss mask skip check). + end_token_ids: Token IDs for end of turn (for loss mask skip check). + dynamic_loss_mask: Whether loss mask is computed dynamically from input_ids. + last_turn_loss_only: Global fallback for last-turn-only loss masking. Returns: DataLoader instance. """ - dataset = MooncakeDataset(ray_queue, mooncake_store, device, prefetch_factor, timeout) + dataset = MooncakeDataset( + ray_queue, + mooncake_store, + device, + prefetch_factor, + timeout, + assistant_header_ids=assistant_header_ids, + end_token_ids=end_token_ids, + dynamic_loss_mask=dynamic_loss_mask, + last_turn_loss_only=last_turn_loss_only, + ) return DataLoader( dataset, @@ -232,6 +315,10 @@ def __init__( batch_size: int = 1, prefetch_factor: int = 2, timeout: Optional[float] = None, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + dynamic_loss_mask: bool = False, + last_turn_loss_only: bool = False, ): self.batch_size = batch_size self._dataloader = create_mooncake_dataloader( @@ -242,6 +329,10 @@ def __init__( batch_size=batch_size, prefetch_factor=prefetch_factor, timeout=timeout, + assistant_header_ids=assistant_header_ids, + end_token_ids=end_token_ids, + dynamic_loss_mask=dynamic_loss_mask, + last_turn_loss_only=last_turn_loss_only, ) def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: diff --git a/torchspec/training/trainer.py b/torchspec/training/trainer.py index dfe8bfe..e4f9da7 100644 --- a/torchspec/training/trainer.py +++ b/torchspec/training/trainer.py @@ -156,12 +156,7 @@ def set_train_queue( if mooncake_config is not None and self.mooncake_store is None: self.init_mooncake_store(mooncake_config) - collator = DataCollatorWithPadding( - assistant_header_ids=self.assistant_header_ids, - end_token_ids=self.end_token_ids, - dynamic_loss_mask=self.dynamic_loss_mask, - last_turn_loss_only=self.last_turn_loss_only, - ) + collator = DataCollatorWithPadding() self.data_fetcher = MooncakeDataFetcher( queue=self.train_queue, @@ -169,6 +164,10 @@ def set_train_queue( collator=collator, device=torch.cuda.current_device(), batch_size=per_dp_rank_batch_size, + assistant_header_ids=self.assistant_header_ids, + end_token_ids=self.end_token_ids, + dynamic_loss_mask=self.dynamic_loss_mask, + last_turn_loss_only=self.last_turn_loss_only, ) logger.info( @@ -188,12 +187,7 @@ def set_eval_queue( if mooncake_config is not None and self.mooncake_store is None: self.init_mooncake_store(mooncake_config) - collator = DataCollatorWithPadding( - assistant_header_ids=self.assistant_header_ids, - end_token_ids=self.end_token_ids, - dynamic_loss_mask=self.dynamic_loss_mask, - last_turn_loss_only=self.last_turn_loss_only, - ) + collator = DataCollatorWithPadding() self._eval_data_fetcher = MooncakeDataFetcher( queue=queue, @@ -201,6 +195,10 @@ def set_eval_queue( collator=collator, device=torch.cuda.current_device(), batch_size=per_dp_rank_batch_size, + assistant_header_ids=self.assistant_header_ids, + end_token_ids=self.end_token_ids, + dynamic_loss_mask=self.dynamic_loss_mask, + last_turn_loss_only=self.last_turn_loss_only, ) self._eval_collator = collator self._eval_cache: list[dict] = [] diff --git a/torchspec/utils/types.py b/torchspec/utils/types.py index ee4b329..9d9fa89 100644 --- a/torchspec/utils/types.py +++ b/torchspec/utils/types.py @@ -49,3 +49,4 @@ class InferenceOutput: tensor_shapes: dict[str, tuple[int, ...]] tensor_dtypes: dict[str, torch.dtype] | None = None packed_loss_mask: str | None = None + metadata: dict = field(default_factory=dict) From 1e926404c31e33f636850195afdf3c724d7bf730 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 9 Mar 2026 17:08:02 -0700 Subject: [PATCH 3/7] fix ruff --- tests/test_dynamic_last_turn_loss.py | 1 - torchspec/data/parse.py | 8 +++++++- torchspec/data/utils.py | 3 +-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_dynamic_last_turn_loss.py b/tests/test_dynamic_last_turn_loss.py index 7900e30..35dd45a 100644 --- a/tests/test_dynamic_last_turn_loss.py +++ b/tests/test_dynamic_last_turn_loss.py @@ -5,7 +5,6 @@ from torchspec.data.parse import has_thinking_content from torchspec.data.utils import DataCollatorWithPadding - # ── has_thinking_content detection ──────────────────────────────────── diff --git a/torchspec/data/parse.py b/torchspec/data/parse.py index e3c636b..3912431 100644 --- a/torchspec/data/parse.py +++ b/torchspec/data/parse.py @@ -33,7 +33,13 @@ Conversation = List[Dict[str, Any]] -__all__ = ["GeneralParser", "HarmonyParser", "KimiK25Parser", "create_parser", "has_thinking_content"] +__all__ = [ + "GeneralParser", + "HarmonyParser", + "KimiK25Parser", + "create_parser", + "has_thinking_content", +] _HAS_THINKING_RE = re.compile(r"(?!\s*)") diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py index 292069c..a73acc7 100644 --- a/torchspec/data/utils.py +++ b/torchspec/data/utils.py @@ -21,14 +21,13 @@ import json import os from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union from urllib.parse import urlparse import torch from datasets import IterableDataset, load_dataset from huggingface_hub import hf_hub_download, list_repo_files - _IMAGE_CACHE_DIR = os.environ.get("TORCHSPEC_IMAGE_CACHE", "/data/ywang/image_cache") From 0a83740b08dd9d8cf8a12c0d651b7de7f53a84fd Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 16 Mar 2026 09:06:41 +0000 Subject: [PATCH 4/7] fix: handle BPE merge in dynamic loss mask via skip_after_header The assistant header's trailing newline can merge with content tokens under BPE, breaking subsequence matching. Strip the newline from the header pattern and skip those token(s) when scanning for assistant content boundaries. Also fixes FSDP sync for batches with no valid loss-mask tokens in Eagle3 by using a dummy index through the compiled graph instead of an early-return path. --- tests/test_dynamic_last_turn_loss.py | 1 + tests/test_loss_mask_cross_validation.py | 185 +++++++++++++++++++++++ torchspec/models/ops/loss_mask.py | 10 +- torchspec/training/data_fetcher.py | 19 +-- torchspec/training/trainer.py | 6 +- torchspec/utils/processing.py | 33 +++- 6 files changed, 231 insertions(+), 23 deletions(-) create mode 100644 tests/test_loss_mask_cross_validation.py diff --git a/tests/test_dynamic_last_turn_loss.py b/tests/test_dynamic_last_turn_loss.py index 35dd45a..a89726e 100644 --- a/tests/test_dynamic_last_turn_loss.py +++ b/tests/test_dynamic_last_turn_loss.py @@ -159,6 +159,7 @@ def _make_dataset(self, dynamic=True, last_turn_loss_only=False): ds.end_token_ids = [30, 40] ds.dynamic_loss_mask = dynamic ds.last_turn_loss_only = last_turn_loss_only + ds.skip_after_header = 0 return ds def test_packed_loss_mask_nonzero(self): diff --git a/tests/test_loss_mask_cross_validation.py b/tests/test_loss_mask_cross_validation.py new file mode 100644 index 0000000..65a37af --- /dev/null +++ b/tests/test_loss_mask_cross_validation.py @@ -0,0 +1,185 @@ +"""Cross-validation: compute_assistant_loss_mask must match parser.parse for every template. + +CI gate — auto-parametrizes over all registered templates so adding a new +template without verifying loss mask correctness will fail the test. + +Requires tokenizer downloads (~10MB each). Templates whose tokenizer is +unavailable are skipped. +""" + +import pytest +import torch + +from torchspec.data.parse import create_parser +from torchspec.data.template import TEMPLATE_REGISTRY +from torchspec.models.ops.loss_mask import compute_assistant_loss_mask + +_END_TOKEN_BPE_XFAIL: set[str] = set() + +MESSAGES_MULTI_TURN = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "The answer is 4."}, + {"role": "user", "content": "And 3+3?"}, + {"role": "assistant", "content": "The answer is 6."}, +] + +MESSAGES_LEADING_NEWLINES = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "\n\nThe answer is 4."}, +] + +_tokenizer_cache: dict = {} + + +def _testable_templates(): + """Templates that support dynamic loss mask (need both header and end token).""" + for name in TEMPLATE_REGISTRY.get_all_template_names(): + template = TEMPLATE_REGISTRY.get(name) + if template.assistant_header and template.end_of_turn_token: + yield name + + +def _get_tokenizer(template_name): + template = TEMPLATE_REGISTRY.get(template_name) + model_path = template.reference_model + if model_path is None: + pytest.skip(f"No reference_model for template '{template_name}'") + + if model_path in _tokenizer_cache: + return _tokenizer_cache[model_path] + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + except Exception as e: + pytest.skip(f"Tokenizer unavailable for {model_path}: {e}") + + _tokenizer_cache[model_path] = tokenizer + return tokenizer + + +def _get_header_ids_and_skip(tokenizer, template): + """Replicate get_assistant_token_ids() logic for the test.""" + full_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False) + stripped = template.assistant_header.rstrip("\n") + stripped_ids = tokenizer.encode(stripped, add_special_tokens=False) + skip_after = len(full_ids) - len(stripped_ids) + end_ids = tokenizer.encode(template.end_of_turn_token, add_special_tokens=False) + return stripped_ids, end_ids, skip_after + + +def _first_diff(a, b): + for i, (x, y) in enumerate(zip(a, b)): + if x != y: + return i + if len(a) != len(b): + return min(len(a), len(b)) + return None + + +def _strip_end_tokens_from_mask(mask_list, ids_list, end_ids): + """Zero out end-of-turn token positions in a mask. + + The parser regex captures content *including* the end_of_turn_token, + but compute_assistant_loss_mask marks content *excluding* it. + This normalizes the parser mask to match the dynamic mask convention. + """ + result = list(mask_list) + end_len = len(end_ids) + i = 0 + while i <= len(ids_list) - end_len: + if ids_list[i : i + end_len] == end_ids: + for k in range(end_len): + result[i + k] = 0 + i += end_len + else: + i += 1 + return result + + +def _cross_validate(template_name, messages, last_turn_only=False): + """Core cross-validation: parser ground truth vs dynamic mask.""" + template = TEMPLATE_REGISTRY.get(template_name) + tokenizer = _get_tokenizer(template_name) + parser = create_parser(tokenizer, template) + header_ids, end_ids, skip_after = _get_header_ids_and_skip(tokenizer, template) + + formatted = parser.format(messages, add_generation_prompt=False, expand_media_tokens=False) + + gt_ids, gt_mask = parser.parse( + formatted, max_length=200000, preformatted=True, last_turn_only=last_turn_only + ) + gt_ids_list = gt_ids.squeeze().tolist() + gt_mask_list = gt_mask.squeeze().tolist() + + if sum(gt_mask_list) == 0: + pytest.skip( + f"Parser produced all-zero mask for '{template_name}' — " + f"template header may not match tokenizer's chat template" + ) + + engine_ids = tokenizer.encode(formatted, add_special_tokens=False) + + assert gt_ids_list == engine_ids, ( + f"[{template_name}] Tokenization mismatch: " + f"parser produced {len(gt_ids_list)} tokens, " + f"tokenizer.encode produced {len(engine_ids)} tokens" + ) + + # Parser includes end_of_turn_token in mask; dynamic mask does not. + normalized_gt = _strip_end_tokens_from_mask(gt_mask_list, engine_ids, end_ids) + + assert sum(normalized_gt) > 0, ( + f"[{template_name}] After stripping end tokens, parser mask is all zeros — " + f"this likely means the parser only matched end tokens as content" + ) + + dyn_mask = compute_assistant_loss_mask( + torch.tensor(engine_ids), + header_ids, + end_ids, + last_turn_only=last_turn_only, + skip_after_header=skip_after, + ) + + diff_idx = _first_diff(normalized_gt, dyn_mask.tolist()) + assert normalized_gt == dyn_mask.tolist(), ( + f"[{template_name}] Loss mask mismatch " + f"(last_turn_only={last_turn_only}):\n" + f" parser mask 1s (normalized): {sum(normalized_gt)}\n" + f" dynamic mask 1s: {dyn_mask.sum().item()}\n" + f" first diff at token {diff_idx}" + ) + + +# ── Parametrized tests ─────────────────────────────────────────────── + + +def _maybe_xfail(template_name): + if template_name in _END_TOKEN_BPE_XFAIL: + pytest.xfail( + f"'{template_name}' has a known end-token BPE merge issue — " + f"use precomputed masks (defer_tokenization=False) for this template" + ) + + +@pytest.mark.parametrize("template_name", list(_testable_templates())) +def test_multi_turn(template_name): + """Dynamic mask matches parser on a standard multi-turn conversation.""" + _maybe_xfail(template_name) + _cross_validate(template_name, MESSAGES_MULTI_TURN) + + +@pytest.mark.parametrize("template_name", list(_testable_templates())) +def test_multi_turn_last_turn_only(template_name): + """Dynamic mask matches parser with last_turn_only=True.""" + _maybe_xfail(template_name) + _cross_validate(template_name, MESSAGES_MULTI_TURN, last_turn_only=True) + + +@pytest.mark.parametrize("template_name", list(_testable_templates())) +def test_leading_newlines(template_name): + """Dynamic mask matches parser when content starts with newlines (BPE merge edge case).""" + _maybe_xfail(template_name) + _cross_validate(template_name, MESSAGES_LEADING_NEWLINES) diff --git a/torchspec/models/ops/loss_mask.py b/torchspec/models/ops/loss_mask.py index dc555b2..4da0abf 100644 --- a/torchspec/models/ops/loss_mask.py +++ b/torchspec/models/ops/loss_mask.py @@ -24,7 +24,7 @@ @numba.njit(cache=True) -def _numba_loss_mask(ids, header, header_len, end, end_len, out): +def _numba_loss_mask(ids, header, header_len, end, end_len, out, skip_after): n = len(ids) i = 0 while i <= n - header_len: @@ -36,7 +36,7 @@ def _numba_loss_mask(ids, header, header_len, end, end_len, out): if not match: i += 1 continue - j = i + header_len + j = i + header_len + skip_after found_end = False while j <= n - end_len: end_match = True @@ -70,6 +70,7 @@ def compute_assistant_loss_mask( assistant_header_ids: list[int], end_token_ids: list[int], last_turn_only: bool = False, + skip_after_header: int = 0, ) -> torch.Tensor: """Compute loss mask where 1s mark assistant content tokens only. @@ -82,6 +83,9 @@ def compute_assistant_loss_mask( assistant_header_ids: Token ID sequence marking the start of assistant content. end_token_ids: Token ID sequence marking the end of assistant content. last_turn_only: If True, only the last assistant turn is marked. + skip_after_header: Number of tokens to skip after header match before + marking content. Use 1 when the header excludes a trailing newline + that BPE may merge with subsequent content. Returns: 1-D long tensor on the same device as input_ids, with 1s for assistant @@ -96,7 +100,7 @@ def compute_assistant_loss_mask( header_np = np.array(assistant_header_ids, dtype=np.int64) end_np = np.array(end_token_ids, dtype=np.int64) out = np.zeros(len(ids_np), dtype=np.int64) - _numba_loss_mask(ids_np, header_np, len(header_np), end_np, len(end_np), out) + _numba_loss_mask(ids_np, header_np, len(header_np), end_np, len(end_np), out, skip_after_header) if last_turn_only and out.any(): last_one = len(out) - 1 - np.argmax(out[::-1]) diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index 565f8b2..6681d3a 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -64,19 +64,8 @@ def __init__( end_token_ids: Optional[List[int]] = None, dynamic_loss_mask: bool = False, last_turn_loss_only: bool = False, + skip_after_header: int = 0, ): - """ - Args: - ray_queue: Ray Queue to receive TrainSample from controller. - mooncake_store: Mooncake store client for loading tensors. - device: Target device for tensors. - prefetch_factor: Number of samples to prefetch in background thread. - timeout: Timeout in seconds for waiting on queue. None means wait forever. - assistant_header_ids: Token IDs for assistant header (for loss mask skip check). - end_token_ids: Token IDs for end of turn (for loss mask skip check). - dynamic_loss_mask: Whether loss mask is computed dynamically from input_ids. - last_turn_loss_only: Global fallback for last-turn-only loss masking. - """ self.ray_queue = ray_queue self.mooncake_store = mooncake_store self.device = device @@ -86,6 +75,7 @@ def __init__( self.end_token_ids = end_token_ids self.dynamic_loss_mask = dynamic_loss_mask self.last_turn_loss_only = last_turn_loss_only + self.skip_after_header = skip_after_header def _load_from_mooncake(self, sample: TrainSample) -> Dict[str, Any]: """Load tensors from mooncake key into device memory.""" @@ -164,6 +154,7 @@ def _compute_loss_mask(self, data: Dict[str, Any]) -> torch.Tensor | None: self.assistant_header_ids, self.end_token_ids, last_turn_only=last_turn_only, + skip_after_header=self.skip_after_header, ) if not mask.any(): return None @@ -243,6 +234,7 @@ def create_mooncake_dataloader( end_token_ids: Optional[List[int]] = None, dynamic_loss_mask: bool = False, last_turn_loss_only: bool = False, + skip_after_header: int = 0, ) -> DataLoader: """Create a DataLoader that fetches from mooncake via queue. @@ -280,6 +272,7 @@ def create_mooncake_dataloader( end_token_ids=end_token_ids, dynamic_loss_mask=dynamic_loss_mask, last_turn_loss_only=last_turn_loss_only, + skip_after_header=skip_after_header, ) return DataLoader( @@ -319,6 +312,7 @@ def __init__( end_token_ids: Optional[List[int]] = None, dynamic_loss_mask: bool = False, last_turn_loss_only: bool = False, + skip_after_header: int = 0, ): self.batch_size = batch_size self._dataloader = create_mooncake_dataloader( @@ -333,6 +327,7 @@ def __init__( end_token_ids=end_token_ids, dynamic_loss_mask=dynamic_loss_mask, last_turn_loss_only=last_turn_loss_only, + skip_after_header=skip_after_header, ) def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: diff --git a/torchspec/training/trainer.py b/torchspec/training/trainer.py index e4f9da7..bd4aa71 100644 --- a/torchspec/training/trainer.py +++ b/torchspec/training/trainer.py @@ -79,7 +79,9 @@ def __init__(self, args: Namespace): self.dynamic_loss_mask = getattr(args, "dynamic_loss_mask", False) self.last_turn_loss_only = getattr(args, "last_turn_loss_only", False) - self.assistant_header_ids, self.end_token_ids = get_assistant_token_ids(self.args) + self.assistant_header_ids, self.end_token_ids, self.skip_after_header = ( + get_assistant_token_ids(self.args) + ) self.save_debug_train_data = getattr(args, "save_debug_train_data", None) self.max_dump_steps = getattr(args, "max_dump_steps", 5) @@ -168,6 +170,7 @@ def set_train_queue( end_token_ids=self.end_token_ids, dynamic_loss_mask=self.dynamic_loss_mask, last_turn_loss_only=self.last_turn_loss_only, + skip_after_header=self.skip_after_header, ) logger.info( @@ -199,6 +202,7 @@ def set_eval_queue( end_token_ids=self.end_token_ids, dynamic_loss_mask=self.dynamic_loss_mask, last_turn_loss_only=self.last_turn_loss_only, + skip_after_header=self.skip_after_header, ) self._eval_collator = collator self._eval_cache: list[dict] = [] diff --git a/torchspec/utils/processing.py b/torchspec/utils/processing.py index a06e315..974b63c 100644 --- a/torchspec/utils/processing.py +++ b/torchspec/utils/processing.py @@ -28,22 +28,41 @@ def load_tokenizer(name_or_path: str, **kwargs): return AutoTokenizer.from_pretrained(name_or_path, **kwargs) -def get_assistant_token_ids(args) -> tuple[list[int] | None, list[int] | None]: - """Derive assistant_header_ids and end_token_ids from chat_template config.""" +def get_assistant_token_ids( + args, +) -> tuple[list[int] | None, list[int] | None, int]: + """Derive assistant_header_ids, end_token_ids, and skip_after_header. + + Returns: + (header_ids, end_ids, skip_after_header) where skip_after_header is the + number of trailing-newline tokens stripped from the header to avoid BPE + merge issues. Pass this to ``compute_assistant_loss_mask`` so it skips + the formatting newline after the role name. + """ from torchspec.data.template import TEMPLATE_REGISTRY chat_template_name = getattr(args, "chat_template", None) if not chat_template_name: - return None, None + return None, None, 0 template = TEMPLATE_REGISTRY.get(chat_template_name) if not template.assistant_header or not template.end_of_turn_token: - return None, None + return None, None, 0 tokenizer = load_tokenizer(args.target_model_path, trust_remote_code=True) - assistant_header_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False) + + full_ids = tokenizer.encode(template.assistant_header, add_special_tokens=False) + header_text = template.assistant_header.rstrip("\n") + stripped_ids = tokenizer.encode(header_text, add_special_tokens=False) + # BPE tokenizers may merge the header's trailing \n with content-leading + # newlines into a single multi-char token, breaking subsequence matching. + # We tokenize without the trailing \n and tell the mask function to skip + # the newline token(s) that follow. + skip_after_header = len(full_ids) - len(stripped_ids) + end_token_ids = tokenizer.encode(template.end_of_turn_token, add_special_tokens=False) logger.info( - f"Assistant loss mask token IDs: header={assistant_header_ids}, end={end_token_ids}" + f"Assistant loss mask token IDs: header={stripped_ids}, end={end_token_ids}, " + f"skip_after_header={skip_after_header}" ) - return assistant_header_ids, end_token_ids + return stripped_ids, end_token_ids, skip_after_header From 09aa71b115d498d144e87dde3ab9a6f1b9bb4137 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 17 Mar 2026 08:36:17 +0000 Subject: [PATCH 5/7] refactor: extract resolve_loss_mask to utils, add image token budget filter, normalize reasoning fields - Move loss-mask resolution from MooncakeDataset into reusable resolve_loss_mask() in data/utils.py - Add pre-tokenization filtering to drop samples exceeding estimated token budget (image-aware) - Preserve reasoning_content/reasoning/thinking fields during conversation normalization - Move process-level imports to module top-level in dataset.py - Expand has_thinking_content tests for reasoning_content and reasoning fields --- tests/test_dynamic_last_turn_loss.py | 97 +++++++++++++++++++--------- torchspec/data/dataset.py | 58 +++++++++++++---- torchspec/data/preprocessing.py | 7 +- torchspec/data/utils.py | 68 +++++++++++++------ torchspec/training/data_fetcher.py | 47 +++----------- 5 files changed, 176 insertions(+), 101 deletions(-) diff --git a/tests/test_dynamic_last_turn_loss.py b/tests/test_dynamic_last_turn_loss.py index a89726e..4767f42 100644 --- a/tests/test_dynamic_last_turn_loss.py +++ b/tests/test_dynamic_last_turn_loss.py @@ -3,7 +3,7 @@ import torch from torchspec.data.parse import has_thinking_content -from torchspec.data.utils import DataCollatorWithPadding +from torchspec.data.utils import DataCollatorWithPadding, resolve_loss_mask # ── has_thinking_content detection ──────────────────────────────────── @@ -118,6 +118,34 @@ def test_system_message_ignored(self): ] assert has_thinking_content(conv) is False + def test_reasoning_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning_content": "Let me think..."}, + ] + assert has_thinking_content(conv) is True + + def test_reasoning_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning": "step by step"}, + ] + assert has_thinking_content(conv) is True + + def test_empty_reasoning_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning_content": ""}, + ] + assert has_thinking_content(conv) is False + + def test_none_reasoning_content_field(self): + conv = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!", "reasoning_content": None}, + ] + assert has_thinking_content(conv) is False + # ── Collator reads pre-computed loss_mask from MooncakeDataset ──────── @@ -145,78 +173,87 @@ def test_missing_loss_mask_raises(self): collator._get_loss_mask(item) -# ── MooncakeDataset._compute_loss_mask (single source of truth) ────── - +# ── resolve_loss_mask (single source of truth in data/utils.py) ─────── -class TestComputeLossMask: - """Test that _compute_loss_mask computes, stores, and skips correctly.""" +_HEADER = [10, 20] +_END = [30, 40] - def _make_dataset(self, dynamic=True, last_turn_loss_only=False): - from torchspec.training.data_fetcher import MooncakeDataset - ds = MooncakeDataset.__new__(MooncakeDataset) - ds.assistant_header_ids = [10, 20] - ds.end_token_ids = [30, 40] - ds.dynamic_loss_mask = dynamic - ds.last_turn_loss_only = last_turn_loss_only - ds.skip_after_header = 0 - return ds +class TestResolveLossMask: + """Test that resolve_loss_mask computes, stores, and skips correctly.""" def test_packed_loss_mask_nonzero(self): - ds = self._make_dataset(dynamic=False) data = {"packed_loss_mask": "2,3,2"} - mask = ds._compute_loss_mask(data) + mask = resolve_loss_mask(data) assert mask is not None assert mask.tolist() == [0, 0, 1, 1, 1, 0, 0] assert "loss_mask" in data def test_packed_loss_mask_zero(self): - ds = self._make_dataset(dynamic=False) data = {"packed_loss_mask": "10"} - assert ds._compute_loss_mask(data) is None + assert resolve_loss_mask(data) is None def test_dynamic_mask_nonzero(self): - ds = self._make_dataset(dynamic=True) data = {"input_ids": torch.tensor([10, 20, 1, 2, 30, 40], dtype=torch.long)} - mask = ds._compute_loss_mask(data) + mask = resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + ) assert mask is not None assert mask.tolist() == [0, 0, 1, 1, 0, 0] assert "loss_mask" in data def test_dynamic_mask_zero(self): - ds = self._make_dataset(dynamic=True) data = {"input_ids": torch.tensor([5, 6, 7, 8], dtype=torch.long)} - assert ds._compute_loss_mask(data) is None + assert ( + resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + ) + is None + ) def test_dynamic_mask_per_sample_last_turn_only(self): - ds = self._make_dataset(dynamic=True, last_turn_loss_only=False) ids = [10, 20, 1, 30, 40, 10, 20, 2, 30, 40] data = { "input_ids": torch.tensor(ids, dtype=torch.long), "last_turn_loss_only": True, } - mask = ds._compute_loss_mask(data) + mask = resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + last_turn_loss_only=False, + ) assert mask is not None assert mask.tolist() == [0, 0, 0, 0, 0, 0, 0, 1, 0, 0] assert "loss_mask" in data def test_dynamic_mask_per_sample_all_turns(self): - ds = self._make_dataset(dynamic=True, last_turn_loss_only=True) ids = [10, 20, 1, 2, 30, 40, 10, 20, 3, 4, 30, 40] data = { "input_ids": torch.tensor(ids, dtype=torch.long), "last_turn_loss_only": False, } - mask = ds._compute_loss_mask(data) + mask = resolve_loss_mask( + data, + dynamic_loss_mask=True, + assistant_header_ids=_HEADER, + end_token_ids=_END, + last_turn_loss_only=True, + ) assert mask is not None assert mask.tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0] def test_no_params_defaults_nonzero(self): - ds = self._make_dataset(dynamic=False) data = {"input_ids": torch.tensor([1, 2, 3], dtype=torch.long)} - assert ds._compute_loss_mask(data) is not None + assert resolve_loss_mask(data) is not None def test_empty_packed_loss_mask(self): - ds = self._make_dataset(dynamic=False) data = {"packed_loss_mask": ""} - assert ds._compute_loss_mask(data) is None + assert resolve_loss_mask(data) is None diff --git a/torchspec/data/dataset.py b/torchspec/data/dataset.py index f54f610..485d16b 100644 --- a/torchspec/data/dataset.py +++ b/torchspec/data/dataset.py @@ -26,7 +26,8 @@ import torch from tqdm import tqdm -from torchspec.data.preprocessing import _normalize_conversation +from torchspec.data.parse import create_parser, has_thinking_content +from torchspec.data.preprocessing import _normalize_conversation, preprocess_conversations from torchspec.data.template import TEMPLATE_REGISTRY from torchspec.data.utils import ( estimate_row_count, @@ -35,6 +36,7 @@ load_hf_dataset, ) from torchspec.utils.logging import logger +from torchspec.utils.processing import load_tokenizer _logging.getLogger("transformers_modules").setLevel(_logging.ERROR) @@ -45,9 +47,6 @@ def _init_tokenize_worker( tokenizer_path, trust_remote_code, chat_template_name, last_turn_loss_only=False ): """Initializer for each worker process — loads tokenizer once.""" - from torchspec.data.preprocessing import preprocess_conversations - from torchspec.utils.processing import load_tokenizer - _logging.getLogger("transformers_modules").setLevel(_logging.ERROR) _worker_state["tokenizer"] = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code) _worker_state["template"] = TEMPLATE_REGISTRY.get(chat_template_name) @@ -58,8 +57,6 @@ def _init_tokenize_worker( def _resolve_last_turn_loss_only(messages): ltlo = _worker_state.get("last_turn_loss_only", False) if ltlo == "auto": - from torchspec.data.parse import has_thinking_content - return has_thinking_content(messages) return bool(ltlo) @@ -94,9 +91,6 @@ def _tokenize_single(args): def _init_format_worker( tokenizer_path, trust_remote_code, chat_template_name, last_turn_loss_only=False ): - from torchspec.data.parse import create_parser - from torchspec.utils.processing import load_tokenizer - _logging.getLogger("transformers_modules").setLevel(_logging.ERROR) tokenizer = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code) _worker_state["template"] = TEMPLATE_REGISTRY.get(chat_template_name) @@ -114,12 +108,12 @@ def _format_single(args): result = {} ltlo = _worker_state.get("last_turn_loss_only", False) if ltlo == "auto": - from torchspec.data.parse import has_thinking_content - result["has_thinking"] = has_thinking_content(messages) parser = _worker_state["parser"] - formatted = parser.format(messages, add_generation_prompt=train_with_decode) + formatted = parser.format( + messages, add_generation_prompt=train_with_decode, expand_media_tokens=False + ) if not formatted: return None result["formatted_prompt"] = formatted @@ -162,10 +156,13 @@ def load_conversation_dataset(args): file_stat = f"-{st.st_size}-{st.st_mtime}" last_turn_loss_only_flag = getattr(args, "last_turn_loss_only", False) train_with_decode = getattr(args, "train_with_decode", False) + max_images_cfg = getattr(args, "max_images_per_sample", None) + tokens_per_image_cfg = getattr(args, "tokens_per_image", 3000) cache_params = ( f"{dataset_name}-{args.train_data_path}{file_stat}-{args.target_model_path}" f"-{max_length}-{chat_template_name}-ltlo={last_turn_loss_only_flag}" f"-defer={defer_tokenization}-decode={train_with_decode}" + f"-mimg={max_images_cfg}-tpi={tokens_per_image_cfg}" ) cache_key = hashlib.md5(cache_params.encode()).hexdigest() cache_dir = os.path.join(getattr(args, "cache_dir", "./cache"), "tokenized_dataset") @@ -199,6 +196,43 @@ def load_conversation_dataset(args): data_id = sample.get("id", f"sample_{idx}") raw_samples.append((data_id, messages, multimodal_inputs)) + # Filter samples that would exceed max_seq_length after image token expansion + max_images = getattr(args, "max_images_per_sample", None) + tokens_per_image = getattr(args, "tokens_per_image", 3000) + pre_filter_count = len(raw_samples) + filtered_samples = [] + for data_id, messages, multimodal_inputs in raw_samples: + num_images = 0 + if multimodal_inputs and multimodal_inputs.get("images"): + num_images = len(multimodal_inputs["images"]) + + if max_images is not None and num_images > max_images: + logger.debug( + f"Dropping sample {data_id}: {num_images} images > max_images_per_sample={max_images}" + ) + continue + + text_chars = sum( + len(m.get("content", "")) for m in messages if isinstance(m.get("content"), str) + ) + estimated_tokens = text_chars // 4 + num_images * tokens_per_image + if estimated_tokens > max_length: + logger.debug( + f"Dropping sample {data_id}: estimated {estimated_tokens} tokens " + f"({text_chars} text chars + {num_images} images) > max_seq_length={max_length}" + ) + continue + + filtered_samples.append((data_id, messages, multimodal_inputs)) + + raw_samples = filtered_samples + if pre_filter_count - len(raw_samples) > 0: + logger.info( + f"Filtered {pre_filter_count - len(raw_samples)}/{pre_filter_count} samples " + f"exceeding estimated token budget (max_seq_length={max_length}, " + f"tokens_per_image={tokens_per_image}, max_images_per_sample={max_images})" + ) + logger.info( f"Loaded {len(raw_samples)} samples, {mode_label.lower()} with {num_proc} workers..." ) diff --git a/torchspec/data/preprocessing.py b/torchspec/data/preprocessing.py index b0f8781..460e66f 100644 --- a/torchspec/data/preprocessing.py +++ b/torchspec/data/preprocessing.py @@ -76,7 +76,12 @@ def _normalize_conversation(conversation: Conversation) -> Conversation: normalized = [] for msg in conversation: role = ROLE_MAPPING.get(msg["from"], msg["from"]) - normalized.append({"role": role, "content": msg["value"]}) + entry = {"role": role, "content": msg["value"]} + for field in ("thinking", "thinking_content", "reasoning_content", "reasoning"): + if msg.get(field): + entry["reasoning_content"] = msg[field] + break + normalized.append(entry) return normalized return conversation diff --git a/torchspec/data/utils.py b/torchspec/data/utils.py index a73acc7..f7a220b 100644 --- a/torchspec/data/utils.py +++ b/torchspec/data/utils.py @@ -21,31 +21,13 @@ import json import os from pathlib import Path -from typing import Any, Dict, List, Union -from urllib.parse import urlparse +from typing import Any, Dict, List, Optional, Union import torch from datasets import IterableDataset, load_dataset from huggingface_hub import hf_hub_download, list_repo_files -_IMAGE_CACHE_DIR = os.environ.get("TORCHSPEC_IMAGE_CACHE", "/data/ywang/image_cache") - - -def resolve_image_url(url: str, cache_dir: str = _IMAGE_CACHE_DIR) -> str: - """Return local file path if a cached copy exists, otherwise the original URL.""" - if not url or not url.startswith("http"): - return url - parsed = urlparse(url) - local_path = os.path.join(cache_dir, parsed.netloc, parsed.path.lstrip("/")) - if os.path.isfile(local_path): - return local_path - return url - - -def resolve_image_urls(urls: list[str], cache_dir: str = _IMAGE_CACHE_DIR) -> list[str]: - """Resolve a list of image URLs to local paths where cached copies exist.""" - return [resolve_image_url(u, cache_dir) for u in urls] - +from torchspec.models.ops.loss_mask import compute_assistant_loss_mask _LOCAL_DATA_EXTS = frozenset({".json", ".jsonl", ".parquet", ".arrow", ".csv", ".tsv", ".txt"}) @@ -210,6 +192,52 @@ def unpack_loss_mask(packed: Union[List[int], str]) -> torch.Tensor: return loss_mask +def resolve_loss_mask( + data: Dict[str, Any], + *, + dynamic_loss_mask: bool = False, + assistant_header_ids: Optional[List[int]] = None, + end_token_ids: Optional[List[int]] = None, + last_turn_loss_only: bool = False, + skip_after_header: int = 0, +) -> torch.Tensor | None: + """ + Two strategies, tried in order: + 1. ``packed_loss_mask`` key present → unpack it. + 2. ``dynamic_loss_mask`` enabled with valid header/end ids → compute from + ``input_ids`` via :func:`compute_assistant_loss_mask`. + """ + packed = data.get("packed_loss_mask") + if packed is not None: + mask = unpack_loss_mask(packed) + if not mask.any(): + return None + data["loss_mask"] = mask + return mask + + if dynamic_loss_mask and assistant_header_ids and end_token_ids: + input_ids = data.get("input_ids") + if input_ids is None: + return None + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + per_sample = data.get("last_turn_loss_only") + last_turn_only = per_sample if per_sample is not None else last_turn_loss_only + mask = compute_assistant_loss_mask( + input_ids, + assistant_header_ids, + end_token_ids, + last_turn_only=last_turn_only, + skip_after_header=skip_after_header, + ) + if not mask.any(): + return None + data["loss_mask"] = mask + return mask + + return torch.ones(1) + + def serialize_packed_loss_mask(packed: List[int]) -> str: """ Serialize packed loss_mask to a comma-separated string. diff --git a/torchspec/training/data_fetcher.py b/torchspec/training/data_fetcher.py index 6681d3a..76fa62c 100644 --- a/torchspec/training/data_fetcher.py +++ b/torchspec/training/data_fetcher.py @@ -32,8 +32,7 @@ from ray.util.queue import Queue as RayQueue from torch.utils.data import DataLoader, IterableDataset -from torchspec.data.utils import unpack_loss_mask -from torchspec.models.ops.loss_mask import compute_assistant_loss_mask +from torchspec.data.utils import resolve_loss_mask from torchspec.utils.logging import logger @@ -126,42 +125,14 @@ def _cleanup_mooncake_data(self, sample: TrainSample) -> None: ) def _compute_loss_mask(self, data: Dict[str, Any]) -> torch.Tensor | None: - """Compute the loss mask for a sample and store it on the data dict. - - This is the single place where loss masks are resolved for mooncake - samples, so the collator never needs to recompute. - - Returns the 1-D mask tensor, or None if the mask is all zeros. - """ - packed = data.get("packed_loss_mask") - if packed is not None: - mask = unpack_loss_mask(packed) - if not mask.any(): - return None - data["loss_mask"] = mask - return mask - - if self.dynamic_loss_mask and self.assistant_header_ids and self.end_token_ids: - input_ids = data.get("input_ids") - if input_ids is None: - return None - if input_ids.dim() == 2: - input_ids = input_ids.squeeze(0) - per_sample = data.get("last_turn_loss_only") - last_turn_only = per_sample if per_sample is not None else self.last_turn_loss_only - mask = compute_assistant_loss_mask( - input_ids, - self.assistant_header_ids, - self.end_token_ids, - last_turn_only=last_turn_only, - skip_after_header=self.skip_after_header, - ) - if not mask.any(): - return None - data["loss_mask"] = mask - return mask - - return torch.ones(1) # non-None signals "don't skip" + return resolve_loss_mask( + data, + dynamic_loss_mask=self.dynamic_loss_mask, + assistant_header_ids=self.assistant_header_ids, + end_token_ids=self.end_token_ids, + last_turn_loss_only=self.last_turn_loss_only, + skip_after_header=self.skip_after_header, + ) def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Iterate over samples synchronously. From f948815324a15546b9105687e7fccebdab0caa20 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 17 Mar 2026 10:53:51 -0700 Subject: [PATCH 6/7] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- torchspec/data/parse.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchspec/data/parse.py b/torchspec/data/parse.py index 3912431..f91e0d0 100644 --- a/torchspec/data/parse.py +++ b/torchspec/data/parse.py @@ -48,9 +48,10 @@ def has_thinking_content(conversation: list) -> bool: """Detect whether any assistant message contains real thinking content. Checks for non-empty blocks in message content and for - separate thinking/thinking_content fields on the message dict. - Must be called on the raw conversation BEFORE formatting, since - formatters (e.g. KimiK25Parser) inject empty tags. + separate thinking/thinking_content or reasoning/reasoning_content + fields on the message dict. Must be called on the raw conversation + BEFORE formatting, since formatters (e.g. KimiK25Parser) inject + empty tags. """ for msg in conversation: if not isinstance(msg, dict) or msg.get("role") != "assistant": @@ -58,7 +59,12 @@ def has_thinking_content(conversation: list) -> bool: content = msg.get("content", "") if isinstance(content, str) and _HAS_THINKING_RE.search(content): return True - if msg.get("thinking") or msg.get("thinking_content"): + if ( + msg.get("thinking") + or msg.get("thinking_content") + or msg.get("reasoning") + or msg.get("reasoning_content") + ): return True return False From b590d5f252122462d908c8720a6cee528163761a Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 17 Mar 2026 18:05:45 +0000 Subject: [PATCH 7/7] revert: remove hacky image token budget filter from dataset loading The pre-tokenization filtering that estimated token budgets based on image count and text length was a rough heuristic. Remove it along with the related cache params. --- torchspec/data/dataset.py | 40 --------------------------------------- 1 file changed, 40 deletions(-) diff --git a/torchspec/data/dataset.py b/torchspec/data/dataset.py index 485d16b..264fbc0 100644 --- a/torchspec/data/dataset.py +++ b/torchspec/data/dataset.py @@ -156,13 +156,10 @@ def load_conversation_dataset(args): file_stat = f"-{st.st_size}-{st.st_mtime}" last_turn_loss_only_flag = getattr(args, "last_turn_loss_only", False) train_with_decode = getattr(args, "train_with_decode", False) - max_images_cfg = getattr(args, "max_images_per_sample", None) - tokens_per_image_cfg = getattr(args, "tokens_per_image", 3000) cache_params = ( f"{dataset_name}-{args.train_data_path}{file_stat}-{args.target_model_path}" f"-{max_length}-{chat_template_name}-ltlo={last_turn_loss_only_flag}" f"-defer={defer_tokenization}-decode={train_with_decode}" - f"-mimg={max_images_cfg}-tpi={tokens_per_image_cfg}" ) cache_key = hashlib.md5(cache_params.encode()).hexdigest() cache_dir = os.path.join(getattr(args, "cache_dir", "./cache"), "tokenized_dataset") @@ -196,43 +193,6 @@ def load_conversation_dataset(args): data_id = sample.get("id", f"sample_{idx}") raw_samples.append((data_id, messages, multimodal_inputs)) - # Filter samples that would exceed max_seq_length after image token expansion - max_images = getattr(args, "max_images_per_sample", None) - tokens_per_image = getattr(args, "tokens_per_image", 3000) - pre_filter_count = len(raw_samples) - filtered_samples = [] - for data_id, messages, multimodal_inputs in raw_samples: - num_images = 0 - if multimodal_inputs and multimodal_inputs.get("images"): - num_images = len(multimodal_inputs["images"]) - - if max_images is not None and num_images > max_images: - logger.debug( - f"Dropping sample {data_id}: {num_images} images > max_images_per_sample={max_images}" - ) - continue - - text_chars = sum( - len(m.get("content", "")) for m in messages if isinstance(m.get("content"), str) - ) - estimated_tokens = text_chars // 4 + num_images * tokens_per_image - if estimated_tokens > max_length: - logger.debug( - f"Dropping sample {data_id}: estimated {estimated_tokens} tokens " - f"({text_chars} text chars + {num_images} images) > max_seq_length={max_length}" - ) - continue - - filtered_samples.append((data_id, messages, multimodal_inputs)) - - raw_samples = filtered_samples - if pre_filter_count - len(raw_samples) > 0: - logger.info( - f"Filtered {pre_filter_count - len(raw_samples)}/{pre_filter_count} samples " - f"exceeding estimated token budget (max_seq_length={max_length}, " - f"tokens_per_image={tokens_per_image}, max_images_per_sample={max_images})" - ) - logger.info( f"Loaded {len(raw_samples)} samples, {mode_label.lower()} with {num_proc} workers..." )