diff --git a/src/ita_gelu.sv b/src/ita_gelu.sv index cace920..6760e17 100644 --- a/src/ita_gelu.sv +++ b/src/ita_gelu.sv @@ -24,40 +24,35 @@ module ita_gelu requant_t data_q1; always_comb begin : first_stage - if (calc_en_i) begin - data_sign_ext = {{GELU_CONSTANTS_WIDTH-WI{data_i[WI-1]}}, data_i}; + data_sign_ext = {{GELU_CONSTANTS_WIDTH-WI{data_i[WI-1]}}, data_i}; - erf_sgn_d = data_i < 0; - erf_abs = erf_sgn_d ? -data_sign_ext : data_sign_ext; - erf_clipped = erf_abs > -b_i ? -b_i : erf_abs; + erf_sgn_d = data_i < 0; + erf_abs = erf_sgn_d ? -data_sign_ext : data_sign_ext; + erf_clipped = erf_abs > -b_i ? -b_i : erf_abs; - poly_d = erf_clipped + b_i; - poly_sq_d = poly_d * poly_d; - end + poly_d = erf_clipped + b_i; + poly_sq_d = (calc_en_i) ? poly_d * poly_d : poly_sq_q1; end always_comb begin : second_stage - if (calc_en_q_i) begin - erf_L_q1 = poly_sq_q1 + c_q1; - - gelu_erf_q1 = erf_sgn_q1 ? -erf_L_q1 : erf_L_q1; - gelu_sum_q1 = gelu_erf_q1 + c_q1; - gelu_out_q1 = data_q1 * gelu_sum_q1; - end + erf_L_q1 = poly_sq_q1 + c_q1; + gelu_erf_q1 = erf_sgn_q1 ? -erf_L_q1 : erf_L_q1; + gelu_sum_q1 = gelu_erf_q1 + c_q1; + gelu_out_q1 = (calc_en_q_i) ? data_q1 * gelu_sum_q1 : gelu_out_q2; end always_ff @(posedge clk_i or negedge rst_ni) begin if (!rst_ni) begin - c_q1 <= '0; - data_q1 <= '0; - erf_sgn_q1 <= '0; - poly_sq_q1 <= '0; + c_q1 <= '0; + data_q1 <= '0; + erf_sgn_q1 <= '0; + poly_sq_q1 <= '0; gelu_out_q2 <= '0; end else begin - c_q1 <= c_i; - data_q1 <= data_i; - erf_sgn_q1 <= erf_sgn_d; - poly_sq_q1 <= poly_sq_d; + c_q1 <= c_i; + data_q1 <= data_i; + erf_sgn_q1 <= erf_sgn_d; + poly_sq_q1 <= poly_sq_d; gelu_out_q2 <= gelu_out_q1; end end