@@ -453,19 +453,18 @@ public Output<?>[] whileLoop(
453453 synchronized SaverDef saverDef () {
454454 if (saverDef == null ) {
455455 // Check to see if this graph has a restore operation
456- if (operation ("save/restore_all" ) == null ) {
456+ if (operation (SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP ) == null ) {
457457 // No saver, create one by mutating the graph
458458 saverDef = addVariableSaver (this );
459459 } else {
460460 // This graph already has saving/restoring operations,
461- // regenerate SaverDef without mutating. The names mirror
462- // the python implementation for compatibility.
463- // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
464- saverDef = SaverDef .newBuilder ()
465- .setFilenameTensorName ("save/filename" )
466- .setSaveTensorName ("save/control_dependency" )
467- .setRestoreOpName ("save/restore_all" )
468- .build ();
461+ // regenerate SaverDef without mutating.
462+ saverDef =
463+ SaverDef .newBuilder ()
464+ .setFilenameTensorName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_FILENAME_OP + ":0" )
465+ .setSaveTensorName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_SAVE_OP )
466+ .setRestoreOpName (SAVER_DEF_SCOPE + "/" + SAVER_DEF_RESTORE_OP )
467+ .build ();
469468 }
470469 }
471470 return saverDef ;
@@ -570,6 +569,13 @@ public void remove() {
570569 private int position ;
571570 }
572571
572+ // These names mirror the python implementation, to reduce the risk of incompatibility.
573+ // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
574+ private static final String SAVER_DEF_SCOPE = "save" ;
575+ private static final String SAVER_DEF_FILENAME_OP = "filename" ;
576+ private static final String SAVER_DEF_SAVE_OP = "control_dependency" ;
577+ private static final String SAVER_DEF_RESTORE_OP = "restore_all" ;
578+
573579 private static TF_Graph allocate () {
574580 return TF_NewGraph ();
575581 }
@@ -797,7 +803,7 @@ private static Object[] whileLoop(
797803 }
798804
799805 private static SaverDef addVariableSaver (Graph graph ) {
800- Ops tf = Ops .create (graph ).withSubScope ("save" );
806+ Ops tf = Ops .create (graph ).withSubScope (SAVER_DEF_SCOPE );
801807
802808 List <String > varNames = new ArrayList <>();
803809 List <Operand <?>> varOutputs = new ArrayList <>();
@@ -812,36 +818,35 @@ private static SaverDef addVariableSaver(Graph graph) {
812818 }
813819 }
814820
815- // FIXME Need an easier way to initialize an NdArray from a list
816- String [] tmp = new String [varNames .size ()];
817- Constant <TString > varNamesTensor = tf .constant (StdArrays .ndCopyOf (varNames .toArray (tmp )));
818- Operand <TString > varSlices = tf .zerosLike (varNamesTensor );
819-
820- Placeholder <TString > saveFilename = tf .withName ("filename" ).placeholder (TString .class );
821- Save saveVariables = tf .train .save (
822- saveFilename ,
823- varNamesTensor ,
824- varSlices ,
825- varOutputs
826- );
827- Identity <TString > id = tf .withControlDependencies (Arrays .asList (saveFilename ,saveVariables ))
828- .withName ("control_dependency" ).identity (saveFilename );
829- Restore restoreVariables = tf .train .restore (
830- saveFilename ,
831- varNamesTensor ,
832- varSlices ,
833- varTypes
834- );
835- List <Op > restoreOps = new ArrayList <>(varOutputs .size ());
836- for (int i = 0 ; i < varOutputs .size (); ++i ) {
837- restoreOps .add (tf .assign (varOutputs .get (i ), (Operand ) restoreVariables .tensors ().get (i )));
821+ Placeholder <TString > filename = tf .withName (SAVER_DEF_FILENAME_OP ).placeholder (TString .class );
822+ Identity <TString > save = null ;
823+ NoOp restore = null ;
824+
825+ if (varNames .isEmpty ()) {
826+ save = tf .withName (SAVER_DEF_SAVE_OP ).identity (filename );
827+ restore = tf .withName (SAVER_DEF_RESTORE_OP ).noOp ();
828+ } else {
829+ String [] tmp = new String [varNames .size ()];
830+ Constant <TString > varNamesTensor = tf .constant (StdArrays .ndCopyOf (varNames .toArray (tmp )));
831+ Operand <TString > varSlices = tf .zerosLike (varNamesTensor );
832+ Save saveVars = tf .train .save (filename , varNamesTensor , varSlices , varOutputs );
833+ List <Op > saveDeps = Arrays .asList (filename , saveVars );
834+ Restore restoreVars = tf .train .restore (filename , varNamesTensor , varSlices , varTypes );
835+ List <Op > restoreDeps = new ArrayList <>(varOutputs .size ());
836+ for (int i = 0 ; i < varOutputs .size (); ++i ) {
837+ restoreDeps .add (tf .assign (varOutputs .get (i ), (Operand ) restoreVars .tensors ().get (i )));
838+ }
839+ save = tf .withControlDependencies (saveDeps ).withName (SAVER_DEF_SAVE_OP ).identity (filename );
840+ restore = tf .withControlDependencies (restoreDeps ).withName (SAVER_DEF_RESTORE_OP ).noOp ();
838841 }
839- NoOp restoreAll = tf .withControlDependencies (restoreOps ).withName ("restore_all" ).noOp ();
840842
843+ // 'Filename' must be the name of a tensor (i.e. with output index)
844+ // 'Save' must be an operation name, even if the field name is confusing (see SaverDef doc)
845+ // 'Restore' must be an operation name
841846 return SaverDef .newBuilder ()
842- .setFilenameTensorName (saveFilename . op ().name ())
843- .setSaveTensorName (id .op ().name ())
844- .setRestoreOpName (restoreAll .op ().name ())
847+ .setFilenameTensorName (filename . output ().name ())
848+ .setSaveTensorName (save .op ().name ())
849+ .setRestoreOpName (restore .op ().name ())
845850 .build ();
846851 }
847852
0 commit comments