diff --git a/examples/4_rollout_neuracore_policy.py b/examples/4_rollout_neuracore_policy.py index a73629a..0d0a36e 100644 --- a/examples/4_rollout_neuracore_policy.py +++ b/examples/4_rollout_neuracore_policy.py @@ -145,41 +145,59 @@ def run_policy( policy: nc.policy, policy_state: PolicyState, visualizer: RobotVisualizer, + model_input_order: dict[DataType, list[str]], ) -> bool: """Handle Run Policy button press to capture state and get policy prediction.""" print("Running policy...") - # Get current joint positions - current_joint_angles = data_manager.get_current_joint_angles() - if current_joint_angles is None: - print("⚠️ No current joint angles available") - return False + # Get available data from data_manager (only log what the model expects) + current_joint_angles = None + gripper_open_value = None + rgb_image = None + + # Only log data types that are in model_input_order + if DataType.JOINT_POSITIONS in model_input_order: + current_joint_angles = data_manager.get_current_joint_angles() + if current_joint_angles is not None: + joint_angles_rad = np.radians(current_joint_angles) + joint_positions_dict = { + joint_name: angle + for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad) + } + nc.log_joint_positions(joint_positions_dict) + print(" ✓ Logged joint positions") + else: + print(" ⚠️ No current joint angles available") - # Get current gripper open value - gripper_open_value = data_manager.get_current_gripper_open_value() - if gripper_open_value is None: - print("⚠️ No gripper open value available") - return False + if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in model_input_order: + gripper_open_value = data_manager.get_current_gripper_open_value() + if gripper_open_value is not None: + nc.log_parallel_gripper_open_amount( + GRIPPER_LOGGING_NAME, gripper_open_value + ) + print(" ✓ Logged gripper open amount") + else: + print(" ⚠️ No gripper open value available") - # Get current RGB image - rgb_image = data_manager.get_rgb_image() - if rgb_image is None: - print("⚠️ No RGB image available") - return False + if DataType.RGB_IMAGES in model_input_order: + rgb_image = data_manager.get_rgb_image() + if rgb_image is not None: + nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) + print(" ✓ Logged RGB image") + else: + print(" ⚠️ No RGB image available") - # Prepare data for NeuraCore logging - joint_angles_rad = np.radians(current_joint_angles) - joint_positions_dict = { - joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad) - } + # Check if we have at least some data to run the policy + if ( + current_joint_angles is None + and gripper_open_value is None + and rgb_image is None + ): + print("✗ No data available to run policy") + return False - # Log joint positions parallel gripper open amounts and RGB image to NeuraCore + # Get policy prediction try: - nc.log_joint_positions(joint_positions_dict) - nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_open_value) - nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) - - # Get policy prediction start_time = time.time() predictions = policy.predict(timeout=5) prediction_horizon = convert_predictions_to_horizon_dict(predictions) @@ -192,9 +210,11 @@ def run_policy( prediction_ratio = visualizer.get_prediction_ratio() policy_state.set_execution_ratio(prediction_ratio) - # Set policy inputs - policy_state.set_policy_rgb_image_input(rgb_image) - policy_state.set_policy_state_input(current_joint_angles) + # Set policy inputs (only if available) + if rgb_image is not None: + policy_state.set_policy_rgb_image_input(rgb_image) + if current_joint_angles is not None: + policy_state.set_policy_state_input(current_joint_angles) # Store prediction horizon actions in policy state policy_state.set_prediction_horizon(prediction_horizon) @@ -290,10 +310,11 @@ def run_and_start_policy_execution( policy: nc.policy, policy_state: PolicyState, visualizer: RobotVisualizer, + model_input_order: dict[DataType, list[str]], ) -> None: """Handle Run and Execute Policy button press to capture state, get policy prediction, and immediately execute it.""" print("Run and Execute Policy for one prediction horizon") - run_policy(data_manager, policy, policy_state, visualizer) + run_policy(data_manager, policy, policy_state, visualizer, model_input_order) start_policy_execution(data_manager, policy_state) @@ -318,6 +339,7 @@ def play_policy( policy: nc.policy, policy_state: PolicyState, visualizer: RobotVisualizer, + model_input_order: dict[DataType, list[str]], ) -> None: """Handle Play Policy button press to start/stop continuous policy execution.""" if not policy_state.get_continuous_play_active(): @@ -325,7 +347,9 @@ def play_policy( print("▶️ Play Policy button pressed - Starting continuous policy execution...") # Run policy to get prediction horizon - success = run_policy(data_manager, policy, policy_state, visualizer) + success = run_policy( + data_manager, policy, policy_state, visualizer, model_input_order + ) if not success: print("⚠️ Failed to run policy") end_policy_play( @@ -368,6 +392,7 @@ def policy_execution_thread( policy_state: PolicyState, robot_controller: PiperController, visualizer: RobotVisualizer, + model_input_order: dict[DataType, list[str]], ) -> None: """Policy execution thread.""" dt_execution = 1.0 / POLICY_EXECUTION_RATE @@ -476,7 +501,13 @@ def policy_execution_thread( # End policy execution to clear input lock policy_state.end_policy_execution() # Run policy to get prediction horizon - success = run_policy(data_manager, policy, policy_state, visualizer) + success = run_policy( + data_manager, + policy, + policy_state, + visualizer, + model_input_order, + ) if not success: print("⚠️ Failed to run policy") end_policy_play( @@ -813,18 +844,25 @@ def update_visualization( ) visualizer.set_go_home_callback(lambda: home_robot(data_manager, robot_controller)) visualizer.set_run_policy_callback( - lambda: (run_policy(data_manager, policy, policy_state, visualizer), None)[1] + lambda: ( + run_policy( + data_manager, policy, policy_state, visualizer, model_input_order + ), + None, + )[1] ) visualizer.set_start_policy_execution_callback( lambda: (start_policy_execution(data_manager, policy_state), None)[1] ) visualizer.set_run_and_start_policy_execution_callback( lambda: run_and_start_policy_execution( - data_manager, policy, policy_state, visualizer + data_manager, policy, policy_state, visualizer, model_input_order ) ) visualizer.set_play_policy_callback( - lambda: play_policy(data_manager, policy, policy_state, visualizer) + lambda: play_policy( + data_manager, policy, policy_state, visualizer, model_input_order + ) ) # Set up execution mode dropdown callback to sync with PolicyState visualizer.set_execution_mode_callback( @@ -846,7 +884,14 @@ def update_visualization( print("\n🤖 Starting policy execution thread...") policy_execution_thread_obj = threading.Thread( target=policy_execution_thread, - args=(policy, data_manager, policy_state, robot_controller, visualizer), + args=( + policy, + data_manager, + policy_state, + robot_controller, + visualizer, + model_input_order, + ), daemon=True, ) policy_execution_thread_obj.start() diff --git a/examples/6_visualize_policy_from_dataset.py b/examples/6_visualize_policy_from_dataset.py index 5dde405..0b6a34e 100644 --- a/examples/6_visualize_policy_from_dataset.py +++ b/examples/6_visualize_policy_from_dataset.py @@ -91,8 +91,17 @@ dataset = nc.get_dataset(args.dataset_name) print(f" ✓ Dataset loaded: {len(dataset)} episodes") +# Get data types from model input and output +required_data_types = set(model_input_order.keys()) | set(model_output_order.keys()) + +# Filter data spec to only include required data types robot_data_spec: RobotDataSpec = { - robot_id: dataset.get_full_data_spec(robot_id) for robot_id in dataset.robot_ids + robot_id: { + data_type: names + for data_type, names in dataset.get_full_data_spec(robot_id).items() + if data_type in required_data_types + } + for robot_id in dataset.robot_ids } print("🔁 Synchronizing dataset...") @@ -165,6 +174,8 @@ def select_random_state() -> None: for joint_name in JOINT_NAMES: if joint_name in joint_data: joint_positions_dict[joint_name] = joint_data[joint_name].value + # Log to NeuraCore for visualization + nc.log_joint_positions(joint_positions_dict) # Extract gripper gripper_value = 1.0 @@ -172,6 +183,8 @@ def select_random_state() -> None: gripper_data = step.data[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] if GRIPPER_LOGGING_NAME in gripper_data: gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount + # Log to NeuraCore for visualization + nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_value) # Extract RGB image rgb_image = None @@ -179,15 +192,12 @@ def select_random_state() -> None: rgb_data = step.data[DataType.RGB_IMAGES] if CAMERA_LOGGING_NAME in rgb_data: rgb_image = np.array(rgb_data[CAMERA_LOGGING_NAME].frame) - - if rgb_image is None: - print("⚠️ No RGB image found") - return - - # Log to NeuraCore - nc.log_joint_positions(joint_positions_dict) - nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_value) - nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) + # Save image to file for visualization + image_pil = Image.fromarray(rgb_image) + image_pil.save("current_image.png") + print("💾 Saved image to current_image.png") + # Log to NeuraCore for visualization + nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image) # Get policy prediction print("🎯 Getting policy prediction...") @@ -197,13 +207,9 @@ def select_random_state() -> None: playing = True print("FINISHED PREDICTION") - # Save image to file - image_pil = Image.fromarray(rgb_image) - image_pil.save("current_image.png") - print("💾 Saved image to current_image.png") + # Update robot to initial pose from first step in the horizon - # Update robot to initial pose - joint_positions = np.array([joint_positions_dict[jn] for jn in JOINT_NAMES]) + joint_positions = np.array([current_horizon[jn][0] for jn in JOINT_NAMES]) urdf_vis.update_cfg(joint_positions) print( @@ -258,6 +264,14 @@ def select_random_state() -> None: ) urdf_vis.update_cfg(joint_config) + # Log to NeuraCore for visualization + # NOTE: we log to joint positions instead of joint target positions + # because the latter is not visualized by Neuracore + joint_config_dict = { + jn: joint_config[i] for i, jn in enumerate(JOINT_NAMES) + } + nc.log_joint_positions(joint_config_dict) + # Update gripper value gripper_value = current_horizon[GRIPPER_LOGGING_NAME][ current_action_idx