diff --git a/seq2seq/contrib/seq2seq/helper.py b/seq2seq/contrib/seq2seq/helper.py index 977d0ab9..044311f2 100644 --- a/seq2seq/contrib/seq2seq/helper.py +++ b/seq2seq/contrib/seq2seq/helper.py @@ -384,10 +384,8 @@ def initialize(self, name=None): def sample(self, time, outputs, state, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): - sampler = bernoulli.Bernoulli(probs=self._sampling_probability) - return math_ops.cast( - sampler.sample(sample_shape=self.batch_size, seed=self._seed), - dtypes.bool) + sampler = bernoulli.Bernoulli(probs=self._sampling_probability, dtype=dtypes.int32) + return sampler.sample(sample_shape=self.batch_size, seed=self._seed) def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", @@ -419,7 +417,7 @@ def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): if self._next_input_layer is None: return array_ops.where( - sample_ids, maybe_concatenate_auxiliary_inputs(outputs), + math_ops.cast(sample_ids, dtypes.bool), maybe_concatenate_auxiliary_inputs(outputs).cell_output, base_next_inputs) where_sampling = math_ops.cast( diff --git a/seq2seq/test/hooks_test.py b/seq2seq/test/hooks_test.py index dedc6594..1537990f 100644 --- a/seq2seq/test/hooks_test.py +++ b/seq2seq/test/hooks_test.py @@ -39,16 +39,16 @@ class TestPrintModelAnalysisHook(tf.test.TestCase): def test_begin(self): model_dir = tempfile.mkdtemp() outfile = tempfile.NamedTemporaryFile() - tf.get_variable("weigths", [128, 128]) + tf.get_variable("weights", [128, 128]) hook = hooks.PrintModelAnalysisHook( params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig()) hook.begin() with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file: - file_contents = file.read().strip() + file_contents = tf.compat.as_text(file.read()).strip() - self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" - " weigths (128x128, 16.38k/16.38k params)") + self.assertEqual(file_contents, "_TFProfRoot (--/16.38k params)\n" + " weights (128x128, 16.38k/16.38k params)") outfile.close() @@ -94,7 +94,7 @@ def test_sampling(self): outfile = os.path.join(self.sample_dir, "samples_000000.txt") with open(outfile, "rb") as readfile: self.assertIn("Prediction followed by Target @ Step 0", - readfile.read().decode("utf-8")) + tf.compat.as_text(readfile.read())) # Should not trigger for step 9 sess.run(tf.assign(global_step, 9)) @@ -108,7 +108,7 @@ def test_sampling(self): outfile = os.path.join(self.sample_dir, "samples_000010.txt") with open(outfile, "rb") as readfile: self.assertIn("Prediction followed by Target @ Step 10", - readfile.read().decode("utf-8")) + tf.compat.as_text(readfile.read())) class TestMetadataCaptureHook(tf.test.TestCase): @@ -125,7 +125,7 @@ def tearDown(self): def test_capture(self): global_step = tf.contrib.framework.get_or_create_global_step() # Some test computation - some_weights = tf.get_variable("weigths", [2, 128]) + some_weights = tf.get_variable("weights", [2, 128]) computation = tf.nn.softmax(some_weights) hook = hooks.MetadataCaptureHook(