Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions applications/llama_3.2_1b/src/block/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(
self.aie_softmax = AIESoftmax(
num_aie_columns=1,
num_channels=1,
size=prompt_length * prompt_length,
last_dim=prompt_length,
rows=prompt_length,
cols=prompt_length,
)
M_for_gemm = prompt_length + num_tokens
self.aie_mha_gemm_qk = AIEGEMM(
Expand Down
4 changes: 2 additions & 2 deletions applications/llama_3.2_1b/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


def generate_test_params():
prompt_lengths = [2048]
num_tokens_list = [40]
prompt_lengths = [2048, 13]
num_tokens_list = [40, 1]

params = []
names = []
Expand Down
57 changes: 40 additions & 17 deletions operators/common/aie_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,23 +167,46 @@ def _move_artifact_paths(self):
todo.extend(artifact.depends)

def run_runlist(self):
bos = set(
self.buffer_bos[buffer_arg]
for _, *buffer_args in self.runlist
for buffer_arg in buffer_args
)
insts_bos = set(
self.xrt_kernels[kernel_name][2] for (kernel_name, *_) in self.runlist
)
for bo in bos | insts_bos:
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
start = time.perf_counter()
self.xrt_runlist.execute()
self.xrt_runlist.wait()
stop = time.perf_counter()
for bo in bos:
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
return stop - start
elapsed = 0.0
if self.xrt_runlist is None:
# Execute as separate xclbin kernel invocations
for i, (kernel_name, *buffer_args) in enumerate(self.runlist):
context, xrt_kernel, insts_bo, insts_len = self.xrt_kernels[kernel_name]
insts_bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
bos = [self.buffer_bos[buffer_arg] for buffer_arg in buffer_args]
for bo in bos:
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
opcode = 3
start = time.perf_counter()
run = xrt_kernel(opcode, insts_bo, insts_len, *bos)
result = run.wait()
stop = time.perf_counter()
elapsed += stop - start
if result != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED:
raise RuntimeError(
f"Kernel {kernel_name} did not complete correctly: {result}"
)
for bo in bos:
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
else:
bos = set(
self.buffer_bos[buffer_arg]
for _, *buffer_args in self.runlist
for buffer_arg in buffer_args
)
insts_bos = set(
self.xrt_kernels[kernel_name][2] for (kernel_name, *_) in self.runlist
)
for bo in bos | insts_bos:
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
start = time.perf_counter()
self.xrt_runlist.execute()
self.xrt_runlist.wait()
stop = time.perf_counter()
for bo in bos:
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
elapsed = stop - start
return elapsed


class AIEOperatorConstraintError(RuntimeError):
Expand Down
35 changes: 20 additions & 15 deletions operators/common/aie_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
class AIEContext:
"""Context for managing AIE operator compilation and runtime state"""

def __init__(self):
def __init__(self, use_runlist=True):
self.operators = []
self.static_data_pool = {}
self.device_manager = AIEDeviceManager()
self.base_dir = Path(__file__).parent.parent.parent
self.build_dir = Path(os.getcwd()) / "build"
self.mlir_aie_dir = Path(aie.utils.config.root_path())
self.peano_dir = Path(aie.utils.config.peano_install_dir())
# Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed)
self.use_runlist = use_runlist
self._runtime_prepared = False

def register_operator(self, operator):
Expand Down Expand Up @@ -146,20 +148,23 @@ def prepare_runtime(self):
context, _ = self.device_manager.get_context_and_kernel(
str(first_xclbin.path), first_xclbin_kernel_name
)
op.xrt_runlist = pyxrt.runlist(context)
for i, (kernel_name, *buffer_args) in enumerate(op.runlist):
this_context, xrt_kernel, insts_bo, insts_len = op.xrt_kernels[
kernel_name
]
assert this_context == context
opcode = 3
run = pyxrt.run(xrt_kernel)
run.set_arg(0, opcode)
run.set_arg(1, insts_bo)
run.set_arg(2, insts_len)
for j, buffer_arg in enumerate(buffer_args):
run.set_arg(j + 3, op.buffer_bos[buffer_arg])
op.xrt_runlist.add(run)
if self.use_runlist:
op.xrt_runlist = pyxrt.runlist(context)
for i, (kernel_name, *buffer_args) in enumerate(op.runlist):
this_context, xrt_kernel, insts_bo, insts_len = op.xrt_kernels[
kernel_name
]
assert this_context == context
opcode = 3
run = pyxrt.run(xrt_kernel)
run.set_arg(0, opcode)
run.set_arg(1, insts_bo)
run.set_arg(2, insts_len)
for j, buffer_arg in enumerate(buffer_args):
run.set_arg(j + 3, op.buffer_bos[buffer_arg])
op.xrt_runlist.add(run)
else:
op.xrt_runlist = None

# Log allocation info
bo_count = sum(len(pool) for pool in bo_pools.values())
Expand Down
17 changes: 5 additions & 12 deletions operators/softmax/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,9 @@
class AIESoftmax(AIEOperatorBase):

def __init__(
self,
rows: int,
cols: int,
num_aie_columns=1,
num_channels=1,
tile_size=None,
context=None,
self, rows: int, cols: int, num_aie_columns=1, num_channels=1, context=None
):
self.size = rows * cols
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the normalization operators structured like this now too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only softmax is. The others still would need to be updated.

self.tile_size = tile_size if tile_size is not None else cols
self.rows = rows
self.cols = cols

Expand All @@ -46,19 +39,19 @@ def __init__(
def set_up_artifacts(self):
# Compilation artifacts
operator_dir = Path(__file__).parent
file_name_base = f"softmax_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t"
file_name_base = f"softmax_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.cols}t"

mlir_artifact = PythonGeneratedMLIRArtifact.new(
f"{file_name_base}.mlir",
import_path=operator_dir / "design.py",
callback_fn="softmax",
callback_args=[
self.context.device_manager.device_type,
self.size,
self.rows * self.cols,
self.num_columns,
self.num_channels,
0,
self.tile_size,
self.cols,
],
)

Expand Down Expand Up @@ -105,7 +98,7 @@ def set_up_runtime(self):
def forward(self, x):
applicable = (
x.shape[-1] * x.shape[-2] == self.size
and x.shape[-1] == self.tile_size
and x.shape[-1] == self.cols
and x.shape[-1] % 16 == 0
and x.shape[-2] % 16 == 0
)
Expand Down
1 change: 0 additions & 1 deletion operators/softmax/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def test_softmax(input_length, num_aie_columns, num_channels, tile_size, aie_con
cols=cols,
num_aie_columns=num_aie_columns,
num_channels=num_channels,
tile_size=tile_size,
context=aie_context,
)

Expand Down
26 changes: 13 additions & 13 deletions operators/swiglu_decode/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def __init__(self, embedding_dim, hidden_dim, prio_accuracy=False, context=None)
super().__init__(context=context)

def set_up_artifacts(self):
# Artifact setup
# ---
artifacts = []
device_str = self.context.device_manager.device_str()

Expand All @@ -57,6 +55,7 @@ def set_up_artifacts(self):
num_aie_columns=8,
tile_size=1,
)
self.gemv_1 = gemv_1
gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts(
prefix="swiglu_decode_gemv_1_"
)
Expand All @@ -75,6 +74,8 @@ def set_up_artifacts(self):
num_channels=2,
tile_size=self.hidden_dim // 16,
)
self.silu = silu
self.hidden_dim_padded = silu.size
silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_decode_silu_")
silu_xclbin.xclbin_input = gemv_1_xclbin
silu_xclbin.extra_flags += [
Expand All @@ -91,6 +92,8 @@ def set_up_artifacts(self):
num_channels=2,
tile_size=self.hidden_dim // 8,
)
self.eltwise_mul = eltwise_mul
assert self.hidden_dim <= eltwise_mul.size <= self.hidden_dim_padded
eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts(
prefix="swiglu_decode_eltwise_mul_"
)
Expand All @@ -109,6 +112,7 @@ def set_up_artifacts(self):
num_aie_columns=8,
tile_size=1,
)
self.gemv_2 = gemv_2
gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts(
prefix="swiglu_decode_gemv_2_"
)
Expand All @@ -135,28 +139,26 @@ def set_up_artifacts(self):
self.add_artifacts(artifacts)

def set_up_runtime(self):
# Runtime setup
# ---
self.add_buffer("input", self.embedding_dim)
self.add_buffer(
"weights_1",
self.embedding_dim * self.hidden_dim,
self.embedding_dim * self.hidden_dim_padded,
static_data=torch_to_numpy(self.weights_1),
)
self.add_buffer(
"weights_2",
self.embedding_dim * self.hidden_dim,
self.embedding_dim * self.hidden_dim_padded,
static_data=torch_to_numpy(self.weights_2),
)
self.add_buffer(
"weights_3",
self.hidden_dim * self.embedding_dim,
self.hidden_dim_padded * self.embedding_dim,
static_data=torch_to_numpy(self.weights_3),
)
self.add_buffer("left", self.hidden_dim)
self.add_buffer("left_swished", self.hidden_dim)
self.add_buffer("right", self.hidden_dim)
self.add_buffer("intermediate", self.hidden_dim)
self.add_buffer("left", self.hidden_dim_padded)
self.add_buffer("left_swished", self.hidden_dim_padded)
self.add_buffer("right", self.hidden_dim_padded)
self.add_buffer("intermediate", self.hidden_dim_padded)
self.add_buffer("output", self.embedding_dim)
self.add_kernel(
"swiglu_gemv_1",
Expand Down Expand Up @@ -191,9 +193,7 @@ def set_up_runtime(self):
self.add_to_runlist("swiglu_gemv_2", "weights_3", "intermediate", "output")

def forward(self, x):
# Turn into a numpy vector and drop the batch and other higher dimensions, if any; will error if batch or other higher dimensions > 1
x_flat = x.reshape(x.shape[-1])

assert x_flat.shape[0] == self.embedding_dim

self.write_buffer("input", x_flat)
Expand Down
Loading