Skip to content

TensorFlow cannot compile shape-checked method with explicit input signature. #1

@jesnie

Description

@jesnie
import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes


class A:

    def f(self, x):
        return x + 2

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def g(self, x):
        return x + 2


a = A()
specs = [tf.TensorSpec(shape=None, dtype=tf.int32)]
f = tf.function(a.f)
f2 = tf.function(a.f, input_signature=specs)
g = tf.function(a.g)
g2 = tf.function(a.g, input_signature=specs)
x = tf.constant(7)

f(x)  # Good
f2(x)  # Good
g(x)  # Good
g2(x)  # Breaks...

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions