import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes
class A:
def f(self, x):
return x + 2
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def g(self, x):
return x + 2
a = A()
specs = [tf.TensorSpec(shape=None, dtype=tf.int32)]
f = tf.function(a.f)
f2 = tf.function(a.f, input_signature=specs)
g = tf.function(a.g)
g2 = tf.function(a.g, input_signature=specs)
x = tf.constant(7)
f(x) # Good
f2(x) # Good
g(x) # Good
g2(x) # Breaks...