diff --git a/diopi_test/python/configs/diopi_configs.py b/diopi_test/python/configs/diopi_configs.py index 0ad81385e..ef288434f 100755 --- a/diopi_test/python/configs/diopi_configs.py +++ b/diopi_test/python/configs/diopi_configs.py @@ -9076,6 +9076,37 @@ ], ), ), + + 'apply_rotary': dict( + name=['apply_rotary'], + interface=['CustomizedTest'], + dtype=[np.float64, np.float32, np.float16], + para=dict( + conj=[True, False, False, False, True, True, False, True, False, True], + interleaved=[True, False, True, False, True, False, False, True, True, False] + ), + tensor_para=dict( + gen_fn='Genfunc.randn', + args=[ + { + "ins": ['input1'], + "shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)), + }, + { + "ins": ['input2'], + "shape": ((6,), (32,), (2, 64), (3, 8, 128), (1, 32), (3, 5, 6), (1, 125, 16, 256), (1, 125, 16, 256), (2, 64, 16, 16), (3, 100, 8, 32)), + }, + { + "ins": ['cos'], + "shape": ((6,), (32,), (2, 64), (3, 1, 128), (1, 32), (3, 5, 6), (125, 1, 256), (125, 1, 256), (64, 1, 16), (100, 1, 32)), + }, + { + "ins": ['sin'], + "shape": ((6,), (32,), (2, 64), (3, 1, 128), (1, 32), (3, 5, 6), (125, 1, 256), (125, 1, 256), (64, 1, 16), (100, 1, 32)), + }, + ], + ), + ), 'rotary_emb_empty_tensor': dict( name=['rotary_emb'], diff --git a/diopi_test/python/conformance/customized_test.py b/diopi_test/python/conformance/customized_test.py index 3f351e27c..52a5332ad 100644 --- a/diopi_test/python/conformance/customized_test.py +++ b/diopi_test/python/conformance/customized_test.py @@ -389,6 +389,22 @@ def rotary_emb(input, cos, sin, conj, interleaved): out2 = out2.to(data_type) out = torch.cat((out1, out2), dim=-1) return out + + def apply_rotary(input1, input2, cos, sin, conj, interleaved): + data_type = input1.dtype + input1 = input1.to(torch.float32) + ipnut2 = input2.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + if not conj: + out1 = input1 * cos - input2 * sin + out2 = input1 * sin + input2 * cos + else: + out1 = input1 * cos + input2 * sin + out2 = -input1 * sin + input2 * cos + out1 = out1.to(data_type) + out2 = out2.to(data_type) + return (out1, out2) def rms_norm(input, normalized_shape, weight, bias, eps): if normalized_shape is not None: diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 3b35e6fc9..a73932ca0 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -7036,6 +7036,14 @@ def rotary_emb(input, cos, sin, conj, interleaved): check_returncode(ret) return out +def apply_rotary(input1, input2, cos, sin, conj, interleaved): + call = "diopiApplyRotary" + func = check_function(call) + out1 = Tensor(list(input1.size().data), input1.get_dtype()) + out2 = Tensor(list(input2.size().data), input2.get_dtype()) + ret = func(input1.context(), out1, out2, input1, input2, cos, sin, conj, interleaved) + check_returncode(ret) + return (out1, out2) def rms_norm(input, normalized_shape, weight, bias, eps): if bias is not None: diff --git a/impl/torch/functions/functions_ext.cpp b/impl/torch/functions/functions_ext.cpp index 3a4736fe9..c36a4f755 100644 --- a/impl/torch/functions/functions_ext.cpp +++ b/impl/torch/functions/functions_ext.cpp @@ -63,6 +63,25 @@ diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t return diopiSuccess; } +diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, + diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj, + const bool interleaved = false) { + if (interleaved) { + set_last_error_string("interleaved rotary embedding is not supported yet"); + return diopiNoImplement; + } + impl::aten::setCurStream(ctx); + auto atX1 = impl::aten::buildATen(x1); + auto atX2 = impl::aten::buildATen(x2); + auto atCos = impl::aten::buildATen(cos); + auto atSin = impl::aten::buildATen(sin); + auto atOut1 = impl::aten::buildATen(out1); + auto atOut2 = impl::aten::buildATen(out2); + ext::ops::apply_rotary_cuda(atX1, atX2, atCos, atSin, atOut1, atOut2, conj); + + return diopiSuccess; +} + diopiError_t diopiRMSNorm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t invRMS, diopiConstTensorHandle_t input, diopiSize_t normalized_shape, diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias, double eps) { impl::aten::setCurStream(ctx); diff --git a/proto/include/diopi/functions_ext.h b/proto/include/diopi/functions_ext.h index 7dc436933..738e40858 100644 --- a/proto/include/diopi/functions_ext.h +++ b/proto/include/diopi/functions_ext.h @@ -31,6 +31,25 @@ extern "C" { DIOPI_API diopiError_t diopiRotaryEmbedding(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t x, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj, const bool interleaved); +/** + * @brief Apply rotary embedding operation to an input tensor. + * @param[in] ctx The diopi context. + * @param[out] out1 The output tensor containing the rotary embeddings. type = [bfloat16, float16, float32, float64]. + * @param[out] out2 The output tensor containing the rotary embeddings. type = [bfloat16, float16, float32, float64]. + * @param[in] x1 The input tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64]. + * @param[in] x2 The input tensor which rotary embedding will be applied. type = [bfloat16, float16, float32, float64]. + * @param[in] cos The cosine values. type = [bfloat16, float16, float32, float64]. + * @param[in] sin The sine values. type = [bfloat16, float16, float32, float64]. + * @param[in] conj bool: If `false`, compute rotary embeddings for forward. If `true`, computes the backward of rotary embeddings according to the conjugate of + * the rotary matrix. + * @param[in] interleaved bool: + * - When set to `false`, rotary embedding is applied by splitting 'x' in half and separately applying sine and cosine to each half. + * - When set to `true`, rotary embedding is applied by pairing every two elements in 'x' and applying sine and cosine to each pair. + */ +DIOPI_API diopiError_t diopiApplyRotary(diopiContextHandle_t ctx, diopiTensorHandle_t out1, diopiTensorHandle_t out2, diopiConstTensorHandle_t x1, + diopiConstTensorHandle_t x2, diopiConstTensorHandle_t cos, diopiConstTensorHandle_t sin, const bool conj, + const bool interleaved); + /** * @brief Apply Root Mean Square (RMS) Normalization to the input tensor. * @param[in] ctx The diopi context.