From 878c426345abfa18bc9be3273ea98bc501f9999b Mon Sep 17 00:00:00 2001 From: chuc92man Date: Thu, 30 May 2024 09:12:56 +0900 Subject: [PATCH] Update models_mamba.py While trying to train vision mamba with bidirectional mode in masked autoencoder network, I experienced nan loss. Though switch training from mixed precision to full precision fixed the problem but significantly increased training time (almost twice). Looking at the code, the adding of forward and backward hidden_states/residuals does increased the magnitude of both twice as compared to the original hidden states (after patch embedding). By dividing by 2, nan loss was resolved and mixed precision training can continue. --- vim/models_mamba.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vim/models_mamba.py b/vim/models_mamba.py index cb10774f..db860fc3 100644 --- a/vim/models_mamba.py +++ b/vim/models_mamba.py @@ -492,8 +492,8 @@ def forward_features(self, x, inference_params=None, if_random_cls_token_positio hidden_states_b, residual_b = self.layers[i * 2 + 1]( hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params ) - hidden_states = hidden_states_f + hidden_states_b.flip([1]) - residual = residual_f + residual_b.flip([1]) + hidden_states = (hidden_states_f + hidden_states_b.flip([1])) / 2 + residual = (residual_f + residual_b.flip([1])) / 2 if not self.fused_add_norm: if residual is None: @@ -597,4 +597,4 @@ def vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_m map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"]) - return model \ No newline at end of file + return model