Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 81 additions & 36 deletions examples/4_rollout_neuracore_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,41 +145,59 @@ def run_policy(
policy: nc.policy,
policy_state: PolicyState,
visualizer: RobotVisualizer,
model_input_order: dict[DataType, list[str]],
) -> bool:
"""Handle Run Policy button press to capture state and get policy prediction."""
print("Running policy...")

# Get current joint positions
current_joint_angles = data_manager.get_current_joint_angles()
if current_joint_angles is None:
print("⚠️ No current joint angles available")
return False
# Get available data from data_manager (only log what the model expects)
current_joint_angles = None
gripper_open_value = None
rgb_image = None

# Only log data types that are in model_input_order
if DataType.JOINT_POSITIONS in model_input_order:
current_joint_angles = data_manager.get_current_joint_angles()
if current_joint_angles is not None:
joint_angles_rad = np.radians(current_joint_angles)
joint_positions_dict = {
joint_name: angle
for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad)
}
nc.log_joint_positions(joint_positions_dict)
print(" ✓ Logged joint positions")
else:
print(" ⚠️ No current joint angles available")

# Get current gripper open value
gripper_open_value = data_manager.get_current_gripper_open_value()
if gripper_open_value is None:
print("⚠️ No gripper open value available")
return False
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in model_input_order:
gripper_open_value = data_manager.get_current_gripper_open_value()
if gripper_open_value is not None:
nc.log_parallel_gripper_open_amount(
GRIPPER_LOGGING_NAME, gripper_open_value
)
print(" ✓ Logged gripper open amount")
else:
print(" ⚠️ No gripper open value available")

# Get current RGB image
rgb_image = data_manager.get_rgb_image()
if rgb_image is None:
print("⚠️ No RGB image available")
return False
if DataType.RGB_IMAGES in model_input_order:
rgb_image = data_manager.get_rgb_image()
if rgb_image is not None:
nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image)
print(" ✓ Logged RGB image")
else:
print(" ⚠️ No RGB image available")

# Prepare data for NeuraCore logging
joint_angles_rad = np.radians(current_joint_angles)
joint_positions_dict = {
joint_name: angle for joint_name, angle in zip(JOINT_NAMES, joint_angles_rad)
}
# Check if we have at least some data to run the policy
if (
current_joint_angles is None
and gripper_open_value is None
and rgb_image is None
):
print("✗ No data available to run policy")
return False

# Log joint positions parallel gripper open amounts and RGB image to NeuraCore
# Get policy prediction
try:
nc.log_joint_positions(joint_positions_dict)
nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_open_value)
nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image)

# Get policy prediction
start_time = time.time()
predictions = policy.predict(timeout=5)
prediction_horizon = convert_predictions_to_horizon_dict(predictions)
Expand All @@ -192,9 +210,11 @@ def run_policy(
prediction_ratio = visualizer.get_prediction_ratio()
policy_state.set_execution_ratio(prediction_ratio)

# Set policy inputs
policy_state.set_policy_rgb_image_input(rgb_image)
policy_state.set_policy_state_input(current_joint_angles)
# Set policy inputs (only if available)
if rgb_image is not None:
policy_state.set_policy_rgb_image_input(rgb_image)
if current_joint_angles is not None:
policy_state.set_policy_state_input(current_joint_angles)

# Store prediction horizon actions in policy state
policy_state.set_prediction_horizon(prediction_horizon)
Expand Down Expand Up @@ -290,10 +310,11 @@ def run_and_start_policy_execution(
policy: nc.policy,
policy_state: PolicyState,
visualizer: RobotVisualizer,
model_input_order: dict[DataType, list[str]],
) -> None:
"""Handle Run and Execute Policy button press to capture state, get policy prediction, and immediately execute it."""
print("Run and Execute Policy for one prediction horizon")
run_policy(data_manager, policy, policy_state, visualizer)
run_policy(data_manager, policy, policy_state, visualizer, model_input_order)
start_policy_execution(data_manager, policy_state)


Expand All @@ -318,14 +339,17 @@ def play_policy(
policy: nc.policy,
policy_state: PolicyState,
visualizer: RobotVisualizer,
model_input_order: dict[DataType, list[str]],
) -> None:
"""Handle Play Policy button press to start/stop continuous policy execution."""
if not policy_state.get_continuous_play_active():
# Start continuous play
print("▶️ Play Policy button pressed - Starting continuous policy execution...")

# Run policy to get prediction horizon
success = run_policy(data_manager, policy, policy_state, visualizer)
success = run_policy(
data_manager, policy, policy_state, visualizer, model_input_order
)
if not success:
print("⚠️ Failed to run policy")
end_policy_play(
Expand Down Expand Up @@ -368,6 +392,7 @@ def policy_execution_thread(
policy_state: PolicyState,
robot_controller: PiperController,
visualizer: RobotVisualizer,
model_input_order: dict[DataType, list[str]],
) -> None:
"""Policy execution thread."""
dt_execution = 1.0 / POLICY_EXECUTION_RATE
Expand Down Expand Up @@ -476,7 +501,13 @@ def policy_execution_thread(
# End policy execution to clear input lock
policy_state.end_policy_execution()
# Run policy to get prediction horizon
success = run_policy(data_manager, policy, policy_state, visualizer)
success = run_policy(
data_manager,
policy,
policy_state,
visualizer,
model_input_order,
)
if not success:
print("⚠️ Failed to run policy")
end_policy_play(
Expand Down Expand Up @@ -813,18 +844,25 @@ def update_visualization(
)
visualizer.set_go_home_callback(lambda: home_robot(data_manager, robot_controller))
visualizer.set_run_policy_callback(
lambda: (run_policy(data_manager, policy, policy_state, visualizer), None)[1]
lambda: (
run_policy(
data_manager, policy, policy_state, visualizer, model_input_order
),
None,
)[1]
)
visualizer.set_start_policy_execution_callback(
lambda: (start_policy_execution(data_manager, policy_state), None)[1]
)
visualizer.set_run_and_start_policy_execution_callback(
lambda: run_and_start_policy_execution(
data_manager, policy, policy_state, visualizer
data_manager, policy, policy_state, visualizer, model_input_order
)
)
visualizer.set_play_policy_callback(
lambda: play_policy(data_manager, policy, policy_state, visualizer)
lambda: play_policy(
data_manager, policy, policy_state, visualizer, model_input_order
)
)
# Set up execution mode dropdown callback to sync with PolicyState
visualizer.set_execution_mode_callback(
Expand All @@ -846,7 +884,14 @@ def update_visualization(
print("\n🤖 Starting policy execution thread...")
policy_execution_thread_obj = threading.Thread(
target=policy_execution_thread,
args=(policy, data_manager, policy_state, robot_controller, visualizer),
args=(
policy,
data_manager,
policy_state,
robot_controller,
visualizer,
model_input_order,
),
daemon=True,
)
policy_execution_thread_obj.start()
Expand Down
46 changes: 30 additions & 16 deletions examples/6_visualize_policy_from_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,17 @@
dataset = nc.get_dataset(args.dataset_name)
print(f" ✓ Dataset loaded: {len(dataset)} episodes")

# Get data types from model input and output
required_data_types = set(model_input_order.keys()) | set(model_output_order.keys())

# Filter data spec to only include required data types
robot_data_spec: RobotDataSpec = {
robot_id: dataset.get_full_data_spec(robot_id) for robot_id in dataset.robot_ids
robot_id: {
data_type: names
for data_type, names in dataset.get_full_data_spec(robot_id).items()
if data_type in required_data_types
}
for robot_id in dataset.robot_ids
}

print("🔁 Synchronizing dataset...")
Expand Down Expand Up @@ -165,29 +174,30 @@ def select_random_state() -> None:
for joint_name in JOINT_NAMES:
if joint_name in joint_data:
joint_positions_dict[joint_name] = joint_data[joint_name].value
# Log to NeuraCore for visualization
nc.log_joint_positions(joint_positions_dict)

# Extract gripper
gripper_value = 1.0
if DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS in step.data:
gripper_data = step.data[DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS]
if GRIPPER_LOGGING_NAME in gripper_data:
gripper_value = gripper_data[GRIPPER_LOGGING_NAME].open_amount
# Log to NeuraCore for visualization
nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_value)

# Extract RGB image
rgb_image = None
if DataType.RGB_IMAGES in step.data:
rgb_data = step.data[DataType.RGB_IMAGES]
if CAMERA_LOGGING_NAME in rgb_data:
rgb_image = np.array(rgb_data[CAMERA_LOGGING_NAME].frame)

if rgb_image is None:
print("⚠️ No RGB image found")
return

# Log to NeuraCore
nc.log_joint_positions(joint_positions_dict)
nc.log_parallel_gripper_open_amount(GRIPPER_LOGGING_NAME, gripper_value)
nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image)
# Save image to file for visualization
image_pil = Image.fromarray(rgb_image)
image_pil.save("current_image.png")
print("💾 Saved image to current_image.png")
# Log to NeuraCore for visualization
nc.log_rgb(CAMERA_LOGGING_NAME, rgb_image)

# Get policy prediction
print("🎯 Getting policy prediction...")
Expand All @@ -197,13 +207,9 @@ def select_random_state() -> None:
playing = True
print("FINISHED PREDICTION")

# Save image to file
image_pil = Image.fromarray(rgb_image)
image_pil.save("current_image.png")
print("💾 Saved image to current_image.png")
# Update robot to initial pose from first step in the horizon

# Update robot to initial pose
joint_positions = np.array([joint_positions_dict[jn] for jn in JOINT_NAMES])
joint_positions = np.array([current_horizon[jn][0] for jn in JOINT_NAMES])
urdf_vis.update_cfg(joint_positions)

print(
Expand Down Expand Up @@ -258,6 +264,14 @@ def select_random_state() -> None:
)
urdf_vis.update_cfg(joint_config)

# Log to NeuraCore for visualization
# NOTE: we log to joint positions instead of joint target positions
# because the latter is not visualized by Neuracore
joint_config_dict = {
jn: joint_config[i] for i, jn in enumerate(JOINT_NAMES)
}
nc.log_joint_positions(joint_config_dict)

# Update gripper value
gripper_value = current_horizon[GRIPPER_LOGGING_NAME][
current_action_idx
Expand Down