Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 49 additions & 56 deletions neuracore/importer/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from neuracore_types.importer.data_config import DataFormat
from neuracore_types.nc_data import DatasetImportConfig, DataType
from neuracore_types.nc_data.nc_data import MappingItem
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Expand Down Expand Up @@ -76,6 +77,14 @@ class ProgressUpdate:
episode_label: str | None = None


_RICH_CONSOLE = Console(stderr=True, force_terminal=True)


def get_shared_console() -> Console:
"""Return the shared console used by logging and progress bars."""
return _RICH_CONSOLE


class NeuracoreDatasetImporter(ABC):
"""Uploader workflow that manages workers and Neuracore session setup."""

Expand All @@ -101,13 +110,14 @@ def __init__(
self.frequency = dataset_config.frequency
self.joint_info = joint_info

self.max_workers = 1
self.max_workers = max_workers
self.min_workers = min_workers
self.continue_on_error = continue_on_error
self.progress_interval = max(1, progress_interval)
self.dry_run = dry_run
self.suppress_warnings = suppress_warnings
self.worker_errors: list[WorkerError] = []
self._logged_error_keys: set[tuple[int | None, int | None, str]] = set()
self.logger = logging.getLogger(
f"{self.__class__.__module__}.{self.__class__.__name__}"
)
Expand Down Expand Up @@ -365,33 +375,22 @@ def prepare_worker(
nc.get_dataset(self.output_dataset_name)

def upload_all(self) -> None:
"""Run uploads across workers while aggregating errors."""
"""Run uploads across workers while aggregating errors.

High-level flow:
1) Build the list of work items (episodes).
2) Decide how many worker processes to spawn.
3) Spin up workers and a progress queue.
4) Listen for progress updates while workers run.
5) Collect and summarize any errors.
"""
items = list(self.build_work_items())
if not items:
self.logger.info("No upload items found; nothing to do.")
return

self.logger.info(
"Preparing import -> dataset=%s | source_dir=%s | items=%s | "
"continue_on_error=%s",
self.output_dataset_name,
self.dataset_dir,
len(items),
self.continue_on_error,
)

worker_count = self._resolve_worker_count(len(items))
cpu_count = os.cpu_count()
self.logger.info(
"Scheduling %s items across %s worker(s) "
"(min=%s max=%s cpu=%s progress_interval=%s)",
len(items),
worker_count,
self.min_workers,
self.max_workers if self.max_workers is not None else "auto",
cpu_count if cpu_count is not None else "unknown",
self.progress_interval,
)
os.cpu_count()

ctx = mp.get_context("spawn")
error_queue: mp.Queue[WorkerError] = ctx.Queue()
Expand Down Expand Up @@ -468,16 +467,7 @@ def _worker_entry(
self.prepare_worker(worker_id, chunk)
else:
self.prepare_worker(worker_id) # type: ignore[misc]
if chunk:
self.logger.info(
"[worker %s] Starting chunk (%s items): %s → %s",
worker_id,
len(chunk),
chunk[0].index,
chunk[-1].index,
)
else:
self.logger.info("[worker %s] Starting with empty chunk.", worker_id)
# Progress bar will reflect work; keep startup quiet to reduce noise.
except Exception as exc: # noqa: BLE001 - propagate unexpected worker failures
if error_queue:
tb = traceback.format_exc()
Expand Down Expand Up @@ -522,7 +512,6 @@ def _step(
self.upload(item)
except Exception as exc: # noqa: BLE001 - keep traceback for summary
tb = traceback.format_exc()
self._log_worker_error(worker_id, item.index, str(exc))
if self.continue_on_error:
error_queue.put(
WorkerError(
Expand All @@ -532,25 +521,13 @@ def _step(
traceback=tb,
)
)
self.logger.warning(
"[worker %s] Continuing after failure on item %s "
"(continue_on_error=True).",
worker_id,
item.index,
)
# Defer logging to the post-run summary to avoid flickering
# and duplicate error lines while the progress bar is live.
return
self._log_worker_error(worker_id, item.index, str(exc))
raise

if (local_index + 1) % self.progress_interval == 0 or (
local_index + 1 == chunk_length
):
self.logger.info(
"[worker %s] processed %s/%s (item index=%s)",
worker_id,
local_index + 1,
chunk_length,
item.index,
)
# Progress bar already shows ongoing status; skip per-interval info logs.

def _collect_errors(self, error_queue: mp.Queue) -> list[WorkerError]:
"""Drain the error queue after workers complete."""
Expand Down Expand Up @@ -580,15 +557,25 @@ def _report_errors(self, errors: list[WorkerError]) -> None:
self.logger.info("All workers completed without reported errors.")
return

self.logger.error("Completed with %s worker error(s).", len(errors))
deduped: dict[tuple[int | None, int | None, str], int] = {}
for err in errors:
prefix = f"[worker {err.worker_id}"
if err.item_index is not None:
prefix += f" item {err.item_index}"
key = (err.worker_id, err.item_index, err.message)
deduped[key] = deduped.get(key, 0) + 1

self.logger.error(
"Completed with %s worker error event(s) (%s unique).",
len(errors),
len(deduped),
)

for (worker_id, item_index, message), count in deduped.items():
prefix = f"[worker {worker_id}"
if item_index is not None:
prefix += f" item {item_index}"
prefix += "]"
self.logger.error("%s %s", prefix, err.message)
if err.traceback:
self.logger.debug(err.traceback)
suffix = f" (x{count})" if count > 1 else ""
self.logger.error("%s %s%s", prefix, message, suffix)

self.logger.error(
"Import finished with errors. Re-run with DEBUG logging for tracebacks "
"or fix the reported issues above."
Expand All @@ -598,6 +585,11 @@ def _log_worker_error(
self, worker_id: int, item_index: int | None, message: str
) -> None:
"""Log a worker error immediately while the process is running."""
key = (worker_id, item_index, message)
if key in self._logged_error_keys:
return
self._logged_error_keys.add(key)

prefix = f"[worker {worker_id}"
if item_index is not None:
prefix += f" item {item_index}"
Expand Down Expand Up @@ -649,6 +641,7 @@ def _monitor_progress(
TimeRemainingColumn(),
refresh_per_second=10,
transient=True,
console=get_shared_console(),
) as progress:
while True:
any_alive = any(proc.is_alive() for proc in processes)
Expand Down
58 changes: 36 additions & 22 deletions neuracore/importer/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

from neuracore_types.importer.config import DatasetTypeConfig
from neuracore_types.nc_data import DatasetImportConfig
from rich.logging import RichHandler

import neuracore as nc
from neuracore.core.data.dataset import Dataset
from neuracore.importer.core.base import get_shared_console
from neuracore.importer.core.dataset_detector import (
DatasetDetector,
iter_first_two_levels,
Expand All @@ -30,7 +32,7 @@
from neuracore.importer.lerobot_importer import LeRobotDatasetImporter
from neuracore.importer.rlds_importer import RLDSDatasetImporter

LOG_FORMAT = "%(asctime)s | %(levelname)s | %(name)s | %(message)s"
LOG_FORMAT = "%(message)s"
logger = logging.getLogger(__name__)


Expand All @@ -39,11 +41,24 @@ def configure_logging(level: int = logging.INFO) -> None:
if logging.getLogger().handlers:
logging.getLogger().setLevel(level)
return
handler = logging.StreamHandler()
handler = RichHandler(
rich_tracebacks=True,
markup=True,
show_path=False,
console=get_shared_console(),
)
handler.setFormatter(logging.Formatter(LOG_FORMAT))
logging.basicConfig(level=level, handlers=[handler])


def load_dataset_config(path: Path) -> DatasetImportConfig:
"""Read the user-provided YAML/JSON into a strongly typed config."""
try:
return DatasetImportConfig.from_file(path)
except Exception as exc: # noqa: BLE001 - show root cause to user
raise ConfigLoadError(f"Failed to load dataset config '{path}': {exc}") from exc


def parse_args() -> argparse.Namespace:
"""Parse command-line arguments for dataset import.

Expand Down Expand Up @@ -98,6 +113,21 @@ def parse_args() -> argparse.Namespace:
return args


def load_or_detect_dataset_type(
dataconfig: DatasetImportConfig, dataset_dir: Path
) -> DatasetTypeConfig:
"""Prefer the explicit dataset type in config, otherwise auto-detect."""
if dataconfig.dataset_type:
return dataconfig.dataset_type

try:
detected = detect_dataset_type(dataset_dir)
logger.info("Detected dataset type: %s", detected.value.upper())
return detected
except Exception as exc: # noqa: BLE001 - surface detection failure
raise DatasetDetectionError(str(exc)) from exc


def cli_args_validation(args: argparse.Namespace) -> None:
"""Validate the provided arguments."""
for path in [args.dataset_config, args.dataset_dir]:
Expand All @@ -122,7 +152,7 @@ def _resolve_robot_descriptions(
config_mjcf_path: str | None,
robot_dir: Path | None,
) -> tuple[str | None, str | None]:
"""Pick the first matching URDF and MJCF files by extension."""
"""Find URDF/MJCF files either from config paths or by scanning a folder."""
urdf_path: str | None = None
mjcf_path: str | None = None
suffix_to_target = {".urdf": "urdf", ".xml": "mjcf", ".mjcf": "mjcf"}
Expand Down Expand Up @@ -173,30 +203,15 @@ def main() -> None:
sys.exit(1)

logger.info(
"Starting dataset import | dataset_config=%s | dataset_dir=%s | robot_dir=%s",
"Starting dataset import\n config: %s\n data: %s\n robot: %s",
args.dataset_config,
args.dataset_dir,
args.robot_dir,
)

try:
dataconfig = DatasetImportConfig.from_file(args.dataset_config)
except Exception as exc: # noqa: BLE001 - show root cause to user
raise ConfigLoadError(
f"Failed to load dataset config '{args.dataset_config}': {exc}"
) from exc

logger.info("Dataset config loaded.")
dataconfig = load_dataset_config(args.dataset_config)

if dataconfig.dataset_type:
dataset_type = dataconfig.dataset_type
logger.info("Using dataset type from config: %s", dataset_type.value.upper())
else:
try:
dataset_type = detect_dataset_type(args.dataset_dir)
logger.info("Detected dataset type: %s", dataset_type.value.upper())
except Exception as exc: # noqa: BLE001 - surface detection failure
raise DatasetDetectionError(str(exc)) from exc
dataset_type = load_or_detect_dataset_type(dataconfig, args.dataset_dir)

output_dataset = dataconfig.output_dataset
if not output_dataset or not output_dataset.name:
Expand Down Expand Up @@ -232,7 +247,6 @@ def main() -> None:
description=dataconfig.output_dataset.description,
tags=dataconfig.output_dataset.tags,
)
logger.info("Output dataset ready: %s (id=%s)", dataset.name, dataset.id)

robot_config = dataconfig.robot
urdf_path, mjcf_path = _resolve_robot_descriptions(
Expand Down
10 changes: 0 additions & 10 deletions neuracore/importer/lerobot_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,6 @@ def __init__(
self._dataset: LeRobotDataset | None = None
self._episode_iter: Iterator[int] | None = None

self.logger.info(
"Initialized LeRobot importer for '%s' "
"(episodes=%s, cameras=%s, fps=%s, root=%s)",
self.dataset_name,
self.num_episodes,
self.camera_keys,
self.frequency,
self.dataset_root,
)

def __getstate__(self) -> dict:
"""Drop worker-local handles when pickling for multiprocessing."""
state = self.__dict__.copy()
Expand Down
17 changes: 3 additions & 14 deletions neuracore/importer/rlds_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def __init__(
self._episode_iter = None

self.logger.info(
"Initialized RLDS importer for '%s' "
"(split=%s, episodes=%s, freq=%s, dir=%s)",
"Dataset ready: name=%s split=%s episodes=%s freq=%s dir=%s",
self.dataset_name,
self.split,
self.num_episodes,
Expand Down Expand Up @@ -98,15 +97,6 @@ def prepare_worker(
chunk_start = chunk[0].index if chunk else 0
chunk_length = len(chunk) if chunk else None

self.logger.info(
"[worker %s] Loading split=%s (start=%s count=%s) from %s",
worker_id,
self.split,
chunk_start,
chunk_length if chunk_length is not None else "remainder",
self.builder_dir,
)

dataset = self._load_dataset(self._builder, self.split)
if chunk_start:
dataset = dataset.skip(chunk_start)
Expand Down Expand Up @@ -142,12 +132,11 @@ def upload(self, item: ImportItem) -> None:
f"worker {self._worker_id}" if self._worker_id is not None else "worker 0"
)
self.logger.info(
"[%s] Importing %s (%s/%s, steps=%s)",
"[%s] Importing episode %s (%s/%s)",
worker_label,
episode_label,
item.index + 1,
self.num_episodes,
total_steps if total_steps is not None else "unknown",
)
self._emit_progress(
item.index, step=0, total_steps=total_steps, episode_label=episode_label
Expand All @@ -162,7 +151,7 @@ def upload(self, item: ImportItem) -> None:
episode_label=episode_label,
)
nc.stop_recording(wait=True)
self.logger.info("[%s] Completed %s", worker_label, episode_label)
self.logger.info("[%s] Completed episode %s", worker_label, episode_label)

def _resolve_builder_dir(self) -> Path:
"""Find the dataset version directory that contains dataset_info.json."""
Expand Down