diff --git a/README.md b/README.md index f2df2d6..4f72e90 100644 --- a/README.md +++ b/README.md @@ -229,6 +229,44 @@ We provide an example hardware code in [this file](scripts/agilex_inference.py) Note: If you want to deploy on the Mobile ALOHA robot, don't forget to install the hardware prerequisites (see [this repo](https://github.com/MarkFzp/mobile-aloha)). +## Deployment on SimplerEnv + 1. Set Required Parameters in `scripts/encode_lang.py` + + ```python + # ... + + GPU = 0 + MODEL_PATH = "google/t5-v1_1-xxl" + CONFIG_PATH = "configs/base.yaml" + SAVE_DIR = "outs/" # output directory + + # Modify this to your task name and instruction + TASK_NAME = "widowx_pick_up_the_spoon" + INSTRUCTION = "put the spoon on the towel." + + # Note: if your GPU VRAM is less than 24GB, + # it is recommanded to enable offloading by specifying an offload directory. + OFFLOAD_DIR = None # Specify your offload directory here, ensuring the directory exists. + + # ... + ``` + + 2. Run the scipt + ``` + python -m scripts.encode_lang + ``` + + 3. Run the inference script + ```bash + python -m scripts.simplerenv_inference \ + --config_path configs/base.yaml \ + --pretrained_model_name_or_path robotics-diffusion-transformer/rdt-1b \ # or robotics-diffusion-transformer/rdt-170m + --lang_embeddings_path \ # e.g. outs/widowx_pick_up_the_spoon.pt + --ctrl_freq 5 \ + --chunk_size 64 \ + --env_name widowx_spoon_on_towel + ``` + ## Citation If you find our work helpful, please cite us: diff --git a/inference.sh b/inference.sh index 0e7f8c8..b5622f7 100644 --- a/inference.sh +++ b/inference.sh @@ -3,3 +3,12 @@ python -m scripts.agilex_inference \ --pretrained_model_name_or_path="checkpoints/your_finetuned_ckpt.pt" \ # your finetuned checkpoint: e.g., checkpoints/rdt-finetune-1b/checkpoint-, checkpoints/rdt-finetune-1b/checkpoint-/pytorch_model/mp_rank_00_model_states.pt, --lang_embeddings_path="outs/lang_embeddings/your_instr.pt" \ --ctrl_freq=25 # your control frequency + +# inference at simplerenv +python -m scripts.simplerenv_inference \ + --config_path configs/base.yaml \ + --pretrained_model_name_or_path robotics-diffusion-transformer/rdt-1b \ + --lang_embeddings_path outs/widowx_pick_up_the_spoon.pt \ + --ctrl_freq 5 \ + --chunk_size 64 \ + --env_name widowx_spoon_on_towel \ No newline at end of file diff --git a/scripts/simplerenv_inference.py b/scripts/simplerenv_inference.py new file mode 100644 index 0000000..6fc2302 --- /dev/null +++ b/scripts/simplerenv_inference.py @@ -0,0 +1,164 @@ +import argparse +import torch +from PIL import Image +import numpy as np +import simpler_env +from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict +from widowx_model import create_model +import yaml +import mediapy +from scipy.spatial.transform import Rotation as R + +CAMERA_NAMES = ['cam_high', 'cam_right_wrist', 'cam_left_wrist'] + +def get_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--config_path', type=str, default="configs/base.yaml", + help='Path to the config file') + parser.add_argument('--pretrained_model_name_or_path', type=str, required=True, + help='Name or path to the pretrained model') + parser.add_argument('--lang_embeddings_path', type=str, required=True, + help='Path to the pre-encoded language instruction embeddings') + parser.add_argument('--env_name', type=str, default="widowx_spoon_on_towel", + help='Name of env') + parser.add_argument('--ctrl_freq', type=int, default=25, + help='The control frequency of the robot') + parser.add_argument('--chunk_size', type=int, default=64, + help='Action chunk size') + return parser.parse_args() + +def make_policy(args): + pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384" + with open(args.config_path, "r") as fp: + config = yaml.safe_load(fp) + model = create_model( + args=config, + dtype=torch.bfloat16, + pretrained=args.pretrained_model_name_or_path, + pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path, + control_frequency=args.ctrl_freq, + ) + return model + +def rotation_matrix_to_ortho6d_1d(matrix: np.array) -> np.array: + """ + The orhto6d represents the first two column vectors a1 and a2 of the + rotation matrix: [ | , |, | ] + [ a1, a2, a3] + [ | , |, | ] + Input: (3, 3) + Output: (6,) + """ + ortho6d = matrix[:, :2] + ortho6d = ortho6d.T + ortho6d = ortho6d.reshape([6]) + return ortho6d + +def ortho6d_to_rotation_matrix(ortho6d): + """ Convert ortho6d (shape (6,)) to a rotation matrix (3x3). """ + return R.from_euler('xyz', ortho6d[:3]).as_matrix() + + +def model_inference(args, env): + frames = [] + policy = make_policy(args) + + lang_dict = torch.load(args.lang_embeddings_path) + print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"") + args.lang_instruction = lang_dict['instruction'] + args.lang_name = lang_dict['name'] + lang_embeddings = lang_dict["embeddings"] + + obs, reset_info = env.reset() + + print("Reset info", reset_info) + + done, truncated = False, False + t = 0 + action_buffer = np.zeros([args.chunk_size, env.action_space.shape[0]]) + """ + obs.keys(): ['agent', 'extra', 'camera_param', 'image'] + obs['agent'].keys(): ['qpos', 'qvel', 'controller', 'base_pose'] + tcp_pose: 3 pos + 4 quat(wxyz) + """ + + env.env.env.env.agent.controller.controllers['arm'].config.use_delta = False # absolute control + env.env.env.env.agent.controller.controllers['arm'].config.frame = 'base' # absolute control + + while not (done or truncated): + if t % args.chunk_size == 0: + images = [None, None, None,get_image_from_maniskill2_obs_dict(env, obs), None, None] + images = [Image.fromarray(img) if img is not None else None for img in images] + proprio = torch.from_numpy(obs['agent']['qpos']).float().cuda().unsqueeze(0) + proprio = torch.cat([proprio, torch.from_numpy(obs['agent']['qvel']).float().cuda().unsqueeze(0)], dim=1) + tcp_pose = env.env.env.env.tcp.pose # by default, tcp pose (eef pose) world coordinate + tcp_pose = env.env.env.env.agent.robot.pose.inv() * tcp_pose # eef pose at base see https://github.com/simpler-env/SimplerEnv/blob/d55e19162be86794875839725fd484b768e25873/tools/sysid/sysid.py#L51 + eef_xyz = tcp_pose.p # see https://github.com/haosulab/ManiSkill/blob/main/mani_skill/utils/structs/pose.py#L94 + + proprio = torch.cat([proprio, torch.from_numpy(eef_xyz).float().cuda().unsqueeze(0)], dim=1) + # quat = tcp_pose[3:] + quat = tcp_pose.q + + rr = R.from_quat(quat, scalar_first=True) + rotmat = rr.as_matrix() + ortho6d = rotation_matrix_to_ortho6d_1d(rotmat) + + + proprio = torch.cat([proprio, torch.from_numpy(ortho6d).float().cuda().unsqueeze(0)], dim=1) + + action_buffer = policy.step( + proprio=proprio, + images=images, + text_embeds=lang_embeddings + ).squeeze(0).cpu().numpy() + + # absolute control + action = action_buffer[t % args.chunk_size] + gripper_action = action[-1] + out_eef_xyz = action[:3] + out_ortho6d = action[3:-1] + out_rot_matrix = ortho6d_to_rotation_matrix(out_ortho6d) + out_r = R.from_matrix(out_rot_matrix) + out_axis_angle = out_r.as_rotvec() + + action = np.concatenate([out_eef_xyz, out_axis_angle, [gripper_action]]) + # action[3:] = action[3:] * 0.01 + # action[:3] = action[:3] * 0.1 + print(f"action={action}") + + obs, reward, done, truncated, info = env.step(action) + + frames.append(get_image_from_maniskill2_obs_dict(env, obs)) + new_instruction = env.get_language_instruction() + if new_instruction != args.lang_instruction: + args.lang_instruction = new_instruction + print("New Instruction", args.lang_instruction) + + t += 1 + # print("Step", t) + + episode_stats = info.get('episode_stats', {}) + print("Episode stats", episode_stats) + return frames + +def main(): + args = get_arguments() + env_name = args.env_name + env = simpler_env.make(env_name) + + # env = simpler_env.make('widowx_spoon_on_towel') + # env = simpler_env.make('widowx_carrot_on_plate') + # env = simpler_env.make('widowx_stack_cube') + # env = simpler_env.make('widowx_put_eggplant_in_basket') + + frames = model_inference(args, env) + + model_name = args.pretrained_model_name_or_path.split("/")[-1] + lang_name = args.lang_name + save_path = f"outs/{model_name}_{lang_name}_{args.ctrl_freq}_chunk_{args.chunk_size}.mp4" + mediapy.write_video(save_path, frames, fps=10) + print("save at ", end="") + print(save_path) + +if __name__ == '__main__': + main() diff --git a/scripts/widowx_model.py b/scripts/widowx_model.py new file mode 100644 index 0000000..9c05e7a --- /dev/null +++ b/scripts/widowx_model.py @@ -0,0 +1,310 @@ +import os + +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from configs.state_vec import STATE_VEC_IDX_MAPPING +from models.multimodal_encoder.siglip_encoder import SiglipVisionTower +from models.multimodal_encoder.t5_encoder import T5Embedder +from models.rdt_runner import RDTRunner + +# The indices that the raw vector should be mapped to in the unified action vector +WIDOWX_STATE_INDICES = [ + STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(7) +] + [ + STATE_VEC_IDX_MAPPING[f"right_gripper_open"] +] + [ + STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_vel"] for i in range(7) +] + [ + STATE_VEC_IDX_MAPPING["right_gripper_open_vel"] +] + [ + STATE_VEC_IDX_MAPPING["right_eef_pos_x"], + STATE_VEC_IDX_MAPPING["right_eef_pos_y"], + STATE_VEC_IDX_MAPPING["right_eef_pos_z"], + STATE_VEC_IDX_MAPPING["right_eef_angle_0"], + STATE_VEC_IDX_MAPPING["right_eef_angle_1"], + STATE_VEC_IDX_MAPPING["right_eef_angle_2"], + STATE_VEC_IDX_MAPPING["right_eef_angle_3"], + STATE_VEC_IDX_MAPPING["right_eef_angle_4"], + STATE_VEC_IDX_MAPPING["right_eef_angle_5"], +] + +def create_model(args, **kwargs): + model = RoboticDiffusionTransformerModel(args, **kwargs) + pretrained = kwargs.get("pretrained", None) + if ( + pretrained is not None + and os.path.isfile(pretrained) + ): + model.load_pretrained_weights(pretrained) + return model + +class RoboticDiffusionTransformerModel(object): + """A wrapper for the RDT model, which handles + 1. Model initialization + 2. Encodings of instructions + 3. Model inference + """ + def __init__( + self, args, + device='cuda', + dtype=torch.bfloat16, + image_size=None, + control_frequency=25, + pretrained=None, + pretrained_vision_encoder_name_or_path=None, + ): + self.args = args + self.dtype = dtype + self.image_size = image_size + self.device = device + self.control_frequency = control_frequency + # We do not use the text encoder due to limited GPU memory + # self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path) + self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path) + self.policy = self.get_policy(pretrained) + + self.reset() + + def get_policy(self, pretrained): + """Initialize the model.""" + # Initialize model with arguments + if ( + pretrained is None + or os.path.isfile(pretrained) + ): + img_cond_len = (self.args["common"]["img_history_size"] + * self.args["common"]["num_cameras"] + * self.vision_model.num_patches) + + _model = RDTRunner( + action_dim=self.args["common"]["state_dim"], + pred_horizon=self.args["common"]["action_chunk_size"], + config=self.args["model"], + lang_token_dim=self.args["model"]["lang_token_dim"], + img_token_dim=self.args["model"]["img_token_dim"], + state_token_dim=self.args["model"]["state_token_dim"], + max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"], + img_cond_len=img_cond_len, + img_pos_embed_config=[ + # No initial pos embed in the last grid size + # since we've already done in ViT + ("image", (self.args["common"]["img_history_size"], + self.args["common"]["num_cameras"], + -self.vision_model.num_patches)), + ], + lang_pos_embed_config=[ + # Similarly, no initial pos embed for language + ("lang", -self.args["dataset"]["tokenizer_max_length"]), + ], + dtype=self.dtype, + ) + else: + _model = RDTRunner.from_pretrained(pretrained) + + return _model + + def get_text_encoder(self, pretrained_text_encoder_name_or_path): + text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path, + model_max_length=self.args["dataset"]["tokenizer_max_length"], + device=self.device) + tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model + return tokenizer, text_encoder + + def get_vision_encoder(self, pretrained_vision_encoder_name_or_path): + vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None) + image_processor = vision_encoder.image_processor + return image_processor, vision_encoder + + def reset(self): + """Set model to evaluation mode. + """ + device = self.device + weight_dtype = self.dtype + self.policy.eval() + # self.text_model.eval() + self.vision_model.eval() + + self.policy = self.policy.to(device, dtype=weight_dtype) + # self.text_model = self.text_model.to(device, dtype=weight_dtype) + self.vision_model = self.vision_model.to(device, dtype=weight_dtype) + + def load_pretrained_weights(self, pretrained=None): + if pretrained is None: + return + print(f'Loading weights from {pretrained}') + filename = os.path.basename(pretrained) + if filename.endswith('.pt'): + checkpoint = torch.load(pretrained) + self.policy.load_state_dict(checkpoint["module"]) + elif filename.endswith('.safetensors'): + from safetensors.torch import load_model + load_model(self.policy, pretrained) + else: + raise NotImplementedError(f"Unknown checkpoint format: {pretrained}") + + def encode_instruction(self, instruction, device="cuda"): + """Encode string instruction to latent embeddings. + + Args: + instruction: a string of instruction + device: a string of device + + Returns: + pred: a tensor of latent embeddings of shape (text_max_length, 512) + """ + tokens = self.text_tokenizer( + instruction, return_tensors="pt", + padding="longest", + truncation=True + )["input_ids"].to(device) + + tokens = tokens.view(1, -1) + with torch.no_grad(): + pred = self.text_model(tokens).last_hidden_state.detach() + + return pred + + def _format_joint_to_state(self, joints): + """ + Format the joint proprioception into the unified action vector. + + Args: + joints (torch.Tensor): The joint proprioception to be formatted. + qpos ([B, N, 14]). + + Returns: + state (torch.Tensor): The formatted vector for RDT ([B, N, 128]). + """ + joints = joints / torch.tensor( + [[[1, 1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 1, 1, 1]]], + device=joints.device, dtype=joints.dtype + ) # HJ + + B, N, _ = joints.shape + state = torch.zeros( + (B, N, self.args["model"]["state_token_dim"]), + device=joints.device, dtype=joints.dtype + ) + # Fill into the unified state vector + state[:, :, WIDOWX_STATE_INDICES] = joints + # Assemble the mask indicating each dimension's availability + state_elem_mask = torch.zeros( + (B, self.args["model"]["state_token_dim"]), + device=joints.device, dtype=joints.dtype + ) + state_elem_mask[:, WIDOWX_STATE_INDICES] = 1 + return state, state_elem_mask + + def _unformat_action_to_joint(self, action): + """ + Unformat the unified action vector into the joint action to be executed. + + Args: + action (torch.Tensor): The unified action vector to be unformatted. + ([B, N, 128]) + + Returns: + joints (torch.Tensor): The unformatted robot joint action. + qpos ([B, N, 14]). + """ + action_indices = WIDOWX_STATE_INDICES[-9:] + [STATE_VEC_IDX_MAPPING[f"right_gripper_open"]] + joints = action[:, :, action_indices] + + joints = joints * torch.tensor( + [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 11.8997]]], + device=joints.device, dtype=joints.dtype + ) + + return joints + + @torch.no_grad() + def step(self, proprio, images, text_embeds): + """ + Predict the next action chunk given the + proprioceptive states, images, and instruction embeddings. + + Args: + proprio: proprioceptive states, specially qpos(7+1), qvel(7+1), tcp_pose(6+1) + images: RGB images, the order should be + [ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1}, + ext_{t}, right_wrist_{t}, left_wrist_{t}] + text_embeds: instruction embeddings + + Returns: + action: predicted action + """ + device = self.device + dtype = self.dtype + # The background image used for padding + background_color = np.array([ + int(x*255) for x in self.image_processor.image_mean + ], dtype=np.uint8).reshape(1, 1, 3) + background_image = np.ones(( + self.image_processor.size["height"], + self.image_processor.size["width"], 3), dtype=np.uint8 + ) * background_color + + # Preprocess the images by order and encode them + image_tensor_list = [] + for image in images: + if image is None: + # Replace it with the background image + image = Image.fromarray(background_image) + + if self.image_size is not None: + image = transforms.Resize(self.data_args.image_size)(image) + + if self.args["dataset"].get("auto_adjust_image_brightness", False): + pixel_values = list(image.getdata()) + average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) + if average_brightness <= 0.15: + image = transforms.ColorJitter(brightness=(1.75,1.75))(image) + + if self.args["dataset"].get("image_aspect_ratio", "pad") == 'pad': + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean)) + image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_tensor_list.append(image) + + image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype) + + image_embeds = self.vision_model(image_tensor).detach() + image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0) + + # Prepare the proprioception states and the control frequency + joints = proprio.to(device).unsqueeze(0) # (1, 1, 25) + states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128) + states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype) + states = states[:, -1:, :] # (1, 1, 128) + ctrl_freqs = torch.tensor([self.control_frequency]).to(device) + + text_embeds = text_embeds.to(device, dtype=dtype) + + # Predict the next action chunk given the inputs + trajectory = self.policy.predict_action( + lang_tokens=text_embeds, + lang_attn_mask=torch.ones( + text_embeds.shape[:2], dtype=torch.bool, + device=text_embeds.device), + img_tokens=image_embeds, + state_tokens=states, + action_mask=state_elem_mask.unsqueeze(1), + ctrl_freqs=ctrl_freqs + ) + trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32) + + return trajectory