@@ -62,7 +62,8 @@ def load(self):
6262 logging .info ("Loading TensorFlow model '{}'" .format (self .model ))
6363 outputs = self .kwargs .get ("outputs" , None )
6464 output_names = get_output_names (outputs )
65- self ._graph_def = self ._graph_def_from_model (output_names )
65+ tags = self .kwargs .get ("tags" , None )
66+ self ._graph_def = self ._graph_def_from_model (output_names , tags )
6667
6768 if self ._graph_def is not None and len (self ._graph_def .node ) == 0 :
6869 msg = "tf.Graph should have at least 1 node, Got empty graph."
@@ -88,7 +89,7 @@ def load(self):
8889 return program
8990
9091 # @abstractmethod
91- def _graph_def_from_model (self , output_names = None ):
92+ def _graph_def_from_model (self , output_names = None , tags = None ):
9293 """Load TensorFlow model into GraphDef. Overwrite for different TF versions."""
9394 pass
9495
@@ -139,7 +140,7 @@ def __init__(self, model, debug=False, **kwargs):
139140 """
140141 TFLoader .__init__ (self , model , debug , ** kwargs )
141142
142- def _graph_def_from_model (self , output_names = None ):
143+ def _graph_def_from_model (self , output_names = None , tags = None ):
143144 """Overwrites TFLoader._graph_def_from_model()"""
144145 msg = "Expected model format: [tf.Graph | .pb | SavedModel | tf.keras.Model | .h5], got {}"
145146 if isinstance (self .model , tf .Graph ) and hasattr (self .model , "as_graph_def" ):
@@ -170,7 +171,7 @@ def _graph_def_from_model(self, output_names=None):
170171 graph_def = self ._from_tf_keras_model (self .model )
171172 return self .extract_sub_graph (graph_def , output_names )
172173 elif os .path .isdir (str (self .model )):
173- graph_def = self ._from_saved_model (self .model )
174+ graph_def = self ._from_saved_model (self .model , tags = tags )
174175 return self .extract_sub_graph (graph_def , output_names )
175176 else :
176177 raise NotImplementedError (msg .format (self .model ))
0 commit comments