1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15-
1615r"""Decode from trained T2T models.
1716
1817This binary performs inference using the Estimator API.
@@ -82,9 +81,13 @@ def create_decode_hparams():
8281
8382def decode (estimator , hparams , decode_hp ):
8483 if FLAGS .decode_interactive :
84+ if estimator .config .use_tpu :
85+ raise ValueError ("TPU can only decode from dataset." )
8586 decoding .decode_interactively (estimator , hparams , decode_hp ,
8687 checkpoint_path = FLAGS .checkpoint_path )
8788 elif FLAGS .decode_from_file :
89+ if estimator .config .use_tpu :
90+ raise ValueError ("TPU can only decode from dataset." )
8891 decoding .decode_from_file (estimator , FLAGS .decode_from_file , hparams ,
8992 decode_hp , FLAGS .decode_to_file ,
9093 checkpoint_path = FLAGS .checkpoint_path )
@@ -160,7 +163,6 @@ def main(_):
160163 tf .logging .set_verbosity (tf .logging .INFO )
161164 trainer_lib .set_random_seed (FLAGS .random_seed )
162165 usr_dir .import_usr_dir (FLAGS .t2t_usr_dir )
163- FLAGS .use_tpu = False # decoding not supported on TPU
164166
165167 if FLAGS .score_file :
166168 filename = os .path .expanduser (FLAGS .score_file )
@@ -183,7 +185,7 @@ def main(_):
183185 hp ,
184186 t2t_trainer .create_run_config (hp ),
185187 decode_hparams = decode_hp ,
186- use_tpu = False )
188+ use_tpu = FLAGS . use_tpu )
187189
188190 decode (estimator , hp , decode_hp )
189191
0 commit comments