Skip to content

Commit 94022d4

Browse files
author
GitHub Actions
committed
Add int8/uint8/int16/uint16 support for Max and Min operators
Fixes #26382 This commit adds support for int8, uint8, int16, and uint16 data types to the Max and Min operators for opset 12 and later, bringing the implementation into compliance with the ONNX specification. Changes: - Updated type registration for Max and Min operators (opset 12+) - Updated type dispatchers in Min_8::Compute and Max_8::Compute - Added comprehensive unit tests for all new data types The existing Eigen-based implementation already handles all numeric types generically, so only type registration and dispatching needed updates. This change is fully backward compatible.
1 parent 949c0de commit 94022d4

File tree

2 files changed

+144
-4
lines changed

2 files changed

+144
-4
lines changed

onnxruntime/core/providers/cpu/math/element_wise_ops.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ namespace op_kernel_type_control {
2121
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0, float, double);
2222

2323
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0,
24-
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
24+
float, double, MLFloat16, int8_t, int16_t, int32_t, uint32_t,
25+
int64_t, uint8_t, uint16_t, uint64_t);
2526
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0,
2627
int32_t, int64_t);
2728

2829
// Min
2930
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0, float, double);
3031
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0,
31-
float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t);
32+
float, double, MLFloat16, int8_t, int16_t, int32_t, uint32_t,
33+
int64_t, uint8_t, uint16_t, uint64_t);
3234
ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0,
3335
int32_t, int64_t);
3436

@@ -922,7 +924,8 @@ Status Min_8::Compute(OpKernelContext* context) const {
922924
return MinMaxMLFloat16<true>(*this, context);
923925
break;
924926
default:
925-
utils::MLTypeCallDispatcher<float, double, int32_t, uint32_t, int64_t, uint64_t>
927+
utils::MLTypeCallDispatcher<float, double, int8_t, int16_t, int32_t, uint32_t,
928+
int64_t, uint8_t, uint16_t, uint64_t>
926929
t_disp(dt_type);
927930
return t_disp.InvokeRet<Status, ComputeImpl>(*this, context);
928931
}
@@ -988,7 +991,8 @@ Status Max_8::Compute(OpKernelContext* context) const {
988991
return MinMaxMLFloat16<false>(*this, context);
989992
break;
990993
default:
991-
utils::MLTypeCallDispatcher<float, double, int32_t, uint32_t, int64_t, uint64_t>
994+
utils::MLTypeCallDispatcher<float, double, int8_t, int16_t, int32_t, uint32_t,
995+
int64_t, uint8_t, uint16_t, uint64_t>
992996
t_disp(dt_type);
993997
return t_disp.InvokeRet<Status, ComputeImpl>(*this, context);
994998
}

onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3699,6 +3699,142 @@ TEST(MathOpTest, Equal_multidirectional_broadcastAB_bool) {
36993699
test.Run();
37003700
}
37013701

3702+
TEST(MathOpTest, Max_12_Int8) {
3703+
OpTester test("Max", 12);
3704+
test.AddInput<int8_t>("data_0", {1, 3},
3705+
{1, 2, 3});
3706+
test.AddInput<int8_t>("data_2", {3, 3},
3707+
{10, 20, 30,
3708+
40, 50, 60,
3709+
70, 80, 90});
3710+
test.AddInput<int8_t>("data_1", {3, 1},
3711+
{-1, -2, 127});
3712+
test.AddOutput<int8_t>("max", {3, 3},
3713+
{10, 20, 30,
3714+
40, 50, 60,
3715+
127, 127, 127});
3716+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
3717+
}
3718+
3719+
TEST(MathOpTest, Max_12_UInt8) {
3720+
OpTester test("Max", 12);
3721+
test.AddInput<uint8_t>("data_0", {1, 3},
3722+
{1, 20, 30});
3723+
test.AddInput<uint8_t>("data_2", {3, 3},
3724+
{10, 20, 30,
3725+
40, 50, 60,
3726+
70, 80, 90});
3727+
test.AddInput<uint8_t>("data_1", {3, 1},
3728+
{100, 20, 30});
3729+
test.AddOutput<uint8_t>("max", {3, 3},
3730+
{100, 100, 100,
3731+
40, 50, 60,
3732+
70, 80, 90});
3733+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
3734+
}
3735+
3736+
TEST(MathOpTest, Min_12_Int8) {
3737+
OpTester test("Min", 12);
3738+
test.AddInput<int8_t>("data_0", {1, 3},
3739+
{1, 2, 3});
3740+
test.AddInput<int8_t>("data_2", {3, 3},
3741+
{10, 20, 30,
3742+
40, 50, 60,
3743+
-70, -80, -90});
3744+
test.AddInput<int8_t>("data_1", {3, 1},
3745+
{-1, 20, 127});
3746+
test.AddOutput<int8_t>("min", {3, 3},
3747+
{-1, -1, -1,
3748+
1, 2, 3,
3749+
-70, -80, -90});
3750+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
3751+
}
3752+
3753+
TEST(MathOpTest, Min_12_UInt8) {
3754+
OpTester test("Min", 12);
3755+
test.AddInput<uint8_t>("data_0", {1, 3},
3756+
{1, 20, 30});
3757+
test.AddInput<uint8_t>("data_2", {3, 3},
3758+
{10, 20, 30,
3759+
40, 50, 60,
3760+
70, 80, 90});
3761+
test.AddInput<uint8_t>("data_1", {3, 1},
3762+
{1, 20, 30});
3763+
test.AddOutput<uint8_t>("min", {3, 3},
3764+
{1, 1, 1,
3765+
1, 20, 20,
3766+
1, 20, 30});
3767+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
3768+
}
3769+
3770+
TEST(MathOpTest, Max_12_Int16) {
3771+
OpTester test("Max", 12);
3772+
test.AddInput<int16_t>("data_0", {1, 3},
3773+
{1, 2, 3});
3774+
test.AddInput<int16_t>("data_2", {3, 3},
3775+
{10, 20, 30,
3776+
40, 50, 60,
3777+
70, 80, 90});
3778+
test.AddInput<int16_t>("data_1", {3, 1},
3779+
{-1, -2, 300});
3780+
test.AddOutput<int16_t>("max", {3, 3},
3781+
{10, 20, 30,
3782+
40, 50, 60,
3783+
300, 300, 300});
3784+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
3785+
}
3786+
3787+
TEST(MathOpTest, Max_12_UInt16) {
3788+
OpTester test("Max", 12);
3789+
test.AddInput<uint16_t>("data_0", {1, 3},
3790+
{1, 20, 30});
3791+
test.AddInput<uint16_t>("data_2", {3, 3},
3792+
{10, 20, 30,
3793+
40, 50, 60,
3794+
70, 80, 90});
3795+
test.AddInput<uint16_t>("data_1", {3, 1},
3796+
{100, 20, 30});
3797+
test.AddOutput<uint16_t>("max", {3, 3},
3798+
{100, 100, 100,
3799+
40, 50, 60,
3800+
70, 80, 90});
3801+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
3802+
}
3803+
3804+
TEST(MathOpTest, Min_12_Int16) {
3805+
OpTester test("Min", 12);
3806+
test.AddInput<int16_t>("data_0", {1, 3},
3807+
{1, 2, 3});
3808+
test.AddInput<int16_t>("data_2", {3, 3},
3809+
{10, 20, 30,
3810+
40, 50, 60,
3811+
-70, -80, -90});
3812+
test.AddInput<int16_t>("data_1", {3, 1},
3813+
{-1, 20, 300});
3814+
test.AddOutput<int16_t>("min", {3, 3},
3815+
{-1, -1, -1,
3816+
1, 2, 3,
3817+
-70, -80, -90});
3818+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
3819+
}
3820+
3821+
TEST(MathOpTest, Min_12_UInt16) {
3822+
OpTester test("Min", 12);
3823+
test.AddInput<uint16_t>("data_0", {1, 3},
3824+
{1, 20, 30});
3825+
test.AddInput<uint16_t>("data_2", {3, 3},
3826+
{10, 20, 30,
3827+
40, 50, 60,
3828+
70, 80, 90});
3829+
test.AddInput<uint16_t>("data_1", {3, 1},
3830+
{1, 20, 30});
3831+
test.AddOutput<uint16_t>("min", {3, 3},
3832+
{1, 1, 1,
3833+
1, 20, 20,
3834+
1, 20, 30});
3835+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
3836+
}
3837+
37023838
TEST(MathOpTest, Mean_6) {
37033839
OpTester test("Mean", 6);
37043840
std::vector<int64_t> dims{3, 3};

0 commit comments

Comments
 (0)