diff --git a/examples/2_collect_teleop_data_with_neuracore.py b/examples/2_collect_teleop_data_with_neuracore.py index c6ddcc1..460de21 100644 --- a/examples/2_collect_teleop_data_with_neuracore.py +++ b/examples/2_collect_teleop_data_with_neuracore.py @@ -35,7 +35,9 @@ DAMPING_COST, FRAME_TASK_GAIN, GRIPPER_FRAME_NAME, + GRIPPER_LOGGING_NAME, IK_SOLVER_RATE, + JOINT_NAMES, JOINT_STATE_STREAMING_RATE, LM_DAMPING, NEUTRAL_JOINT_ANGLES, @@ -104,18 +106,25 @@ def neuracore_logging_worker(queue: Queue, worker_id: int) -> None: if function_name == "log_joint_positions": data_value = np.radians(data_value) data_dict = { - f"joint{i+1}": angle for i, angle in enumerate(data_value) + joint_name: angle + for joint_name, angle in zip(JOINT_NAMES, data_value) } nc.log_joint_positions(data_dict, timestamp=timestamp) elif function_name == "log_joint_target_positions": data_value = np.radians(data_value) data_dict = { - f"joint{i+1}": angle for i, angle in enumerate(data_value) + joint_name: angle + for joint_name, angle in zip(JOINT_NAMES, data_value) } nc.log_joint_target_positions(data_dict, timestamp=timestamp) elif function_name == "log_parallel_gripper_open_amounts": - data_dict = {"gripper": data_value} + data_dict = {GRIPPER_LOGGING_NAME: data_value} nc.log_parallel_gripper_open_amounts(data_dict, timestamp=timestamp) + elif function_name == "log_parallel_gripper_target_open_amounts": + data_dict = {GRIPPER_LOGGING_NAME: data_value} + nc.log_parallel_gripper_target_open_amounts( + data_dict, timestamp=timestamp + ) elif function_name == "log_rgb": camera_name = "rgb" image_array = data_value @@ -327,7 +336,11 @@ def on_button_rj_pressed() -> None: # Initialize Meta Quest reader print("\nšŸŽ® Initializing Meta Quest reader...") - quest_reader = MetaQuestReader(ip_address=args.ip_address, port=5555, run=True) + quest_reader = MetaQuestReader( + ip_address=args.ip_address, + port=5555, + run=True, + ) # Register button callbacks (after state and robot_controller are initialized) quest_reader.on("button_a_pressed", on_button_a_pressed) diff --git a/examples/3_replay_neuracore_episodes.py b/examples/3_replay_neuracore_episodes.py index 16e6950..86a6b5c 100644 --- a/examples/3_replay_neuracore_episodes.py +++ b/examples/3_replay_neuracore_episodes.py @@ -28,7 +28,7 @@ def main() -> None: """Main function for replaying a Neuracore dataset on the Piper robot.""" parser = argparse.ArgumentParser() parser.add_argument("--dataset-name", type=str, required=True) - parser.add_argument("--frequency", type=int, required=False, default=100) + parser.add_argument("--frequency", type=int, required=True) parser.add_argument("--episode-index", type=int, required=False, default=0) args = parser.parse_args() @@ -57,8 +57,10 @@ def main() -> None: print("\nšŸ” Building robot data spec for synchronization...") data_types_to_synchronize = [ DataType.JOINT_POSITIONS, + DataType.JOINT_TARGET_POSITIONS, DataType.RGB_IMAGES, DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS, + DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS, ] robot_data_spec: RobotDataSpec = {} robot_ids_dataset = dataset.robot_ids @@ -123,8 +125,8 @@ def main() -> None: # Extract joint positions joint_positions_dict = {} - if DataType.JOINT_POSITIONS in step.data: - joint_data = step.data[DataType.JOINT_POSITIONS] + if DataType.JOINT_TARGET_POSITIONS in step.data: + joint_data = step.data[DataType.JOINT_TARGET_POSITIONS] for joint_name in JOINT_NAMES: if joint_name in joint_data: joint_positions_dict[joint_name] = joint_data[ @@ -134,8 +136,10 @@ def main() -> None: # Extract gripper gripper_value = 0.0 - if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in step.data: - gripper_data = step.data[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] + if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in step.data: + gripper_data = step.data[ + DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS + ] if GRIPPER_LOGGING_NAME in gripper_data: gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount parallel_gripper_open_amounts.append(gripper_value) diff --git a/examples/4_rollout_neuracore_policy.py b/examples/4_rollout_neuracore_policy.py index 51c7a5f..17dd759 100644 --- a/examples/4_rollout_neuracore_policy.py +++ b/examples/4_rollout_neuracore_policy.py @@ -86,8 +86,8 @@ def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[flo horizon[joint_name] = values # Extract gripper open amounts - if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions: - gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] + if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions: + gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS] if GRIPPER_LOGGING_NAME in gripper_data: batched = gripper_data[GRIPPER_LOGGING_NAME] if isinstance(batched, BatchedParallelGripperOpenAmountData): @@ -155,8 +155,8 @@ def run_policy( print("āš ļø No current joint angles available") return False - # Get target gripper open value because this is how the policy was trained - gripper_open_value = data_manager.get_target_gripper_open_value() + # Get current gripper open value + gripper_open_value = data_manager.get_current_gripper_open_value() if gripper_open_value is None: print("āš ļø No gripper open value available") return False @@ -170,7 +170,7 @@ def run_policy( # Prepare data for NeuraCore logging joint_angles_rad = np.radians(current_joint_angles) joint_positions_dict = { - JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) + joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad) } # Log joint positions parallel gripper open amounts and RGB image to NeuraCore @@ -455,10 +455,12 @@ def policy_execution_thread( # Send current gripper open value to robot (if available) if GRIPPER_LOGGING_NAME in locked_horizon: - current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][ - execution_index - ] - robot_controller.set_gripper_open_value(current_gripper_open_value) + current_gripper_target_open_value = locked_horizon[ + GRIPPER_LOGGING_NAME + ][execution_index] + robot_controller.set_gripper_open_value( + current_gripper_target_open_value + ) # Update execution index policy_state.increment_execution_action_index() @@ -682,7 +684,7 @@ def update_visualization( } model_output_order = { DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES, - DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], + DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], } print("\nšŸ“‹ Model input order:") @@ -721,8 +723,6 @@ def update_visualization( CONTROLLER_BETA, CONTROLLER_D_CUTOFF, ) - # Setting the target gripper so policy doesn't crash first time it runs - data_manager.set_target_gripper_open_value(1.0) # Initialize robot controller print("\nšŸ¤– Initializing Piper robot controller...") diff --git a/examples/5_rollout_neuracore_policy_minimal.py b/examples/5_rollout_neuracore_policy_minimal.py index b5498c7..b9b321b 100644 --- a/examples/5_rollout_neuracore_policy_minimal.py +++ b/examples/5_rollout_neuracore_policy_minimal.py @@ -60,8 +60,8 @@ def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[flo horizon[joint_name] = values # Extract gripper open amounts - if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions: - gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] + if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions: + gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS] if GRIPPER_LOGGING_NAME in gripper_data: batched = gripper_data[GRIPPER_LOGGING_NAME] if isinstance(batched, BatchedParallelGripperOpenAmountData): @@ -79,8 +79,8 @@ def log_current_state(data_manager: DataManager) -> None: print("āš ļø No joint angles available") return - # Get target gripper open value because this is how the policy was trained - gripper_open_value = data_manager.get_target_gripper_open_value() + # Get current gripper open value + gripper_open_value = data_manager.get_current_gripper_open_value() if gripper_open_value is None: print("āš ļø No gripper open value available") return @@ -94,7 +94,7 @@ def log_current_state(data_manager: DataManager) -> None: # Prepare data for NeuraCore logging joint_angles_rad = np.radians(current_joint_angles) joint_positions_dict = { - JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad) + joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad) } # Log joint positions, parallel gripper open amounts, and RGB image to NeuraCore @@ -123,8 +123,6 @@ def run_policy( horizon_length = policy_state.get_prediction_horizon_length() print(f"āœ“ Got {horizon_length} actions in {elapsed:.3f}s") - # Set execution ratio and save prediction horizon - policy_state.set_execution_ratio(PREDICTION_HORIZON_EXECUTION_RATIO) policy_state.set_prediction_horizon(prediction_horizon) return True @@ -138,6 +136,7 @@ def execute_horizon( data_manager: DataManager, policy_state: PolicyState, robot_controller: PiperController, + frequency: int, ) -> None: """Execute prediction horizon.""" policy_state.start_policy_execution() @@ -145,7 +144,7 @@ def execute_horizon( locked_horizon = policy_state.get_locked_prediction_horizon() horizon_length = policy_state.get_locked_prediction_horizon_length() - dt = 1.0 / POLICY_EXECUTION_RATE + dt = 1.0 / frequency for i in range(horizon_length): start_time = time.time() @@ -162,8 +161,8 @@ def execute_horizon( # Send current gripper open value to robot (if available) if GRIPPER_LOGGING_NAME in locked_horizon: - current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][i] - robot_controller.set_gripper_open_value(current_gripper_open_value) + current_gripper_target_open_value = locked_horizon[GRIPPER_LOGGING_NAME][i] + robot_controller.set_gripper_open_value(current_gripper_target_open_value) # Log current state for visualization log_current_state(data_manager) @@ -191,6 +190,18 @@ def execute_horizon( default=None, help="Path to local model file to load policy from. Mutually exclusive with --train-run-name.", ) + parser.add_argument( + "--frequency", + type=int, + default=POLICY_EXECUTION_RATE, + help="Frequency of policy execution", + ) + parser.add_argument( + "--execution-ratio", + type=float, + default=PREDICTION_HORIZON_EXECUTION_RATIO, + help="Execution ratio of the policy", + ) args = parser.parse_args() # Validate that exactly one of train-run-name or model-path is provided @@ -223,7 +234,7 @@ def execute_horizon( } model_output_order = { DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES, - DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], + DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], } print("\nšŸ“‹ Model input order:") @@ -252,8 +263,8 @@ def execute_horizon( # Initialize state data_manager = DataManager() - data_manager.set_target_gripper_open_value(1.0) policy_state = PolicyState() + policy_state.set_execution_ratio(args.execution_ratio) # Initialize robot controller print("\nšŸ¤– Initializing robot controller...") @@ -318,7 +329,9 @@ def execute_horizon( continue # Execute horizon - execute_horizon(data_manager, policy_state, robot_controller) + execute_horizon( + data_manager, policy_state, robot_controller, args.frequency + ) except KeyboardInterrupt: print("\n\nšŸ‘‹ Interrupt received - shutting down...") diff --git a/examples/6_visualize_policy_from_dataset.py b/examples/6_visualize_policy_from_dataset.py index 29a9048..5dde405 100644 --- a/examples/6_visualize_policy_from_dataset.py +++ b/examples/6_visualize_policy_from_dataset.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 -"""Simple policy visualization from dataset - single script, no classes.""" +"""Simple policy visualization from dataset. + +Loads a policy and a dataset, and randomly selects a state +from the dataset to run the policy with and visualize the results. +""" import argparse import random @@ -40,6 +44,9 @@ "--train-run-name", type=str, default=None, help="Training run name" ) parser.add_argument("--model-path", type=str, default=None, help="Model file path") +parser.add_argument( + "--frequency", type=int, default=100, help="Frequency of visualization" +) args = parser.parse_args() if (args.train_run_name is None) == (args.model_path is None): @@ -58,7 +65,7 @@ } model_output_order = { DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES, - DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], + DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME], } if args.train_run_name: @@ -90,7 +97,7 @@ print("šŸ” Synchronizing dataset...") synced_dataset = dataset.synchronize( - frequency=100, + frequency=args.frequency, robot_data_spec=robot_data_spec, prefetch_videos=True, max_prefetch_workers=2, @@ -126,8 +133,8 @@ def convert_predictions_to_horizon( if isinstance(batched, BatchedJointData): values = batched.value[0, :, 0].cpu().numpy().tolist() horizon[joint_name] = values - if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions: - gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS] + if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions: + gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS] if GRIPPER_LOGGING_NAME in gripper_data: batched = gripper_data[GRIPPER_LOGGING_NAME] if isinstance(batched, BatchedParallelGripperOpenAmountData): @@ -221,7 +228,7 @@ def select_random_state() -> None: # Add frequency control frequency_handle = server.gui.add_number( "Visualization Frequency (Hz)", - initial_value=100.0, + initial_value=args.frequency, min=1.0, max=500.0, step=1.0, diff --git a/examples/common/configs.py b/examples/common/configs.py index 836d0dc..bb3883c 100644 --- a/examples/common/configs.py +++ b/examples/common/configs.py @@ -44,7 +44,7 @@ CAMERA_FRAME_STREAMING_RATE = 60.0 # Data collection rate for camera frame # # Initial neutral pose for robot (degrees) -NEUTRAL_JOINT_ANGLES = [-5.251, 21.356, -41.386, -4.323, 53.374, 0.0] +NEUTRAL_JOINT_ANGLES = [-1.003, 80.167, -51.064, -4.127, 16.548, 2.619] # Posture task cost vector (one weight per joint) POSTURE_COST_VECTOR = [0.0, 0.0, 0.0, 0.05, 0.0, 0.0] diff --git a/examples/common/data_manager.py b/examples/common/data_manager.py index 5bb9a6b..bd2ff5f 100644 --- a/examples/common/data_manager.py +++ b/examples/common/data_manager.py @@ -406,6 +406,12 @@ def set_current_gripper_open_value(self, value: float) -> None: """ with self._robot_state._lock: self._robot_state.current_gripper_open_value = value + if self._on_change_callback: + self._on_change_callback( + "log_parallel_gripper_open_amounts", + value, + time.time(), + ) def get_target_gripper_open_value(self) -> float | None: """Get target gripper open value (thread-safe). @@ -426,7 +432,7 @@ def set_target_gripper_open_value(self, value: float) -> None: self._robot_state.target_gripper_open_value = value if self._on_change_callback: self._on_change_callback( - "log_parallel_gripper_open_amounts", + "log_parallel_gripper_target_open_amounts", self._robot_state.target_gripper_open_value, time.time(), ) diff --git a/meta_quest_teleop b/meta_quest_teleop index 81b8dad..46f268c 160000 --- a/meta_quest_teleop +++ b/meta_quest_teleop @@ -1 +1 @@ -Subproject commit 81b8dad6dfb4f724077a68d1587a6d06a2f03a5e +Subproject commit 46f268ced706e0b56110fb6554ad3c4fec44a4be