Skip to content

Allow access to integer parameters in shape constraints #42

@uri-granta

Description

@uri-granta

Feature request

Allow access to integer parameters in shape constraints.

Motivation

Defining shapes of synthetic functions in trieste. E.g. something like:

@check_shapes(
	"a: [batch..., $d]",
	"return: [batch..., 1]"
)
def levy(x: TensorType, d: int) -> TensorType:
        ...

Proposal

There are at least two ways to handle this.

One is to allow references to int parameters (similar to the proposal in #6) using a syntax like the one above.

Another is to support value constraints for ints, maybe something like:

@check_shapes(
   "a: [batch..., dim]",
   "d: dim",
   "return: [batch..., 1]"
)
def levy(x: TensorType, d: int) -> TensorType:
       ...

Workarounds

One workaround is to move the dynamic check inside the function, but this splits the spec and doesn't support docstring rewriting.

@check_shapes("return: [batch..., 1]")
def levy(x: TensorType, d: int) -> TensorType:
        check_shape(x, "[batch..., d]")

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions