@@ -109,9 +109,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
109109 if inputs .images :
110110 input_ids = encoded ['input_ids' ]
111111 labels = encoded ['labels' ]
112+ loss_scale = encoded .get ('loss_scale' , None )
112113 idx_list = findall (input_ids , self .boi_token_id )
113114 img_tokens = self ._tokenize (self .processor .full_image_sequence )
114115 input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , lambda _ : img_tokens )
116+ loss_scale = self ._extend_loss_scale (loss_scale , idx_list , lambda _ : img_tokens )
115117
116118 # TODO: customize
117119 processor_kwargs = Gemma3ProcessorKwargs ._defaults ['images_kwargs' ]
@@ -126,6 +128,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
126128 encoded ['input_ids' ] = input_ids
127129 encoded ['pixel_values' ] = image_inputs ['pixel_values' ]
128130 encoded ['labels' ] = labels
131+ encoded ['loss_scale' ] = loss_scale
129132 return encoded
130133
131134
@@ -158,6 +161,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
158161 processor = self .processor
159162 input_ids = encoded ['input_ids' ]
160163 labels = encoded ['labels' ]
164+ loss_scale = encoded .get ('loss_scale' , None )
161165
162166 # Initialize token_type_ids and other outputs
163167 array_ids = np .array (input_ids )
@@ -168,6 +172,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
168172 idx_list = findall (input_ids , self .boi_token_id )
169173 img_tokens = self ._tokenize (processor .full_image_sequence )
170174 input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , lambda _ : img_tokens )
175+ loss_scale = self ._extend_loss_scale (loss_scale , idx_list , lambda _ : img_tokens )
171176
172177 # Process images
173178 processor_kwargs = Gemma3nProcessorKwargs ._defaults .get ('images_kwargs' , {})
@@ -184,6 +189,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
184189 # Get audio token sequence from processor
185190 audio_tokens = self ._tokenize (processor .full_audio_sequence )
186191 input_ids , labels = self ._extend_tokens (input_ids , labels , audio_idx_list , lambda _ : audio_tokens )
192+ loss_scale = self ._extend_loss_scale (loss_scale , audio_idx_list , lambda _ : audio_tokens )
187193
188194 # Process audios
189195 processor_kwargs = Gemma3nProcessorKwargs ._defaults .get ('audio_kwargs' , {})
@@ -209,7 +215,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
209215 encoded ['token_type_ids' ] = mm_token_type_ids .tolist ()
210216 encoded ['input_ids' ] = input_ids
211217 encoded ['labels' ] = labels
212-
218+ encoded [ 'loss_scale' ] = loss_scale
213219 return encoded
214220
215221 def _data_collator_mm_data (self , batch : List [Dict [str , Any ]]) -> Dict [str , Any ]:
0 commit comments