diff --git a/.gitignore b/.gitignore index 01a52fbd..5293aadb 100644 --- a/.gitignore +++ b/.gitignore @@ -297,3 +297,5 @@ cython_debug/ # Version files made by setuptools_scm **/version.py + +.vscode/ diff --git a/src/ldp/alg/rollout.py b/src/ldp/alg/rollout.py index ff06f266..e6645049 100644 --- a/src/ldp/alg/rollout.py +++ b/src/ldp/alg/rollout.py @@ -2,11 +2,13 @@ import itertools import logging import uuid +from collections import Counter from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager, nullcontext from typing import Any, TypeVar, overload from aviary.core import Environment, Message +from tqdm.asyncio import tqdm from ldp.agent import Agent from ldp.data_structures import Trajectory, Transition @@ -24,6 +26,7 @@ class CaughtError(Exception): """Base class for reraised exceptions when catching is enabled.""" def __init__(self, original_exc: Exception): + super().__init__(str(original_exc)) self.original_exc = original_exc exc_type = "undefined" @@ -39,12 +42,12 @@ class EnvError(CaughtError): @contextmanager def reraise_exc_as(reraise: type[CaughtError], enabled: bool) -> Iterator[None]: + """Context manager that reraises exceptions as a custom CaughtError type if enabled.""" try: yield except Exception as e: if enabled: - error_details = format_error_details(e) - logger.exception(f"Caught {reraise.exc_type} exception:\n{error_details}") + logger.debug(f"Reraising {reraise.exc_type} exception.") raise reraise(e) from None raise @@ -106,6 +109,9 @@ async def sample_trajectories( # noqa: D418 environments: A list of environments to run rollouts on. max_steps: Max steps per rollout. Defaults to None, in which case the rollouts are run until environment returns done. + log_exceptions_immediately: Whether to log exceptions in the rollout immediately + to the console. Defaults to True. If False, progress bar will show and a summary + will be logged after all rollouts are complete. """ async def sample_trajectories(self, **kwargs): @@ -118,6 +124,9 @@ async def sample_trajectories(self, **kwargs): kwargs["environment_factory"], kwargs.get("batch_size", 1), kwargs.get("max_steps"), + log_exceptions_immediately=kwargs.get( + "log_exceptions_immediately", True + ), ) if "environments" in kwargs: @@ -125,7 +134,11 @@ async def sample_trajectories(self, **kwargs): "Cannot use environments with environment_factory" ) return await self._sample_trajectories_from_envs( - kwargs["environments"], kwargs.get("max_steps") + kwargs["environments"], + kwargs.get("max_steps"), + log_exceptions_immediately=kwargs.get( + "log_exceptions_immediately", True + ), ) raise TypeError( @@ -138,13 +151,18 @@ async def _sample_trajectories_from_env_factory( environment_factory: Callable[[], Environment], batch_size: int = 1, max_steps: int | None = None, + *, + log_exceptions_immediately: bool = True, ) -> list[tuple[Trajectory, Environment]]: self.traj_buffer.clear() + exception_counter: Counter = Counter() async def rollout_with_args(idx: int, **rollout_kwargs): return idx, await self._rollout(**rollout_kwargs), rollout_kwargs accumulated_steps = [0] * batch_size + total_trajectories = 0 # Counter for completed trajectories + # submit initial batch of tasks tasks = [ asyncio.create_task( @@ -153,37 +171,67 @@ async def rollout_with_args(idx: int, **rollout_kwargs): traj_id=uuid.uuid4().hex, env=environment_factory(), max_steps=max_steps, + log_exceptions_immediately=log_exceptions_immediately, ) ) for idx in range(batch_size) ] results = [] - while tasks: - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - new_tasks = [] - for task in done: - idx, traj, kwargs = await task - results.append((traj, kwargs["env"])) - accumulated_steps[idx] += len(traj.steps) - if ( - max_steps is not None - and (remaining_steps := max_steps - accumulated_steps[idx]) > 0 - ): - # submit another task if we haven't reached max_steps - new_task = asyncio.create_task( - rollout_with_args( - idx, - traj_id=uuid.uuid4().hex, - env=environment_factory(), - max_steps=remaining_steps, + with tqdm( + desc="Rollouts", + unit="rollout", + ncols=0, + disable=log_exceptions_immediately, + ) as pbar: + while tasks: + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + new_tasks = [] + for task in done: + idx, traj, kwargs = await task + results.append((traj, kwargs["env"])) + total_trajectories += 1 + pbar.update(1) + + steps_in_traj = len(traj.steps) + accumulated_steps[idx] += steps_in_traj + + # Check for exceptions in this trajectory + if traj.steps and traj.steps[-1].metadata.get("exception"): + exc_str: str = str(traj.steps[-1].metadata["exception"])[ + :500 + ].replace('"', "'") + exception_counter[exc_str] += 1 + num_exceptions = sum(exception_counter.values()) + pbar.set_postfix({"num_exceptions": num_exceptions}) + + if ( + max_steps is not None + and (remaining_steps := max_steps - accumulated_steps[idx]) > 0 + ): + # submit another task if we haven't reached max_steps + new_task = asyncio.create_task( + rollout_with_args( + idx, + traj_id=uuid.uuid4().hex, + env=environment_factory(), + max_steps=remaining_steps, + log_exceptions_immediately=log_exceptions_immediately, + ) ) - ) - new_tasks.append(new_task) + new_tasks.append(new_task) - tasks = list(pending) + new_tasks + tasks = list(pending) + new_tasks + + # Final summary of exceptions (if any) + if exception_counter and not log_exceptions_immediately: + summary = ["Caught exceptions:", "Count Exception"] + summary.extend( + f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items() + ) + logger.info("\n".join(summary)) return results @@ -191,16 +239,57 @@ async def _sample_trajectories_from_envs( self, environments: Sequence[Environment], max_steps: int | None = None, + *, + log_exceptions_immediately: bool = True, ) -> list[Trajectory]: self.traj_buffer.clear() + exception_counter: Counter = Counter() + + traj_ids = [uuid.uuid4().hex for _ in environments] - traj_ids = [uuid.uuid4().hex for _ in range(len(environments))] - await asyncio.gather( - *( - self._rollout(*args, max_steps=max_steps) - for args in zip(traj_ids, environments, strict=True) + # Create all tasks first + tasks = [ + asyncio.create_task( + self._rollout( + traj_id, + env, + max_steps=max_steps, + log_exceptions_immediately=log_exceptions_immediately, + ) ) - ) + for traj_id, env in zip(traj_ids, environments, strict=True) + ] + + with tqdm( + total=len(tasks), + desc="Rollouts", + unit="rollout", + ncols=0, + disable=log_exceptions_immediately, + ) as pbar: + for task in asyncio.as_completed(tasks): + trajectory = await task + pbar.update(1) + # Check if this trajectory ended with an exception + if trajectory.steps: + last_step = trajectory.steps[-1] + if last_step.metadata.get("exception"): + # We'll keep it short but still have something to categorize + exc_str: str = str(last_step.metadata["exception"])[ + :500 + ].replace('"', "'") + exception_counter[exc_str] += 1 + num_exceptions = sum(exception_counter.values()) + pbar.set_postfix({"num_exceptions": num_exceptions}) + + # Final summary of exceptions (if any) + if exception_counter and not log_exceptions_immediately: + summary = ["Caught exceptions:", "Count Exception"] + summary.extend( + f"{count:<6d} {exc:<50s}" for exc, count in exception_counter.items() + ) + logger.info("\n".join(summary)) + return [self.traj_buffer[traj_id] for traj_id in traj_ids] async def _rollout( @@ -208,6 +297,8 @@ async def _rollout( traj_id: str, env: Environment, max_steps: int | None, + *, + log_exceptions_immediately: bool = True, ) -> Trajectory: trajectory = Trajectory(traj_id=traj_id) @@ -260,6 +351,10 @@ async def store_step(step: Transition): except CaughtError as e: # NOTE: This trajectory should not be used for regular training. # We save the last transition here for debugging, etc. + if log_exceptions_immediately: + error_details = format_error_details(e.original_exc) + logger.exception(f"Exception in rollout {traj_id}:\n{error_details}") + await store_step( Transition( timestep=len(trajectory.steps), diff --git a/src/ldp/graph/async_torch.py b/src/ldp/graph/async_torch.py index d0aa8274..612e0dad 100644 --- a/src/ldp/graph/async_torch.py +++ b/src/ldp/graph/async_torch.py @@ -1,6 +1,7 @@ __all__ = ["AsyncTorchModule", "async_protect_torch_call"] import asyncio +import logging import operator import time from abc import ABC, abstractmethod @@ -19,6 +20,9 @@ "Please run `pip install ldp[nn]`." ) from None + +logger = logging.getLogger(__name__) + _TORCH_LOCK = asyncio.Lock() # Supported devices here: https://pytorch.org/docs/stable/amp.html#torch.autocast @@ -90,6 +94,7 @@ def __init__( self._work_buffer: list[tuple[float, UUID, dict[str, Any]]] = [] self._result_buffer: dict[UUID, Any] = {} self._lock = asyncio.Lock() + self._exception_raised: Exception | None = None async def __call__(self, **kwargs): request_id = uuid4() @@ -101,16 +106,23 @@ async def __call__(self, **kwargs): while True: async with self._lock: + if self._exception_raised is not None: + logger.info("Exception raised in another coroutine") + raise self._exception_raised + # Only one coroutine allowed in here when: # - modifying the result buffer # - modifying the work buffer - if request_id in self._result_buffer: # Our request was fulfilled by this or another coroutine! return self._result_buffer.pop(request_id) # Try to run a batch. - await self._maybe_process_batch() + try: + await self._maybe_process_batch() + except Exception as e: + self._exception_raised = e + raise # Sleep, to let another coroutine take over if it needs to await asyncio.sleep(0.0) diff --git a/src/ldp/nn/agent/simple_local_agent.py b/src/ldp/nn/agent/simple_local_agent.py index e6a91709..82d5f5ec 100644 --- a/src/ldp/nn/agent/simple_local_agent.py +++ b/src/ldp/nn/agent/simple_local_agent.py @@ -1,8 +1,10 @@ +import logging from typing import cast import torch import torch.distributed as dist from aviary.core import Message, Tool, ToolRequestMessage +from litellm.utils import token_counter from pydantic import Field, field_validator from ldp.agent import Agent, SimpleAgentState @@ -17,6 +19,8 @@ ) from ldp.nn.lm_config import LMConfig as _LMConfig +logger = logging.getLogger(__name__) + class AgentLMConfig(_LMConfig): """Adds some additional configuration options for running an LM in an Op.""" @@ -42,6 +46,10 @@ class AgentLMConfig(_LMConfig): ), validate_default=True, ) + max_messages_token_count: int | None = Field( + default=None, + description="If set, raise an error if the total tokens in the trajectory exceed this value.", + ) @field_validator("llm_call_kwargs") @classmethod @@ -91,6 +99,8 @@ async def get_asv( else next_state.messages ) + self._validate_token_count(messages, next_state.tools) + # Execute the LLM operation call result = cast( "OpResult[Message | ToolRequestMessage]", @@ -112,8 +122,30 @@ async def get_asv( # Update state messages with result and return the new state next_state.messages = [*next_state.messages, result.value] + self._validate_token_count(next_state.messages, next_state.tools) + return cast("OpResult[ToolRequestMessage]", result), next_state, 0.0 + def _validate_token_count(self, messages: list[Message], tools: list[Tool]): + """Asserts token count for the trajectory is within the limit.""" + if self.llm_model.max_messages_token_count is None: + return + messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages) + tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer(tools) + + total_tokens = token_counter( + model=self.llm_model.model, + messages=messages_for_tokenizer, + tools=tools_for_tokenizer, # type: ignore[arg-type] + ) + if total_tokens > self.llm_model.max_messages_token_count: + logger.error( + f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}" + ) + raise ValueError( + f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}" + ) + # TODO: maybe remove these recomputation methods. I added them to debug some things. But idk, # maybe they'll come in handy later. @staticmethod diff --git a/src/ldp/nn/handlers/chunking.py b/src/ldp/nn/handlers/chunking.py index 38fdbe96..9369d6d3 100644 --- a/src/ldp/nn/handlers/chunking.py +++ b/src/ldp/nn/handlers/chunking.py @@ -9,9 +9,8 @@ class TensorChunker: """Splits tensors into chunks and adds dummy chunks as needed for parallel processing frameworks like FSDP.""" - def __init__(self, num_chunks: int, dummy_value: int = 0): + def __init__(self, num_chunks: int): self.num_chunks = num_chunks - self.dummy_value = dummy_value def chunkify(self, *args, **kwargs) -> tuple[list[tuple], list[dict], list[bool]]: """Splits the args into self.num_chunks chunks, adding dummy chunks as needed. @@ -159,8 +158,10 @@ def _split_value(self, value): for i in range(self.num_chunks): if i >= len(chunks): # Chunk 0 will always exist, and we need only a batch of one ([:1]) - # to activate the model - chunks.append(torch.full_like(chunks[0][:1], self.dummy_value)) + # to activate the model. + # We use the first element of the existing chunks as real data to avoid + # errors in the model that may expect a specific token structure. + chunks.append(chunks[0][:1]) dummy_chunk_flags.append(True) else: dummy_chunk_flags.append(False) diff --git a/src/ldp/nn/handlers/transformer_handler.py b/src/ldp/nn/handlers/transformer_handler.py index eb104e87..7ca01b74 100644 --- a/src/ldp/nn/handlers/transformer_handler.py +++ b/src/ldp/nn/handlers/transformer_handler.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import atexit import logging import os import socket @@ -9,14 +11,15 @@ from enum import StrEnum, auto from functools import cache, partial, wraps from pathlib import Path -from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never, cast +from typing import Any, Concatenate, ParamSpec, Self, TypeVar, assert_never import accelerate import torch import torch.distributed as dist import tree from dask import config -from dask.distributed import Client +from dask.distributed import Actor, ActorFuture, Client +from distributed.utils import sync from pydantic import BaseModel, ConfigDict, Field, field_validator from torch import nn from torch.cuda import nccl @@ -45,6 +48,7 @@ else: from typing_extensions import overload # noqa: UP035 +logger = logging.getLogger(__name__) config.set({ # We have no use for rebooting workers in aviary for now, and rebooting workers @@ -55,11 +59,9 @@ # Gives us more time to debug a downed worker. TODO: see if there are negative consequences # of having this always enabled "distributed.comm.timeouts.connect": "300s", - "distributed.comm.timeouts.tcp": "300s", + "distributed.comm.timeouts.tcp": "1200s", }) -logger = logging.getLogger(__name__) - TReturn = TypeVar("TReturn") TParams = ParamSpec("TParams") @@ -192,6 +194,14 @@ async def __call__( # type: ignore[override] @staticmethod def model_generate(model: PreTrainedModel, *args, **kwargs): """A method that can be used as module_call_fn to sample from an LLM.""" + if int(os.environ.get("WORLD_SIZE", "1")) > 1: + synced_gpus = kwargs.pop("synced_gpus", None) + if synced_gpus is None: + logger.debug("synced_gpus not defined, defaulting to True.") + kwargs["synced_gpus"] = True + elif not synced_gpus: + raise ValueError("synced_gpus must be True when using FSDP.") + # Summoning params per https://github.com/pytorch/pytorch/issues/100069 # If model is not FSDP, this context manager is a no-op. with FullyShardedDataParallel.summon_full_params(model, recurse=False): @@ -229,6 +239,7 @@ def __init__(self, config: TransformerHandlerConfig): assert_never(config.lm_type) super().__init__(model) self.tokenizer = tokenizer + maybe_set_tokenizer_chat_template( self.tokenizer, self.config.lm_config.chat_template ) @@ -417,22 +428,29 @@ def _exec_func( args = tree.map_structure(to_device, args) kwargs = tree.map_structure(to_device, kwargs) - with torch.autocast( - device_type=self.module.device.type, dtype=self.module.dtype - ): - res = ( - getattr(self, func)(*args, **kwargs) - if isinstance(func, str) - else func(self, *args, **kwargs) - ) + try: + with torch.autocast( + device_type=self.module.device.type, dtype=self.module.dtype + ): + res = ( + getattr(self, func)(*args, **kwargs) + if isinstance(func, str) + else func(self, *args, **kwargs) + ) - # Needed to prevent GPU memory leak to the main process scheduling the workers - if isinstance(res, GenerateDecoderOnlyOutput): - res.past_key_values = None - res["past_key_values"] = None + # Needed to prevent GPU memory leak to the main process scheduling the workers + if isinstance(res, GenerateDecoderOnlyOutput): + res.past_key_values = None + res["past_key_values"] = None - to_cpu = partial(_move_tensor, device=torch.device("cpu")) - return tree.map_structure(to_cpu, res) + to_cpu = partial(_move_tensor, device=torch.device("cpu")) + return tree.map_structure(to_cpu, res) + except Exception as e: + # Re-raise the exception with traceback preserved. For some exceptions, Dask + # modifies or loses the original traceback when crossing process boundaries. + # RuntimeError preserves the traceback when using with_traceback() of original + # exception. + raise RuntimeError(str(e)).with_traceback(e.__traceback__) # noqa: B904 def __del__(self) -> None: dist.destroy_process_group() @@ -463,6 +481,8 @@ def __init__(self, config: TransformerHandlerConfig): self._initialized = True + atexit.register(self.teardown) + # don't call AsyncTorchModule.__init__ because we don't need to set up module[_call_fn] AsyncBufferedWorker.__init__( self, @@ -484,6 +504,10 @@ def _init_local_cluster( # lazy import since dask-cuda only works on Linux machines from dask_cuda import LocalCUDACluster + # This uses NVIDIA's NVML layer instead of native CUDA, which is more robust in GPU detection + # post initialization. This prevents issues with forked processes wrongly detecting the + # default GPU as cuda:0 + os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" self.cluster = LocalCUDACluster( n_workers=parallel_mode_config.num_workers, threads_per_worker=parallel_mode_config.num_cpus_per_worker, @@ -568,7 +592,7 @@ def get_cuda_visible_devices() -> int | None: futures.append(future_op) worker_ids.append(worker_id) - self.handlers = self.client.gather(futures) + self.actors: list[Actor] = self._client_gather(futures) self.worker_ids = worker_ids async def __call__( @@ -619,28 +643,24 @@ def _submit_and_gather( """ if split_data: chunker = TensorChunker( - num_chunks=len(self.handlers), + num_chunks=len(self.actors), ) split_args, split_kwargs, dummy_flags = chunker.chunkify(*args, **kwargs) else: - split_args = [args] * len(self.handlers) - split_kwargs = [kwargs] * len(self.handlers) + split_args = [args] * len(self.actors) + split_kwargs = [kwargs] * len(self.actors) futures = [ - self.client.submit( - handler._exec_func, + handler._exec_func( func, *args_i, - workers=[worker_id], - actor=True, **kwargs_i, ) for handler, worker_id, args_i, kwargs_i in zip( - self.handlers, self.worker_ids, split_args, split_kwargs, strict=True + self.actors, self.worker_ids, split_args, split_kwargs, strict=True ) ] - results = self.client.gather(futures) - results = cast("list[TReturn]", [res.result().result() for res in results]) + results: list[TReturn] = self._client_gather(futures) if split_data: return chunker.dechunkify(results, dummy_flags) @@ -751,13 +771,78 @@ def save_checkpoint(self, ckpt: os.PathLike | str, **kwargs) -> None: def teardown(self) -> None: if self._initialized: - self.client.close() + self.client.shutdown() self.cluster.close() + del self.client + del self.cluster self._initialized = False def __del__(self) -> None: self.teardown() + @staticmethod + def _wrap_dask_future(dask_future: ActorFuture): + """Converts a Dask ActorFuture into an awaitable asyncio.Future.""" + loop = asyncio.get_running_loop() + return asyncio.ensure_future(loop.run_in_executor(None, dask_future.result)) + + @staticmethod + def _raise_exceptions(done, pending, wrapped_futures): + exceptions = [] + for future in done: + exc = future.exception() + if exc: + exceptions.append(exc) + if exceptions: + if len(exceptions) == 1: + raise exceptions[0] + raise ExceptionGroup("Multiple actor exceptions", exceptions) + + if pending: + pending_indices = sorted([wrapped_futures.index(p) for p in pending]) + raise TimeoutError( + f"Tasks didn't complete within timeout. {len(pending)} out of {len(wrapped_futures)} " + f"still pending. Pending task indices: {pending_indices}" + ) + + async def _client_gather_async(self, futures): + """Gather results from futures, propagating exceptions as they arrive. + + Unlike client.gather() which waits for all futures to complete before raising + any exceptions, this method processes futures as they complete and raises + exceptions immediately. This is crucial when using FSDP where workers may + be stuck waiting for each other when one worker crashes, causing long hangs. + + Note: Dask Actors currently have an issue where they're not working properly with + dask.gather() and can cause blocking issues or hide worker errors. This implementation + works around those limitations. + """ + try: + wrapped_futures = [self._wrap_dask_future(f) for f in futures] + + # Use asyncio.wait with FIRST_EXCEPTION instead of gather + done, pending = await asyncio.wait( + wrapped_futures, timeout=1200, return_when=asyncio.FIRST_EXCEPTION + ) + + self._raise_exceptions(done, pending, wrapped_futures) + + return await asyncio.gather(*wrapped_futures) + except Exception: + logger.exception("Error in dask workers: %s") + for future in wrapped_futures: + future.cancel() + self.teardown() + # sys.exit(1) would wait for dask to finish, which can cause hanging + # when workers are in a deadlock. Use os._exit to force immediate termination + # TODO: this is more of a hack, we should propagate special exception that is + # not caught by the rollout manager. + os._exit(1) + + def _client_gather(self, futures: list[ActorFuture]) -> list[Any]: + # Use distributed.utils.sync to run the async function in the current thread + return sync(self.client.loop, self._client_gather_async, futures) # type: ignore[arg-type] + # Helpers diff --git a/tests/test_nn_models.py b/tests/test_nn_models.py index aa717b48..55917ce3 100644 --- a/tests/test_nn_models.py +++ b/tests/test_nn_models.py @@ -27,11 +27,10 @@ class TestTensorChunker: def test_chunkify_add_dummy_chunks(self): batch_size = 3 num_chunks = 5 - dummy_value = 0 sample_tensor = torch.arange(1, batch_size * 10 + 1).reshape(batch_size, 10) - chunker = ldp.nn.TensorChunker(num_chunks=num_chunks, dummy_value=dummy_value) + chunker = ldp.nn.TensorChunker(num_chunks=num_chunks) split_args, split_kwargs, dummy_chunk_flags = chunker.chunkify(sample_tensor) assert len(split_args) == num_chunks @@ -40,21 +39,16 @@ def test_chunkify_add_dummy_chunks(self): assert torch.equal(split_args[0][0], sample_tensor[:1]) assert torch.equal(split_args[1][0], sample_tensor[1:2]) assert torch.equal(split_args[2][0], sample_tensor[2:3]) - assert torch.equal( - split_args[3][0], torch.full_like(sample_tensor[:1], dummy_value) - ) - assert torch.equal( - split_args[4][0], torch.full_like(sample_tensor[:1], dummy_value) - ) + assert torch.equal(split_args[3][0], sample_tensor[:1]) + assert torch.equal(split_args[4][0], sample_tensor[:1]) def test_chunkify_no_dummy_chunks(self): batch_size = 9 num_chunks = 5 - dummy_value = 0 sample_tensor = torch.arange(1, batch_size * 10 + 1).reshape(batch_size, 10) - chunker = ldp.nn.TensorChunker(num_chunks=num_chunks, dummy_value=dummy_value) + chunker = ldp.nn.TensorChunker(num_chunks=num_chunks) split_args, split_kwargs, dummy_chunk_flags = chunker.chunkify(sample_tensor) assert len(split_args) == num_chunks @@ -69,7 +63,6 @@ def test_chunkify_no_dummy_chunks(self): def test_chunkify_with_args_and_kwargs(self): batch_size = 2 num_chunks = 3 - dummy_value = 0 sample_tensor = torch.arange(1, batch_size * 10 + 1).reshape(batch_size, 10) sample_tensor_kwarg = torch.arange(1, batch_size * 5 + 1).reshape(batch_size, 5) @@ -78,7 +71,7 @@ def test_chunkify_with_args_and_kwargs(self): "key2": "Not split", } - chunker = ldp.nn.TensorChunker(num_chunks=num_chunks, dummy_value=dummy_value) + chunker = ldp.nn.TensorChunker(num_chunks=num_chunks) split_args, split_kwargs, dummy_chunk_flags = chunker.chunkify( sample_tensor, **sample_kwargs ) @@ -88,15 +81,10 @@ def test_chunkify_with_args_and_kwargs(self): assert dummy_chunk_flags == [False, False, True] assert torch.equal(split_args[0][0], sample_tensor[:1]) assert torch.equal(split_args[1][0], sample_tensor[1:2]) - assert torch.equal( - split_args[2][0], torch.full_like(sample_tensor[:1], dummy_value) - ) + assert torch.equal(split_args[2][0], sample_tensor[:1]) assert torch.equal(split_kwargs[0]["key1"], sample_tensor_kwarg[:1]) assert torch.equal(split_kwargs[1]["key1"], sample_tensor_kwarg[1:2]) - assert torch.equal( - split_kwargs[2]["key1"], - torch.full_like(sample_tensor_kwarg[:1], dummy_value), - ) + assert torch.equal(split_kwargs[2]["key1"], sample_tensor_kwarg[:1]) assert all(split_kwargs[i]["key2"] == "Not split" for i in range(num_chunks)) def test_dechunkify(self):