-
Notifications
You must be signed in to change notification settings - Fork 66
LIBERO posttrain problem #49
Copy link
Copy link
Open
Description
Question: Training performance issues on LIBERO-Spatial task with 60Hz recollected dataset
Hi, thank you sharing your work!
I am trying to do posttrain on a single task from LIBERO-Spatial. I recollected the dataset into 60Hz. And for action, I used absolute eef pose of the next frame as action. I didn't downsample the video before extracting the latents since I noticed that in the libero dataset you use, you didn't downsample the video.
I trained for 600 steps in the case. The curves are as follows:
The result is not working well. I am wondering if the problem is about my dataset or it's just I didn't train for enough time. I upload my dataset here: https://huggingface.co/datasets/yunju-15/lingbot_libero
My training configs are as follows:
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
from easydict import EasyDict
from .va_libero_cfg import va_libero_cfg
import os
va_libero_train_cfg = EasyDict(__name__='Config: VA libero train')
va_libero_train_cfg.update(va_libero_cfg)
# va_libero_train_cfg.resume_from = '/robby/share/Robotics/lilin1/code/Wan_VA_Release/train_out/checkpoints/checkpoint_step_10'
va_libero_train_cfg.dataset_path = '/data1/yunju/lingbot-va/dataset/libero_60'
va_libero_train_cfg.empty_emb_path = os.path.join(va_libero_train_cfg.dataset_path, 'empty_emb.pt')
va_libero_train_cfg.enable_wandb = True
va_libero_train_cfg.load_worker = 16
va_libero_train_cfg.save_interval = 100
va_libero_train_cfg.gc_interval = 50
va_libero_train_cfg.cfg_prob = 0.1
# Training parameters
va_libero_train_cfg.learning_rate = 1e-5
va_libero_train_cfg.beta1 = 0.9
va_libero_train_cfg.beta2 = 0.95
va_libero_train_cfg.weight_decay = 1e-1
va_libero_train_cfg.warmup_steps = 10
va_libero_train_cfg.batch_size = 1
va_libero_train_cfg.gradient_accumulation_steps = 8
va_libero_train_cfg.num_steps = 3000
libero_cfg:
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
import torch
from easydict import EasyDict
from .shared_config import va_shared_cfg
va_libero_cfg = EasyDict(__name__='Config: VA libero')
va_libero_cfg.update(va_shared_cfg)
va_shared_cfg.infer_mode = 'server'
va_libero_cfg.wan22_pretrained_model_name_or_path = "/data1/yunju/lingbot-va/models/lingbot-va-base"
va_libero_cfg.attn_window = 30
va_libero_cfg.frame_chunk_size = 4
va_libero_cfg.env_type = 'libero_env'
va_libero_cfg.height = 256
va_libero_cfg.width = 256
va_libero_cfg.action_dim = 30
va_libero_cfg.action_per_frame = 4
va_libero_cfg.obs_cam_keys = [
'observation.images.cam_high', 'observation.images.cam_wrist'
]
va_libero_cfg.guidance_scale = 5
va_libero_cfg.action_guidance_scale = 1
va_libero_cfg.num_inference_steps = 20
va_libero_cfg.video_exec_step = -1
va_libero_cfg.action_num_inference_steps = 50
va_libero_cfg.snr_shift = 5.0
va_libero_cfg.action_snr_shift = 1.0
va_libero_cfg.used_action_channel_ids = list(range(0, 7)) + list(range(28, 29))
inverse_used_action_channel_ids = [len(va_libero_cfg.used_action_channel_ids)
] * va_libero_cfg.action_dim
for i, j in enumerate(va_libero_cfg.used_action_channel_ids):
inverse_used_action_channel_ids[j] = i
va_libero_cfg.inverse_used_action_channel_ids = inverse_used_action_channel_ids
va_libero_cfg.action_norm_method = 'quantiles'
va_libero_cfg.norm_stat = {
"q01": [
0.0,
0.00031695778481662275,
-0.263985458612442,
-0.11008505668947856,
-0.07417451706048174,
-0.06287607050280382,
0.8795063851028853
] + [0.0] * 21 + [-1.0, 0.0],
"q99": [
0.31398141406476493,
0.19865495219826695,
0.014632310867309525,
0.11908656035632535,
0.2048846753995732,
0.45741868402168506,
0.9999982499824769,
] + [0.0] * 21 + [1.0, 0.0],
}
```pythonReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels