Skip to content

Commit a9cda47

Browse files
committed
correcting the int case
1 parent ce56ca6 commit a9cda47

File tree

1 file changed

+11
-3
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+11
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/cat.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,17 @@ def unify_and_concat_trt_tensors(
4646
elif isinstance(x, int) and not has_dynamic and not force_trt_output:
4747
t = x # pure static path
4848
else:
49-
t = ctx.net.add_constant((x.numel(),), np.array([x], dtype=np.int32))
50-
set_layer_name(t, target, f"{name}_dim{i}_const")
51-
t = t.get_output(0)
49+
if isinstance(x, int):
50+
# wrap int into 1-element np array
51+
const_arr = np.array([x], dtype=np.int32)
52+
shape = (1,)
53+
else:
54+
const_arr = np.array(x, dtype=np.int32)
55+
shape = (x.numel(),)
56+
57+
layer = ctx.net.add_constant(shape, const_arr)
58+
set_layer_name(layer, target, f"{name}_dim{i}_const")
59+
t = layer.get_output(0)
5260

5361
# optional cast
5462
if cast_dtype and isinstance(t, TRTTensor):

0 commit comments

Comments
 (0)