From 767687343cdd698ce53585f8e678a58e058588f9 Mon Sep 17 00:00:00 2001 From: Max Mustermann Date: Mon, 13 Jan 2025 23:42:27 +0100 Subject: [PATCH] PatternGen: experimental fp32/fp64 support --- core_descs/ExampleFP32.core_desc | 25 +++ core_descs/ExampleFP32.ll | 35 ++++ core_descs/ExampleFP32.td | 15 ++ core_descs/ExampleFP32InstrFormat.td | 22 ++ core_descs/ExampleFP64.core_desc | 25 +++ core_descs/ExampleFP64.ll | 35 ++++ core_descs/ExampleFP64.td | 15 ++ core_descs/ExampleFP64InstrFormat.td | 22 ++ core_descs/ExampleXCV.ll | 222 +++++++++++++++++++++ core_descs/ExampleXCV.td | 36 ++++ core_descs/ExampleXCVInstrFormat.td | 78 ++++++++ llvm/lib/CodeGen/GlobalISel/PatternGen.cpp | 96 ++++++--- llvm/tools/pattern-gen/Main.cpp | 7 +- llvm/tools/pattern-gen/PatternGen.hpp | 1 + llvm/tools/pattern-gen/lib/InstrInfo.hpp | 1 + llvm/tools/pattern-gen/lib/Parser.cpp | 92 ++++++++- llvm/tools/pattern-gen/lib/Parser.hpp | 2 +- 17 files changed, 694 insertions(+), 35 deletions(-) create mode 100644 core_descs/ExampleFP32.core_desc create mode 100644 core_descs/ExampleFP32.ll create mode 100644 core_descs/ExampleFP32.td create mode 100644 core_descs/ExampleFP32InstrFormat.td create mode 100644 core_descs/ExampleFP64.core_desc create mode 100644 core_descs/ExampleFP64.ll create mode 100644 core_descs/ExampleFP64.td create mode 100644 core_descs/ExampleFP64InstrFormat.td diff --git a/core_descs/ExampleFP32.core_desc b/core_descs/ExampleFP32.core_desc new file mode 100644 index 000000000000..62fa00e12035 --- /dev/null +++ b/core_descs/ExampleFP32.core_desc @@ -0,0 +1,25 @@ +// RUN: pattern-gen %s -O 3 --mattr=+m,+f --riscv-xlen 32 --riscv-flen 32 | FileCheck --check-prefixes=CHECK-RV32,CHECK-RV32-EXTEND -allow-unused-prefixes %s +// RUN: pattern-gen %s -O 3 --no-extend --mattr=+m,+f --riscv-xlen 32 --riscv-flen 32 | FileCheck --check-prefixes=CHECK-RV32,CHECK-RV32-NOEXTED -allow-unused-prefixes %s + +// CHECK-RV32: Pattern for FMAC: (any_fma FPR32:$rd, FPR32:$rs1, FPR32:$rs2) +FMAC { + encoding: 7'b0101000 :: rs2[4:0] :: rs1[4:0] :: 3'b011 :: rd[4:0] :: 7'b0101011; + assembly: "{name(rd)}, {name(rs1)}, {name(rs2)}"; + behavior: { + F[rd] = llvm_fmuladd_f32(F[rs1], F[rs2], F[rd]); + } +} + +// CHECK-RV32-NEXT: Pattern for FMEAN: (fmul (fadd FPR32:$rs1, FPR32:$rs2), (f32 0.500000)) +FMEAN { + operands: { + unsigned<5> rd [[out]] [[is_freg]]; + unsigned<5> rs1 [[in]] [[is_freg]]; + unsigned<5> rs2 [[in]] [[is_freg]]; + } + encoding: 7'b0101000 :: rs2[4:0] :: rs1[4:0] :: 3'b011 :: rd[4:0] :: 7'b0101011; + assembly: "{name(rd)}, {name(rs1)}, {name(rs2)}"; + behavior: { + F[rd] = llvm_fdiv_fp32(llvm_fadd_fp32(F[rs1], F[rs2]), llvm_uitofp_fp32(2)); + } +} diff --git a/core_descs/ExampleFP32.ll b/core_descs/ExampleFP32.ll new file mode 100644 index 000000000000..af6ca56261ee --- /dev/null +++ b/core_descs/ExampleFP32.ll @@ -0,0 +1,35 @@ +; ModuleID = 'mod' +source_filename = "mod" + +define void @implFMAC(ptr %rs2, ptr %rs1, ptr noalias %rd) { + %rs1.v = load i32, ptr %rs1, align 4 + %rs2.v = load i32, ptr %rs2, align 4 + %rd.v = load i32, ptr %rd, align 4 + %1 = bitcast i32 %rs1.v to float + %2 = bitcast i32 %rs2.v to float + %3 = bitcast i32 %rd.v to float + %4 = call float @llvm.fmuladd.f32(float %3, float %1, float %2) + %5 = bitcast float %4 to i32 + store i32 %5, ptr %rd, align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare float @llvm.fmuladd.f32(float, float, float) #0 + +define void @implFMEAN(ptr noalias %rd, ptr %rs1, ptr %rs2) { + %rs1.v = load i32, ptr %rs1, align 4 + %rs2.v = load i32, ptr %rs2, align 4 + %1 = bitcast i32 %rs1.v to float + %2 = bitcast i32 %rs2.v to float + %3 = fadd float %1, %2 + %4 = bitcast float %3 to i32 + %5 = bitcast i32 %4 to float + %6 = fdiv float %5, 2.000000e+00 + %7 = bitcast float %6 to i32 + store i32 %7, ptr %rd, align 4 + ret void +} + +attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } + diff --git a/core_descs/ExampleFP32.td b/core_descs/ExampleFP32.td new file mode 100644 index 000000000000..b78927fc7e8d --- /dev/null +++ b/core_descs/ExampleFP32.td @@ -0,0 +1,15 @@ +let Predicates = [HasVendorXCValu] in { + +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCodeGenOnly = 1, Constraints = "$rd = $rd_wb" in def FMAC_ : RVInst_FMAC<(outs GPR:$rd_wb), (ins GPR:$rs2, GPR:$rs1, GPR:$rd)>; + +def : Pat< + (i32 (any_fma FPR32:$rd, FPR32:$rs1, FPR32:$rs2)), + (FMAC_ GPR:$rs2, GPR:$rs1, GPR:$rd)>; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCodeGenOnly = 1, Constraints = "" in def FMEAN_ : RVInst_FMEAN<(outs GPR:$rd), (ins GPR:$rs1, GPR:$rs2)>; + +def : Pat< + (i32 (fmul (fadd FPR32:$rs1, FPR32:$rs2), (f32 0.500000))), + (FMEAN_ GPR:$rs1, GPR:$rs2)>; + +} diff --git a/core_descs/ExampleFP32InstrFormat.td b/core_descs/ExampleFP32InstrFormat.td new file mode 100644 index 000000000000..236a100ac1d3 --- /dev/null +++ b/core_descs/ExampleFP32InstrFormat.td @@ -0,0 +1,22 @@ +class RVInst_FMAC : RVInst { + bits<5> rs2; + bits<5> rs1; + bits<5> rd; + let Inst{31-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} +class RVInst_FMEAN : RVInst { + bits<5> rd; + bits<5> rs1; + bits<5> rs2; + let Inst{31-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} diff --git a/core_descs/ExampleFP64.core_desc b/core_descs/ExampleFP64.core_desc new file mode 100644 index 000000000000..eedc6ac0c2e4 --- /dev/null +++ b/core_descs/ExampleFP64.core_desc @@ -0,0 +1,25 @@ +// RUN: pattern-gen %s -O 3 --mattr=+m,+f,+d --riscv-xlen 64 --riscv-flen 64 | FileCheck --check-prefixes=CHECK-RV64,CHECK-RV64-EXTEND -allow-unused-prefixes %s +// RUN: pattern-gen %s -O 3 --no-extend --mattr=+m,+f,+d --riscv-xlen 64 --riscv-flen 64 | FileCheck --check-prefixes=CHECK-RV64,CHECK-RV64-NOEXTED -allow-unused-prefixes %s + +// CHECK-RV64: Pattern for FMAC: (any_fma FPR64:$rd, FPR64:$rs1, FPR64:$rs2) +FMAC { + encoding: 7'b0101000 :: rs2[4:0] :: rs1[4:0] :: 3'b011 :: rd[4:0] :: 7'b0101011; + assembly: "{name(rd)}, {name(rs1)}, {name(rs2)}"; + behavior: { + F[rd] = llvm_fmuladd_f64(F[rs1], F[rs2], F[rd]); + } +} + +// CHECK-RV64-NEXT: Pattern for FMEAN: (fmul (fadd FPR64:$rs1, FPR64:$rs2), (f64 0.500000)) +FMEAN { + operands: { + unsigned<5> rd [[out]] [[is_freg]]; + unsigned<5> rs1 [[in]] [[is_freg]]; + unsigned<5> rs2 [[in]] [[is_freg]]; + } + encoding: 7'b0101000 :: rs2[4:0] :: rs1[4:0] :: 3'b011 :: rd[4:0] :: 7'b0101011; + assembly: "{name(rd)}, {name(rs1)}, {name(rs2)}"; + behavior: { + F[rd] = llvm_fdiv_fp64(llvm_fadd_fp64(F[rs1], F[rs2]), llvm_uitofp_fp64(2)); + } +} diff --git a/core_descs/ExampleFP64.ll b/core_descs/ExampleFP64.ll new file mode 100644 index 000000000000..7255b80fe150 --- /dev/null +++ b/core_descs/ExampleFP64.ll @@ -0,0 +1,35 @@ +; ModuleID = 'mod' +source_filename = "mod" + +define void @implFMAC(ptr %rs2, ptr %rs1, ptr noalias %rd) { + %rs1.v = load i64, ptr %rs1, align 8 + %rs2.v = load i64, ptr %rs2, align 8 + %rd.v = load i64, ptr %rd, align 8 + %1 = bitcast i64 %rs1.v to double + %2 = bitcast i64 %rs2.v to double + %3 = bitcast i64 %rd.v to double + %4 = call double @llvm.fmuladd.f64(double %3, double %1, double %2) + %5 = bitcast double %4 to i64 + store i64 %5, ptr %rd, align 8 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare double @llvm.fmuladd.f64(double, double, double) #0 + +define void @implFMEAN(ptr noalias %rd, ptr %rs1, ptr %rs2) { + %rs1.v = load i64, ptr %rs1, align 8 + %rs2.v = load i64, ptr %rs2, align 8 + %1 = bitcast i64 %rs1.v to double + %2 = bitcast i64 %rs2.v to double + %3 = fadd double %1, %2 + %4 = bitcast double %3 to i64 + %5 = bitcast i64 %4 to double + %6 = fdiv double %5, 2.000000e+00 + %7 = bitcast double %6 to i64 + store i64 %7, ptr %rd, align 8 + ret void +} + +attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } + diff --git a/core_descs/ExampleFP64.td b/core_descs/ExampleFP64.td new file mode 100644 index 000000000000..eca9f83e3884 --- /dev/null +++ b/core_descs/ExampleFP64.td @@ -0,0 +1,15 @@ +let Predicates = [HasVendorXCValu] in { + +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCodeGenOnly = 1, Constraints = "$rd = $rd_wb" in def FMAC_ : RVInst_FMAC<(outs GPR:$rd_wb), (ins GPR:$rs2, GPR:$rs1, GPR:$rd)>; + +def : Pat< + (i64 (any_fma FPR64:$rd, FPR64:$rs1, FPR64:$rs2)), + (FMAC_ GPR:$rs2, GPR:$rs1, GPR:$rd)>; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCodeGenOnly = 1, Constraints = "" in def FMEAN_ : RVInst_FMEAN<(outs GPR:$rd), (ins GPR:$rs1, GPR:$rs2)>; + +def : Pat< + (i64 (fmul (fadd FPR64:$rs1, FPR64:$rs2), (f64 0.500000))), + (FMEAN_ GPR:$rs1, GPR:$rs2)>; + +} diff --git a/core_descs/ExampleFP64InstrFormat.td b/core_descs/ExampleFP64InstrFormat.td new file mode 100644 index 000000000000..236a100ac1d3 --- /dev/null +++ b/core_descs/ExampleFP64InstrFormat.td @@ -0,0 +1,22 @@ +class RVInst_FMAC : RVInst { + bits<5> rs2; + bits<5> rs1; + bits<5> rd; + let Inst{31-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} +class RVInst_FMEAN : RVInst { + bits<5> rd; + bits<5> rs1; + bits<5> rs2; + let Inst{31-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} diff --git a/core_descs/ExampleXCV.ll b/core_descs/ExampleXCV.ll index fa03b6b5a627..a9e18dadeb10 100644 --- a/core_descs/ExampleXCV.ll +++ b/core_descs/ExampleXCV.ll @@ -65,5 +65,227 @@ define void @implCV_ADDN(i32 %Luimm5, ptr %rs2, ptr %rs1, ptr noalias %rd) { ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) declare void @llvm.assume(i1 noundef) #0 +define void @implCV_ADD_H(ptr %rs2, ptr %rs1, ptr noalias %rd) { + br i1 true, label %1, label %14 + +1: ; preds = %0 + %2 = getelementptr i16, ptr %rd, i32 0 + %3 = getelementptr i16, ptr %rs1, i32 0 + %4 = getelementptr i16, ptr %rs2, i32 0 + %.v = load i16, ptr %3, align 2 + %.v1 = load i16, ptr %4, align 2 + %5 = add i16 %.v, %.v1 + %6 = lshr i16 %5, 0 + %7 = and i16 %6, -1 + store i16 %7, ptr %2, align 2 + %8 = getelementptr i16, ptr %rd, i32 1 + %9 = getelementptr i16, ptr %rs1, i32 1 + %10 = getelementptr i16, ptr %rs2, i32 1 + %.v2 = load i16, ptr %9, align 2 + %.v3 = load i16, ptr %10, align 2 + %11 = add i16 %.v2, %.v3 + %12 = lshr i16 %11, 0 + %13 = and i16 %12, -1 + store i16 %13, ptr %8, align 2 + br label %14 + +14: ; preds = %1, %0 + ret void +} + +define void @implCV_ADD_SC_H(ptr %rs2, ptr %rs1, ptr noalias %rd) { + br i1 true, label %1, label %14 + +1: ; preds = %0 + %2 = getelementptr i16, ptr %rd, i32 0 + %3 = getelementptr i16, ptr %rs1, i32 0 + %4 = getelementptr i16, ptr %rs2, i32 0 + %.v = load i16, ptr %3, align 2 + %.v1 = load i16, ptr %4, align 2 + %5 = add i16 %.v, %.v1 + %6 = lshr i16 %5, 0 + %7 = and i16 %6, -1 + store i16 %7, ptr %2, align 2 + %8 = getelementptr i16, ptr %rd, i32 1 + %9 = getelementptr i16, ptr %rs1, i32 1 + %10 = getelementptr i16, ptr %rs2, i32 0 + %.v2 = load i16, ptr %9, align 2 + %.v3 = load i16, ptr %10, align 2 + %11 = add i16 %.v2, %.v3 + %12 = lshr i16 %11, 0 + %13 = and i16 %12, -1 + store i16 %13, ptr %8, align 2 + br label %14 + +14: ; preds = %1, %0 + ret void +} + +define void @implCV_ADD_SCI_H(i32 %Imm6, ptr %rs1, ptr noalias %rd) { + %1 = and i32 %Imm6, 63 + %2 = icmp eq i32 %Imm6, %1 + call void @llvm.assume(i1 %2) + br i1 true, label %3, label %16 + +3: ; preds = %0 + %4 = getelementptr i16, ptr %rd, i32 0 + %5 = getelementptr i16, ptr %rs1, i32 0 + %6 = trunc i32 %Imm6 to i16 + %.v = load i16, ptr %5, align 2 + %7 = add i16 %.v, %6 + %8 = lshr i16 %7, 0 + %9 = and i16 %8, -1 + store i16 %9, ptr %4, align 2 + %10 = getelementptr i16, ptr %rd, i32 1 + %11 = getelementptr i16, ptr %rs1, i32 1 + %12 = trunc i32 %Imm6 to i16 + %.v1 = load i16, ptr %11, align 2 + %13 = add i16 %.v1, %12 + %14 = lshr i16 %13, 0 + %15 = and i16 %14, -1 + store i16 %15, ptr %10, align 2 + br label %16 + +16: ; preds = %3, %0 + ret void +} + +define void @implCV_ADD_B(ptr %rs2, ptr %rs1, ptr noalias %rd) { + br i1 true, label %1, label %26 + +1: ; preds = %0 + %2 = getelementptr i8, ptr %rd, i32 0 + %3 = getelementptr i8, ptr %rs1, i32 0 + %4 = getelementptr i8, ptr %rs2, i32 0 + %.v = load i8, ptr %3, align 1 + %.v1 = load i8, ptr %4, align 1 + %5 = add i8 %.v, %.v1 + %6 = lshr i8 %5, 0 + %7 = and i8 %6, -1 + store i8 %7, ptr %2, align 1 + %8 = getelementptr i8, ptr %rd, i32 1 + %9 = getelementptr i8, ptr %rs1, i32 1 + %10 = getelementptr i8, ptr %rs2, i32 1 + %.v2 = load i8, ptr %9, align 1 + %.v3 = load i8, ptr %10, align 1 + %11 = add i8 %.v2, %.v3 + %12 = lshr i8 %11, 0 + %13 = and i8 %12, -1 + store i8 %13, ptr %8, align 1 + %14 = getelementptr i8, ptr %rd, i32 2 + %15 = getelementptr i8, ptr %rs1, i32 2 + %16 = getelementptr i8, ptr %rs2, i32 2 + %.v4 = load i8, ptr %15, align 1 + %.v5 = load i8, ptr %16, align 1 + %17 = add i8 %.v4, %.v5 + %18 = lshr i8 %17, 0 + %19 = and i8 %18, -1 + store i8 %19, ptr %14, align 1 + %20 = getelementptr i8, ptr %rd, i32 3 + %21 = getelementptr i8, ptr %rs1, i32 3 + %22 = getelementptr i8, ptr %rs2, i32 3 + %.v6 = load i8, ptr %21, align 1 + %.v7 = load i8, ptr %22, align 1 + %23 = add i8 %.v6, %.v7 + %24 = lshr i8 %23, 0 + %25 = and i8 %24, -1 + store i8 %25, ptr %20, align 1 + br label %26 + +26: ; preds = %1, %0 + ret void +} + +define void @implCV_ADD_SC_B(ptr %rs2, ptr %rs1, ptr noalias %rd) { + br i1 true, label %1, label %26 + +1: ; preds = %0 + %2 = getelementptr i8, ptr %rd, i32 0 + %3 = getelementptr i8, ptr %rs1, i32 0 + %4 = getelementptr i8, ptr %rs2, i32 0 + %.v = load i8, ptr %3, align 1 + %.v1 = load i8, ptr %4, align 1 + %5 = add i8 %.v, %.v1 + %6 = lshr i8 %5, 0 + %7 = and i8 %6, -1 + store i8 %7, ptr %2, align 1 + %8 = getelementptr i8, ptr %rd, i32 1 + %9 = getelementptr i8, ptr %rs1, i32 1 + %10 = getelementptr i8, ptr %rs2, i32 0 + %.v2 = load i8, ptr %9, align 1 + %.v3 = load i8, ptr %10, align 1 + %11 = add i8 %.v2, %.v3 + %12 = lshr i8 %11, 0 + %13 = and i8 %12, -1 + store i8 %13, ptr %8, align 1 + %14 = getelementptr i8, ptr %rd, i32 2 + %15 = getelementptr i8, ptr %rs1, i32 2 + %16 = getelementptr i8, ptr %rs2, i32 0 + %.v4 = load i8, ptr %15, align 1 + %.v5 = load i8, ptr %16, align 1 + %17 = add i8 %.v4, %.v5 + %18 = lshr i8 %17, 0 + %19 = and i8 %18, -1 + store i8 %19, ptr %14, align 1 + %20 = getelementptr i8, ptr %rd, i32 3 + %21 = getelementptr i8, ptr %rs1, i32 3 + %22 = getelementptr i8, ptr %rs2, i32 0 + %.v6 = load i8, ptr %21, align 1 + %.v7 = load i8, ptr %22, align 1 + %23 = add i8 %.v6, %.v7 + %24 = lshr i8 %23, 0 + %25 = and i8 %24, -1 + store i8 %25, ptr %20, align 1 + br label %26 + +26: ; preds = %1, %0 + ret void +} + +define void @implCV_ADD_SCI_B(i32 %Imm6, ptr %rs1, ptr noalias %rd) { + %1 = and i32 %Imm6, 63 + %2 = icmp eq i32 %Imm6, %1 + call void @llvm.assume(i1 %2) + br i1 true, label %3, label %28 + +3: ; preds = %0 + %4 = getelementptr i8, ptr %rd, i32 0 + %5 = getelementptr i8, ptr %rs1, i32 0 + %6 = trunc i32 %Imm6 to i8 + %.v = load i8, ptr %5, align 1 + %7 = add i8 %.v, %6 + %8 = lshr i8 %7, 0 + %9 = and i8 %8, -1 + store i8 %9, ptr %4, align 1 + %10 = getelementptr i8, ptr %rd, i32 1 + %11 = getelementptr i8, ptr %rs1, i32 1 + %12 = trunc i32 %Imm6 to i8 + %.v1 = load i8, ptr %11, align 1 + %13 = add i8 %.v1, %12 + %14 = lshr i8 %13, 0 + %15 = and i8 %14, -1 + store i8 %15, ptr %10, align 1 + %16 = getelementptr i8, ptr %rd, i32 2 + %17 = getelementptr i8, ptr %rs1, i32 2 + %18 = trunc i32 %Imm6 to i8 + %.v2 = load i8, ptr %17, align 1 + %19 = add i8 %.v2, %18 + %20 = lshr i8 %19, 0 + %21 = and i8 %20, -1 + store i8 %21, ptr %16, align 1 + %22 = getelementptr i8, ptr %rd, i32 3 + %23 = getelementptr i8, ptr %rs1, i32 3 + %24 = trunc i32 %Imm6 to i8 + %.v3 = load i8, ptr %23, align 1 + %25 = add i8 %.v3, %24 + %26 = lshr i8 %25, 0 + %27 = and i8 %26, -1 + store i8 %27, ptr %22, align 1 + br label %28 + +28: ; preds = %3, %0 + ret void +} + attributes #0 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) } diff --git a/core_descs/ExampleXCV.td b/core_descs/ExampleXCV.td index a9e2eb33ddbd..855b86fc61c4 100644 --- a/core_descs/ExampleXCV.td +++ b/core_descs/ExampleXCV.td @@ -18,4 +18,40 @@ def : Pat< (i32 (i32 (sra (i32 (add GPR:$rs2, GPR:$rs1)), (i32 (i32 uimm5:$Luimm5))))), (CV_ADDN_ uimm5:$Luimm5, GPR:$rs2, GPR:$rs1)>; +let hasSideEffects = 0, mayLoad = 0, mayStore = 1, isCodeGenOnly = 1, Constraints = "" in def CV_ADD_H_ : RVInst_CV_ADD_H<(outs ), (ins GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +def : Pat< + (truncstorei16 (XLenVT (add (i32 (srl GPR:$rs2, (i32 16))), (i32 (srl GPR:$rs1, (i32 16))))), (iPTR (ptradd (iPTR (i32 uimm5:$rd)), (i32 (i32 2))))), + (CV_ADD_H_ GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 1, isCodeGenOnly = 1, Constraints = "" in def CV_ADD_SC_H_ : RVInst_CV_ADD_SC_H<(outs ), (ins GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +def : Pat< + (truncstorei16 (XLenVT (add (i32 (srl GPR:$rs1, (i32 16))), GPR:$rs2)), (iPTR (ptradd (iPTR (i32 uimm5:$rd)), (i32 (i32 2))))), + (CV_ADD_SC_H_ GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 1, isCodeGenOnly = 1, Constraints = "" in def CV_ADD_SCI_H_ : RVInst_CV_ADD_SCI_H<(outs ), (ins uimm6:$Imm6, GPR:$rs1, uimm5:$rd)>; + +def : Pat< + (truncstorei16 (XLenVT (add (i32 (srl GPR:$rs1, (i32 16))), (i32 uimm6:$Imm6))), (iPTR (ptradd (iPTR (i32 uimm5:$rd)), (i32 (i32 2))))), + (CV_ADD_SCI_H_ uimm6:$Imm6, GPR:$rs1, uimm5:$rd)>; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 1, isCodeGenOnly = 1, Constraints = "" in def CV_ADD_B_ : RVInst_CV_ADD_B<(outs ), (ins GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +def : Pat< + (truncstorei8 (XLenVT (add (i32 (srl GPR:$rs2, (i32 24))), (i32 (srl GPR:$rs1, (i32 24))))), (iPTR (ptradd (iPTR (i32 uimm5:$rd)), (i32 (i32 3))))), + (CV_ADD_B_ GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 1, isCodeGenOnly = 1, Constraints = "" in def CV_ADD_SC_B_ : RVInst_CV_ADD_SC_B<(outs ), (ins GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +def : Pat< + (truncstorei8 (XLenVT (add (i32 (srl GPR:$rs1, (i32 24))), GPR:$rs2)), (iPTR (ptradd (iPTR (i32 uimm5:$rd)), (i32 (i32 3))))), + (CV_ADD_SC_B_ GPR:$rs2, GPR:$rs1, uimm5:$rd)>; + +let hasSideEffects = 0, mayLoad = 0, mayStore = 1, isCodeGenOnly = 1, Constraints = "" in def CV_ADD_SCI_B_ : RVInst_CV_ADD_SCI_B<(outs ), (ins uimm6:$Imm6, GPR:$rs1, uimm5:$rd)>; + +def : Pat< + (truncstorei8 (XLenVT (add (i32 (srl GPR:$rs1, (i32 24))), (i32 uimm6:$Imm6))), (iPTR (ptradd (iPTR (i32 uimm5:$rd)), (i32 (i32 3))))), + (CV_ADD_SCI_B_ uimm6:$Imm6, GPR:$rs1, uimm5:$rd)>; + } diff --git a/core_descs/ExampleXCVInstrFormat.td b/core_descs/ExampleXCVInstrFormat.td index ef62b854685f..20ceab2eae4e 100644 --- a/core_descs/ExampleXCVInstrFormat.td +++ b/core_descs/ExampleXCVInstrFormat.td @@ -32,3 +32,81 @@ class RVInst_CV_ADDN : RVInst : RVInst { + bits<5> rs2; + bits<5> rs1; + bits<5> rd; + let Inst{31-27} = 0x0; + let Inst{26-26} = 0x0; + let Inst{25-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} +class RVInst_CV_ADD_SC_H : RVInst { + bits<5> rs2; + bits<5> rs1; + bits<5> rd; + let Inst{31-27} = 0x0; + let Inst{26-26} = 0x0; + let Inst{25-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} +class RVInst_CV_ADD_SCI_H : RVInst { + bits<6> Imm6; + bits<5> rs1; + bits<5> rd; + let Inst{31-27} = 0x0; + let Inst{26-26} = 0x0; + let Inst{25-25} = Imm6{0-0}; + let Inst{24-20} = Imm6{5-1}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} +class RVInst_CV_ADD_B : RVInst { + bits<5> rs2; + bits<5> rs1; + bits<5> rd; + let Inst{31-27} = 0x0; + let Inst{26-26} = 0x0; + let Inst{25-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} +class RVInst_CV_ADD_SC_B : RVInst { + bits<5> rs2; + bits<5> rs1; + bits<5> rd; + let Inst{31-27} = 0x0; + let Inst{26-26} = 0x0; + let Inst{25-25} = 0x0; + let Inst{24-20} = rs2{4-0}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} +class RVInst_CV_ADD_SCI_B : RVInst { + bits<6> Imm6; + bits<5> rs1; + bits<5> rd; + let Inst{31-27} = 0x0; + let Inst{26-26} = 0x0; + let Inst{25-25} = Imm6{0-0}; + let Inst{24-20} = Imm6{5-1}; + let Inst{19-15} = rs1{4-0}; + let Inst{14-12} = 0x0; + let Inst{11-7} = rd{4-0}; + let Inst{6-0} = 0x0; +} diff --git a/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp b/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp index 9438cd62b5c8..454fdf96d06e 100644 --- a/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp +++ b/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp @@ -177,19 +177,19 @@ static const std::unordered_map CmpStr = { {CmpInst::Predicate::ICMP_UGE, "SETUGE"}, }; -std::string lltToString(LLT Llt) { +std::string lltToString(LLT Llt, bool IsFloat = false) { if (Llt.isFixedVector()) return "v" + std::to_string(Llt.getElementCount().getFixedValue()) + lltToString(Llt.getElementType()); if (Llt.isScalar()) - return "i" + std::to_string(Llt.getSizeInBits()); + return (IsFloat ? "f" : "i") + std::to_string(Llt.getSizeInBits()); if (Llt.isPointer()) return "iPTR"; assert(0 && "invalid type"); return "invalid"; } -std::string lltToRegTypeStr(LLT Type) { +std::string lltToRegTypeStr(LLT Type, bool IsFloat) { if (Type.isValid()) { if (Type.isFixedVector() && Type.getElementType().isScalar() && Type.getSizeInBits() == 32) { @@ -199,7 +199,7 @@ std::string lltToRegTypeStr(LLT Type) { return "GPR32V2"; abort(); } else - return "GPR"; + return IsFloat ? ("FPR" + std::to_string(PatternGenArgs::Args.FLen)) : "GPR"; } assert(0 && "invalid type"); return "invalid"; @@ -218,6 +218,7 @@ struct PatternNode { PN_Compare, PN_Unop, PN_Constant, + PN_ConstantFP, PN_Register, PN_Load, PN_Select, @@ -293,7 +294,7 @@ struct ShuffleNode : public PatternNode { Second(std::move(Second)), Mask(std::move(Mask)) {} std::string patternString() override { - std::string TypeStr = lltToString(Type); + std::string TypeStr = lltToString(Type, false); std::string MaskStr = ""; for (size_t I = 0; I < Mask.size(); I++) { @@ -345,10 +346,11 @@ struct TernopNode : public PatternNode { static const std::unordered_map TernopStr = { {TargetOpcode::G_FSHL, "fshl"}, {TargetOpcode::G_FSHR, "fshr"}, + {TargetOpcode::G_FMA, "any_fma"}, {TargetOpcode::G_INSERT_VECTOR_ELT, "vector_insert"}, {TargetOpcode::G_SELECT, "select"}}; - std::string TypeStr = lltToString(Type); + std::string TypeStr = lltToString(Type, false); std::string OpString = "(" + std::string(TernopStr.at(Op)) + " " + First->patternString() + ", " + Second->patternString() + ", " + Third->patternString() + ")"; @@ -387,15 +389,18 @@ struct BinopNode : public PatternNode { std::string patternString() override { static const std::unordered_map BinopStr = { {TargetOpcode::G_ADD, "add"}, + {TargetOpcode::G_FADD, "fadd"}, {TargetOpcode::G_PTR_ADD, "ptradd"}, {TargetOpcode::G_SUB, "sub"}, {TargetOpcode::G_MUL, "mul"}, + {TargetOpcode::G_FMUL, "fmul"}, {TargetOpcode::G_UMULH, "mulhu"}, {TargetOpcode::G_SMULH, "mulhs"}, {TargetOpcode::G_UDIV, "udiv"}, {TargetOpcode::G_SREM, "srem"}, {TargetOpcode::G_UREM, "urem"}, {TargetOpcode::G_SDIV, "sdiv"}, + {TargetOpcode::G_FDIV, "fdiv"}, {TargetOpcode::G_SADDSAT, "saddsat"}, {TargetOpcode::G_UADDSAT, "uaddsat"}, {TargetOpcode::G_SSUBSAT, "ssubsat"}, @@ -438,9 +443,10 @@ struct BinopNode : public PatternNode { bool LeftImm = Left->IsImm; bool RightImm = Right->IsImm; bool DoSwap = IsCommutable && LeftImm && !RightImm; - std::string TypeStr = lltToString(Type); - std::string LhsTypeStr = lltToString(Left->Type); - std::string RhsTypeStr = lltToString(Right->Type); + bool IsFloat = isPreISelGenericFloatingPointOpcode(Op); + std::string TypeStr = lltToString(Type, IsFloat); + std::string LhsTypeStr = lltToString(Left->Type, IsFloat); + std::string RhsTypeStr = lltToString(Right->Type, IsFloat); // Explicitly specifying types for all ops increases pattern compile time // significantly, so we only do for ops where deduction fails otherwise. @@ -497,9 +503,9 @@ struct CompareNode : public BinopNode { Cond(Cond) {} std::string patternString() override { - std::string TypeStr = lltToString(Type); - std::string LhsTypeStr = lltToString(Left->Type); - std::string RhsTypeStr = lltToString(Right->Type); + std::string TypeStr = lltToString(Type, false); + std::string LhsTypeStr = lltToString(Left->Type, false); + std::string RhsTypeStr = lltToString(Right->Type, false); return "(" + TypeStr + " (setcc (" + LhsTypeStr + " " + Left->patternString() + "), (" + RhsTypeStr + " " + @@ -522,7 +528,7 @@ struct SelectNode : public PatternNode { Right(std::move(Right)), Tval(std::move(Tval)), Fval(std::move(Fval)) {} std::string patternString() override { - std::string TypeStr = lltToString(Type); + std::string TypeStr = lltToString(Type, false); return "(" + TypeStr + " (riscv_selectcc " + Left->patternString() + ", " + Right->patternString() + ", " + CmpStr.at(Cond) + ", " + @@ -571,7 +577,7 @@ struct UnopNode : public PatternNode { {TargetOpcode::G_CTPOP, "ctpop"}, {TargetOpcode::G_ABS, "abs"}}; - std::string TypeStr = lltToString(Type); + std::string TypeStr = lltToString(Type, false); // ignore bitcast ops for now if (Op == TargetOpcode::G_BITCAST) @@ -602,10 +608,10 @@ struct ConstantNode : public PatternNode { : std::to_string((int32_t)Constant); if (Type.isFixedVector()) { - std::string TypeStr = lltToString(Type); + std::string TypeStr = lltToString(Type, false); return "(" + TypeStr + " (" + RegT + " " + ConstantStr + "))"; } - return "(" + lltToString(Type) + " " + ConstantStr + ")"; + return "(" + lltToString(Type, false) + " " + ConstantStr + ")"; } static bool classof(const PatternNode *Pat) { @@ -613,6 +619,26 @@ struct ConstantNode : public PatternNode { } }; +struct ConstantFPNode : public PatternNode { + double Constant; + ConstantFPNode(LLT Type, double Const) + : PatternNode(PN_ConstantFP, Type, true), Constant(Const) {} + + std::string patternString() override { + std::string ConstantStr = std::to_string(Constant); + if (Type.isFixedVector()) { + + std::string TypeStr = lltToString(Type, true); + return "(" + TypeStr + " (" + RegT + " " + ConstantStr + "))"; + } + return "(" + lltToString(Type, true) + " " + ConstantStr + ")"; + } + + static bool classof(const PatternNode *Pat) { + return Pat->getKind() == PN_ConstantFP; + } +}; + struct RegisterNode : public PatternNode { StringRef Name; @@ -620,18 +646,19 @@ struct RegisterNode : public PatternNode { int Offset; int Size; bool Sext; + bool IsFloat; bool VectorExtract = false; // TODO: set based on type of this register in other uses size_t RegIdx; RegisterNode(LLT Type, StringRef Name, size_t RegIdx, bool IsImm, int Offset, - int Size, bool Sext) + int Size, bool Sext, bool IsFloat) : PatternNode(PN_Register, Type, IsImm), Name(Name), Offset(Offset), - Size(Size), Sext(Sext), RegIdx(RegIdx) {} + Size(Size), Sext(Sext), RegIdx(RegIdx), IsFloat(IsFloat) {} std::string patternString() override { - std::string TypeStr = lltToString(Type); + std::string TypeStr = lltToString(Type, false); // TODO bool PrintType = Type.isPointer(); if (IsImm) { @@ -645,7 +672,7 @@ struct RegisterNode : public PatternNode { if ((uint64_t)Size == XLen) { std::string Str; if ((Type.isScalar() && Type.getSizeInBits() == XLen) || Type.isPointer()) - Str = "GPR:$" + std::string(Name); + Str = (IsFloat ? ("FPR" + std::to_string(PatternGenArgs::Args.FLen) + ":$") : "GPR:$") + std::string(Name); if (PrintType) return "(" + TypeStr + " " + Str + ")"; return Str; @@ -735,7 +762,7 @@ struct CastNode : public PatternNode { : PatternNode(PN_Cast, Type, false), Value(std::move(Value)) {} std::string patternString() override { - auto LLTString = lltToString(Type); + auto LLTString = lltToString(Type, false); return "(" + LLTString + " " + Value->patternString() + ")"; } @@ -877,18 +904,15 @@ traverseNOpOperands(MachineRegisterInfo &MRI, MachineInstr &Cur, size_t N, int Start = 1) { std::vector> Operands(N); for (size_t I = 0; I < N; I++) { - // llvm::outs() << "i=" << i << '\n'; assert(Cur.getOperand(Start + I).isReg() && "expected register"); auto *Node = MRI.getOneDef(Cur.getOperand(Start + I).getReg()); if (!Node) { - // llvm::outs() << "Err" << '\n'; return std::make_tuple(PatternError(FORMAT, &Cur), std::vector>()); } auto [Err_, Node_] = traverse(MRI, *Node->getParent()); if (Err_) { - // llvm::outs() << "Err2" << '\n'; return std::make_tuple(Err_, std::vector>()); } // return std::make_tuple(SUCCESS, std::move(NodeR)); @@ -977,13 +1001,13 @@ static PatternOrError traverseRegLoad(MachineRegisterInfo &MRI, return pError(FORMAT_LOAD, AddrI); PatternArgs[Idx].Llt = MRI.getType(Cur.getOperand(0).getReg()); - PatternArgs[Idx].ArgTypeStr = lltToRegTypeStr(PatternArgs[Idx].Llt); + PatternArgs[Idx].ArgTypeStr = lltToRegTypeStr(PatternArgs[Idx].Llt, false); // TODO PatternArgs[Idx].In = true; assert(Cur.getOperand(0).isReg() && "expected register"); auto Node = std::make_unique( MRI.getType(Cur.getOperand(0).getReg()), Field->ident, Idx, false, - ReadOffset, ReadSize, false); + ReadOffset, ReadSize, false, Field->type & CDSLInstr::FREG); return PPattern(std::move(Node)); } @@ -992,13 +1016,16 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) { switch (Cur.getOpcode()) { case TargetOpcode::G_ADD: + case TargetOpcode::G_FADD: case TargetOpcode::G_PTR_ADD: case TargetOpcode::G_SUB: case TargetOpcode::G_MUL: + case TargetOpcode::G_FMUL: case TargetOpcode::G_UMULH: case TargetOpcode::G_SMULH: case TargetOpcode::G_SDIV: case TargetOpcode::G_UDIV: + case TargetOpcode::G_FDIV: case TargetOpcode::G_SREM: case TargetOpcode::G_UREM: case TargetOpcode::G_SADDSAT: @@ -1081,7 +1108,7 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) { assert(Cur.getOperand(0).isReg() && "expected register"); AsRegNode->Type = MRI.getType(Cur.getOperand(0).getReg()); PatternArgs[AsRegNode->RegIdx].ArgTypeStr = - lltToRegTypeStr(AsRegNode->Type); + lltToRegTypeStr(AsRegNode->Type, false); } return std::make_pair(SUCCESS, std::move(Node)); @@ -1110,6 +1137,14 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) { MRI.getType(Cur.getOperand(0).getReg()), Imm->getLimitedValue())); } + case TargetOpcode::G_FCONSTANT: { + // auto *Imm = Cur.getOperand(1).getCImm(); + auto *Imm = Cur.getOperand(1).getFPImm(); + assert(Cur.getOperand(0).isReg() && "expected register"); + return std::make_pair(SUCCESS, std::make_unique( + MRI.getType(Cur.getOperand(0).getReg()), + Imm->getValueAPF().convertToDouble())); + } case TargetOpcode::G_IMPLICIT_DEF: { assert(Cur.getOperand(0).isReg() && "expected register"); return std::make_pair(SUCCESS, @@ -1151,7 +1186,7 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) { std::make_unique( MRI.getType(Cur.getOperand(0).getReg()), Field->ident, Idx, true, 0, Field->len, - Field->type & CDSLInstr::SIGNED)); + Field->type & CDSLInstr::SIGNED, Field->type & CDSLInstr::FREG)); } // Else COPY is just a pass-through. @@ -1182,6 +1217,7 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) { } case TargetOpcode::G_FSHL: case TargetOpcode::G_FSHR: + case TargetOpcode::G_FMA: case TargetOpcode::G_SELECT: case TargetOpcode::G_INSERT_VECTOR_ELT: { auto [Err, NodeFirst, NodeSecond, NodeThird] = @@ -1238,7 +1274,7 @@ static PatternOrError traverseRegStore(size_t Idx, MachineRegisterInfo &MRI, PatternArgs[Idx].Out = true; PatternArgs[Idx].Llt = Type; - PatternArgs[Idx].ArgTypeStr = lltToRegTypeStr(Type); + PatternArgs[Idx].ArgTypeStr = lltToRegTypeStr(Type, false); // TODO return traverse(MRI, Root); } @@ -1404,7 +1440,7 @@ bool PatternGen::runOnMachineFunction(MachineFunction &MF) { std::string Code = "def : Pat<\n\t"; if (OutType.isValid()) - Code += "(" + lltToString(OutType) + " " + PatternStr + "),\n\t(" + + Code += "(" + lltToString(OutType, false) + " " + PatternStr + "),\n\t(" + InstName + "_ "; else Code += PatternStr + ",\n\t(" + InstName + "_ "; diff --git a/llvm/tools/pattern-gen/Main.cpp b/llvm/tools/pattern-gen/Main.cpp index 62eea631137a..333f389e97ec 100644 --- a/llvm/tools/pattern-gen/Main.cpp +++ b/llvm/tools/pattern-gen/Main.cpp @@ -62,6 +62,8 @@ static cl::opt NoExtend( static cl::opt XLen("riscv-xlen", cl::desc("RISC-V XLEN (32 or 64 bit)"), cl::init(32), cl::cat(ToolOptions)); +static cl::opt FLen("riscv-flen", cl::desc("RISC-V FLEN (32 or 64 bit)"), + cl::init(32), cl::cat(ToolOptions)); // Determine optimization level. static cl::opt @@ -126,7 +128,7 @@ int main(int argc, char **argv) { TokenStream Ts(InputFilename.c_str()); LLVMContext Ctx; auto Mod = std::make_unique("mod", Ctx); - auto Instrs = ParseCoreDSL2(Ts, (XLen == 64), Mod.get(), NoExtend); + auto Instrs = ParseCoreDSL2(Ts, (XLen == 64), Mod.get(), NoExtend, FLen); if (irOut) { std::string Str; @@ -165,7 +167,8 @@ int main(int argc, char **argv) { PGArgsStruct Args{.Mattr = "", .OptLevel = Opt, .Predicates = Predicates, - .Is64Bit = (XLen == 64)}; + .Is64Bit = (XLen == 64), + .FLen = FLen}; optimizeBehavior(Mod.get(), Instrs, irOut, Args); if (PrintIR) diff --git a/llvm/tools/pattern-gen/PatternGen.hpp b/llvm/tools/pattern-gen/PatternGen.hpp index 18e8d8e2f526..23318892489b 100644 --- a/llvm/tools/pattern-gen/PatternGen.hpp +++ b/llvm/tools/pattern-gen/PatternGen.hpp @@ -10,6 +10,7 @@ struct PGArgsStruct llvm::CodeGenOptLevel OptLevel; std::string Predicates; bool Is64Bit; + int FLen; }; int optimizeBehavior(llvm::Module* M, std::vector const& Instrs, std::ostream& OstreamIR, PGArgsStruct Args); diff --git a/llvm/tools/pattern-gen/lib/InstrInfo.hpp b/llvm/tools/pattern-gen/lib/InstrInfo.hpp index ba741d6fb628..4fbdfcb81010 100644 --- a/llvm/tools/pattern-gen/lib/InstrInfo.hpp +++ b/llvm/tools/pattern-gen/lib/InstrInfo.hpp @@ -28,6 +28,7 @@ struct CDSLInstr IN = 32, OUT = 64, IS_32_BIT = 128, + FREG = 256, }; struct Field diff --git a/llvm/tools/pattern-gen/lib/Parser.cpp b/llvm/tools/pattern-gen/lib/Parser.cpp index 53ab0295c55c..575e61d43a20 100644 --- a/llvm/tools/pattern-gen/lib/Parser.cpp +++ b/llvm/tools/pattern-gen/lib/Parser.cpp @@ -43,6 +43,7 @@ struct Value { llvm::Value *ll; int bitWidth; bool isSigned; + bool isFloat; bool isLValue; Value(llvm::Value *llvalue, bool isSigned = false) @@ -50,11 +51,13 @@ struct Value { assert(!llvm::isa(llvalue->getType())); bitWidth = llvalue->getType()->getIntegerBitWidth(); isLValue = false; + isFloat = false; } Value(llvm::Value *llvalue, int bitWidth, bool isSigned = false) : ll(llvalue), bitWidth(bitWidth), isSigned(isSigned) { isLValue = true; + isFloat = false; } Value() {} @@ -69,8 +72,11 @@ static llvm::BasicBlock *entry; static CDSLInstr *curInstr; static int xlen; +static int flen; static bool NoExtend_; static llvm::Type *regT; +static llvm::Type *regT2; +static llvm::Type *fregT; static void reset_globals() { variables.clear(); @@ -779,6 +785,73 @@ static auto find_var(uint32_t identIdx) { [identIdx](CDSLInstr::Field &f) { return f.identIdx == identIdx; }); } +std::vector ParseFuncCallArgs(TokenStream &ts, llvm::Function *func, llvm::IRBuilder<> &build) { + std::vector args; + pop_cur(ts, RBrOpen); + if (ts.Peek().type != RBrClose) { + do { + auto expr = ParseExpression(ts, func, build); + promote_lvalue(build, expr); + args.push_back(expr); + } while (pop_cur_if(ts, Comma)); + } + pop_cur(ts, RBrClose); + return args; +} + +Value ParseLLVMFuncCall(TokenStream &ts, llvm::Function *func, + llvm::IRBuilder<> &build, std::string func_name) { + auto args = ParseFuncCallArgs(ts, func, build); + if (func_name == "llvm_fmuladd_f32" || func_name == "llvm_fmuladd_f64") { + assert(flen == std::stoi(func_name.substr(func_name.size() - 2))); + assert(args.size() == 3); + auto A_ = build.CreateBitCast(args[0].ll, fregT); + auto B_ = build.CreateBitCast(args[1].ll, fregT); + auto M_ = build.CreateBitCast(args[2].ll, fregT); + auto temp = build.CreateIntrinsic(fregT, llvm::Intrinsic::fmuladd, llvm::ArrayRef{M_, A_, B_}, nullptr); + auto temp_ = build.CreateBitCast(temp, regT2); + Value v = {temp_, false}; + v.bitWidth = xlen; // TODO + return v; + } else if (func_name == "llvm_fdiv_fp32" || func_name == "llvm_fdiv_fp64") { + assert(flen == std::stoi(func_name.substr(func_name.size() - 2))); + assert(args.size() == 2); + auto A_ = build.CreateBitCast(args[0].ll, fregT); + auto B_ = build.CreateBitCast(args[1].ll, fregT); + auto temp = build.CreateFDiv(A_, B_); + temp = build.CreateBitCast(temp, regT2); + Value v = {temp, false}; + v.bitWidth = xlen; // TODO + return v; + } else if (func_name == "llvm_fadd_fp32" || func_name == "llvm_fadd_fp64") { + assert(flen == std::stoi(func_name.substr(func_name.size() - 2))); + assert(args.size() == 2); + auto A_ = build.CreateBitCast(args[0].ll, fregT); + auto B_ = build.CreateBitCast(args[1].ll, fregT); + auto temp = build.CreateFAdd(A_, B_); + temp = build.CreateBitCast(temp, regT2); + Value v = {temp, false}; + v.bitWidth = xlen; // TODO + return v; + } else if (func_name == "llvm_uitofp_fp32" || func_name == "llvm_uitofp_fp64") { + assert(flen == std::stoi(func_name.substr(func_name.size() - 2))); + assert(args.size() == 1); + auto A = args[0]; + if (A.bitWidth != xlen) { + A.isSigned = false; + A.bitWidth = xlen; + fit_to_size(A, build); + } + auto A_ = A.ll; + auto temp = build.CreateUIToFP(A_, fregT); + temp = build.CreateBitCast(temp, regT2); + Value v = {temp, false}; + v.bitWidth = xlen; // TODO + return v; + } + error(("undefined llvm function: " + func_name).c_str(), ts); +} + Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func, llvm::IRBuilder<> &build) { auto &ctx = func->getContext(); @@ -805,8 +878,9 @@ Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func, len = xlen; return Value{addrPtr, len, false}; } - if (t.ident.str == "X" || t.ident.str == "XW") { + if (t.ident.str == "X" || t.ident.str == "XW" || t.ident.str == "F") { bool sizeIs32 = t.ident.str == "XW"; + bool isFloat = t.ident.str == "F"; pop_cur(ts, ABrOpen); auto ident = pop_cur(ts, Identifier).ident; pop_cur(ts, ABrClose); @@ -819,6 +893,8 @@ Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func, .c_str(), ts); sizeIs32 |= (match->type & CDSLInstr::IS_32_BIT); + if (isFloat) // TODO: assert not REG + match->type = (CDSLInstr::FieldType)(match->type | CDSLInstr::FieldType::FREG); return Value{func->getArg(match - curInstr->fields.begin()), sizeIs32 ? 32 : xlen, (bool)(match->type & CDSLInstr::SIGNED_REG)}; @@ -840,11 +916,17 @@ Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func, return v; } } + // TODO: check if float reg and get flen + // TODO: parse funtion call util + if (t.ident.str.rfind("llvm_", 0) == 0) { + return ParseLLVMFuncCall(ts, func, build, std::string(t.ident.str)); + } auto iter = variables.find(t.ident.idx); if (iter != variables.end()) return iter->getSecond().back().val; + error(("undefined symbol: " + std::string(t.ident.str)).c_str(), ts); } case IntLiteral: { @@ -1181,6 +1263,7 @@ void ParseOperands(TokenStream &ts, CDSLInstr &instr) { {"is_signed", {FieldType::SIGNED_REG, 0}}, {"is_imm", {FieldType::IMM, 0}}, {"is_reg", {FieldType::REG, 0}}, + {"is_freg", {FieldType::FREG, 0}}, {"in", {FieldType::IN, 0}}, {"out", {FieldType::OUT, 0}}, {"inout", {(FieldType::IN | FieldType::OUT), 0}}, @@ -1438,11 +1521,14 @@ void ParseBehaviour(TokenStream &ts, CDSLInstr &instr, llvm::Module *mod, } std::vector ParseCoreDSL2(TokenStream &ts, bool is64Bit, - llvm::Module *mod, bool NoExtend) { + llvm::Module *mod, bool NoExtend, int Flen) { std::vector instrs; xlen = is64Bit ? 64 : 32; + flen = Flen; // TODO: allow 0? NoExtend_ = NoExtend; regT = llvm::Type::getIntNTy(mod->getContext(), xlen); + regT2 = llvm::Type::getIntNTy(mod->getContext(), flen); + fregT = flen == 32 ? llvm::Type::getFloatTy(mod->getContext()) : llvm::Type::getDoubleTy(mod->getContext()); while (ts.Peek().type != None) { bool parseBoilerplate = @@ -1464,6 +1550,8 @@ std::vector ParseCoreDSL2(TokenStream &ts, bool is64Bit, // add XLEN and RFS as constants for now. add_variable(ts, ts.GetIdentIdx("XLEN"), Value{llvm::ConstantInt::get(regT, xlen)}); + add_variable(ts, ts.GetIdentIdx("FLEN"), + Value{llvm::ConstantInt::get(regT, flen)}); add_variable(ts, ts.GetIdentIdx("RFS"), Value{llvm::ConstantInt::get(regT, 32)}); ++PatternGenNumInstructionsParsed; diff --git a/llvm/tools/pattern-gen/lib/Parser.hpp b/llvm/tools/pattern-gen/lib/Parser.hpp index 19cbcd62ca98..c3d64ed7da7e 100644 --- a/llvm/tools/pattern-gen/lib/Parser.hpp +++ b/llvm/tools/pattern-gen/lib/Parser.hpp @@ -4,4 +4,4 @@ #include std::vector ParseCoreDSL2(TokenStream &ts, bool is64Bit, - llvm::Module *mod, bool NoExtend); + llvm::Module *mod, bool NoExtend, int FLen);