@@ -867,18 +867,18 @@ def dict_pack(example):
867867
868868 def _standardize (self , dataset , keys ):
869869 """Force dataset structure into a tuple of Tensors."""
870- shapes = tf .compat . v1 . data .get_output_shapes (dataset )
870+ shapes = tf .data .get_output_shapes (dataset )
871871
872872 if isinstance (shapes , dict ):
873873 keys = keys or tuple (shapes .keys ())
874874 dataset = dataset .map (lambda x : tuple (x [k ] for k in keys ))
875- shapes = tf .compat . v1 . data .get_output_shapes (dataset )
875+ shapes = tf .data .get_output_shapes (dataset )
876876
877877 if not all (isinstance (i , tf .TensorShape ) for i in shapes ):
878878 # Internally this class expects tuples of Tensors, even for the degenerate
879879 # case of a single sequence.
880880 dataset = dataset .map (lambda x : (x ,))
881- shapes = tf .compat . v1 . data .get_output_shapes (dataset )
881+ shapes = tf .data .get_output_shapes (dataset )
882882
883883 for s in shapes :
884884 if not s .is_compatible_with (tf .TensorShape ([None ])):
@@ -890,7 +890,7 @@ def _standardize(self, dataset, keys):
890890 if self ._chop_long_sequences and len (shapes ) != 1 :
891891 raise ValueError ("chop_long_sequences expects a single sequence dataset." )
892892
893- token_types = tf .compat . v1 . data .get_output_types (dataset )
893+ token_types = tf .data .get_output_types (dataset )
894894 if len (set (token_types )) > 1 :
895895 raise ValueError ("Inconsistent dtypes: {}" .format (token_types ))
896896
0 commit comments