Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions operators/gemm/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
Kernel,
ObjectFifo,
Program,
GlobalBuffer,
Buffer,
Runtime,
Worker,
WorkerRuntimeBarrier,
LocalBuffer,
Buffer,
str_to_dtype,
)
from aie.iron.placers import SequentialPlacer
Expand Down Expand Up @@ -330,7 +330,7 @@ def my_matmul(
# Runtime parameters
rtps = [
[
GlobalBuffer(
Buffer(
np.ndarray[(2,), np.dtype[np.int32]],
name=f"rtp{row}_{col}",
initial_value=np.array([0, 0], dtype=np.int32),
Expand Down Expand Up @@ -429,11 +429,17 @@ def my_matmul(
C_l1l2_fifos[j][col] = c_tmp_fifos[j]

# Tasks for each worker to perform
def core_fn(in_a, in_b, out_c, zero, matmul, convert_copy, my_rtp, barrier):
if use_larger_internal_buffer:
elem_out_internal = LocalBuffer(
type=C_l1_ty_internal,
)
def core_fn(
in_a,
in_b,
out_c,
zero,
matmul,
convert_copy,
my_rtp,
barrier,
elem_out_internal,
):
barrier.wait_for_value(1)
rtp_K_div_k = my_rtp[0]
rtp_n_tiles_per_core = my_rtp[1]
Expand Down Expand Up @@ -464,6 +470,12 @@ def core_fn(in_a, in_b, out_c, zero, matmul, convert_copy, my_rtp, barrier):
for row in range(n_aie_rows):
for col in range(n_aie_cols):
tile_col, tile_row = core_tiles[row][col]
acc_buffer = None
if use_larger_internal_buffer:
acc_buffer = Buffer(
type=C_l1_ty_internal, name=f"acc_buffer_{row}_{col}"
)
Comment on lines +475 to +477
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can buffers not be defined inside the core function? I guess there's no required MLIR context there?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Buffer is added to the context when it is passed to the function but this is transparent to the user. But yes, instead of having a GlobalBuffer and a LocalBuffer there is now just Buffer that must be globally declared.


workers.append(
Worker(
core_fn,
Expand All @@ -476,6 +488,7 @@ def core_fn(in_a, in_b, out_c, zero, matmul, convert_copy, my_rtp, barrier):
convert_copy_kernel if use_larger_internal_buffer else None,
rtps[row][col],
workerBarriers[row][col],
acc_buffer,
],
placement=Tile(tile_col, tile_row),
stack_size=0xD00,
Expand Down
49 changes: 36 additions & 13 deletions operators/mha/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
Program,
Runtime,
Worker,
LocalBuffer,
GlobalBuffer,
Buffer,
Buffer,
WorkerRuntimeBarrier,
)
from aie.iron.placers import SequentialPlacer
Expand Down Expand Up @@ -391,11 +391,17 @@ def fused_mha(
)

def batched_matmul_qk(
of_q, of_k, of_a_out, zero, matmul_QK, q_block_bias, mha_rtps, barrier
of_q,
of_k,
of_a_out,
zero,
matmul_QK,
q_block_bias,
mha_rtps,
barrier,
idx_buffer,
):

idx_buffer = LocalBuffer(initial_value=np.zeros(shape=(2,), dtype=np.int32))

barrier.wait_for_value(1)

loop_idx_q = mha_rtps[0]
Expand Down Expand Up @@ -437,14 +443,12 @@ def softmax(
q_block_bias,
mha_rtps,
barrier,
idx_buffer,
scale_buffer,
):

# VJUNG: The index buffer count how many Q and KV block this worker has processed
# From this info we can infer the position in A and P
idx_buffer = LocalBuffer(initial_value=np.zeros(shape=(2,), dtype=np.int32))
scale_buffer = LocalBuffer(
initial_value=np.zeros(shape=(4 * B_q,), dtype=dtype)
)

barrier.wait_for_value(1)

Expand Down Expand Up @@ -502,10 +506,9 @@ def batched_matmul_pv(
q_block_bias,
mha_rtps,
barrier,
idx_buffer,
):

idx_buffer = LocalBuffer(initial_value=np.zeros(shape=(2,), dtype=np.int32))

barrier.wait_for_value(1)

loop_idx_q = mha_rtps[0]
Expand Down Expand Up @@ -601,10 +604,10 @@ def batched_matmul_pv(
of_o_out.release(1)

# Runtime parameter for workers loop index
# VJUNG: We need one GlobalBuffer per worker since they need to be placed
# VJUNG: We need one Buffer per worker since they need to be placed
mha_rtps_list = [
[
GlobalBuffer(
Buffer(
np.ndarray[(4,), np.dtype[np.int32]],
name=f"mha_rtpss_{i}_stage{j}",
initial_value=None,
Expand All @@ -625,6 +628,10 @@ def batched_matmul_pv(
softmax_workers = []
matmul_pv_workers = []
for i in range(number_of_pipelines):
idx_buffer_qk = Buffer(
initial_value=np.zeros(shape=(2,), dtype=np.int32),
name=f"idx_buffer_qk_{i}",
)
matmul_workers.append(
Worker(
batched_matmul_qk,
Expand All @@ -637,12 +644,21 @@ def batched_matmul_pv(
i,
mha_rtps_list[0][i],
worker_barrier_list[0][i],
idx_buffer_qk,
],
stack_size=0xD00,
placement=Tile(col=i, row=2),
while_true=False,
)
)
idx_buffer_softmax = Buffer(
initial_value=np.zeros(shape=(2,), dtype=np.int32),
name=f"idx_buffer_softmax_{i}",
)
scale_buffer_softmax = Buffer(
initial_value=np.zeros(shape=(4 * B_q,), dtype=dtype),
name=f"scale_buffer_softmax_{i}",
)
softmax_workers.append(
Worker(
softmax,
Expand All @@ -656,12 +672,18 @@ def batched_matmul_pv(
i,
mha_rtps_list[1][i],
worker_barrier_list[1][i],
idx_buffer_softmax,
scale_buffer_softmax,
],
stack_size=0xD00,
placement=Tile(col=i, row=3),
while_true=False,
)
)
idx_buffer_pv = Buffer(
initial_value=np.zeros(shape=(2,), dtype=np.int32),
name=f"idx_buffer_pv_{i}",
)
matmul_pv_workers.append(
Worker(
batched_matmul_pv,
Expand All @@ -676,6 +698,7 @@ def batched_matmul_pv(
i,
mha_rtps_list[2][i],
worker_barrier_list[2][i],
idx_buffer_pv,
],
stack_size=0xD00,
placement=Tile(col=i, row=4),
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
# version of torch (don't need CUDA), so we give this index precedence over the
# main PyPI. These indices are consulted in order of precedence by pip.
--index-url https://download.pytorch.org/whl/cpu
--extra-index-url https://github.com/Xilinx/mlir-aie/releases/expanded_assets/v1.1.3
--extra-index-url https://github.com/Xilinx/mlir-aie/releases/expanded_assets/v1.1.4
--extra-index-url https://github.com/Xilinx/llvm-aie/releases/expanded_assets/nightly
--extra-index-url https://pypi.org/simple

mlir_aie==1.1.3
mlir_aie==1.1.4
llvm-aie

black
Expand Down