|
32 | 32 | from tensor2tensor.utils import mlperf_log |
33 | 33 |
|
34 | 34 | import tensorflow as tf |
35 | | -import tf_slim as slim |
36 | | -from tensorflow.contrib.tpu.python.tpu import tpu_config |
| 35 | +# pylint: disable=g-import-not-at-top |
| 36 | +try: |
| 37 | + from tensorflow.contrib.tpu.python.tpu import tpu_config |
| 38 | +except ImportError: |
| 39 | + # TF 2.0 doesn't ship with contrib. |
| 40 | + tpu_config = None |
| 41 | +# pylint: enable=g-import-not-at-top |
37 | 42 |
|
38 | 43 |
|
39 | 44 |
|
@@ -199,7 +204,7 @@ class Problem(object): |
199 | 204 | - Mutate defaults as needed |
200 | 205 | * example_reading_spec |
201 | 206 | - Specify the names and types of the features on disk. |
202 | | - - Specify slim.tfexample_decoder |
| 207 | + - Specify tf.contrib.slim.tfexample_decoder |
203 | 208 | * preprocess_example(example, mode, hparams) |
204 | 209 | - Preprocess the example feature dict from feature name to Tensor or |
205 | 210 | SparseTensor. |
@@ -643,7 +648,7 @@ def dataset(self, |
643 | 648 |
|
644 | 649 | data_filepattern = self.filepattern(data_dir, dataset_split, shard=shard) |
645 | 650 | tf.logging.info("Reading data files from %s", data_filepattern) |
646 | | - data_files = sorted(slim.parallel_reader.get_data_files( |
| 651 | + data_files = sorted(tf.contrib.slim.parallel_reader.get_data_files( |
647 | 652 | data_filepattern)) |
648 | 653 |
|
649 | 654 | # Functions used in dataset transforms below. `filenames` can be either a |
@@ -706,12 +711,12 @@ def decode_example(self, serialized_example): |
706 | 711 | data_fields["batch_prediction_key"] = tf.FixedLenFeature([1], tf.int64, 0) |
707 | 712 | if data_items_to_decoders is None: |
708 | 713 | data_items_to_decoders = { |
709 | | - field: slim.tfexample_decoder.Tensor(field) |
| 714 | + field: tf.contrib.slim.tfexample_decoder.Tensor(field) |
710 | 715 | for field in data_fields |
711 | 716 | } |
712 | 717 |
|
713 | | - decoder = slim.tfexample_decoder.TFExampleDecoder(data_fields, |
714 | | - data_items_to_decoders) |
| 718 | + decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder( |
| 719 | + data_fields, data_items_to_decoders) |
715 | 720 |
|
716 | 721 | decode_items = list(sorted(data_items_to_decoders)) |
717 | 722 | decoded = decoder.decode(serialized_example, items=decode_items) |
|
0 commit comments