Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions examples/2_collect_teleop_data_with_neuracore.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
DAMPING_COST,
FRAME_TASK_GAIN,
GRIPPER_FRAME_NAME,
GRIPPER_LOGGING_NAME,
IK_SOLVER_RATE,
JOINT_NAMES,
JOINT_STATE_STREAMING_RATE,
LM_DAMPING,
NEUTRAL_JOINT_ANGLES,
Expand Down Expand Up @@ -104,18 +106,25 @@ def neuracore_logging_worker(queue: Queue, worker_id: int) -> None:
if function_name == "log_joint_positions":
data_value = np.radians(data_value)
data_dict = {
f"joint{i+1}": angle for i, angle in enumerate(data_value)
joint_name: angle
for joint_name, angle in zip(JOINT_NAMES, data_value)
}
nc.log_joint_positions(data_dict, timestamp=timestamp)
elif function_name == "log_joint_target_positions":
data_value = np.radians(data_value)
data_dict = {
f"joint{i+1}": angle for i, angle in enumerate(data_value)
joint_name: angle
for joint_name, angle in zip(JOINT_NAMES, data_value)
}
nc.log_joint_target_positions(data_dict, timestamp=timestamp)
elif function_name == "log_parallel_gripper_open_amounts":
data_dict = {"gripper": data_value}
data_dict = {GRIPPER_LOGGING_NAME: data_value}
nc.log_parallel_gripper_open_amounts(data_dict, timestamp=timestamp)
elif function_name == "log_parallel_gripper_target_open_amounts":
data_dict = {GRIPPER_LOGGING_NAME: data_value}
nc.log_parallel_gripper_target_open_amounts(
data_dict, timestamp=timestamp
)
elif function_name == "log_rgb":
camera_name = "rgb"
image_array = data_value
Expand Down Expand Up @@ -327,7 +336,11 @@ def on_button_rj_pressed() -> None:

# Initialize Meta Quest reader
print("\n🎮 Initializing Meta Quest reader...")
quest_reader = MetaQuestReader(ip_address=args.ip_address, port=5555, run=True)
quest_reader = MetaQuestReader(
ip_address=args.ip_address,
port=5555,
run=True,
)

# Register button callbacks (after state and robot_controller are initialized)
quest_reader.on("button_a_pressed", on_button_a_pressed)
Expand Down
14 changes: 9 additions & 5 deletions examples/3_replay_neuracore_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main() -> None:
"""Main function for replaying a Neuracore dataset on the Piper robot."""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-name", type=str, required=True)
parser.add_argument("--frequency", type=int, required=False, default=100)
parser.add_argument("--frequency", type=int, required=True)
parser.add_argument("--episode-index", type=int, required=False, default=0)
args = parser.parse_args()

Expand Down Expand Up @@ -57,8 +57,10 @@ def main() -> None:
print("\n🔁 Building robot data spec for synchronization...")
data_types_to_synchronize = [
DataType.JOINT_POSITIONS,
DataType.JOINT_TARGET_POSITIONS,
DataType.RGB_IMAGES,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
]
robot_data_spec: RobotDataSpec = {}
robot_ids_dataset = dataset.robot_ids
Expand Down Expand Up @@ -123,8 +125,8 @@ def main() -> None:

# Extract joint positions
joint_positions_dict = {}
if DataType.JOINT_POSITIONS in step.data:
joint_data = step.data[DataType.JOINT_POSITIONS]
if DataType.JOINT_TARGET_POSITIONS in step.data:
joint_data = step.data[DataType.JOINT_TARGET_POSITIONS]
for joint_name in JOINT_NAMES:
if joint_name in joint_data:
joint_positions_dict[joint_name] = joint_data[
Expand All @@ -134,8 +136,10 @@ def main() -> None:

# Extract gripper
gripper_value = 0.0
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in step.data:
gripper_data = step.data[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in step.data:
gripper_data = step.data[
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS
]
if GRIPPER_LOGGING_NAME in gripper_data:
gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount
parallel_gripper_open_amounts.append(gripper_value)
Expand Down
24 changes: 12 additions & 12 deletions examples/4_rollout_neuracore_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[flo
horizon[joint_name] = values

# Extract gripper open amounts
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions:
gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions:
gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS]
if GRIPPER_LOGGING_NAME in gripper_data:
batched = gripper_data[GRIPPER_LOGGING_NAME]
if isinstance(batched, BatchedParallelGripperOpenAmountData):
Expand Down Expand Up @@ -155,8 +155,8 @@ def run_policy(
print("⚠️ No current joint angles available")
return False

# Get target gripper open value because this is how the policy was trained
gripper_open_value = data_manager.get_target_gripper_open_value()
# Get current gripper open value
gripper_open_value = data_manager.get_current_gripper_open_value()
if gripper_open_value is None:
print("⚠️ No gripper open value available")
return False
Expand All @@ -170,7 +170,7 @@ def run_policy(
# Prepare data for NeuraCore logging
joint_angles_rad = np.radians(current_joint_angles)
joint_positions_dict = {
JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad)
joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad)
}

# Log joint positions parallel gripper open amounts and RGB image to NeuraCore
Expand Down Expand Up @@ -455,10 +455,12 @@ def policy_execution_thread(

# Send current gripper open value to robot (if available)
if GRIPPER_LOGGING_NAME in locked_horizon:
current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][
execution_index
]
robot_controller.set_gripper_open_value(current_gripper_open_value)
current_gripper_target_open_value = locked_horizon[
GRIPPER_LOGGING_NAME
][execution_index]
robot_controller.set_gripper_open_value(
current_gripper_target_open_value
)

# Update execution index
policy_state.increment_execution_action_index()
Expand Down Expand Up @@ -682,7 +684,7 @@ def update_visualization(
}
model_output_order = {
DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
}

print("\n📋 Model input order:")
Expand Down Expand Up @@ -721,8 +723,6 @@ def update_visualization(
CONTROLLER_BETA,
CONTROLLER_D_CUTOFF,
)
# Setting the target gripper so policy doesn't crash first time it runs
data_manager.set_target_gripper_open_value(1.0)

# Initialize robot controller
print("\n🤖 Initializing Piper robot controller...")
Expand Down
39 changes: 26 additions & 13 deletions examples/5_rollout_neuracore_policy_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def convert_predictions_to_horizon_dict(predictions: dict) -> dict[str, list[flo
horizon[joint_name] = values

# Extract gripper open amounts
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions:
gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions:
gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS]
if GRIPPER_LOGGING_NAME in gripper_data:
batched = gripper_data[GRIPPER_LOGGING_NAME]
if isinstance(batched, BatchedParallelGripperOpenAmountData):
Expand All @@ -79,8 +79,8 @@ def log_current_state(data_manager: DataManager) -> None:
print("⚠️ No joint angles available")
return

# Get target gripper open value because this is how the policy was trained
gripper_open_value = data_manager.get_target_gripper_open_value()
# Get current gripper open value
gripper_open_value = data_manager.get_current_gripper_open_value()
if gripper_open_value is None:
print("⚠️ No gripper open value available")
return
Expand All @@ -94,7 +94,7 @@ def log_current_state(data_manager: DataManager) -> None:
# Prepare data for NeuraCore logging
joint_angles_rad = np.radians(current_joint_angles)
joint_positions_dict = {
JOINT_NAMES[i]: angle for i, angle in enumerate(joint_angles_rad)
joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad)
}

# Log joint positions, parallel gripper open amounts, and RGB image to NeuraCore
Expand Down Expand Up @@ -123,8 +123,6 @@ def run_policy(
horizon_length = policy_state.get_prediction_horizon_length()
print(f"✓ Got {horizon_length} actions in {elapsed:.3f}s")

# Set execution ratio and save prediction horizon
policy_state.set_execution_ratio(PREDICTION_HORIZON_EXECUTION_RATIO)
policy_state.set_prediction_horizon(prediction_horizon)
return True

Expand All @@ -138,14 +136,15 @@ def execute_horizon(
data_manager: DataManager,
policy_state: PolicyState,
robot_controller: PiperController,
frequency: int,
) -> None:
"""Execute prediction horizon."""
policy_state.start_policy_execution()
data_manager.set_robot_activity_state(RobotActivityState.POLICY_CONTROLLED)

locked_horizon = policy_state.get_locked_prediction_horizon()
horizon_length = policy_state.get_locked_prediction_horizon_length()
dt = 1.0 / POLICY_EXECUTION_RATE
dt = 1.0 / frequency

for i in range(horizon_length):
start_time = time.time()
Expand All @@ -162,8 +161,8 @@ def execute_horizon(

# Send current gripper open value to robot (if available)
if GRIPPER_LOGGING_NAME in locked_horizon:
current_gripper_open_value = locked_horizon[GRIPPER_LOGGING_NAME][i]
robot_controller.set_gripper_open_value(current_gripper_open_value)
current_gripper_target_open_value = locked_horizon[GRIPPER_LOGGING_NAME][i]
robot_controller.set_gripper_open_value(current_gripper_target_open_value)

# Log current state for visualization
log_current_state(data_manager)
Expand Down Expand Up @@ -191,6 +190,18 @@ def execute_horizon(
default=None,
help="Path to local model file to load policy from. Mutually exclusive with --train-run-name.",
)
parser.add_argument(
"--frequency",
type=int,
default=POLICY_EXECUTION_RATE,
help="Frequency of policy execution",
)
parser.add_argument(
"--execution-ratio",
type=float,
default=PREDICTION_HORIZON_EXECUTION_RATIO,
help="Execution ratio of the policy",
)
args = parser.parse_args()

# Validate that exactly one of train-run-name or model-path is provided
Expand Down Expand Up @@ -223,7 +234,7 @@ def execute_horizon(
}
model_output_order = {
DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
}

print("\n📋 Model input order:")
Expand Down Expand Up @@ -252,8 +263,8 @@ def execute_horizon(

# Initialize state
data_manager = DataManager()
data_manager.set_target_gripper_open_value(1.0)
policy_state = PolicyState()
policy_state.set_execution_ratio(args.execution_ratio)

# Initialize robot controller
print("\n🤖 Initializing robot controller...")
Expand Down Expand Up @@ -318,7 +329,9 @@ def execute_horizon(
continue

# Execute horizon
execute_horizon(data_manager, policy_state, robot_controller)
execute_horizon(
data_manager, policy_state, robot_controller, args.frequency
)

except KeyboardInterrupt:
print("\n\n👋 Interrupt received - shutting down...")
Expand Down
19 changes: 13 additions & 6 deletions examples/6_visualize_policy_from_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#!/usr/bin/env python3
"""Simple policy visualization from dataset - single script, no classes."""
"""Simple policy visualization from dataset.

Loads a policy and a dataset, and randomly selects a state
from the dataset to run the policy with and visualize the results.
"""

import argparse
import random
Expand Down Expand Up @@ -40,6 +44,9 @@
"--train-run-name", type=str, default=None, help="Training run name"
)
parser.add_argument("--model-path", type=str, default=None, help="Model file path")
parser.add_argument(
"--frequency", type=int, default=100, help="Frequency of visualization"
)
args = parser.parse_args()

if (args.train_run_name is None) == (args.model_path is None):
Expand All @@ -58,7 +65,7 @@
}
model_output_order = {
DataType.JOINT_TARGET_POSITIONS: JOINT_NAMES,
DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
}

if args.train_run_name:
Expand Down Expand Up @@ -90,7 +97,7 @@

print("🔁 Synchronizing dataset...")
synced_dataset = dataset.synchronize(
frequency=100,
frequency=args.frequency,
robot_data_spec=robot_data_spec,
prefetch_videos=True,
max_prefetch_workers=2,
Expand Down Expand Up @@ -126,8 +133,8 @@ def convert_predictions_to_horizon(
if isinstance(batched, BatchedJointData):
values = batched.value[0, :, 0].cpu().numpy().tolist()
horizon[joint_name] = values
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in predictions:
gripper_data = predictions[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
if DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS in predictions:
gripper_data = predictions[DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS]
if GRIPPER_LOGGING_NAME in gripper_data:
batched = gripper_data[GRIPPER_LOGGING_NAME]
if isinstance(batched, BatchedParallelGripperOpenAmountData):
Expand Down Expand Up @@ -221,7 +228,7 @@ def select_random_state() -> None:
# Add frequency control
frequency_handle = server.gui.add_number(
"Visualization Frequency (Hz)",
initial_value=100.0,
initial_value=args.frequency,
min=1.0,
max=500.0,
step=1.0,
Expand Down
2 changes: 1 addition & 1 deletion examples/common/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
CAMERA_FRAME_STREAMING_RATE = 60.0 # Data collection rate for camera frame

# # Initial neutral pose for robot (degrees)
NEUTRAL_JOINT_ANGLES = [-5.251, 21.356, -41.386, -4.323, 53.374, 0.0]
NEUTRAL_JOINT_ANGLES = [-1.003, 80.167, -51.064, -4.127, 16.548, 2.619]

# Posture task cost vector (one weight per joint)
POSTURE_COST_VECTOR = [0.0, 0.0, 0.0, 0.05, 0.0, 0.0]
Expand Down
8 changes: 7 additions & 1 deletion examples/common/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,12 @@ def set_current_gripper_open_value(self, value: float) -> None:
"""
with self._robot_state._lock:
self._robot_state.current_gripper_open_value = value
if self._on_change_callback:
self._on_change_callback(
"log_parallel_gripper_open_amounts",
value,
time.time(),
)

def get_target_gripper_open_value(self) -> float | None:
"""Get target gripper open value (thread-safe).
Expand All @@ -426,7 +432,7 @@ def set_target_gripper_open_value(self, value: float) -> None:
self._robot_state.target_gripper_open_value = value
if self._on_change_callback:
self._on_change_callback(
"log_parallel_gripper_open_amounts",
"log_parallel_gripper_target_open_amounts",
self._robot_state.target_gripper_open_value,
time.time(),
)
Expand Down
2 changes: 1 addition & 1 deletion meta_quest_teleop