diff --git a/rsl_rl/modules/student_teacher.py b/rsl_rl/modules/student_teacher.py index 82b3ef62..ec6beecc 100644 --- a/rsl_rl/modules/student_teacher.py +++ b/rsl_rl/modules/student_teacher.py @@ -64,7 +64,6 @@ def __init__( # Teacher self.teacher = MLP(num_teacher_obs, num_actions, teacher_hidden_dims, activation) - self.teacher.eval() print(f"Teacher MLP: {self.teacher}") # Teacher observation normalization diff --git a/rsl_rl/runners/distillation_runner.py b/rsl_rl/runners/distillation_runner.py index 6f9c502b..02aeb95b 100644 --- a/rsl_rl/runners/distillation_runner.py +++ b/rsl_rl/runners/distillation_runner.py @@ -58,6 +58,7 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None: # Initialize writer self._prepare_logging_writer() + # Check if teacher is loaded if not self.alg.policy.loaded_teacher: raise ValueError("Teacher model parameters not loaded. Please load a teacher model to distill.")