From f83e268289499be910f722840eac7c4b69d21260 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Mon, 18 Nov 2024 15:57:35 +0800 Subject: [PATCH 1/4] add new interfaces and tests --- diopi_test/python/configs/diopi_configs.py | 28 +++++++++++++++++++ .../python/conformance/customized_test.py | 16 +++++++++++ .../python/conformance/diopi_functions.py | 8 ++++++ impl/torch/functions/functions_ext.cpp | 18 ++++++++++++ proto/include/diopi/functions_ext.h | 18 ++++++++++++ 5 files changed, 88 insertions(+) diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 0ad81385e..9afc473cf 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -9076,6 +9076,34 @@ ], ), ), + 'apply_rotary': dict( + name=['apply_rotary'], + interface=['CustomizedTest'], + dtype=[np.float64, np.float32, np.float16], + para=dict( + conj=[True, False, False, False, True, True, False, True, False, True], + interleaved=[True, False, True, False, True, False, False, True, True, False] + ), + tensor_para=dict( + gen_fn='Genfunc.randn', args=[ + { + "ins": ['input1'], + "shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)), + }, + { + "ins": ['input2'], + "shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)), + { + "ins": ['cos'], + "shape": ((6,), (32,), (2, 64), (3, 1, 128), (1, 32), (3, 5, 6), (125, 1, 256), (125, 1, 256), (64, 1, 16), (100, 1, 32)), + }, + { + "ins": ['sin'], + "shape": ((6,), (32,), (2, 64), (3, 1, 128), (1, 32), (3, 5, 6), (125, 1, 256), (125, 1, 256), (64, 1, 16), (100, 1, 32)), + }, + ], + ), + ), 'rotary_emb_empty_tensor': dict( name=['rotary_emb'], diff --git a/diopi_test/python/conformance/customized_test.py b/diopi_test/python/conformance/customized_test.py index 3f351e27c..52a5332ad 100644 --- a/diopi_test/python/conformance/customized_test.py +++ b/diopi_test/python/conformance/customized_test.py @@ -389,6 +389,22 @@ def rotary_emb(input, cos, sin, conj, interleaved): out2 = out2.to(data_type) out = torch.cat((out1, out2), dim=-1) return out + + def apply_rotary(input1, input2, cos, sin, conj, interleaved): + data_type = input1.dtype + input1 = input1.to(torch.float32) + ipnut2 = input2.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + if not conj: + out1 = input1 * cos - input2 * sin + out2 = input1 * sin + input2 * cos + else: + out1 = input1 * cos + input2 * sin + out2 = -input1 * sin + input2 * cos + out1 = out1.to(data_type) + out2 = out2.to(data_type) + return (out1, out2) def rms_norm(input, normalized_shape, weight, bias, eps): if normalized_shape is not None: diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 3b35e6fc9..14f6c750f 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -7036,6 +7036,14 @@ def rotary_emb(input, cos, sin, conj, interleaved): check_returncode(ret) return out +def apply_rotary(input1, input2, cos, sin, conj, interleaved): + call = "diopiApplyRotary" + func = check_function(call) + out1 = Tensor(list(input1.size().data), input1.get_dtype()) + out1 = Tensor(list(input2.size().data), input2.get_dtype()) + ret = func(input.context(), out1, out2, input1, input2, cos, sin, conj, interleaved) + check_returncode(ret) + return (out1, out2) def rms_norm(input, normalized_shape, weight, bias, eps): if bias is not None: diff --git a/impl/torch/functions/functions_ext.cpp b/impl/torch/functions/functions_ext.cpp index 3a4736fe9..20495d13f 100644 --- a/impl/torch/functions/functions_ext.cpp +++ b/impl/torch/functions/functions_ext.cpp @@ -63,6 +63,24 @@ diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t return diopiSuccess; } +diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, + diopiConstTensorHandle_t sin, const bool conj, const bool interleaved=false) { + if (interleaved) { + set_last_error_string("interleaved rotary embedding is not supported yet"); + return diopiNoImplement; + } + impl::aten::setCurStream(ctx); + auto atX1 = impl::aten::buildATen(x1); + auto atX2 = impl::aten::buildATen(x2); + auto atCos = impl::aten::buildATen(cos); + auto atSin = impl::aten::buildATen(sin); + auto atOut1 = impl::aten::buildATen(out1); + auto atOut2 = impl::aten::buildATen(out2); + ext::ops::apply_rotary_cuda(atX1, atX2, atCos, atSin, atOut1, atOut2, conj); + + return diopiSuccess; +} + diopiError_t diopiRMSNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t invRMS, diopiConstTensorHandle_t input, diopiSize_t normalized_shape, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, double eps) { impl::aten::setCurStream(ctx); diff --git a/proto/include/diopi/functions_ext.h b/proto/include/diopi/functions_ext.h index 7dc436933..28f49212c 100644 --- a/proto/include/diopi/functions_ext.h +++ b/proto/include/diopi/functions_ext.h @@ -31,6 +31,24 @@ extern "C" { DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj, const bool interleaved); +/** + * @brief Apply rotary embedding operation to an input tensor. + * @param[in] ctx The diopi context. + * @param[out] out1 The output tensor containing the rotary embeddings. type = [bfloat16, float16, float32, float64]. + * @param[out] out2 The output tensor containing the rotary embeddings. type = [bfloat16, float16, float32, float64]. + * @param[in] x1 The input tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64]. + * @param[in] x2 The input tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64]. + * @param[in] cos The cosine values. type = [bfloat16, float16, float32, float64]. + * @param[in] sin The sine values. type = [bfloat16, float16, float32, float64]. + * @param[in] conj bool: If `false`, compute rotary embeddings for forward. If `true`, computes the backward of rotary embeddings according to the conjugate of + * the rotary matrix. + * @param[in] interleaved bool: + * - When set to `false`, rotary embedding is applied by splitting 'x' in half and separately applying sine and cosine to each half. + * - When set to `true`, rotary embedding is applied by pairing every two elements in 'x' and applying sine and cosine to each pair. + */ +DIOPI_API diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, + diopiConstTensorHandle_t sin, const bool conj, const bool interleaved); + /** * @brief Apply Root Mean Square (RMS) Normalization to the input tensor. * @param[in] ctx The diopi context. From 248da96b1e88e95c8e064222c41c52c11960cbe7 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Mon, 18 Nov 2024 16:12:36 +0800 Subject: [PATCH 2/4] fix a little bug --- diopi_test/python/configs/diopi_configs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 9afc473cf..ef288434f 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -9076,6 +9076,7 @@ ], ), ), + 'apply_rotary': dict( name=['apply_rotary'], interface=['CustomizedTest'], @@ -9085,7 +9086,8 @@ interleaved=[True, False, True, False, True, False, False, True, True, False] ), tensor_para=dict( - gen_fn='Genfunc.randn', args=[ + gen_fn='Genfunc.randn', + args=[ { "ins": ['input1'], "shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)), @@ -9093,6 +9095,7 @@ { "ins": ['input2'], "shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)), + }, { "ins": ['cos'], "shape": ((6,), (32,), (2, 64), (3, 1, 128), (1, 32), (3, 5, 6), (125, 1, 256), (125, 1, 256), (64, 1, 16), (100, 1, 32)), From 52de8ebcee44573ee752d3bbd7856707f2a4714f Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Mon, 18 Nov 2024 16:45:34 +0800 Subject: [PATCH 3/4] fix bug in tests --- diopi_test/python/conformance/diopi_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 14f6c750f..a73932ca0 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -7040,8 +7040,8 @@ def apply_rotary(input1, input2, cos, sin, conj, interleaved): call = "diopiApplyRotary" func = check_function(call) out1 = Tensor(list(input1.size().data), input1.get_dtype()) - out1 = Tensor(list(input2.size().data), input2.get_dtype()) - ret = func(input.context(), out1, out2, input1, input2, cos, sin, conj, interleaved) + out2 = Tensor(list(input2.size().data), input2.get_dtype()) + ret = func(input1.context(), out1, out2, input1, input2, cos, sin, conj, interleaved) check_returncode(ret) return (out1, out2) From 164ff291eb2c9dcb3b95efba2b9735234fad6a02 Mon Sep 17 00:00:00 2001 From: HuayiL <442488254@qq.com> Date: Mon, 18 Nov 2024 18:05:32 +0800 Subject: [PATCH 4/4] clang-format the code --- impl/torch/functions/functions_ext.cpp | 5 +++-- proto/include/diopi/functions_ext.h | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/impl/torch/functions/functions_ext.cpp b/impl/torch/functions/functions_ext.cpp index 20495d13f..c36a4f755 100644 --- a/impl/torch/functions/functions_ext.cpp +++ b/impl/torch/functions/functions_ext.cpp @@ -63,8 +63,9 @@ diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t return diopiSuccess; } -diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, - diopiConstTensorHandle_t sin, const bool conj, const bool interleaved=false) { +diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, + diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj, + const bool interleaved = false) { if (interleaved) { set_last_error_string("interleaved rotary embedding is not supported yet"); return diopiNoImplement; diff --git a/proto/include/diopi/functions_ext.h b/proto/include/diopi/functions_ext.h index 28f49212c..738e40858 100644 --- a/proto/include/diopi/functions_ext.h +++ b/proto/include/diopi/functions_ext.h @@ -46,8 +46,9 @@ DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTenso * - When set to `false`, rotary embedding is applied by splitting 'x' in half and separately applying sine and cosine to each half. * - When set to `true`, rotary embedding is applied by pairing every two elements in 'x' and applying sine and cosine to each pair. */ -DIOPI_API diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, - diopiConstTensorHandle_t sin, const bool conj, const bool interleaved); +DIOPI_API diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, + diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj, + const bool interleaved); /** * @brief Apply Root Mean Square (RMS) Normalization to the input tensor.