Skip to content

Commit 4a114c9

Browse files
committed
passing tags for use when loading TF saved_model
Some saved_models have more than one tag, e.g., MobileBERT SQuAD 1.1 checkpoints. Need a way to specify tag
1 parent aeb3e8e commit 4a114c9

File tree

2 files changed

+9
-8
lines changed
  • coremltools/converters/mil/frontend

2 files changed

+9
-8
lines changed

coremltools/converters/mil/frontend/tensorflow/load.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def load(self):
6262
logging.info("Loading TensorFlow model '{}'".format(self.model))
6363
outputs = self.kwargs.get("outputs", None)
6464
output_names = get_output_names(outputs)
65-
self._graph_def = self._graph_def_from_model(output_names)
65+
tags = self.kwargs.get("tags", None)
66+
self._graph_def = self._graph_def_from_model(output_names, tags)
6667

6768
if self._graph_def is not None and len(self._graph_def.node) == 0:
6869
msg = "tf.Graph should have at least 1 node, Got empty graph."
@@ -88,7 +89,7 @@ def load(self):
8889
return program
8990

9091
# @abstractmethod
91-
def _graph_def_from_model(self, output_names=None):
92+
def _graph_def_from_model(self, output_names=None, tags=None):
9293
"""Load TensorFlow model into GraphDef. Overwrite for different TF versions."""
9394
pass
9495

@@ -139,7 +140,7 @@ def __init__(self, model, debug=False, **kwargs):
139140
"""
140141
TFLoader.__init__(self, model, debug, **kwargs)
141142

142-
def _graph_def_from_model(self, output_names=None):
143+
def _graph_def_from_model(self, output_names=None, tags=None):
143144
"""Overwrites TFLoader._graph_def_from_model()"""
144145
msg = "Expected model format: [tf.Graph | .pb | SavedModel | tf.keras.Model | .h5], got {}"
145146
if isinstance(self.model, tf.Graph) and hasattr(self.model, "as_graph_def"):
@@ -170,7 +171,7 @@ def _graph_def_from_model(self, output_names=None):
170171
graph_def = self._from_tf_keras_model(self.model)
171172
return self.extract_sub_graph(graph_def, output_names)
172173
elif os.path.isdir(str(self.model)):
173-
graph_def = self._from_saved_model(self.model)
174+
graph_def = self._from_saved_model(self.model, tags=tags)
174175
return self.extract_sub_graph(graph_def, output_names)
175176
else:
176177
raise NotImplementedError(msg.format(self.model))

coremltools/converters/mil/frontend/tensorflow2/load.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, model, debug=False, **kwargs):
9696
fuse_dilation_conv,
9797
]
9898

99-
def _get_concrete_functions_and_graph_def(self):
99+
def _get_concrete_functions_and_graph_def(self, tags=None):
100100
msg = (
101101
"Expected model format: [SavedModel | [concrete_function] | "
102102
"tf.keras.Model | .h5], got {}"
@@ -120,7 +120,7 @@ def _get_concrete_functions_and_graph_def(self):
120120
and (self.model.endswith(".h5") or self.model.endswith(".hdf5")):
121121
cfs = self._concrete_fn_from_tf_keras_or_h5(self.model)
122122
elif _os_path.isdir(self.model):
123-
saved_model = _tf.saved_model.load(self.model)
123+
saved_model = _tf.saved_model.load(self.model, tags=tags)
124124
sv = saved_model.signatures.values()
125125
cfs = sv if isinstance(sv, list) else list(sv)
126126
else:
@@ -132,9 +132,9 @@ def _get_concrete_functions_and_graph_def(self):
132132

133133
return cfs, graph_def
134134

135-
def _graph_def_from_model(self, output_names=None):
135+
def _graph_def_from_model(self, output_names=None, tags=None):
136136
"""Overwrites TFLoader._graph_def_from_model()"""
137-
_, graph_def = self._get_concrete_functions_and_graph_def()
137+
_, graph_def = self._get_concrete_functions_and_graph_def(tags=tags)
138138
return self.extract_sub_graph(graph_def, output_names)
139139

140140
def _tf_ssa_from_graph_def(self, fn_name="main"):

0 commit comments

Comments
 (0)