From 4235a58b8a8656834f322b109e23356ca8d98de7 Mon Sep 17 00:00:00 2001 From: sdas-neuro Date: Thu, 5 Feb 2026 11:17:50 +0000 Subject: [PATCH] fix: importer logging --- neuracore/importer/core/base.py | 105 ++++++++++++------------- neuracore/importer/importer.py | 58 ++++++++------ neuracore/importer/lerobot_importer.py | 10 --- neuracore/importer/rlds_importer.py | 17 +--- 4 files changed, 88 insertions(+), 102 deletions(-) diff --git a/neuracore/importer/core/base.py b/neuracore/importer/core/base.py index a9a4831c..f2fe7b76 100644 --- a/neuracore/importer/core/base.py +++ b/neuracore/importer/core/base.py @@ -17,6 +17,7 @@ from neuracore_types.importer.data_config import DataFormat from neuracore_types.nc_data import DatasetImportConfig, DataType from neuracore_types.nc_data.nc_data import MappingItem +from rich.console import Console from rich.progress import ( BarColumn, MofNCompleteColumn, @@ -76,6 +77,14 @@ class ProgressUpdate: episode_label: str | None = None +_RICH_CONSOLE = Console(stderr=True, force_terminal=True) + + +def get_shared_console() -> Console: + """Return the shared console used by logging and progress bars.""" + return _RICH_CONSOLE + + class NeuracoreDatasetImporter(ABC): """Uploader workflow that manages workers and Neuracore session setup.""" @@ -101,13 +110,14 @@ def __init__( self.frequency = dataset_config.frequency self.joint_info = joint_info - self.max_workers = 1 + self.max_workers = max_workers self.min_workers = min_workers self.continue_on_error = continue_on_error self.progress_interval = max(1, progress_interval) self.dry_run = dry_run self.suppress_warnings = suppress_warnings self.worker_errors: list[WorkerError] = [] + self._logged_error_keys: set[tuple[int | None, int | None, str]] = set() self.logger = logging.getLogger( f"{self.__class__.__module__}.{self.__class__.__name__}" ) @@ -365,33 +375,22 @@ def prepare_worker( nc.get_dataset(self.output_dataset_name) def upload_all(self) -> None: - """Run uploads across workers while aggregating errors.""" + """Run uploads across workers while aggregating errors. + + High-level flow: + 1) Build the list of work items (episodes). + 2) Decide how many worker processes to spawn. + 3) Spin up workers and a progress queue. + 4) Listen for progress updates while workers run. + 5) Collect and summarize any errors. + """ items = list(self.build_work_items()) if not items: self.logger.info("No upload items found; nothing to do.") return - self.logger.info( - "Preparing import -> dataset=%s | source_dir=%s | items=%s | " - "continue_on_error=%s", - self.output_dataset_name, - self.dataset_dir, - len(items), - self.continue_on_error, - ) - worker_count = self._resolve_worker_count(len(items)) - cpu_count = os.cpu_count() - self.logger.info( - "Scheduling %s items across %s worker(s) " - "(min=%s max=%s cpu=%s progress_interval=%s)", - len(items), - worker_count, - self.min_workers, - self.max_workers if self.max_workers is not None else "auto", - cpu_count if cpu_count is not None else "unknown", - self.progress_interval, - ) + os.cpu_count() ctx = mp.get_context("spawn") error_queue: mp.Queue[WorkerError] = ctx.Queue() @@ -468,16 +467,7 @@ def _worker_entry( self.prepare_worker(worker_id, chunk) else: self.prepare_worker(worker_id) # type: ignore[misc] - if chunk: - self.logger.info( - "[worker %s] Starting chunk (%s items): %s → %s", - worker_id, - len(chunk), - chunk[0].index, - chunk[-1].index, - ) - else: - self.logger.info("[worker %s] Starting with empty chunk.", worker_id) + # Progress bar will reflect work; keep startup quiet to reduce noise. except Exception as exc: # noqa: BLE001 - propagate unexpected worker failures if error_queue: tb = traceback.format_exc() @@ -522,7 +512,6 @@ def _step( self.upload(item) except Exception as exc: # noqa: BLE001 - keep traceback for summary tb = traceback.format_exc() - self._log_worker_error(worker_id, item.index, str(exc)) if self.continue_on_error: error_queue.put( WorkerError( @@ -532,25 +521,13 @@ def _step( traceback=tb, ) ) - self.logger.warning( - "[worker %s] Continuing after failure on item %s " - "(continue_on_error=True).", - worker_id, - item.index, - ) + # Defer logging to the post-run summary to avoid flickering + # and duplicate error lines while the progress bar is live. return + self._log_worker_error(worker_id, item.index, str(exc)) raise - if (local_index + 1) % self.progress_interval == 0 or ( - local_index + 1 == chunk_length - ): - self.logger.info( - "[worker %s] processed %s/%s (item index=%s)", - worker_id, - local_index + 1, - chunk_length, - item.index, - ) + # Progress bar already shows ongoing status; skip per-interval info logs. def _collect_errors(self, error_queue: mp.Queue) -> list[WorkerError]: """Drain the error queue after workers complete.""" @@ -580,15 +557,25 @@ def _report_errors(self, errors: list[WorkerError]) -> None: self.logger.info("All workers completed without reported errors.") return - self.logger.error("Completed with %s worker error(s).", len(errors)) + deduped: dict[tuple[int | None, int | None, str], int] = {} for err in errors: - prefix = f"[worker {err.worker_id}" - if err.item_index is not None: - prefix += f" item {err.item_index}" + key = (err.worker_id, err.item_index, err.message) + deduped[key] = deduped.get(key, 0) + 1 + + self.logger.error( + "Completed with %s worker error event(s) (%s unique).", + len(errors), + len(deduped), + ) + + for (worker_id, item_index, message), count in deduped.items(): + prefix = f"[worker {worker_id}" + if item_index is not None: + prefix += f" item {item_index}" prefix += "]" - self.logger.error("%s %s", prefix, err.message) - if err.traceback: - self.logger.debug(err.traceback) + suffix = f" (x{count})" if count > 1 else "" + self.logger.error("%s %s%s", prefix, message, suffix) + self.logger.error( "Import finished with errors. Re-run with DEBUG logging for tracebacks " "or fix the reported issues above." @@ -598,6 +585,11 @@ def _log_worker_error( self, worker_id: int, item_index: int | None, message: str ) -> None: """Log a worker error immediately while the process is running.""" + key = (worker_id, item_index, message) + if key in self._logged_error_keys: + return + self._logged_error_keys.add(key) + prefix = f"[worker {worker_id}" if item_index is not None: prefix += f" item {item_index}" @@ -649,6 +641,7 @@ def _monitor_progress( TimeRemainingColumn(), refresh_per_second=10, transient=True, + console=get_shared_console(), ) as progress: while True: any_alive = any(proc.is_alive() for proc in processes) diff --git a/neuracore/importer/importer.py b/neuracore/importer/importer.py index 389050cc..196fc0e3 100644 --- a/neuracore/importer/importer.py +++ b/neuracore/importer/importer.py @@ -9,9 +9,11 @@ from neuracore_types.importer.config import DatasetTypeConfig from neuracore_types.nc_data import DatasetImportConfig +from rich.logging import RichHandler import neuracore as nc from neuracore.core.data.dataset import Dataset +from neuracore.importer.core.base import get_shared_console from neuracore.importer.core.dataset_detector import ( DatasetDetector, iter_first_two_levels, @@ -30,7 +32,7 @@ from neuracore.importer.lerobot_importer import LeRobotDatasetImporter from neuracore.importer.rlds_importer import RLDSDatasetImporter -LOG_FORMAT = "%(asctime)s | %(levelname)s | %(name)s | %(message)s" +LOG_FORMAT = "%(message)s" logger = logging.getLogger(__name__) @@ -39,11 +41,24 @@ def configure_logging(level: int = logging.INFO) -> None: if logging.getLogger().handlers: logging.getLogger().setLevel(level) return - handler = logging.StreamHandler() + handler = RichHandler( + rich_tracebacks=True, + markup=True, + show_path=False, + console=get_shared_console(), + ) handler.setFormatter(logging.Formatter(LOG_FORMAT)) logging.basicConfig(level=level, handlers=[handler]) +def load_dataset_config(path: Path) -> DatasetImportConfig: + """Read the user-provided YAML/JSON into a strongly typed config.""" + try: + return DatasetImportConfig.from_file(path) + except Exception as exc: # noqa: BLE001 - show root cause to user + raise ConfigLoadError(f"Failed to load dataset config '{path}': {exc}") from exc + + def parse_args() -> argparse.Namespace: """Parse command-line arguments for dataset import. @@ -98,6 +113,21 @@ def parse_args() -> argparse.Namespace: return args +def load_or_detect_dataset_type( + dataconfig: DatasetImportConfig, dataset_dir: Path +) -> DatasetTypeConfig: + """Prefer the explicit dataset type in config, otherwise auto-detect.""" + if dataconfig.dataset_type: + return dataconfig.dataset_type + + try: + detected = detect_dataset_type(dataset_dir) + logger.info("Detected dataset type: %s", detected.value.upper()) + return detected + except Exception as exc: # noqa: BLE001 - surface detection failure + raise DatasetDetectionError(str(exc)) from exc + + def cli_args_validation(args: argparse.Namespace) -> None: """Validate the provided arguments.""" for path in [args.dataset_config, args.dataset_dir]: @@ -122,7 +152,7 @@ def _resolve_robot_descriptions( config_mjcf_path: str | None, robot_dir: Path | None, ) -> tuple[str | None, str | None]: - """Pick the first matching URDF and MJCF files by extension.""" + """Find URDF/MJCF files either from config paths or by scanning a folder.""" urdf_path: str | None = None mjcf_path: str | None = None suffix_to_target = {".urdf": "urdf", ".xml": "mjcf", ".mjcf": "mjcf"} @@ -173,30 +203,15 @@ def main() -> None: sys.exit(1) logger.info( - "Starting dataset import | dataset_config=%s | dataset_dir=%s | robot_dir=%s", + "Starting dataset import\n config: %s\n data: %s\n robot: %s", args.dataset_config, args.dataset_dir, args.robot_dir, ) - try: - dataconfig = DatasetImportConfig.from_file(args.dataset_config) - except Exception as exc: # noqa: BLE001 - show root cause to user - raise ConfigLoadError( - f"Failed to load dataset config '{args.dataset_config}': {exc}" - ) from exc - - logger.info("Dataset config loaded.") + dataconfig = load_dataset_config(args.dataset_config) - if dataconfig.dataset_type: - dataset_type = dataconfig.dataset_type - logger.info("Using dataset type from config: %s", dataset_type.value.upper()) - else: - try: - dataset_type = detect_dataset_type(args.dataset_dir) - logger.info("Detected dataset type: %s", dataset_type.value.upper()) - except Exception as exc: # noqa: BLE001 - surface detection failure - raise DatasetDetectionError(str(exc)) from exc + dataset_type = load_or_detect_dataset_type(dataconfig, args.dataset_dir) output_dataset = dataconfig.output_dataset if not output_dataset or not output_dataset.name: @@ -232,7 +247,6 @@ def main() -> None: description=dataconfig.output_dataset.description, tags=dataconfig.output_dataset.tags, ) - logger.info("Output dataset ready: %s (id=%s)", dataset.name, dataset.id) robot_config = dataconfig.robot urdf_path, mjcf_path = _resolve_robot_descriptions( diff --git a/neuracore/importer/lerobot_importer.py b/neuracore/importer/lerobot_importer.py index 20a20708..351cfc17 100644 --- a/neuracore/importer/lerobot_importer.py +++ b/neuracore/importer/lerobot_importer.py @@ -62,16 +62,6 @@ def __init__( self._dataset: LeRobotDataset | None = None self._episode_iter: Iterator[int] | None = None - self.logger.info( - "Initialized LeRobot importer for '%s' " - "(episodes=%s, cameras=%s, fps=%s, root=%s)", - self.dataset_name, - self.num_episodes, - self.camera_keys, - self.frequency, - self.dataset_root, - ) - def __getstate__(self) -> dict: """Drop worker-local handles when pickling for multiprocessing.""" state = self.__dict__.copy() diff --git a/neuracore/importer/rlds_importer.py b/neuracore/importer/rlds_importer.py index ed22629b..d39848de 100644 --- a/neuracore/importer/rlds_importer.py +++ b/neuracore/importer/rlds_importer.py @@ -67,8 +67,7 @@ def __init__( self._episode_iter = None self.logger.info( - "Initialized RLDS importer for '%s' " - "(split=%s, episodes=%s, freq=%s, dir=%s)", + "Dataset ready: name=%s split=%s episodes=%s freq=%s dir=%s", self.dataset_name, self.split, self.num_episodes, @@ -98,15 +97,6 @@ def prepare_worker( chunk_start = chunk[0].index if chunk else 0 chunk_length = len(chunk) if chunk else None - self.logger.info( - "[worker %s] Loading split=%s (start=%s count=%s) from %s", - worker_id, - self.split, - chunk_start, - chunk_length if chunk_length is not None else "remainder", - self.builder_dir, - ) - dataset = self._load_dataset(self._builder, self.split) if chunk_start: dataset = dataset.skip(chunk_start) @@ -142,12 +132,11 @@ def upload(self, item: ImportItem) -> None: f"worker {self._worker_id}" if self._worker_id is not None else "worker 0" ) self.logger.info( - "[%s] Importing %s (%s/%s, steps=%s)", + "[%s] Importing episode %s (%s/%s)", worker_label, episode_label, item.index + 1, self.num_episodes, - total_steps if total_steps is not None else "unknown", ) self._emit_progress( item.index, step=0, total_steps=total_steps, episode_label=episode_label @@ -162,7 +151,7 @@ def upload(self, item: ImportItem) -> None: episode_label=episode_label, ) nc.stop_recording(wait=True) - self.logger.info("[%s] Completed %s", worker_label, episode_label) + self.logger.info("[%s] Completed episode %s", worker_label, episode_label) def _resolve_builder_dir(self) -> Path: """Find the dataset version directory that contains dataset_info.json."""