diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index 6ce27fdf4..dc092ab4e 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -307,7 +307,7 @@ def compute_metrics( r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here def f_teacher(x, t): - o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + o = self._apply_subnet(x, self.time_emb(t), conditions, training=stage == "training") return self.subnet_projector(o) primals = (xt / self.sigma, t) @@ -321,7 +321,7 @@ def f_teacher(x, t): cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt) # calculate output of the network - subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + subnet_out = self._apply_subnet(xt / self.sigma, self.time_emb(t), conditions, training=stage == "training") student_out = self.subnet_projector(subnet_out) # calculate the tangent