55
66from  keras .src  import  backend 
77from  keras .src  import  ops 
8+ from  keras .src  import  tree 
89from  keras .src .api_export  import  keras_export 
910from  keras .src .callbacks .monitor_callback  import  (
1011    MonitorCallback ,  # For metric monitoring logic 
1112)
13+ from  keras .src .distribution .distribution_lib  import  process_id 
1214from  keras .src .utils .io_utils  import  print_msg 
13- from  keras .src .utils .module_utils  import  LazyModule 
14- 
15- ocp  =  LazyModule (
16-     "orbax.checkpoint" ,
17-     pip_name = "orbax-checkpoint" ,
18-     import_error_msg = (
19-         "OrbaxCheckpoint requires the 'orbax-checkpoint' package. " 
20-         "Install it with: pip install orbax-checkpoint" 
21-     ),
22- )
23- 
24- # Note: Advanced Orbax functionality is available through the ocp LazyModule 
25- # Users can access it via: from keras.src.utils.module_utils import LazyModule 
26- # ocp = LazyModule("orbax.checkpoint"); ocp.CheckpointManager 
15+ from  keras .src .utils .module_utils  import  ocp 
2716
2817
2918def  _get_state_tree (model ):
@@ -38,68 +27,49 @@ def convert_scalars(obj):
3827        elif  isinstance (obj , np .generic ):
3928            # Convert numpy scalar types (like np.float32) to Python types 
4029            return  obj .item ()
41-         elif  isinstance (obj , dict ):
42-             return  {k : convert_scalars (v ) for  k , v  in  obj .items ()}
4330        else :
4431            return  obj 
4532
46-     return  convert_scalars ( state_tree )
33+     return  tree . map_structure ( convert_scalars ,  state_tree )
4734
4835
4936def  _flatten_state_tree_values (state_tree ):
5037    """Flatten nested state tree into a list of values in consistent order.""" 
51-     values  =  []
52- 
53-     def  _flatten (obj ):
54-         if  isinstance (obj , dict ):
55-             for  key  in  sorted (obj .keys ()):  # Sort for consistent ordering 
56-                 _flatten (obj [key ])
57-         else :
58-             # Save any non-dict value (numpy arrays, lists, scalars, etc.) 
59-             values .append (obj )
60- 
61-     _flatten (state_tree )
62-     return  values 
38+     return  tree .flatten (state_tree )
6339
6440
6541def  _reconstruct_state_tree_with_values (structure , values ):
6642    """Reconstruct state tree structure with provided values.""" 
6743    value_iter  =  iter (values )
6844
69-     def  _reconstruct (obj ):
70-         if  isinstance (obj , dict ):
71-             new_dict  =  {}
72-             for  key  in  sorted (obj .keys ()):
73-                 new_dict [key ] =  _reconstruct (obj [key ])
74-             return  new_dict 
75-         else :
76-             value  =  next (value_iter )
77-             # Handle different cases for value conversion 
78-             if  isinstance (obj , np .generic ):
79-                 # obj is a numpy scalar (0-dimensional) 
80-                 if  isinstance (value , (int , float )):
81-                     # Convert Python scalar to numpy scalar 
82-                     return  np .array (value , dtype = obj .dtype )
83-                 elif  isinstance (value , np .ndarray ):
84-                     # value is a numpy array, convert to scalar if needed 
85-                     if  value .ndim  ==  0 :
86-                         return  np .array (value .item (), dtype = obj .dtype )
87-                     elif  value .ndim  ==  1  and  value .size  ==  1 :
88-                         return  np .array (value .item (), dtype = obj .dtype )
89-                     else :
90-                         return  value .astype (obj .dtype ).reshape (obj .shape )
45+     def  _reconstruct_value (obj ):
46+         value  =  next (value_iter )
47+         # Handle different cases for value conversion 
48+         if  isinstance (obj , np .generic ):
49+             # obj is a numpy scalar (0-dimensional) 
50+             if  isinstance (value , (int , float )):
51+                 # Convert Python scalar to numpy scalar 
52+                 return  np .array (value , dtype = obj .dtype )
53+             elif  isinstance (value , np .ndarray ):
54+                 # value is a numpy array, convert to scalar if needed 
55+                 if  value .ndim  ==  0 :
56+                     return  np .array (value .item (), dtype = obj .dtype )
57+                 elif  value .ndim  ==  1  and  value .size  ==  1 :
58+                     return  np .array (value .item (), dtype = obj .dtype )
9159                else :
92-                     return  np .array (value , dtype = obj .dtype )
93-             elif  isinstance (obj , np .ndarray ):
94-                 # obj is a numpy array 
95-                 if  isinstance (value , np .ndarray ):
9660                    return  value .astype (obj .dtype ).reshape (obj .shape )
97-                 else :
98-                     return  np .array (value , dtype = obj .dtype ).reshape (obj .shape )
9961            else :
100-                 return  value 
62+                 return  np .array (value , dtype = obj .dtype )
63+         elif  isinstance (obj , np .ndarray ):
64+             # obj is a numpy array 
65+             if  isinstance (value , np .ndarray ):
66+                 return  value .astype (obj .dtype ).reshape (obj .shape )
67+             else :
68+                 return  np .array (value , dtype = obj .dtype ).reshape (obj .shape )
69+         else :
70+             return  value 
10171
102-     return  _reconstruct ( structure )
72+     return  tree . map_structure ( _reconstruct_value ,  structure )
10373
10474
10575def  _restore_legacy_format (
@@ -327,7 +297,7 @@ def __init__(
327297            save_decision_policy = save_decision_policy ,
328298        )
329299        # Ensure directory exists (only needed on one process in multi-host) 
330-         if  backend . get_process_index () ==  0 :
300+         if  process_id () ==  0 :
331301            os .makedirs (directory , exist_ok = True )
332302
333303        # Create the CheckpointManager 
@@ -380,38 +350,27 @@ def _save_checkpoint(self, step, logs=None):
380350        state_tree  =  _get_state_tree (self .model )
381351
382352        if  state_tree  is  None :
383-             if  self .verbose  >  0 :
384-                 print_msg (
385-                     "OrbaxCheckpoint: Skipping save due to state tree error" 
386-                 )
387-             return 
388- 
389-         # Flatten the trainable variables values for cross-model compatibility 
390-         trainable_values  =  _flatten_state_tree_values (
391-             state_tree ["trainable_variables" ]
392-         )
393- 
394-         # Save optimizer and metrics state if requested 
395-         optimizer_values  =  None 
396-         if  self .save_optimizer_state  and  "optimizer_variables"  in  state_tree :
397-             optimizer_values  =  _flatten_state_tree_values (
398-                 state_tree ["optimizer_variables" ]
399-             )
400- 
401-         metrics_values  =  None 
402-         if  self .save_metrics_state  and  "metrics_variables"  in  state_tree :
403-             metrics_values  =  _flatten_state_tree_values (
404-                 state_tree ["metrics_variables" ]
353+             raise  RuntimeError (
354+                 "OrbaxCheckpoint: Failed to get model state tree. " 
355+                 "The model may not be properly built or may have no " 
356+                 "savable state." 
405357            )
406358
359+         # Save the nested state structures directly (preserving layer 
360+         # names and structure) 
407361        composite_state  =  {
408-             "model_weights " : trainable_values ,
362+             "trainable_variables " : state_tree [ "trainable_variables" ] ,
409363        }
410364
411-         if  optimizer_values  is  not   None :
412-             composite_state ["optimizer_state" ] =  optimizer_values 
413-         if  metrics_values  is  not   None :
414-             composite_state ["metrics_variables" ] =  metrics_values 
365+         if  self .save_optimizer_state  and  "optimizer_variables"  in  state_tree :
366+             composite_state ["optimizer_variables" ] =  state_tree [
367+                 "optimizer_variables" 
368+             ]
369+ 
370+         if  self .save_metrics_state  and  "metrics_variables"  in  state_tree :
371+             composite_state ["metrics_variables" ] =  state_tree [
372+                 "metrics_variables" 
373+             ]
415374
416375        # Add metadata if specified 
417376        if  self .save_metadata  is  not   None :
@@ -435,7 +394,7 @@ def _save_checkpoint(self, step, logs=None):
435394
436395        # --- Save Logic --- 
437396        # Only save on the primary process (rank 0) in distributed setups 
438-         is_primary_host  =  backend . get_process_index () ==  0 
397+         is_primary_host  =  process_id () ==  0 
439398
440399        if  is_primary_host :
441400            if  self .verbose  >  0 :
@@ -540,7 +499,7 @@ def load_checkpoint(self, step, model=None):
540499            data iterator state dict if available, None otherwise. 
541500        """ 
542501        # In distributed training, only load on primary process 
543-         if  backend . get_process_index () !=  0 :
502+         if  process_id () !=  0 :
544503            return  True   # Return True to indicate no error, but no loading 
545504
546505        if  self .verbose  >  0 :
@@ -594,11 +553,18 @@ def _restore_model_state(self, checkpoint_data, model=None):
594553        """ 
595554        target_model  =  model  if  model  is  not   None  else  self .model 
596555
597-         # Check if this is the new flattened format 
598-         if  "model_weights"  in  checkpoint_data  and  isinstance (
556+         # Check if this is the new nested structure format 
557+         if  "trainable_variables"  in  checkpoint_data  and  isinstance (
558+             checkpoint_data ["trainable_variables" ], dict 
559+         ):
560+             # New format: nested structures 
561+             return  self ._restore_from_nested_structures (
562+                 checkpoint_data , target_model 
563+             )
564+         elif  "model_weights"  in  checkpoint_data  and  isinstance (
599565            checkpoint_data ["model_weights" ], list 
600566        ):
601-             # New  format: flattened values 
567+             # Old  format: flattened values (for backward compatibility)  
602568            return  self ._restore_from_flattened_values (
603569                checkpoint_data , target_model 
604570            )
@@ -617,8 +583,109 @@ def _restore_model_state(self, checkpoint_data, model=None):
617583            )
618584            return  True 
619585
586+     def  _restore_from_nested_structures (self , checkpoint_data , target_model ):
587+         """Restore from the new nested structures format.""" 
588+         # Ensure the target model is built so it has variables 
589+         if  len (target_model .trainable_variables ) ==  0 :
590+             try :
591+                 # Try to build the model by doing a dummy forward pass 
592+                 if  (
593+                     hasattr (target_model , "input_shape" )
594+                     and  target_model .input_shape  is  not   None 
595+                 ):
596+                     dummy_input_shape  =  target_model .input_shape 
597+                     if  dummy_input_shape [0 ] is  None :  # Batch dimension is None 
598+                         dummy_input  =  np .zeros ((1 ,) +  dummy_input_shape [1 :])
599+                     else :
600+                         dummy_input  =  np .zeros (dummy_input_shape )
601+                     target_model (dummy_input )
602+             except  Exception :
603+                 # If dummy forward pass fails, try build 
604+                 try :
605+                     if  (
606+                         hasattr (target_model , "input_shape" )
607+                         and  target_model .input_shape  is  not   None 
608+                     ):
609+                         build_shape  =  target_model .input_shape 
610+                         if  (
611+                             isinstance (build_shape , (list , tuple ))
612+                             and  len (build_shape ) >  1 
613+                             and  build_shape [0 ] is  None 
614+                         ):
615+                             build_shape  =  build_shape [1 :]
616+                         target_model .build (build_shape )
617+                 except  Exception :
618+                     # If building fails, continue anyway 
619+                     pass 
620+ 
621+         # Prepare the state tree to restore 
622+         reconstructed_state  =  {}
623+ 
624+         # Restore trainable variables 
625+         if  "trainable_variables"  in  checkpoint_data :
626+             reconstructed_state ["trainable_variables" ] =  checkpoint_data [
627+                 "trainable_variables" 
628+             ]
629+ 
630+         # Restore optimizer variables if available and model has optimizer 
631+         if  (
632+             "optimizer_variables"  in  checkpoint_data 
633+             and  self .save_optimizer_state 
634+             and  hasattr (target_model , "optimizer" )
635+             and  target_model .optimizer  is  not   None 
636+         ):
637+             reconstructed_state ["optimizer_variables" ] =  checkpoint_data [
638+                 "optimizer_variables" 
639+             ]
640+ 
641+         # Restore metrics variables if available 
642+         if  "metrics_variables"  in  checkpoint_data  and  self .save_metrics_state :
643+             reconstructed_state ["metrics_variables" ] =  checkpoint_data [
644+                 "metrics_variables" 
645+             ]
646+ 
647+         # Use set_state_tree to restore the state 
648+         target_model .set_state_tree (reconstructed_state )
649+ 
650+         if  self .verbose  >  0 :
651+             print_msg ("OrbaxCheckpoint: Successfully restored model state" )
652+         return  True 
653+ 
620654    def  _restore_from_flattened_values (self , checkpoint_data , target_model ):
621655        """Restore from the new flattened values format.""" 
656+         # Ensure the target model is built so it has variables 
657+         if  len (target_model .trainable_variables ) ==  0 :
658+             try :
659+                 # Try to build the model by doing a dummy forward pass 
660+                 if  (
661+                     hasattr (target_model , "input_shape" )
662+                     and  target_model .input_shape  is  not   None 
663+                 ):
664+                     dummy_input_shape  =  target_model .input_shape 
665+                     if  dummy_input_shape [0 ] is  None :  # Batch dimension is None 
666+                         dummy_input  =  np .zeros ((1 ,) +  dummy_input_shape [1 :])
667+                     else :
668+                         dummy_input  =  np .zeros (dummy_input_shape )
669+                     target_model (dummy_input )
670+             except  Exception :
671+                 # If dummy forward pass fails, try build 
672+                 try :
673+                     if  (
674+                         hasattr (target_model , "input_shape" )
675+                         and  target_model .input_shape  is  not   None 
676+                     ):
677+                         build_shape  =  target_model .input_shape 
678+                         if  (
679+                             isinstance (build_shape , (list , tuple ))
680+                             and  len (build_shape ) >  1 
681+                             and  build_shape [0 ] is  None 
682+                         ):
683+                             build_shape  =  build_shape [1 :]
684+                         target_model .build (build_shape )
685+                 except  Exception :
686+                     # If building fails, continue anyway 
687+                     pass 
688+ 
622689        # Get the target model's state tree structure (without convert_scalars) 
623690        target_state_tree  =  target_model .get_state_tree (
624691            value_format = "numpy_array" 
0 commit comments