diff --git a/QR-DQN/QR-DQN.py b/QR-DQN/QR-DQN.py index e57d303..dcb146c 100644 --- a/QR-DQN/QR-DQN.py +++ b/QR-DQN/QR-DQN.py @@ -47,8 +47,8 @@ def __init__(self, state_dim, action_dim): self.action_dim = action_dim self.atoms = args.atoms self.tau = [(2*(i-1)+1)/(2*self.atoms) for i in range(1, self.atoms+1)] - self.huber_loss = tf.keras.losses.Huber( - reduction=tf.keras.losses.Reduction.NONE) + # self.huber_loss = tf.keras.losses.Huber( + # reduction=tf.keras.losses.Reduction.NONE) self.opt = tf.keras.optimizers.Adam(args.lr) self.model = self.create_model() @@ -61,6 +61,17 @@ def create_model(self): Reshape([self.action_dim, self.atoms]) ]) + + def huber_loss(self, y_true, y_pred, clip_delta=1.0): + error = y_true - y_pred + cond = tf.keras.backend.abs(error) < clip_delta + + squared_loss = 0.5 * tf.keras.backend.square(error) + linear_loss = clip_delta * (tf.keras.backend.abs(error) - 0.5 * clip_delta) + + return tf.where(cond, squared_loss, linear_loss) + + def quantile_huber_loss(self, target, pred, actions): pred = tf.reduce_sum(pred * tf.expand_dims(actions, -1), axis=1) pred_tile = tf.tile(tf.expand_dims(pred, axis=2), [1, 1, self.atoms])