From 4de8f1fbfc37c71ca8cbc0792c6e1620277b1ff3 Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Thu, 22 May 2025 18:47:04 -0400 Subject: [PATCH 01/12] Reference HERACLES Assembler tools Signed-off-by: Flavio Bergamaschi --- assembler_tools/empty_file.txt | 0 .../hec-assembler-tools/.gitignore | 205 ++ .../.pre-commit-config.yaml | 17 + .../hec-assembler-tools/CODEOWNERS | 2 + assembler_tools/hec-assembler-tools/README.md | 139 + .../assembler/common/__init__.py | 13 + .../assembler/common/config.py | 30 + .../assembler/common/constants.py | 440 +++ .../assembler/common/counter.py | 101 + .../assembler/common/cycle_tracking.py | 293 ++ .../assembler/common/decorators.py | 28 + .../assembler/common/priority_queue.py | 333 ++ .../assembler/common/queue_dict.py | 141 + .../assembler/common/run_config.py | 126 + .../assembler/common/utilities.py | 21 + .../assembler/instructions/__init__.py | 28 + .../assembler/instructions/cinst/__init__.py | 18 + .../assembler/instructions/cinst/bload.py | 181 ++ .../assembler/instructions/cinst/bones.py | 164 + .../assembler/instructions/cinst/cexit.py | 102 + .../instructions/cinst/cinstruction.py | 73 + .../assembler/instructions/cinst/cload.py | 165 + .../assembler/instructions/cinst/cnop.py | 105 + .../assembler/instructions/cinst/cstore.py | 195 ++ .../assembler/instructions/cinst/csyncm.py | 143 + .../assembler/instructions/cinst/ifetch.py | 141 + .../assembler/instructions/cinst/kgload.py | 204 ++ .../assembler/instructions/cinst/kgseed.py | 190 ++ .../assembler/instructions/cinst/kgstart.py | 112 + .../assembler/instructions/cinst/nload.py | 185 ++ .../instructions/cinst/xinstfetch.py | 155 + .../assembler/instructions/instruction.py | 687 ++++ .../assembler/instructions/minst/__init__.py | 8 + .../instructions/minst/minstruction.py | 77 + .../assembler/instructions/minst/mload.py | 209 ++ .../assembler/instructions/minst/mstore.py | 230 ++ .../assembler/instructions/minst/msyncc.py | 145 + .../assembler/instructions/xinst/__init__.py | 140 + .../assembler/instructions/xinst/add.py | 225 ++ .../assembler/instructions/xinst/copy.py | 244 ++ .../assembler/instructions/xinst/exit.py | 116 + .../assembler/instructions/xinst/intt.py | 249 ++ .../assembler/instructions/xinst/irshuffle.py | 401 +++ .../assembler/instructions/xinst/mac.py | 251 ++ .../assembler/instructions/xinst/maci.py | 267 ++ .../assembler/instructions/xinst/move.py | 222 ++ .../assembler/instructions/xinst/mul.py | 225 ++ .../assembler/instructions/xinst/muli.py | 246 ++ .../assembler/instructions/xinst/nop.py | 112 + .../assembler/instructions/xinst/ntt.py | 243 ++ .../instructions/xinst/parse_xntt.py | 267 ++ .../assembler/instructions/xinst/rshuffle.py | 375 +++ .../assembler/instructions/xinst/sub.py | 232 ++ .../assembler/instructions/xinst/twintt.py | 288 ++ .../assembler/instructions/xinst/twntt.py | 298 ++ .../instructions/xinst/xinstruction.py | 255 ++ .../assembler/instructions/xinst/xstore.py | 283 ++ .../assembler/isa_spec/__init__.py | 144 + .../assembler/isa_spec/cinst/__init__.py | 638 ++++ .../assembler/isa_spec/minst/__init__.py | 152 + .../assembler/isa_spec/xinst/__init__.py | 919 ++++++ .../assembler/memory_model/__init__.py | 455 +++ .../assembler/memory_model/hbm.py | 160 + .../assembler/memory_model/mem_info.py | 670 ++++ .../assembler/memory_model/mem_utilities.py | 134 + .../assembler/memory_model/memory_bank.py | 172 + .../assembler/memory_model/register_file.py | 393 +++ .../assembler/memory_model/spad.py | 322 ++ .../assembler/memory_model/variable.py | 431 +++ .../assembler/stages/__init__.py | 29 + .../assembler/stages/asm_scheduler.py | 2794 +++++++++++++++++ .../assembler/stages/preprocessor.py | 290 ++ .../assembler/stages/scheduler.py | 440 +++ .../hec-assembler-tools/atomic_tester.py | 261 ++ .../hec-assembler-tools/config/isa_spec.json | 225 ++ .../hec-assembler-tools/debug_tools/README.md | 108 + .../debug_tools/deadlock_test.py | 160 + .../debug_tools/isolation_test.py | 136 + .../hec-assembler-tools/debug_tools/main.py | 426 +++ .../debug_tools/order_test.py | 100 + .../xinst_timing_check/inject_bundles.py | 274 ++ .../xinst_timing_check/spec_config.py | 62 + .../xinst_timing_check/xinst/__init__.py | 24 + .../xinst_timing_check/xinst/add.py | 77 + .../xinst_timing_check/xinst/exit.py | 75 + .../xinst_timing_check/xinst/intt.py | 76 + .../xinst_timing_check/xinst/mac.py | 77 + .../xinst_timing_check/xinst/maci.py | 79 + .../xinst_timing_check/xinst/move.py | 79 + .../xinst_timing_check/xinst/mul.py | 79 + .../xinst_timing_check/xinst/muli.py | 79 + .../xinst_timing_check/xinst/nop.py | 81 + .../xinst_timing_check/xinst/ntt.py | 80 + .../xinst_timing_check/xinst/rshuffle.py | 140 + .../xinst_timing_check/xinst/sub.py | 79 + .../xinst_timing_check/xinst/twintt.py | 79 + .../xinst_timing_check/xinst/twntt.py | 79 + .../xinst_timing_check/xinst/xinstruction.py | 178 ++ .../xinst_timing_check/xinst/xstore.py | 83 + .../xinst_timing_check/xtiming_check.py | 371 +++ .../hec-assembler-tools/docsrc/changelog.md | 28 + .../docsrc/inst_spec/cinst/cinst_bload.md | 46 + .../docsrc/inst_spec/cinst/cinst_bones.md | 32 + .../docsrc/inst_spec/cinst/cinst_cexit.md | 25 + .../docsrc/inst_spec/cinst/cinst_cload.md | 39 + .../docsrc/inst_spec/cinst/cinst_cstore.md | 37 + .../docsrc/inst_spec/cinst/cinst_csyncm.md | 31 + .../docsrc/inst_spec/cinst/cinst_ifetch.md | 43 + .../docsrc/inst_spec/cinst/cinst_nload.md | 32 + .../docsrc/inst_spec/cinst/cinst_nop.md | 35 + .../inst_spec/cinst/cinst_xinstfetch.md | 38 + .../docsrc/inst_spec/minst/minst_mload.md | 28 + .../docsrc/inst_spec/minst/minst_mstore.md | 28 + .../docsrc/inst_spec/minst/minst_msyncc.md | 29 + .../docsrc/inst_spec/xinst/xinst_add.md | 30 + .../docsrc/inst_spec/xinst/xinst_exit.md | 27 + .../docsrc/inst_spec/xinst/xinst_intt.md | 37 + .../docsrc/inst_spec/xinst/xinst_mac.md | 31 + .../docsrc/inst_spec/xinst/xinst_maci.md | 31 + .../docsrc/inst_spec/xinst/xinst_move.md | 28 + .../docsrc/inst_spec/xinst/xinst_mul.md | 30 + .../docsrc/inst_spec/xinst/xinst_muli.md | 30 + .../docsrc/inst_spec/xinst/xinst_nop.md | 33 + .../docsrc/inst_spec/xinst/xinst_ntt.md | 37 + .../docsrc/inst_spec/xinst/xinst_rshuffle.md | 76 + .../docsrc/inst_spec/xinst/xinst_sub.md | 30 + .../docsrc/inst_spec/xinst/xinst_twintt.md | 35 + .../docsrc/inst_spec/xinst/xinst_twntt.md | 35 + .../docsrc/inst_spec/xinst/xinst_xstore.md | 36 + .../hec-assembler-tools/docsrc/specs.md | 278 ++ .../hec-assembler-tools/gen_he_ops.py | 264 ++ assembler_tools/hec-assembler-tools/he_as.py | 411 +++ .../hec-assembler-tools/he_link.py | 378 +++ .../hec-assembler-tools/he_prep.py | 155 + .../hec-assembler-tools/linker/__init__.py | 246 ++ .../linker/instructions/__init__.py | 25 + .../linker/instructions/cinst/__init__.py | 40 + .../linker/instructions/cinst/bload.py | 64 + .../linker/instructions/cinst/bones.py | 63 + .../linker/instructions/cinst/cexit.py | 47 + .../linker/instructions/cinst/cinstruction.py | 41 + .../linker/instructions/cinst/cload.py | 63 + .../linker/instructions/cinst/cnop.py | 75 + .../linker/instructions/cinst/cstore.py | 63 + .../linker/instructions/cinst/csyncm.py | 76 + .../linker/instructions/cinst/ifetch.py | 75 + .../linker/instructions/cinst/kgload.py | 42 + .../linker/instructions/cinst/kgseed.py | 42 + .../linker/instructions/cinst/kgstart.py | 42 + .../linker/instructions/cinst/nload.py | 64 + .../linker/instructions/cinst/xinstfetch.py | 103 + .../linker/instructions/instruction.py | 175 ++ .../linker/instructions/minst/__init__.py | 19 + .../linker/instructions/minst/minstruction.py | 42 + .../linker/instructions/minst/mload.py | 72 + .../linker/instructions/minst/mstore.py | 72 + .../linker/instructions/minst/msyncc.py | 76 + .../linker/instructions/xinst/__init__.py | 46 + .../linker/instructions/xinst/add.py | 48 + .../linker/instructions/xinst/exit.py | 47 + .../linker/instructions/xinst/intt.py | 47 + .../linker/instructions/xinst/mac.py | 47 + .../linker/instructions/xinst/maci.py | 47 + .../linker/instructions/xinst/move.py | 47 + .../linker/instructions/xinst/mul.py | 47 + .../linker/instructions/xinst/muli.py | 47 + .../linker/instructions/xinst/nop.py | 47 + .../linker/instructions/xinst/ntt.py | 46 + .../linker/instructions/xinst/rshuffle.py | 42 + .../linker/instructions/xinst/sub.py | 47 + .../linker/instructions/xinst/twintt.py | 47 + .../linker/instructions/xinst/twntt.py | 47 + .../linker/instructions/xinst/xinstruction.py | 64 + .../linker/instructions/xinst/xstore.py | 47 + .../hec-assembler-tools/linker/loader.py | 124 + .../linker/steps/__init__.py | 1 + .../linker/steps/program_linker.py | 352 +++ .../linker/steps/variable_discovery.py | 65 + .../hec-assembler-tools/requirements.txt | 14 + 179 files changed, 27531 insertions(+) delete mode 100644 assembler_tools/empty_file.txt create mode 100644 assembler_tools/hec-assembler-tools/.gitignore create mode 100644 assembler_tools/hec-assembler-tools/.pre-commit-config.yaml create mode 100644 assembler_tools/hec-assembler-tools/CODEOWNERS create mode 100644 assembler_tools/hec-assembler-tools/README.md create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/config.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/constants.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/counter.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/decorators.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/run_config.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/common/utilities.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/isa_spec/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/isa_spec/cinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/isa_spec/minst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/isa_spec/xinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/stages/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py create mode 100644 assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py create mode 100644 assembler_tools/hec-assembler-tools/atomic_tester.py create mode 100644 assembler_tools/hec-assembler-tools/config/isa_spec.json create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/README.md create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/main.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/order_test.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py create mode 100644 assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py create mode 100644 assembler_tools/hec-assembler-tools/docsrc/changelog.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bload.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bones.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cexit.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cload.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cstore.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_csyncm.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_ifetch.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nload.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_xinstfetch.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mload.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mstore.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_msyncc.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_add.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_exit.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mac.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_maci.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_move.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mul.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_muli.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_nop.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_sub.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twintt.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twntt.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_xstore.md create mode 100644 assembler_tools/hec-assembler-tools/docsrc/specs.md create mode 100644 assembler_tools/hec-assembler-tools/gen_he_ops.py create mode 100644 assembler_tools/hec-assembler-tools/he_as.py create mode 100644 assembler_tools/hec-assembler-tools/he_link.py create mode 100644 assembler_tools/hec-assembler-tools/he_prep.py create mode 100644 assembler_tools/hec-assembler-tools/linker/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/bload.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/bones.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/cexit.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/cload.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/cnop.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/cstore.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgload.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgseed.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgstart.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/nload.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/instruction.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/minst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/minst/mload.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/minst/mstore.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/add.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/exit.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/mac.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/move.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/mul.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/muli.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/nop.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/rshuffle.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/sub.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py create mode 100644 assembler_tools/hec-assembler-tools/linker/instructions/xinst/xstore.py create mode 100644 assembler_tools/hec-assembler-tools/linker/loader.py create mode 100644 assembler_tools/hec-assembler-tools/linker/steps/__init__.py create mode 100644 assembler_tools/hec-assembler-tools/linker/steps/program_linker.py create mode 100644 assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py create mode 100644 assembler_tools/hec-assembler-tools/requirements.txt diff --git a/assembler_tools/empty_file.txt b/assembler_tools/empty_file.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/assembler_tools/hec-assembler-tools/.gitignore b/assembler_tools/hec-assembler-tools/.gitignore new file mode 100644 index 00000000..5daf1d05 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/.gitignore @@ -0,0 +1,205 @@ +#======================== +# Intermediate and Output files +#======================== + +*.tmp +*.temp +*.mem +*.minst +*.cinst +*.xinst +*.csv +*.out +#*# +*~ +tmp/ + + +# Local files +#======================== + +*.yml +*.pyc +*.bak +*.pkl +*.lock +*.swp +tfedlrn.egg-info/ +bin/out + +#======================== +# Generated docs +#======================== + +*.htm +*.html +*.pdf +html +latex +[Dd]ocs/ + + +#======================== +# Eclipse & PyDev intermediate files +#======================== + +.metadata/ +RemoteSystemsTempFiles/ +.settings +.project +.pydevproject + +#======================== +# Visual Studio +#======================== + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Visual Studio cache/options directory +.vs/ +.vscode/ + +#======================== +# Python general stuff +#======================== + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/assembler_tools/hec-assembler-tools/.pre-commit-config.yaml b/assembler_tools/hec-assembler-tools/.pre-commit-config.yaml new file mode 100644 index 00000000..6e479482 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +# Copyright (C) 2023 Intel Corporation + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 # Updated 2023/02 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-merge-conflict + - id: mixed-line-ending + - id: check-yaml + - repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.1.15 + hooks: + - id: remove-tabs + files: \.(py)$ + args: [--whitespaces-count, '4'] diff --git a/assembler_tools/hec-assembler-tools/CODEOWNERS b/assembler_tools/hec-assembler-tools/CODEOWNERS new file mode 100644 index 00000000..ffa15bdf --- /dev/null +++ b/assembler_tools/hec-assembler-tools/CODEOWNERS @@ -0,0 +1,2 @@ +# Default codeowners for all files +* @faberga @ChrisWilkerson @sidezrw @jlhcrawford @hamishun @kylanerace @jobottle diff --git a/assembler_tools/hec-assembler-tools/README.md b/assembler_tools/hec-assembler-tools/README.md new file mode 100644 index 00000000..90d45b38 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/README.md @@ -0,0 +1,139 @@ +# HERACLES Code Generation Framework User Guide + +This tool, also known as the "assembler", takes a pre-generated Polynomial Instruction Set Architecture (P-ISA) kernel containing instructions that use an abstract, flat memory model for polynomial operations, such as those applied in homomorphic encryption (HE), and maps them to a corresponding set of instructions compatible with the HERACLES architecture, accounting for hardware restrictions, including memory management for the HERACLES memory model. + +## Table of Contents +1. [Dependencies](#dependencies) +2. [Inputs](#as_inputs) +3. [Outputs](#as_outputs) + 1. [Assembler Instruction Specs](#asm_specs) +4. [Executing the Assembler](#executing_asm) + 1. [Running for a Pre-Generated Kernel](#executing_single) + 2. [Running for a Batch of Operations](#executing_batch) +5. [Debug Tools](./debug_tools/README.md) + +## Dependencies + +This project is Python based. Dependencies for the project are as follows: + +- Python 3.10 or newer (Tested with 3.10) +- Python pip 22.0 or newer +- Requirements for the Python environment listed in file `requirements.txt`. + + Install with `pip`: + + ```bash + pip3 install -r requirements.txt + ``` + +It is recommended to install Python dependencies and run inside a [Python virtual environment](https://virtualenv.pypa.io/en/stable/index.html) to avoid polluting the global Python. + +## Inputs + +The assembler framework requires two inputs: + +- **Abstract P-ISA kernel**: a kernel of instructions in abstract P-ISA where the memory model is abstracted with a flat, infinite structure. + +- **Memory mapping metadata**: metadata information indicating the location where variable identifiers from the kernel are stored in the HERACLES HBM. + +Kernels and metadata are structured in comma-separated value (csv) files. + +P-ISA kernels, along with corresponding memory metadata required as input to the assembler, are generated by Python script `HERACLES-SEAL-isa-mapping/kernels/run_he_op.py` in the repo [HERACLES-SEAL-isa-mapping](https://github.com/IntelLabs/HERACLES-SEAL-isa-mapping) + +## Outputs + +On a successful run, given a P-ISA kernel in file `filename.csv` (and corresponding memory metadata file), the assmebler generates three files: + +- `filename.minst`: contains the list of instructions for the MINST queue. +- `filename.cinst`: contains the list of instructions for the CINST queue. +- `filename.xinst`: contains the list of instructions for the XINST queue. + +### Assembler Output Instruction Specs + +The format for the output files and instruction set can be found at [HCGF Instruction Specification](docsrc/specs.md). + +## Executing the Assembler + +There are two ways to execute the assembler: + +- [Running on a pre-generated kernel](#executing_single): uses the main interface of the assembler to assemble a single pre-existing kernel. + + This method is intended for a production chain. + +or + +- [Running for a batch of kernels](#executing_batch): uses a provided script wrapper to generate a collection of kernels and runs them through the assembler. + + This method is intended for testing purposes as it generates test kernels using external tools before assembling. + +### Running for a Pre-Generated Kernel + +Given a P-ISA kernel (`filename.csv`) and corresponding memory mapping file (`filename.mem`), there are three steps to assemble them into HERACLES code. + +1. Pre-process the P-ISA input kernel using `he_prep.py`. + +```bash +# pre-process kernel: outputs filename.tw.csv +python3 he_prep.py filename.csv +``` + +2. Assemble the pre-processed result using `he_as.py`. + +```bash +# use output from pre-processing as input to asm: +# outputs filename.tw.minst, filename.tw.cinst, filename.tw.xinst +python3 he_as.py filename.tw.csv --input_mem_file filename.mem +``` + +3. Link the assembler output into a HERACLES program using `he_link.py`. + +```bash +# link assembled output (input prefix: filename.tw) +# outputs filename.minst, filename.cinst, filename.xinst +python3 he_link.py filename.tw --input_mem_file filename.mem --output_prefix filename +``` + +This will generate the main three output files in the same directory as the input file: + +- `filename.minst`: contains the list of instructions for the MINST queue. +- `filename.cinst`: contains the list of instructions for the CINST queue. +- `filename.xinst`: contains the list of instructions for the XINST queue. + +Intermediate files, if any, are kept as well. + +The linker program is able to link several assembled kernels into a single HERACLES program, given a correct memory mapping for the resulting program. + +This version of executing is intended for the assembler to be usable as part of a compilation pipeline. + +Use commands below for more configuration information, including changing output directories, input and output filenames, etc. + +```bash +python3 he_prep.py -h +python3 he_as.py -h +python3 he_link.py -h +``` + +### Running for a Batch of Operations + +This project provides script `gen_he_ops.py` that allows for assembling a batch of P-ISA kernels generated for HE operations. It calls the generator script internally to generate a batch of kernels, and then runs them through the assembler. + +Since the script to generate P-ISA kernels resides in another repo (HERACLES-SEAL-isa-mapping), we must specify the location of the cloned external repo using the environment variable `HERACLES_MAPPING_PATH`. Correctly setting this variable should result in the following path being valid: `$HERACLES_MAPPING_PATH/kernels/run_he_ops.py` . + +Provided script, `gen_he_ops.py`, takes in a YAML configuration file that specifies parameters for operations to assemble. To obtain a template for the configuration file, use the script itself with the `--dump` command line flag. Use `-h` flag for more information. + +```bash +# save template for configuration file to ./config.yaml +python3 gen_he_ops.py config.yaml --dump +``` + +Set the parameters in the configuration file to match your needs and then execute the script as shown below (code for Linux terminal). + +```bash +# env variable pointing to HERACLES-SEAL-isa-mapping +export HERACLES_MAPPING_PATH=/path/to/HERACLES-SEAL-isa-mapping +python3 gen_he_ops.py config.yaml +``` + +Based on your chosen configuration, this will generate kernels, run them through the assembler and place all the outputs (and intermediate files) in the output directory specified in the configuration file. + +This way of executing is mostly intended for testing purposes. diff --git a/assembler_tools/hec-assembler-tools/assembler/common/__init__.py b/assembler_tools/hec-assembler-tools/assembler/common/__init__.py new file mode 100644 index 00000000..535ed76a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/__init__.py @@ -0,0 +1,13 @@ +import os + +def makeUniquePath(path: str) -> str: + """ + Returns a unique, normalized, and absolute version of the given file path. + + Args: + path (str): The file path to be processed. + + Returns: + str: A unique, normalized, and absolute version of the input path. + """ + return os.path.normcase(os.path.realpath(os.path.expanduser(path))) diff --git a/assembler_tools/hec-assembler-tools/assembler/common/config.py b/assembler_tools/hec-assembler-tools/assembler/common/config.py new file mode 100644 index 00000000..4e31682e --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/config.py @@ -0,0 +1,30 @@ + +class GlobalConfig: + """ + A configuration class for controlling various aspects of the assembler's behavior. + + Attributes: + suppressComments (bool): + If True, no comments will be emitted in the output generated by the assembler. + + useHBMPlaceHolders (bool): + Specifies whether to use placeholders (names) for variable locations in HBM + or the actual variable locations. + + useXInstFetch (bool): + Specifies whether `xinstfetch` instructions should be added into CInstQ or not. + When no `xinstfetch` instructions are added, it is assumed that the HERACLES + automated mechanism for `xinstfetch` will be activated. + + debugVerbose (int): + If greater than 0, verbose prints will occur. Its value indicates how often to + print within loops (every `debugVerbose` iterations). This is used for internal + debugging purposes. + hashHBM (bool): Specifies whether the target architecture has HBM or not. + """ + + suppressComments = False + useHBMPlaceHolders = True + useXInstFetch = True + debugVerbose: int = 0 + hasHBM = True diff --git a/assembler_tools/hec-assembler-tools/assembler/common/constants.py b/assembler_tools/hec-assembler-tools/assembler/common/constants.py new file mode 100644 index 00000000..9c11090c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/constants.py @@ -0,0 +1,440 @@ + +from .decorators import * + +class Constants: + """ + Contains project level and global constants that won't fit logically into any other category. + + Attributes: + KILOBYTE (int): Number of bytes in a kilobyte (2^10). + MEGABYTE (int): Number of bytes in a megabyte (2^20). + GIGABYTE (int): Number of bytes in a gigabyte (2^30). + WORD_SIZE (int): Word size in bytes (32KB/word). + + REPLACEMENT_POLICY_FTBU (str): Identifier for the "furthest used" replacement policy. + REPLACEMENT_POLICY_LRU (str): Identifier for the "least recently used" replacement policy. + REPLACEMENT_POLICIES (tuple): Tuple containing all replacement policy identifiers. + + XINSTRUCTION_SIZE_BYTES (int): Size of an x-instruction in bytes. + MAX_BUNDLE_SIZE (int): Maximum number of instructions in a bundle. + MAX_BUNDLE_SIZE_BYTES (int): Maximum bundle size in bytes. + + TW_GRAMMAR_SEPARATOR (str): Separator for twiddle arguments used in grammar parsing. + OPERATIONS (list): List of high-level operations supported by the system. + """ + + # Data Constants + # -------------- + + @classproperty + def KILOBYTE(cls) -> int: + """Number of bytes in a kilobyte (2^10).""" + return 2**10 + + @classproperty + def MEGABYTE(csl) -> int: + """Number of bytes in a megabyte (2^20).""" + return 2**20 + + @classproperty + def GIGABYTE(cls) -> int: + """Number of bytes in a gigabyte (2^30).""" + return 2**30 + + @classproperty + def WORD_SIZE(cls) -> int: + """Word size in bytes (32KB/word).""" + return 32 * cls.KILOBYTE + + # Replacement Policies Constants + # ------------------------------ + + @classproperty + def REPLACEMENT_POLICY_FTBU(cls) -> str: + """Identifier for the "furthest used" replacement policy.""" + return "ftbu" + + @classproperty + def REPLACEMENT_POLICY_LRU(cls) -> str: + """Identifier for the "least recently used" replacement policy.""" + return "lru" + + @classproperty + def REPLACEMENT_POLICIES(cls) -> tuple: + """Tuple containing all replacement policy identifiers.""" + return ( cls.REPLACEMENT_POLICY_FTBU, cls.REPLACEMENT_POLICY_LRU ) + + # Misc Constants + # -------------- + + @classproperty + def XINSTRUCTION_SIZE_BYTES(cls) -> int: + """Size of an x-instruction in bytes.""" + return 8 + + @classproperty + def MAX_BUNDLE_SIZE(cls) -> int: + """Maximum number of instructions in a bundle.""" + return 64 + + @classproperty + def MAX_BUNDLE_SIZE_BYTES(cls) -> int: + """Maximum bundle size in bytes.""" + return cls.XINSTRUCTION_SIZE_BYTES * cls.MAX_BUNDLE_SIZE + + @classproperty + def TW_GRAMMAR_SEPARATOR(cls) -> str: + """ + Separator for twiddle arguments. + + Used in the grammar to parse the twiddle argument of an xntt kernel operation. + """ + return "_" + + @classproperty + def OPERATIONS(cls) -> list: + """List of high-level operations supported by the system.""" + return [ "add", "mul", "ntt", "intt", "relin", "mod_switch", "rotate", + "square", "add_plain", "add_corrected", "mul_plain", "rescale", + "boot_dot_prod", "boot_mod_drop_scale", "boot_mul_const", "boot_galois_plain" ] + +def convertBytes2Words(bytes: int) -> int: + """ + Converts a size in bytes to the equivalent number of words. + + Args: + bytes (int): The size in bytes to be converted. + + Returns: + int: The equivalent size in words. + """ + return int(bytes / Constants.WORD_SIZE) + +def convertWords2Bytes(words: int) -> int: + """ + Converts a size in words to the equivalent number of bytes. + + Args: + words (int): The size in words to be converted. + + Returns: + int: The equivalent size in bytes. + """ + return words * Constants.WORD_SIZE + +class MemInfo: + """ + Constants related to memory information, read from the P-ISA kernel memory file. + + This class provides a structured way to access various constants and keywords + used in the P-ISA kernel memory file, including keywords for loading and storing + data, metadata fields, and metadata targets. + """ + + class Keyword: + """ + Keywords for loading memory information from the P-ISA kernel memory file. + + These keywords are used to identify different operations and data types + within the memory file. + """ + @classproperty + def KEYGEN(cls): + """Keyword for key generation.""" + return "keygen" + + @classproperty + def LOAD(cls): + """Keyword for data load operation.""" + return "dload" + + @classproperty + def LOAD_INPUT(cls): + """Keyword for loading input polynomial.""" + return "poly" + + @classproperty + def LOAD_KEYGEN_SEED(cls): + """Keyword for loading key generation seed.""" + return "keygen_seed" + + @classproperty + def LOAD_ONES(cls): + """Keyword for loading ones.""" + return "ones" + + @classproperty + def LOAD_NTT_AUX_TABLE(cls): + """Keyword for loading NTT auxiliary table.""" + return "ntt_auxiliary_table" + + @classproperty + def LOAD_NTT_ROUTING_TABLE(cls): + """Keyword for loading NTT routing table.""" + return "ntt_routing_table" + + @classproperty + def LOAD_iNTT_AUX_TABLE(cls): + """Keyword for loading iNTT auxiliary table.""" + return "intt_auxiliary_table" + + @classproperty + def LOAD_iNTT_ROUTING_TABLE(cls): + """Keyword for loading iNTT routing table.""" + return "intt_routing_table" + + @classproperty + def LOAD_TWIDDLE(cls): + """Keyword for loading twiddle factors.""" + return "twid" + + @classproperty + def STORE(cls): + """Keyword for data store operation.""" + return "dstore" + + class MetaFields: + """ + Names of different metadata fields. + """ + @classproperty + def FIELD_KEYGEN_SEED(cls): + return MemInfo.Keyword.LOAD_KEYGEN_SEED + + @classproperty + def FIELD_ONES(cls): + return MemInfo.Keyword.LOAD_ONES + + @classproperty + def FIELD_NTT_AUX_TABLE(cls): + return MemInfo.Keyword.LOAD_NTT_AUX_TABLE + + @classproperty + def FIELD_NTT_ROUTING_TABLE(cls): + return MemInfo.Keyword.LOAD_NTT_ROUTING_TABLE + + @classproperty + def FIELD_iNTT_AUX_TABLE(cls): + return MemInfo.Keyword.LOAD_iNTT_AUX_TABLE + + @classproperty + def FIELD_iNTT_ROUTING_TABLE(cls): + return MemInfo.Keyword.LOAD_iNTT_ROUTING_TABLE + + @classproperty + def FIELD_TWIDDLE(cls): + return MemInfo.Keyword.LOAD_TWIDDLE + + @classproperty + def FIELD_KEYGENS(cls): + return "keygens" + + @classproperty + def FIELD_INPUTS(cls): + return "inputs" + + @classproperty + def FIELD_OUTPUTS(cls): + return "outputs" + + @classproperty + def FIELD_METADATA(cls): + return "metadata" + + @classproperty + def FIELD_METADATA_SUBFIELDS(cls): + """Tuple of subfield names for metadata.""" + return ( cls.MetaFields.FIELD_KEYGEN_SEED, + cls.MetaFields.FIELD_TWIDDLE, + cls.MetaFields.FIELD_ONES, + cls.MetaFields.FIELD_NTT_AUX_TABLE, + cls.MetaFields.FIELD_NTT_ROUTING_TABLE, + cls.MetaFields.FIELD_iNTT_AUX_TABLE, + cls.MetaFields.FIELD_iNTT_ROUTING_TABLE ) + + class MetaTargets: + """ + Targets for different metadata. + """ + @classproperty + def TARGET_ONES(cls): + """Special target register for Ones.""" + return 0 + + @classproperty + def TARGET_NTT_AUX_TABLE(cls): + """Special target register for rshuffle NTT auxiliary table.""" + return 0 + + @classproperty + def TARGET_NTT_ROUTING_TABLE(cls): + """Special target register for rshuffle NTT routing table.""" + return 1 + + @classproperty + def TARGET_iNTT_AUX_TABLE(cls): + """Special target register for rshuffle iNTT auxiliary table.""" + return 2 + + @classproperty + def TARGET_iNTT_ROUTING_TABLE(cls): + """Special target register for rshuffle iNTT routing table.""" + return 3 + +class MemoryModel: + """ + Constants related to memory model. + + This class defines a hierarchical structure for different parts of the memory model, + including queue capacities, metadata registers, and specific memory components like + HBM and SPAD. + """ + + __XINST_QUEUE_MAX_CAPACITY = 1 * Constants.MEGABYTE + __XINST_QUEUE_MAX_CAPACITY_WORDS = convertBytes2Words(__XINST_QUEUE_MAX_CAPACITY) + __CINST_QUEUE_MAX_CAPACITY = 128 * Constants.KILOBYTE + __CINST_QUEUE_MAX_CAPACITY_WORDS = convertBytes2Words(__CINST_QUEUE_MAX_CAPACITY) + __MINST_QUEUE_MAX_CAPACITY = 128 * Constants.KILOBYTE + __MINST_QUEUE_MAX_CAPACITY_WORDS = convertBytes2Words(__MINST_QUEUE_MAX_CAPACITY) + __STORE_BUFFER_MAX_CAPACITY = 128 * Constants.KILOBYTE + __STORE_BUFFER_MAX_CAPACITY_WORDS = convertBytes2Words(__STORE_BUFFER_MAX_CAPACITY) + + @classproperty + def XINST_QUEUE_MAX_CAPACITY(cls): + """Maximum capacity of the XINST queue in bytes.""" + return cls.__XINST_QUEUE_MAX_CAPACITY + @classproperty + def XINST_QUEUE_MAX_CAPACITY_WORDS(cls): + """Maximum capacity of the XINST queue in words.""" + return cls.__XINST_QUEUE_MAX_CAPACITY_WORDS + @classproperty + def CINST_QUEUE_MAX_CAPACITY(cls): + """Maximum capacity of the CINST queue in bytes.""" + return cls.__CINST_QUEUE_MAX_CAPACITY + @classproperty + def CINST_QUEUE_MAX_CAPACITY_WORDS(cls): + """Maximum capacity of the CINST queue in words.""" + return cls.__CINST_QUEUE_MAX_CAPACITY_WORDS + @classproperty + def MINST_QUEUE_MAX_CAPACITY(cls): + """Maximum capacity of the MINST queue in bytes.""" + return cls.__MINST_QUEUE_MAX_CAPACITY + @classproperty + def MINST_QUEUE_MAX_CAPACITY_WORDS(cls): + """Maximum capacity of the MINST queue in words.""" + return cls.__MINST_QUEUE_MAX_CAPACITY_WORDS + @classproperty + def STORE_BUFFER_MAX_CAPACITY(cls): + """Maximum capacity of the store buffer in bytes.""" + return cls.__STORE_BUFFER_MAX_CAPACITY + @classproperty + def STORE_BUFFER_MAX_CAPACITY_WORDS(cls): + """Maximum capacity of the store buffer in words.""" + return cls.__STORE_BUFFER_MAX_CAPACITY_WORDS + + @classproperty + def NUM_BLOCKS_PER_TWID_META_WORD(cls) -> int: + """Number of blocks per twiddle metadata word.""" + return 4 + + @classproperty + def NUM_BLOCKS_PER_KGSEED_META_WORD(cls) -> int: + """Number of blocks per key generation seed metadata word.""" + return 4 + + @classproperty + def NUM_ROUTING_TABLE_REGISTERS(cls) -> int: + """ + Number of routing table registers. + + This affects how many rshuffle of different types can be performed + at the same time, since rshuffle instructions will pick a routing table + to use to compute the shuffled result. + """ + return 1 + + @classproperty + def NUM_ONES_META_REGISTERS(cls) -> int: + """ + Number of registers to hold identity metadata. + + This directly affects the maximum number of residuals that can be + processed in the CE without needing to load new metadata. + """ + return 1 + + @classproperty + def NUM_TWIDDLE_META_REGISTERS(cls) -> int: + """ + Number of registers to hold twiddle factor metadata. + + This directly affects the maximum number of residuals that can be + processed in the CE without needing to load new metadata. + """ + return 32 * cls.NUM_ONES_META_REGISTERS + + @classproperty + def TWIDDLE_META_REGISTER_SIZE_BYTES(cls) -> int: + """ + Size, in bytes, of a twiddle factor metadata register. + """ + return 8 * Constants.KILOBYTE + + @classproperty + def MAX_RESIDUALS(cls) -> int: + """ + Maximum number of residuals that can be processed in the CE without + needing to load new metadata. + """ + return cls.NUM_TWIDDLE_META_REGISTERS * 2 + + @classproperty + def NUM_REGISTER_BANKS(cls) -> int: + """Number of register banks in the CE""" + return 4 + + @classproperty + def NUM_REGISTER_PER_BANKS(cls) -> int: + """Number of register per register banks in the CE""" + return 72 + + class HBM: + """ + Constants related to High Bandwidth Memory (HBM). + + This class defines the maximum capacity of HBM in both bytes and words. + """ + __MAX_CAPACITY = 64 * Constants.GIGABYTE + __MAX_CAPACITY_WORDS = convertBytes2Words(__MAX_CAPACITY) + + @classproperty + def MAX_CAPACITY(cls) -> int: + """Total capacity of HBM in Bytes""" + return cls.__MAX_CAPACITY + + @classproperty + def MAX_CAPACITY_WORDS(cls) -> int: + """Total capacity of HBM in Words""" + return cls.__MAX_CAPACITY_WORDS + + class SPAD: + """ + Constants related to Scratchpad Memory (SPAD). + + This class defines the maximum capacity of SPAD in both bytes and words. + """ + __MAX_CAPACITY = 64 * Constants.MEGABYTE + __MAX_CAPACITY_WORDS = convertBytes2Words(__MAX_CAPACITY) + + # Class methods and properties + # ---------------------------- + + @classproperty + def MAX_CAPACITY(cls) -> int: + """Total capacity of SPAD in Bytes""" + return cls.__MAX_CAPACITY + + @classproperty + def MAX_CAPACITY_WORDS(cls) -> int: + """Total capacity of SPAD in Words""" + return cls.__MAX_CAPACITY_WORDS diff --git a/assembler_tools/hec-assembler-tools/assembler/common/counter.py b/assembler_tools/hec-assembler-tools/assembler/common/counter.py new file mode 100644 index 00000000..65ca61a7 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/counter.py @@ -0,0 +1,101 @@ + +import itertools + +class Counter: + """ + Provides counters that can be globally reset. + + This class allows for the creation of counters that can be iterated over and reset + to their initial start values. It supports creating multiple counters with different + start and step values, and provides functionality to reset individual counters or all + counters at once. + """ + + class CounterIter: + """ + An iterator for generating evenly spaced values. + + This iterator starts at a specified value and increments by a specified step. + It can be reset to start over from its initial start value. + """ + def __init__(self, start = 0, step = 1): + """ + Initializes a new CounterIter object. + + Args: + start (int, optional): The starting value of the counter. Defaults to 0. + step (int, optional): The step value for the counter. Defaults to 1. + """ + self.__start = start + self.__step = step + self.__counter = None # itertools.counter + self.reset() + + def __next__(self): + """ + Returns the next value in the counter sequence. + + Returns: + int: The next value. + """ + return next(self.__counter) + + @property + def start(self) -> int: + """ + Gets the start value for this counter. + + Returns: + int: The start value. + """ + return self.__start + + @property + def step(self) -> int: + """ + Gets the step value for this counter. + + Returns: + int: The step value. + """ + return self.__step + + def reset(self): + """ + Resets this counter to start from its `start` value. + """ + self.__counter = itertools.count(self.start, self.step) + + __counters = set() + + @classmethod + def count(cls, start = 0, step = 1) -> CounterIter: + """ + Creates a new counter iterator that returns evenly spaced values. + + Args: + start (int, optional): The starting value of the counter. Defaults to 0. + step (int, optional): The step value for the counter. Defaults to 1. + + Returns: + CounterIter: An iterator that generates evenly spaced values starting from `start`. + """ + retval = cls.CounterIter(start, step) + cls.__counters.add(retval) + return retval + + @classmethod + def reset(cls, counter: CounterIter = None): + """ + Reset the specified counter, or all counters if none is specified. + + This method resets the specified counter, or all counters, to start + over from their respective `start` values. + + Args: + counter (CounterIter, optional): The counter to reset. + If None, all counters are reset. + """ + counters_to_reset = cls.__counters if counter is None else { counter } + for c in counters_to_reset: + c.reset() diff --git a/assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py b/assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py new file mode 100644 index 00000000..4e91213c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/cycle_tracking.py @@ -0,0 +1,293 @@ +import numbers +from typing import NamedTuple + +class PrioritizedPlaceholder: + """ + Base class for priority queue items. + + This class provides a framework for items that can be used in a priority queue, + where each item has a priority that can be dynamically adjusted. Derived classes + can override the `_get_priority()` method to provide their own logic for determining + priority, allowing for priorities that change on the fly. + + Priorities are expected to be tuples, and the class supports comparison operations + based on these priorities. + + Properties: + priority (tuple): The current priority of the item, calculated as the sum of + the base priority and the priority delta. + priority_delta (tuple): The current priority delta. + + Methods: + _get_priority(): Returns the base priority of the item. + _get_priority_delta(): Returns the priority delta of the item. + """ + def __init__(self, + priority = (0, 0), + priority_delta = (0, 0)): + """ + Initializes a new PrioritizedPlaceholder object. + + Args: + priority (tuple, optional): The base priority of the item. Defaults to (0, 0). + priority_delta (tuple, optional): The delta to be applied to the base priority. Defaults to (0, 0). + """ + self._priority = priority + self._priority_delta = priority_delta + + @property + def priority(self): + """ + Calculates and returns the current priority of the item. + + The current priority is the sum of the base priority and the priority delta. + + Returns: + tuple: The current priority. + """ + return tuple([sum(x) for x in zip(self._get_priority(), self.priority_delta)]) + + @property + def priority_delta(self): + """ + Returns the current priority delta. + + Returns: + tuple: The current delta. + """ + return self._get_priority_delta() + + def _get_priority(self): + """ + Returns the base priority of the item. + + Returns: + tuple: The base priority. + """ + return self._priority + + def _get_priority_delta(self): + """ + Returns the priority delta of the item. + + This method can be overridden by derived classes to provide custom priority delta logic. + + Returns: + tuple: The priority delta. + """ + return self._priority_delta + + def __lt__(self, other): + """ + Compares this item with another item for less-than ordering based on priority. + + Args: + other (PrioritizedPlaceholder): The other item to compare against. + + Returns: + bool: True if this item's priority is less than the other item's priority, False otherwise. + """ + return self.priority < other.priority + + def __eq__(self, other): + """ + Compares this item with another item for equality based on priority. + + Args: + other (PrioritizedPlaceholder): The other item to compare against. + + Returns: + bool: True if this item's priority is equal to the other item's priority, False otherwise. + """ + return self.priority == other.priority + + def __gt__(self, other): + """ + Compares this item with another item for greater-than ordering based on priority. + + Args: + other (PrioritizedPlaceholder): The other item to compare against. + + Returns: + bool: True if this item's priority is greater than the other item's priority, False otherwise. + """ + return self.priority > other.priority + +class CycleType(NamedTuple): + """ + Named tuple to add structure to a cycle type. + + CycleType is a structured representation of a cycle, consisting of a bundle + identifier and a cycle count within that bundle. It supports arithmetic operations + for adding and subtracting cycles or tuples. + + Attributes: + bundle (int): Bundle identifier or index. + cycle (int): Clock cycle inside the specified bundle. + + Operators: + __add__(self, other: Union[tuple, int]) -> CycleType: + Adds a tuple or an integer to the CycleType and returns the resulting CycleType. + If other is a tuple, only the first two elements are used for addition. + If other is an integer, it is added to the cycle component. + + __sub__(self, other: Union[tuple, int]) -> CycleType: + Subtracts a tuple or an integer from the CycleType and returns the resulting CycleType. + If other is a tuple, only the first two elements are used for subtraction. + If other is an integer, it is subtracted from the cycle component. + """ + + bundle: int + cycle: int + + def __add__(self, other): + """ + Adds a tuple or an integer to the `CycleType`. + + Args: + other (Union[tuple, int]): The value to add. Can be a tuple or an integer. + + Returns: + CycleType: The resulting `CycleType` after addition. + + Raises: + TypeError: If `other` is not a tuple or an integer. + """ + if isinstance(other, int): + return self.__binaryop_cycles(other, lambda m, n: m + n) + elif isinstance(other, tuple): + return self.__binaryop_tuple(other, lambda m, n: m + n) + else: + raise TypeError('`other`: expected type `int` or `tuple`.') + + def __sub__(self, other): + """ + Subtracts a tuple or an integer from the `CycleType`. + + Args: + other (Union[tuple, int]): The value to subtract. Can be a tuple or an integer. + + Returns: + CycleType: The resulting `CycleType` after subtraction. + + Raises: + TypeError: If `other` is not a tuple or an integer. + """ + if isinstance(other, int): + return self.__binaryop_cycles(other, lambda m, n: m - n) + elif isinstance(other, tuple): + return self.__binaryop_tuple(other, lambda m, n: m - n) + else: + raise TypeError('`other`: expected type `int` or `tuple`.') + + def __binaryop_cycles(self, cycles, binaryop_callable): + """ + Performs a binary operation on the cycle component with an integer. + + Args: + cycles (int): The integer to operate with. + binaryop_callable (callable): The binary operation to perform. + + Returns: + CycleType: The resulting `CycleType` after the operation. + """ + assert(isinstance(cycles, int)) + return CycleType(self.bundle, binaryop_callable(self.cycle, cycles)) + + def __binaryop_tuple(self, other, binaryop_callable): + """ + Performs a binary operation on the `CycleType` with a tuple. + + Args: + other (tuple): The tuple to operate with. + binaryop_callable (callable): The binary operation to perform. + + Returns: + CycleType: The resulting `CycleType` after the operation. + """ + return CycleType(binaryop_callable(self.bundle, int(other[0]) if len(other) > 0 else 0), + binaryop_callable(self.cycle, int(other[1]) if len(other) > 1 else 0)) + +class CycleTracker: + """ + Base class for tracking the clock cycle when an object is ready to be used. + + The cycle ready value is interpreted as a tuple (bundle: int, cycle: int). If the bundle + is not used, it can always be set to `0`. + + Attributes: + tag (Any): User-defined tag to hold any kind of extra information related to the object. + + Properties: + cycle_ready (CycleType): Clock cycle where this object is ready to use. Uses + :func:`~cycle_tracking.CycleTracker._get_cycle_ready` and + :func:`~cycle_tracking.CycleTracker._set_cycle_ready`. + + Methods: + _get_cycle_ready(): Returns the current value for the ready cycle. Derived classes can override this method + to add their own logic to compute this value. + + _set_cycle_ready(value): Sets the current value for the ready cycle (only if the specified value is greater than + the current `CycleTracker.cycle_ready`). Derived classes can override this method to add their own logic to compute this value. + """ + + def __init__(self, cycle_ready: CycleType): + """ + Initializes a new CycleTracker object. + + Args: + cycle_ready (CycleType): The initial cycle when the object is ready to be used. Must be a tuple with at least + two elements (bundle, cycle). + """ + assert(len(cycle_ready) > 1) + self.__cycle_ready = CycleType(*cycle_ready) + self.tag = 0 # User-defined tag + + @property + def cycle_ready(self): + """ + Gets the current cycle ready value. + + Returns: + CycleType: The value. + """ + return self._get_cycle_ready() + + @cycle_ready.setter + def cycle_ready(self, value: CycleType): + """ + Set a new cycle ready value. + + Args: + value (CycleType): The new cycle ready value to set. + """ + return self._set_cycle_ready(value) + + def _get_cycle_ready(self) -> CycleType: + """ + Return the current value for the ready cycle. + + This method is called by the `cycle_ready` property getter to retrieve the value. + Derived classes can override this method to add their own logic to compute this value. + + Returns: + CycleType: The current value for the ready cycle. + """ + return self.__cycle_ready + + def _set_cycle_ready(self, value: CycleType): + """ + Set the current value for the ready cycle, only if the specified value is greater than + the current `CycleTracker.cycle_ready`. + + This method is called by the `cycle_ready` property setter to set the new value. + Derived classes can override this method to add their own logic to compute this value. + + Args: + value (CycleType or tuple): New clock cycle when this object will be ready for use. + The tuple should be in the form (bundle: int, cycle: int). + """ + assert(len(value) > 1) + #if self.cycle_ready < value: + # self.__cycle_ready = CycleType(*value) + self.__cycle_ready = CycleType(*value) diff --git a/assembler_tools/hec-assembler-tools/assembler/common/decorators.py b/assembler_tools/hec-assembler-tools/assembler/common/decorators.py new file mode 100644 index 00000000..09beaa98 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/decorators.py @@ -0,0 +1,28 @@ + +class classproperty(object): + """ + A decorator that allows a method to be accessed as a class-level property + rather than on instances of the class. + """ + + def __init__(self, f): + """ + Initializes the classproperty with the given function. + + Args: + f (function): The function to be used as a class-level property. + """ + self.f = f + + def __get__(self, obj, owner): + """ + Retrieves the value of the class-level property. + + Args: + obj: The instance of the class (ignored in this context). + owner: The class that owns the property. + + Returns: + The result of calling the decorated function with the class as an argument. + """ + return self.f(owner) diff --git a/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py b/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py new file mode 100644 index 00000000..9d845d22 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py @@ -0,0 +1,333 @@ +import heapq +import bisect +import itertools + +class PriorityQueue: + """ + A priority queue implementation that supports task prioritization and ordering. + + This class allows tasks to be added with a specified priority, and supports + operations to update, remove, and retrieve tasks based on their priority. + """ + + class __PriorityQueueIter: + """ + An iterator for the PriorityQueue class. + + This iterator allows for iterating over the tasks in the priority queue + while ensuring that the queue's size does not change during iteration. + """ + def __init__(self, pq, removed): + """ + Initializes the iterator with the priority queue and removed marker. + + Args: + pq: The priority queue to iterate over. + removed: The marker for removed tasks. + """ + self.__pq = pq if pq else [] + self.__initial_len = len(self.__pq) + self.__removed = removed + self.__current = 0 + + def __next__(self): + """ + Returns the next task in the priority queue. + + Returns: + tuple: The (priority, task) pair of the next task. + + Raises: + RuntimeError: If the priority queue changes size during iteration. + StopIteration: If there are no more tasks to iterate over. + """ + if len(self.__pq) != self.__initial_len: + raise RuntimeError("PriorityQueue changed size during iteration.") + + # Skip all removed tasks + while self.__current < len(self.__pq) \ + and self.__pq[self.__current][-1] is self.__removed: + self.__current += 1 + if self.__current >= len(self.__pq): + raise StopIteration + priority, _, task = self.__pq[self.__current] + self.__current += 1 # point to nex element + return (priority, task) + + + class __PriorityTracker: + """ + A helper class to track tasks by their priority. + + This class maintains a mapping of priorities to tasks and supports + operations to add, find, and remove tasks based on their priority. + """ + def __init__(self): + """ + Initializes the priority tracker with empty mappings. + """ + self.__priority_dict = {} # dict(int, SortedList(task)): maps priority to unordered set of tasks with same priority + self.__priority_dict_set = {} # dict(int, set(task)): maps priority to unordered set of tasks with same priority + + def find(self, priority: int) -> object: + """ + Finds a task with the specified priority. + + Args: + priority (int): The priority to search for. + + Returns: + object: A task with the specified priority, or None if not found. + """ + return next(iter(self.__priority_dict[priority]))[1] if priority in self.__priority_dict else None + + def push(self, priority: int, tie_breaker: tuple, task: object): + """ + Adds a task with the specified priority and tie breaker. + + Args: + priority (int): The priority of the task. + tie_breaker (tuple): A tuple used to break ties between tasks with the same priority. + task (object): The task to add. + + Raises: + ValueError: If the task is None. + """ + if task is None: + raise ValueError('`task` cannot be `None`.') + + if priority not in self.__priority_dict: + self.__priority_dict[priority] = [] + assert priority not in self.__priority_dict_set + self.__priority_dict_set[priority] = set() + if task not in self.__priority_dict_set[priority]: + bisect.insort_right(self.__priority_dict[priority], (tie_breaker, task)) + self.__priority_dict_set[priority].add(task) + + def pop(self, priority: int, task = None) -> object: + """ + Removes a task with the specified priority. + + Args: + priority (int): The priority of the task to remove. + task (object, optional): The specific task to remove. If None, the first task is removed. + + Raises: + KeyError: If the priority is not found. + ValueError: If the specified task is not found in the priority. + + Returns: + object: The task that was removed. + """ + if priority not in self.__priority_dict: + raise KeyError(str(priority)) + + retval = None + assert priority in self.__priority_dict_set + if task: + # Find index for task + idx = next((i for i, (_, contained_task) in enumerate(self.__priority_dict[priority]) if contained_task == task), + len(self.__priority_dict[priority])) + if idx >= len(self.__priority_dict[priority]): + raise ValueError('`task` not found in priority.') + _, retval = self.__priority_dict[priority].pop(idx) + assert(retval == task) + else: + # Remove first task + _, retval = self.__priority_dict[priority].pop(0) + self.__priority_dict_set[priority].remove(retval) + + if len(self.__priority_dict[priority]) <= 0: + # Remove priority from dictionary if empty (we do not want to keep too many of these around) + self.__priority_dict.pop(priority) + assert len(self.__priority_dict_set[priority]) <= 0 + self.__priority_dict_set.pop(priority) + return retval + + __REMOVED = object() # Placeholder for a removed task + + def __init__(self, queue: list = None): + """ + Creates a new PriorityQueue object. + + Args: + queue (list, optional): A list of (priority, task) tuples to initialize the queue. + This is an O(len(queue)) operation. + + Raises: + ValueError: If any task in the queue is None. + """ + # entry: [priority: int, nonce: int, task: hashable_object] + self.__pq = [] # list(entry) - List of entries arranged in a heap + self.__entry_finder = {} # dictionary(task: Hashable_object, entry) - mapping of tasks to entries + self.__priority_tracker = PriorityQueue.__PriorityTracker() # Tracks tasks by priority + self.__counter: int = itertools.count(1) # Unique sequence count + + if queue: + for priority, task in queue: + if task is None: + raise ValueError('`queue`: tasks cannot be `None`.') + count = next(self.__counter) + entry = [priority, ((0, ), count), task] + self.__entry_finder[task] = entry + self.__priority_tracker.push(*entry)#priority, task) + self.__pq.append() + heapq.heapify(self.__pq) + + def __bool__(self): + """ + Returns True if the priority queue is not empty, False otherwise. + + Returns: + bool: True if it is not empty, False otherwise. + """ + return len(self) > 0 + + def __contains__(self, task: object): + """ + Checks if a task is in the priority queue. + + Args: + task (object): The task to check for. + + Returns: + bool: True if it is in the queue, False otherwise. + """ + return task in self.__entry_finder + + def __iter__(self): + """ + Returns an iterator over the tasks in the priority queue. + + Returns: + __PriorityQueueIter: An iterator over the tasks in the queue. + """ + return PriorityQueue.__PriorityQueueIter(self.__pq, PriorityQueue.__REMOVED) + + def __len__(self): + """ + Returns the number of tasks in the priority queue. + + Returns: + int: The number of tasks. + """ + return len(self.__entry_finder) + + def __repr__(self): + """ + Returns a string representation of the priority queue. + + Returns: + str: A string representation of the queue. + """ + return '<{} object at {}>(len={}, pq={})'.format(type(self).__name__, + hex(id(self)), + len(self), + self.__pq) + + def push(self, priority: int, task: object, tie_breaker: tuple = None): #ahead: bool = None): + """ + Adds a new task or update the priority of an existing task. + + Args: + priority (int): The priority of the task. + task (object): The task to add or update. + tie_breaker (tuple, optional): A tuple of ints to use as a tie breaker for tasks + of the same priority. Defaults to (0,) if None. + + Raises: + ValueError: If the task is None. + TypeError: If the tie_breaker is not a tuple of ints or None. + """ + if task is None: + raise ValueError('`task` cannot be `None`.') + if tie_breaker is not None \ + and not all(isinstance(x, int) for x in tie_breaker): + raise TypeError('`tie_breaker` expected tuple of `int`s, or `None`.') + b_add_needed = True + if task in self.__entry_finder: + old_priority, (old_tie_breaker, _), _ = self.__entry_finder[task] + if tie_breaker is None: + tie_breaker = old_tie_breaker + if old_priority != priority \ + or tie_breaker != old_tie_breaker: + self.remove(task) + else: + # same task without priority change detected: no need to add + b_add_needed = False + + if tie_breaker is None: + tie_breaker = (0,) + + if b_add_needed: + if len(self.__pq) == 0: + self.__counter: int = itertools.count(1) # restart sequence count when queue is empty + count = next(self.__counter) + entry = [priority, (tie_breaker, count), task] + self.__entry_finder[task] = entry + self.__priority_tracker.push(*entry)#priority, task) + heapq.heappush(self.__pq, entry) + + def remove(self, task: object): + """ + Removes an existing task from the priority queue. + + Args: + task (object): The task to remove from the queue. It must exist. + + Raises: + KeyError: If the task is not found in the queue. + """ + # mark an existing task as PriorityQueue.__REMOVED. + entry = self.__entry_finder.pop(task) + priority, *_ = entry + self.__priority_tracker.pop(priority, task) # remove it from the priority tracker + entry[-1] = PriorityQueue.__REMOVED + + def peek(self) -> tuple: + """ + Returns the task with the lowest priority without removing it from the queue. + + Returns: + tuple: The (priority, task) pair of the task with the lowest priority, + or None if the queue is empty. + """ + # make sure head is not a removed task + while self.__pq and self.__pq[0][-1] is PriorityQueue.__REMOVED: + heapq.heappop(self.__pq) + retval = None + if self.__pq: + priority, _, task = self.__pq[0] + retval = (priority, task) + return retval + + def find(self, priority: int) -> object: + """ + Returns a task with the specified priority, if there is one. + + The returned task is not removed from the priority queue. + + Args: + priority (int): The priority of the task to find. + + Returns: + object: The task with the specified priority, or None if no such task exists. + """ + return self.__priority_tracker.find(priority) + + def pop(self) -> tuple: + """ + Removes and return the task with the lowest priority. + + Returns: + tuple: The (priority, task) pair of the task that was removed. + + Raises: + IndexError: If the queue is empty. + """ + task = PriorityQueue.__REMOVED + while task is PriorityQueue.__REMOVED: # make sure head is not a removed task + priority, _, task = heapq.heappop(self.__pq) + self.__entry_finder.pop(task) + self.__priority_tracker.pop(priority, task) + return (priority, task) diff --git a/assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py b/assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py new file mode 100644 index 00000000..d2ecefb9 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/queue_dict.py @@ -0,0 +1,141 @@ +from collections import deque + +class QueueDict: + """ + A dictionary that keeps its elements in a FIFO (queue) order. + + This class allows adding new items to the dictionary, but they will always be + added at the end of the queue structure. Modifying the value of an existing + key will not change the order of the item in the queue structure. + + Read/write access to contained items via their keys is allowed, but removal + of items occur at the start of the queue structure only. No removals + are allowed on any other items of the structure. + """ + def __init__(self): + """ + Initializes a new, empty QueueDict object. + """ + self.__q = deque() + self.__lookup = {} + + def __len__(self) -> int: + """ + Returns the number of items in the QueueDict. + + Returns: + int: Number of items. + """ + return len(self.__lookup) + + def __iter__(self): + """ + Returns an iterator over the keys of the QueueDict. + + Yields: + The next key in the QueueDict. + + Raises: + RuntimeError: If the QueueDict changes size during iteration. + """ + q = self.__q.copy() + initial_len = len(self.__lookup) + while q: + if len(self.__lookup) != initial_len: + raise RuntimeError("QueueDict changed size during iteration.") + key = q.popleft() + yield key + + def __contains__(self, key) -> bool: + """ + Checks if a key is in the QueueDict. + + Args: + key: The key to check for. + + Returns: + bool: True if the key is in the QueueDict, False otherwise. + """ + return key in self.__lookup + + def __getitem__(self, key) -> object: + """ + Gets the value associated with a key in the QueueDict. + + Args: + key: The key whose value is to be retrieved. + + Returns: + object: The value associated with the key. + """ + return self.__lookup[key] + + def __setitem__(self, key, value: object): + """ + Sets the value associated with a key in the QueueDict. + + Args: + key: The key to set the value for. + value: The value to associate with the key. + """ + self.push(key, value) + + def clear(self): + """ + Empties the QueueDict, removing all items. + """ + self.__q.clear() + self.__lookup = {} + + def copy(self) -> object: # QueueDict + """ + Returns a shallow copy of the QueueDict. + + Returns: + QueueDict: The shallow copy. + """ + retval = QueueDict() + retval.__q = self.__q.copy() + retval.__lookup = self.__lookup.copy() + return retval + + def peek(self) -> tuple: + """ + Returns the (key, value) pair item at the start of the QueueDict, but does + not modify the QueueDict. + + This is the next item that would be removed on the next call to `QueueDict.pop()`. + + Returns: + tuple: The (key, value) pair. + """ + key = self.__q[0] + value = self.__lookup[key] + return (key, value) + + def pop(self) -> tuple: + """ + Removes and returns the (key, value) pair item at the start of the QueueDict. + + Returns: + tuple: The (key, value) pair that was removed. + """ + key = self.__q.popleft() + value = self.__lookup.pop(key) + return (key, value) + + def push(self, key, value: object): + """ + Adds a new (key, value) pair item at the end of the QueueDict if `key` does not + exists in the QueueDict. Otherwise, the value of the existing item with specified + key is changed to the new `value`. + + This method is equivalent to assigning a value to the key: `QueueDict[key] = value`. + + Args: + key: The key to add or update. + value: The value to associate with the key. + """ + if key not in self.__lookup: + self.__q.append(key) + self.__lookup[key] = value diff --git a/assembler_tools/hec-assembler-tools/assembler/common/run_config.py b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py new file mode 100644 index 00000000..83e9d06e --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/run_config.py @@ -0,0 +1,126 @@ +import io + +from . import constants +from .config import GlobalConfig + +def static_initializer(cls): + """ + Decorator to initialize static members of a class. + + This decorator calls the `init_static` method of the class to initialize + any static members or configurations. + + Args: + cls: The class to be initialized. + + Returns: + The class with initialized static members. + """ + cls.init_static() + return cls + +@static_initializer +class RunConfig: + """ + Configuration class for running the assembler with specific settings. + + This class manages configuration settings such as memory sizes, replacement + policies, and other options that affect the behavior of the assembler. + """ + + __initialized = False # Specifies whether static members have been initialized + __default_config = {} # Dictionary of all configuration items supported and their default values + + # Config defaults + DEFAULT_HBM_SIZE_KB = int(constants.MemoryModel.HBM.MAX_CAPACITY / constants.Constants.KILOBYTE) + DEFAULT_SPAD_SIZE_KB = int(constants.MemoryModel.SPAD.MAX_CAPACITY / constants.Constants.KILOBYTE) + DEFAULT_REPL_POLICY = constants.Constants.REPLACEMENT_POLICY_FTBU + + def __init__(self, + **kwargs): + """ + Constructs a new RunConfig object from input parameters. + + Args: + hbm_size (int, optional): + Optional HBM size in KB. Defaults to `RunConfig.DEFAULT_HBM_SIZE_KB`. + + spad_size (int, optional): + Optional scratchpad size in KB. Defaults to `RunConfig.DEFAULT_SPAD_SIZE_KB`. + + repl_policy (str, optional): + Optional replacement policy. This should be one of `constants.Constants.REPLACEMENT_POLICIES`. + Defaults to `RunConfig.DEFAULT_REPL_POLICY`. + + suppress_comments (bool, optional): + If true, no comments will be emitted in the output generated by the assembler. + Defaults to GlobalConfig.suppressComments (`False`). + + use_hbm_placeholders (bool, optional): + [DEPRECATED]/[UNUSED] Specifies whether to use placeholders (names) for variable locations in HBM (`True`) + or the actual variable locations (`False`). Defaults to GlobalConfig.useHBMPlaceHolders (`True`). + + use_xinstfetch (bool, optional): + Specifies whether `xinstfetch` instructions should be generated in the CInstQ (`True`) or not (`False`). + When no `xinstfetch` instructions are added, it is assumed that the HERACLES automated mechanism for `xinstfetch` will be activated. + Defaults to GlobalConfig.useXInstFetch (`True`). + + debug_verbose (int, optional): + If greater than 0, debug prints will occur. Its value indicates how often to print within loops + (every `debugVerbose` iterations). Defaults to GlobalConfig.debugVerbose (`0`). + + Raises: + ValueError: If at least one of the arguments passed is invalid. + """ + + # Initialize class members + for config_name, default_value in self.__default_config.items(): + setattr(self, config_name, kwargs.get(config_name, default_value)) + + # Validate inputs + if self.repl_policy not in constants.Constants.REPLACEMENT_POLICIES: + raise ValueError('Invalid `repl_policy`. "{}" not in {}'.format(self.repl_policy, + constants.Constants.REPLACEMENT_POLICIES)) + + @classmethod + def init_static(cls): + """ + Initializes static members of the RunConfig class. + + This method sets up default configuration values for the class, ensuring + that they are only initialized once. + """ + if not cls.__initialized: + cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB + cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB + cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY + cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments + #cls.__default_config["use_hbm_placeholders"] = GlobalConfig.useHBMPlaceHolders + cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch + cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose + + cls.__initialized = True + + def __str__(self): + """ + Returns a string representation of the configuration. + + This method provides a human-readable format of the current configuration + settings, listing each configuration item and its value. + """ + self_dict = self.as_dict() + with io.StringIO() as retval_f: + for key, value in self_dict.items(): + print("{}: {}".format(key, value), file=retval_f) + retval = retval_f.getvalue() + return retval + + def as_dict(self) -> dict: + """ + Converts the configuration to a dictionary. + + Returns: + dict: A dictionary representation of the current configuration settings. + """ + tmp_self_dict = vars(self) + return { config_name: tmp_self_dict[config_name] for config_name in self.__default_config } diff --git a/assembler_tools/hec-assembler-tools/assembler/common/utilities.py b/assembler_tools/hec-assembler-tools/assembler/common/utilities.py new file mode 100644 index 00000000..55b30d65 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/common/utilities.py @@ -0,0 +1,21 @@ + +def clamp(x, minimum = float("-inf"), maximum = float("inf")): + """ + Clamp a value between a specified minimum and maximum. + + This function ensures that a given value `x` is constrained within the + bounds defined by `minimum` and `maximum`. + + Args: + x: The value to be clamped. + minimum (float, optional): The lower bound to clamp `x` to. + maximum (float, optional): The upper bound to clamp `x` to. + + Returns: + The clamped value. + """ + if x < minimum: + return minimum + if x > maximum: + return maximum + return x diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py new file mode 100644 index 00000000..a401ed56 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py @@ -0,0 +1,28 @@ + +def tokenizeFromLine(line: str) -> list: + """ + Tokenizes a line of text and extracts any comment present. + + This function processes a line of text, removing line breaks and splitting the line + into tokens based on commas. It also identifies and extracts comments, which are + denoted by the pound symbol `#`. + + Args: + line (str): Line of text to tokenize. + + Returns: + tuple: A tuple containing the tokens and the comment. The `tokens` are a tuple of strings, + and `comment` is a string. The `comment` is an empty string if no comment is found in the line. + """ + tokens = tuple() + comment = "" + if line: + line = ''.join(line.splitlines()) # remove line breaks + comment_idx = line.find('#') + if comment_idx >= 0: + # Found a comment + comment = line[comment_idx + 1:] + line = line[:comment_idx] + tokens = tuple(map(lambda s: s.strip(), line.split(','))) + retval = (tokens, comment) + return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py new file mode 100644 index 00000000..837b3ae6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/__init__.py @@ -0,0 +1,18 @@ + +from . import bload, bones, cexit, cload, cnop, cstore, csyncm, ifetch, kgload, kgseed, kgstart, nload, xinstfetch + +# MInst aliases + +BLoad = bload.Instruction +BOnes = bones.Instruction +CExit = cexit.Instruction +CLoad = cload.Instruction +CNop = cnop.Instruction +CStore = cstore.Instruction +CSyncm = csyncm.Instruction +IFetch = ifetch.Instruction +KGLoad = kgload.Instruction +KGSeed = kgseed.Instruction +KGStart = kgstart.Instruction +NLoad = nload.Instruction +XInstFetch = xinstfetch.Instruction diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py new file mode 100644 index 00000000..44526533 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py @@ -0,0 +1,181 @@ +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model.variable import Variable + +class Instruction(CInstruction): + """ + Encapsulates the `bload` CInstruction. + + The `bload` instruction loads metadata from the scratchpad to special registers in the register file. + + For more information, check the `bload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bload.md + + Attributes: + col_num (int): Block index inside the metadata source word. See documentation for details. + m_idx (int): Target metadata register index. See documentation for details. + spad_src (int): SPAD address of the metadata word to load. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name for the operation. + + Returns: + str: The ASM name for the operation, which is "bload". + """ + return "bload" + + def __init__(self, + id: int, + col_num: int, + m_idx: int, + src: Variable, + mem_model, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Constructs a new `bload` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + col_num (int): Metadata register column number. See documentation for details. + m_idx (int): Metadata register index. See documentation for details. + src (Variable): Metadata variable to load from SPAD. + mem_model: The memory model associated with the instruction. + throughput (int, optional): The throughput of the instruction. Defaults to the class-defined throughput. + latency (int, optional): The latency of the instruction. Defaults to the class-defined latency. + comment (str, optional): An optional comment for the instruction. + + Raises: + ValueError: If `mem_model` is None. + """ + if not mem_model: + raise ValueError('`mem_model` cannot be `None`.') + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.col_num = col_num + self.m_idx = m_idx + self.__mem_model = mem_model + self._set_sources( [ src ] ) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation. + """ + assert(len(self.sources) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'col_num={}, m_idx={}, src={}, ' + 'mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.col_num, + self.m_idx, + self.sources[0], + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the `bload` instruction does not have destination parameters. + + Parameters: + value: The value to set as destinations. + + Raises: + RuntimeError: Always, as `bload` does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have destination parameters.") + + def _set_sources(self, value): + """ + Validates and sets the list of source objects. + + Parameters: + value (list): The list of source objects to set. + + Raises: + ValueError: If the value is not a list of the expected number of `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_id (int): The schedule ID for the instruction. + + Raises: + RuntimeError: If the SPAD address is invalid or if the column number is out of range. + + Returns: + int: The throughput for this instruction, i.e., the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + + variable: Variable = self.sources[0] # expected sources to contain a Variable + if variable.spad_address < 0: + raise RuntimeError(f'Null Access Violation: Variable "{variable}" not allocated in SPAD.') + if self.m_idx < 0: + raise RuntimeError(f"Invalid negative index `m_idx`.") + if self.col_num not in range(4): + raise RuntimeError(f"Invalid `col_num`: {self.col_num}. Must be in range [0, 4).") + + retval = super()._schedule(cycle_count, schedule_id) + # Track last access to SPAD address + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking.last_cload = self + # No need to sync to any previous MLoads after bload + spad_access_tracking.last_mload = None + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments for the conversion. + + Raises: + ValueError: If `extra_args` are provided. + + Returns: + str: The ASM format string of the instruction. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # `op, target_idx, spad_src [# comment]` + preamble = [] + # Instruction sources + extra_args = (self.col_num, ) + extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + # Instruction destinations + extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + extra_args = (self.m_idx, ) + extra_args + return self.toStringFormat(preamble, + self.OP_NAME_ASM, + *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py new file mode 100644 index 00000000..e958ac60 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py @@ -0,0 +1,164 @@ +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model.variable import Variable + +class Instruction(CInstruction): + """ + Encapsulates a `bones` CInstruction. + + The `bones` instruction loads metadata of identity (one) from the scratchpad to the register file. + + For more information, check the `bones` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bones.md + + Attributes: + spad_src (int): SPAD address of the metadata variable to load. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name for the operation. + + Returns: + str: The ASM name for the operation, which is "bones". + """ + return "bones" + + def __init__(self, + id: int, + src_col_num: int, + src: Variable, + mem_model, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Constructs a new `bones` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + src_col_num (int): Column number of the source metadata. + src (Variable): Metadata variable to load from SPAD. + mem_model: The memory model associated with the instruction. + throughput (int, optional): The throughput of the instruction. Defaults to the class-defined throughput. + latency (int, optional): The latency of the instruction. Defaults to the class-defined latency. + comment (str, optional): An optional comment for the instruction. + + Raises: + ValueError: If `mem_model` is None. + """ + if not mem_model: + raise ValueError('`mem_model` cannot be `None`.') + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.src_col_num = src_col_num + self.__mem_model = mem_model + self._set_sources( [ src ] ) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object. + """ + assert(len(self.sources) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'src_col_num={}, src={}, ' + 'mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.src_col_num, + self.sources[0], + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the `bones` instruction does not have destination parameters. + + Parameters: + value: The value to set as destinations. + + Raises: + RuntimeError: Always, as `bones` does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Validates and sets the list of source objects. + + Parameters: + value (list): The list of source objects to set. + + Raises: + ValueError: If the value is not a list of the expected number of `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_id (int): The schedule ID for the instruction. + + Raises: + RuntimeError: If the SPAD address is invalid or if the source column number is negative. + + Returns: + int: The throughput for this instruction, i.e., the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + + variable: Variable = self.sources[0] # Expected sources to contain a Variable. + if variable.spad_address < 0: + raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") + if self.src_col_num < 0: + raise RuntimeError("Invalid `src_col_num` negative `Ones` target index.") + + retval = super()._schedule(cycle_count, schedule_id) + # Track last access to SPAD address. + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking.last_cload = self + # No need to sync to any previous MLoads after bones. + spad_access_tracking.last_mload = None + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments for the conversion. + + Raises: + ValueError: If `extra_args` are provided. + + Returns: + str: The ASM format string of the instruction. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # `op, spad_src, src_col_num [# comment]` + return super()._toCASMISAFormat(self.src_col_num) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py new file mode 100644 index 00000000..529f94cd --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py @@ -0,0 +1,102 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Represents a `cexit` CInstruction. + + This instruction terminates execution of a HERACLES program. + + For more information, check the `cexit` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cexit.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name for the operation. + + Returns: + str: The ASM name for the operation, which is "cexit". + """ + return "cexit" + + def __init__(self, + id: int, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Constructs a new `cexit` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + throughput (int, optional): The throughput of the instruction. Defaults to the class-defined throughput. + latency (int, optional): The latency of the instruction. Defaults to the class-defined latency. + comment (str, optional): An optional comment for the instruction. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the `cexit` instruction does not have destination parameters. + + Parameters: + value: The value to set as destinations. + + Raises: + RuntimeError: Always, as `cexit` does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as the `cexit` instruction does not have source parameters. + + Parameters: + value: The value to set as sources. + + Raises: + RuntimeError: Always, as `cexit` does not have source parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments for the conversion. + + Raises: + ValueError: If `extra_args` are provided. + + Returns: + str: The ASM format string of the instruction. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toCASMISAFormat() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py new file mode 100644 index 00000000..79eed14b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py @@ -0,0 +1,73 @@ +from assembler.common.cycle_tracking import CycleType +from ..instruction import BaseInstruction + +class CInstruction(BaseInstruction): + """ + Represents a CInstruction, which is a type of BaseInstruction. + + This class provides the basic structure and functionality for CInstructions, including + methods for converting to CInst ASM-ISA format. + + Attributes: + id (int): User-defined ID for the instruction. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + comment (str): An optional comment for the instruction. + """ + + # Constructor + # ----------- + + def __init__(self, + id: int, + throughput : int, + latency : int, + comment: str = ""): + """ + Constructs a new CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + comment (str, optional): An optional comment for the instruction. + """ + super().__init__(id, throughput, latency, comment=comment) + + + # Methods and properties + # ---------------------- + + def _get_cycle_ready(self): + """ + Returns the cycle ready value for the instruction. + + This method overrides the base method to provide a specific cycle ready value for CInstructions. + + Returns: + CycleType: A CycleType object with bundle and cycle set to 0. + """ + return CycleType(bundle = 0, cycle = 0) + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to CInst ASM-ISA format. + + This method constructs the ASM-ISA format string for the instruction by combining + the instruction's sources and destinations with any additional arguments. + + Parameters: + extra_args: Additional arguments for the conversion. + + Returns: + str: The CInst ASM-ISA format string of the instruction. + """ + + preamble = [] + # instruction sources + extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + # instruction destinations + extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + return self.toStringFormat(preamble, + self.OP_NAME_ASM, + *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py new file mode 100644 index 00000000..d5d2aa7c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py @@ -0,0 +1,165 @@ + +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable +from assembler.memory_model.register_file import Register + +class Instruction(CInstruction): + """ + Encapsulates a `cload` CInstruction. + + A `cload` instruction loads a word, corresponding to a single polynomial residue, + from scratchpad memory into the register file memory. + + For more information, check the `cload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cload.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name for the operation. + + Returns: + str: The ASM name for the operation, which is "cload". + """ + return "cload" + + def __init__(self, + id: int, + dst: Register, + src: list, + mem_model: MemoryModel, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Constructs a new `cload` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + dst (Register): The destination register where to load the variable in `src`. + src (list): A list containing a single Variable object indicating the source variable to store from + register into SPAD. + mem_model (MemoryModel): The memory model containing the SPAD where to store the source variable. + throughput (int, optional): The throughput of the instruction. Defaults to the class-defined throughput. + latency (int, optional): The latency of the instruction. Defaults to the class-defined latency. + comment (str, optional): An optional comment for the instruction. + + Raises: + AssertionError: If the destination register bank index is not 0. + """ + assert(dst.bank.bank_index == 0) # We must be following convention of loading from SPAD into bank 0 + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.__mem_model = mem_model + self._set_dests([ dst ]) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object. + """ + assert(len(self.dests) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'dst={}, src={},' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests[0], + self.sources, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Validates and sets the list of destination registers. + + Parameters: + value (list): The list of destination registers to set. + + Raises: + ValueError: If the value is not a list of the expected number of `Register` objects. + TypeError: If the value is not a list of `Register` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Register) for x in value): + raise TypeError("`value`: Expected list of `Register` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Validates and sets the list of source variables. + + Parameters: + value (list): The list of source variables to set. + + Raises: + ValueError: If the value is not a list of the expected number of `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Source Variable and destination Register will be updated to reflect the load. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_id (int): The schedule ID for the instruction. + + Raises: + RuntimeError: If the variable or register is already allocated, or if other exceptions occur. + + Returns: + int: The throughput for this instruction, i.e., the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + + variable: Variable = self.sources[0] # Expected sources to contain a Variable + target_register: Register = self.dests[0] + + if variable.spad_address < 0: + raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") + # Cannot allocate variable to more than one register (memory coherence) + # and must not overrite a register that already contains a variable. + if variable.register: + raise RuntimeError(f"Variable `{variable}` already allocated in register `{variable.register}`.") + if target_register.contained_variable: + raise RuntimeError(f"Register `{target_register}` already contains a Variable object.") + + retval = super()._schedule(cycle_count, schedule_id) + # Perform the load + target_register.allocateVariable(variable) + # Track last access to SPAD address + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking.last_cload = self + # No need to sync to any previous MLoads after cload + spad_access_tracking.last_mload = None + + if self.comment: + self.comment += ';' + self.comment += f' {variable.name}' + + return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py new file mode 100644 index 00000000..61e05f0d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py @@ -0,0 +1,105 @@ + +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Represents a 'cnop' CInstruction from the ASM ISA specification. + + This class is used to create a 'cnop' instruction, which is a type of + no-operation (NOP) instruction that inserts a specified number of idle + cycles during its execution. The instruction does not have any destination + or source operands. + + For more information, check the `cnop` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_nop.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'cnop'. + """ + return "cnop" + + def __init__(self, + id: int, + idle_cycles: int, + comment: str = ""): + """ + Constructs a new 'cnop' CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled + with a nonce to form a unique ID. + idle_cycles (int): Number of idle cycles to insert in the CInst execution. + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + # Throughput and latency for 'nop' is the number of idle cycles + super().__init__(id, idle_cycles, idle_cycles, comment=comment) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, and throughput. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'idle_cycles={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination operands. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as the instruction does not have source operands. + + Parameters: + value: The value to set as source, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + AssertionError: If the number of destinations or sources is incorrect. + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # The idle cycles in the ASM ISA for 'nop' must be one less because decoding/scheduling + # the instruction counts as a cycle. + return super()._toCASMISAFormat(self.throughput - 1) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py new file mode 100644 index 00000000..a6908427 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py @@ -0,0 +1,195 @@ + +from assembler.common.config import GlobalConfig +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable, DummyVariable +from assembler.memory_model.register_file import Register + +class Instruction(CInstruction): + """ + Encapsulates a `cstore` CInstruction. + + A `cstore` instruction pops the top word from the intermediate data buffer queue + and stores it in SPAD. To accomplish this in scheduling, a `cstore` should + be scheduled immediately after the `ifetch` for the bundle containing the matching + `xstore`. + + For more information, check the `cstore` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cstore.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'cstore'. + """ + return "cstore" + + def __init__(self, + id: int, + mem_model: MemoryModel, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `cstore` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + mem_model (MemoryModel): The memory model containing the SPAD where to store the source variable. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If `mem_model` is not an instance of `MemoryModel`. + """ + if not isinstance(mem_model, MemoryModel): + raise ValueError('`mem_model` must be an instance of `MemoryModel`.') + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.__mem_model = mem_model + self.__spad_addr = -1 + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source registers for the instruction. + + Parameters: + value (list): A list of `Register` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect. + TypeError: If the list does not contain `Register` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Register) for x in value): + raise TypeError("`value`: Expected list of `Register` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Source Variable and its Register will be updated to reflect the store. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): The schedule ID for the instruction. + + Raises: + RuntimeError: When one of the following happens: + - Source and destination are not the same variable. + - Source is not on a register. + See inherited for more exceptions. + + ValueError: Invalid arguments or either double or conflicting allocations. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + spad = self.__mem_model.spad + + var_name, (variable, self.__spad_addr) = self.__mem_model.store_buffer.pop() # Will raise IndexError if popping from empty queue + assert(var_name == variable.name) + assert self.__spad_addr >= 0 and (variable.spad_address < 0 or variable.spad_address == self.__spad_addr), \ + f'self.__spad_addr = {self.__spad_addr}; {variable.name}.spad_address = {variable.spad_address}' + + retval = super()._schedule(cycle_count, schedule_id) + # Perform the cstore + if spad.buffer[self.__spad_addr] and spad.buffer[self.__spad_addr] != variable: + if not isinstance(spad.buffer[self.__spad_addr], DummyVariable): + raise RuntimeError(f'SPAD location {self.__spad_addr} for instruction (`{self.name}`, id {self.id}) is occupied by variable {spad.buffer[self.__spad_addr]}.') + spad.deallocate(self.__spad_addr) + spad.allocateForce(self.__spad_addr, variable) # Allocate in SPAD + # Track last access to SPAD address + spad_access_tracking = spad.getAccessTracking(self.__spad_addr) + spad_access_tracking.last_cstore = self + spad_access_tracking.last_mload = None # Last mload is now obsolete + variable.spad_dirty = True # Variable has new value in SPAD + + if not GlobalConfig.hasHBM: + # Used to track the variable name going into spad at the moment of cstore. + # This is used to output var name instead of spad address when requested. + # remove when we have spad and HBM back + self.__spad_addr = variable.toCASMISAFormat() + + if self.comment: + self.comment += ';' + self.comment += f' {variable.name}' + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to CInst ASM-ISA format. + + See inherited for more information. + + Parameters: + extra_args (tuple): Additional arguments, which are not supported. + + Returns: + str: The instruction in CInst ASM-ISA format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toCASMISAFormat(self.__spad_addr) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py new file mode 100644 index 00000000..18485c88 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py @@ -0,0 +1,143 @@ +import warnings + +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `csyncm` CInstruction. + + This instruction is used to synchronize with a specific instruction from the MINST queue. + + For more information, check the `csyncm` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_csyncm.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'csyncm'. + """ + return "csyncm" + + def __init__(self, + id: int, + minstr, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `csyncm` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + minstr: MInstruction + Instruction from the MINST queue for which to wait. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.minstr = minstr # Instruction from the MINST queue for which to wait + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, minstr, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'minstr={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + repr(self.minstr), + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination operands. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as the instruction does not have source operands. + + Parameters: + value: The value to set as source, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: MInstruction to sync is invalid or has not been scheduled. + See inherited for more exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + if not self.minstr: + raise RuntimeError("Invalid empty MInstruction.") + if not self.minstr.is_scheduled: + raise RuntimeError("MInstruction to sync is not scheduled yet.") + + retval = super()._schedule(cycle_count, schedule_id) + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert(self.minstr.is_scheduled) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # warnings.warn("`csyncm` instruction requires second pass to set correct instruction number.") + return super()._toCASMISAFormat(self.minstr.schedule_timing.index) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py new file mode 100644 index 00000000..c4919186 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py @@ -0,0 +1,141 @@ + +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates an `ifetch` CInstruction. + + This instruction is used to fetch a bundle of instructions from the instruction memory. + + For more information, check the `ifetch` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_ifetch.md + + Attributes: + OP_DEFAULT_LATENCY (int): The default latency as per ASM ISA spec. + bundle_id (int): Zero-based index for the bundle of instructions to fetch. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'ifetch'. + """ + return "ifetch" + + def __init__(self, + id: int, + bundle_id: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `ifetch` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + bundle_id (int): Zero-based index for the bundle of instructions to fetch. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.bundle_id = bundle_id # Instruction number from the MINST queue for which to wait + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, bundle_id, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'bundle_id={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.bundle_id, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination operands. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as the instruction does not have source operands. + + Parameters: + value: The value to set as source, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: If the bundle ID is invalid (less than zero). + See inherited for more exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + if self.bundle_id < 0: + raise RuntimeError("Invalid bundle ID. Expected zero or greater.") + + retval = super()._schedule(cycle_count, schedule_id) + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toCASMISAFormat(self.bundle_id) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py new file mode 100644 index 00000000..3bf97a36 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py @@ -0,0 +1,204 @@ +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model.variable import Variable +from assembler.memory_model.register_file import Register + +class Instruction(CInstruction): + """ + Encapsulates `kg_load` CInstruction. + + `kg_load` instruction loads HW-generated key material from the keygen engine + into a CE data register. + + To start the keygen engine, a seed should be loaded followed by a kg_start + instruction to start the key material generation. + + kg_load, dst_register + + Rules: + 1. `kg_load`s and `kg_start`s must be `latency` cycles apart from any other + `kg_load` and `kg_start`. It takes between 10 to `latency` cycles for the key generation + resource to generate the next key material, possibly causing contention if + the key material is requested by any `kg_load` within `latency` cycles. + """ + + @classmethod + def SetNumSources(cls, val): + cls._OP_NUM_SOURCES = val + 1 # Adding the keygen variable (since the actual instruction needs no sources) + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'kg_load'. + """ + return "kg_load" + + def __init__(self, + id: int, + dst: Register, + src: list, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `kg_load` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + dst (Register): Register to contain the key material generated. Associated keygen variable will + be set to this register when scheduled. + + src (list of Variable): Contains the keygen variable to be loaded. The variable register will be set + to the specified destination register when scheduled. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self._set_sources(src) + self._set_dests([dst]) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, column number, memory index, source, throughput, and latency. + """ + assert(len(self.sources) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'col_num={}, m_idx={}, src={}, ' + 'mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.col_num, + self.m_idx, + self.sources[0], + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination registers for the instruction. + + Parameters: + value (list): A list of `Register` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect. + TypeError: If the list does not contain `Register` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Register) for x in value): + raise TypeError("`value`: Expected list of `Register` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Source Variable and destination Register will be updated to reflect the load. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: Variable or Register already allocated. See inherited for other exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + + variable: Variable = self.sources[0] # Expected sources to contain a Variable + target_register: Register = self.dests[0] + + if variable.spad_address >= 0 or variable.hbm_address >= 0: + raise RuntimeError(f"Variable `{variable}` already generated.") + # Cannot allocate variable to more than one register (memory coherence) + # and must not overwrite a register that already contains a variable. + if variable.register: + raise RuntimeError(f"Variable `{variable}` already allocated in register `{variable.register}`.") + if target_register.contained_variable: + raise RuntimeError(f"Register `{target_register}` already contains a Variable object.") + + retval = super()._schedule(cycle_count, schedule_id) + # Variable generated, reflect the load + target_register.allocateVariable(variable) + + if self.comment: + self.comment += ';' + self.comment += f' {variable.name}' + + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # `op, dest_reg [# comment]` + preamble = [] + # Instruction sources + # kg_load has no sources + + # Instruction destinations + extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + return self.toStringFormat(preamble, + self.OP_NAME_ASM, + *extra_args) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py new file mode 100644 index 00000000..1d2478cb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py @@ -0,0 +1,190 @@ +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model.variable import Variable + +class Instruction(CInstruction): + """ + Encapsulates `kg_seed` CInstruction. + + `kg_seed` instruction loads a seed value into the keygen engine. The engine + is reset and prepared to start generating key material. + + A word holds up to 4 seed values (2KB each), so, kg_seed has extra parameters + to select the seed inside the word. + + kg_seed, , + + spad_address: SPAD address where the word containing the seed resides. + block_index: Index of data block inside the word for the seed to load. + A word is 8KB, containing up to 4 seeds of 2KB each. `block_index` + indicates the index of the block to load. + + Attributes: + block_index (int): Index of data block inside the word for the seed to load. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'kg_seed'. + """ + return "kg_seed" + + def __init__(self, + id: int, + block_index: int, + src: Variable, + mem_model, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `kg_seed` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + block_index (int): Index of data block inside the word for the seed to load. See docs. + + src (Variable): Variable containing the seed to load from SPAD. + + mem_model: The memory model used for tracking SPAD access. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If `mem_model` is `None`. + """ + if not mem_model: + raise ValueError('`mem_model` cannot be `None`.') + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.block_index = block_index + self.__mem_model = mem_model + self._set_sources([src]) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, block index, source, throughput, and latency. + """ + assert(len(self.sources) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'col_num={}, m_idx={}, src={}, ' + 'mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.block_index, + self.sources[0], + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination parameters. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have destination parameters.") + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: If the variable is not allocated in SPAD or if the block index is invalid. + See inherited for more exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + + variable: Variable = self.sources[0] # Expected sources to contain a Variable + if variable.spad_address < 0: + raise RuntimeError(f'Null Access Violation: Variable "{variable}" not allocated in SPAD.') + if self.block_index not in range(4): + raise RuntimeError(f"Invalid `block_index`: {self.block_index}. Must be in range [0, 4).") + + retval = super()._schedule(cycle_count, schedule_id) + # Track last access to SPAD address + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking.last_cload = self + # No need to sync to any previous MLoads after kg_seed + spad_access_tracking.last_mload = None + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # `op, spad_src, block_index [# comment]` + preamble = [] + # Instruction sources + extra_args = (self.block_index, ) + extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + # Instruction destinations + extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + return self.toStringFormat(preamble, + self.OP_NAME_ASM, + *extra_args) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py new file mode 100644 index 00000000..7de5ae12 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py @@ -0,0 +1,112 @@ +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model.variable import Variable + +class Instruction(CInstruction): + """ + Encapsulates `kg_start` CInstruction. + + `kg_start` instruction signals the keygen engine to start producing key material + using the currently loaded seed. + + Rules: + 1. `kg_load`s and `kg_start`s must be `latency` cycles apart from any other + `kg_load` and `kg_start`. It takes between 10 to `latency` cycles for the key generation + resource to generate the next key material, possibly causing contention if + the key material is requested by any `kg_load` within `latency` cycles. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'kg_start'. + """ + return "kg_start" + + def __init__(self, + id: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `kg_start` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination parameters. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as the instruction does not have source parameters. + + Parameters: + value: The value to set as source, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have source parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toCASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py new file mode 100644 index 00000000..ffe1111e --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py @@ -0,0 +1,185 @@ + +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction +from assembler.memory_model.variable import Variable + +class Instruction(CInstruction): + """ + Encapsulates an `nload` CInstruction. + + `nload` instruction loads metadata (for NTT/iNTT routing mapping) from scratchpad + into special routing table registers. + + For more information, check the `nload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_nload.md + + Attributes: + table_idx (int): Index for destination routing table. See docs. + spad_src (int): SPAD address of metadata variable to load. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'nload'. + """ + return "nload" + + def __init__(self, + id: int, + table_idx: int, + src: Variable, + mem_model, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `nload` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + table_idx (int): Index for destination routing table. See docs. + + src (Variable): Variable containing the metadata to load from SPAD. + + mem_model: The memory model used for tracking SPAD access. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If `mem_model` is `None`. + """ + if not mem_model: + raise ValueError('`mem_model` cannot be `None`.') + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.table_idx = table_idx + self.__mem_model = mem_model + self._set_sources([src]) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, table index, source, throughput, and latency. + """ + assert(len(self.sources) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'table_idx={}, src={}, ' + 'mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.table_idx, + self.sources[0], + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination parameters. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: If the variable is not allocated in SPAD or if the table index is invalid. + See inherited for more exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + + variable: Variable = self.sources[0] # Expected sources to contain a Variable + if variable.spad_address < 0: + raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") + if self.table_idx < 0: + raise RuntimeError("Invalid `table_idx` negative routing table index.") + + retval = super()._schedule(cycle_count, schedule_id) + # Track last access to SPAD address + spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking.last_cload = self + # No need to sync to any previous MLoads after bones + spad_access_tracking.last_mload = None + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # `op, target_idx, spad_src [# comment]` + preamble = [] + # Instruction sources + extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + # Instruction destinations + extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + extra_args = (self.table_idx, ) + extra_args + return self.toStringFormat(preamble, + self.OP_NAME_ASM, + *extra_args) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py new file mode 100644 index 00000000..3f9d1804 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py @@ -0,0 +1,155 @@ +from assembler.common import constants +from assembler.common.cycle_tracking import CycleType +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates an `xinstfetch` CInstruction. + + `xinstfetch` fetches 1 word (32KB) worth of instructions from the HBM XInst + region and sends it to the XINST queue. + + For more information, check the `xinstfetch` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_xinstfetch.md + + Attributes: + xq_dst (int): Destination word address in XINST queue in the range + [0, constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS). + + hbm_src (int): Address of the word worth of instructions in HBM XInst region to copy into XINST queue. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'xinstfetch'. + """ + return "xinstfetch" + + def __init__(self, + id: int, + xq_dst: int, + hbm_src: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `xinstfetch` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + xq_dst (int): Destination word address in XINST queue in the range [0, 32). XINST queue capacity is + 32 words (1MB). + + hbm_src (int): Address of the word worth of instructions in HBM XInst region to copy into XINST queue. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.xq_dst = xq_dst + self.hbm_src = hbm_src + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, xq_dst, hbm_src, throughput, and latency. + """ + assert(len(self.dests) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'xq_dst={}, hbm_src={},' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.xq_dst, + self.hbm_src, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination parameters. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as the instruction does not have source parameters. + + Parameters: + value: The value to set as source, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have source parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: If the xq_dst is out of range or if the hbm_src is negative. + See inherited for more exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + if self.xq_dst < 0 or self.xq_dst >= constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS: + raise RuntimeError(('Invalid `xq_dst` XINST queue destination address. Expected value in range ' + '[0, {}), but received {}.'. format(constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS, + self.xq_dst))) + if self.hbm_src < 0: + raise RuntimeError("Invalid `hbm_src` negative HBM address.") + + retval = super()._schedule(cycle_count, schedule_id) + return retval + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toCASMISAFormat(self.xq_dst, self.hbm_src) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py new file mode 100644 index 00000000..5ed888c7 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py @@ -0,0 +1,687 @@ +from typing import final +from typing import NamedTuple + +from assembler.common.config import GlobalConfig +from assembler.common.counter import Counter +from assembler.common.cycle_tracking import CycleTracker, CycleType +from assembler.common.decorators import * + +class ScheduleTiming(NamedTuple): + """ + A named tuple to add structure to schedule timing. + + Attributes: + cycle (CycleType): The cycle in which the instruction was scheduled. + index (int): The index for the instruction in its schedule listing. + """ + cycle: CycleType + index: int + +class BaseInstruction(CycleTracker): + """ + The base class for all instructions. + + This class encapsulates data regarding an instruction, as well as scheduling + logic and functionality. It inherits members from the CycleTracker class. + + Class Properties: + name (str): Returns the name of the represented operation. + OP_NAME_ASM (str): ASM-ISA name for the instruction. + OP_NAME_PISA (str): P-ISA name for the instruction. + + Class Methods: + _get_name(cls) -> str: Derived classes should implement this method and return the correct + name for the instruction. Defaults to the ASM-ISA name. + _get_OP_NAME_ASM(cls) -> str: Derived classes should implement this method and return the correct + ASM name for the operation. Default throws not implemented. + _get_OP_NAME_PISA(cls) -> str: Derived classes should implement this method and return the correct + P-ISA name for the operation. Defaults to the ASM-ISA name. + + Constructors: + __init__(self, id: int, throughput: int, latency: int, comment: str = ""): + Initializes a new BaseInstruction object. + + Attributes: + _dests (list[CycleTracker]): List of destination objects. Derived classes can override + _set_dests to validate this attribute. + _frozen_cisa (str): Contains frozen CInst in ASM ISA format after scheduling. Empty string if not frozen. + _frozen_misa (str): Contains frozen MInst in ASM ISA format after scheduling. Empty string if not frozen. + _frozen_pisa (str): Contains frozen P-ISA format after scheduling. Empty string if not frozen. + _frozen_xisa (str): Contains frozen XInst in ASM ISA format after scheduling. Empty string if not frozen. + _sources (list[CycleTracker]): List of source objects. Derived classes can override + _set_sources to validate this attribute. + comment (str): Comment for the instruction. + + Properties: + dests (list): Gets or sets the list of destination objects. The elements of the list are derived dependent. + Calls _set_dests to set value. + id (tuple): Gets the unique instruction ID. This is a combination of the client ID specified during + construction and a unique nonce per instruction. + is_scheduled (bool): Returns whether the instruction has been scheduled (True) or not (False). + latency (int): Returns the latency of the represented operation. This is the number + of clock cycles before the results of the operation are ready in the destination. + schedule_timing (ScheduleTiming): Gets the cycle and index in which this instruction was scheduled or + None if not scheduled yet. Index is subject to change and it is not final until the second pass of scheduling. + sources (list): Gets or sets the list of source objects. The elements of the list are derived dependent. + Calls _set_sources to set value. + throughput (int): Returns the throughput of the represented operation. Number of clock cycles + before a new instruction can be decoded/queued for execution. + + Magic Methods: + __eq__(self, other): Checks equality between two BaseInstruction objects. + __hash__(self): Returns the hash of the BaseInstruction object. + __repr__(self): Returns a string representation of the BaseInstruction object. + __str__(self): Returns a string representation of the BaseInstruction object. + + Methods: + _schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: + Schedules the instruction, simulating timings of executing this instruction. Derived + classes should override with their scheduling functionality. + _toCASMISAFormat(self, *extra_args) -> str: Converts the instruction to CInst ASM-ISA format. + Derived classes should override with their functionality. + _toMASMISAFormat(self, *extra_args) -> str: Converts the instruction to MInst ASM-ISA format. + Derived classes should override with their functionality. + _toPISAFormat(self, *extra_args) -> str: Converts the instruction to P-ISA kernel format. + Derived classes should override with their functionality. + _toXASMISAFormat(self, *extra_args) -> str: Converts the instruction to XInst ASM-ISA format. + Derived classes should override with their functionality. + freeze(self): Called immediately after _schedule() to freeze the instruction after scheduling + to preserve the instruction string representation to output into the listing. + Changes made to the instruction and its components after freezing are ignored. + schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: + Schedules and freezes the instruction, simulating timings of executing this instruction. + toStringFormat(self, preamble, op_name: str, *extra_args) -> str: + Converts the instruction to a string format. + toPISAFormat(self) -> str: Converts the instruction to P-ISA kernel format. + toXASMISAFormat(self) -> str: Converts the instruction to ASM-ISA format. + toCASMISAFormat(self) -> str: Converts the instruction to CInst ASM-ISA format. + toMASMISAFormat(self) -> str: Converts the instruction to MInst ASM-ISA format. + """ + # To be initialized from ASM ISA spec + _OP_NUM_DESTS : int + _OP_NUM_SOURCES : int + _OP_DEFAULT_THROUGHPUT : int + _OP_DEFAULT_LATENCY : int + + __id_count = Counter.count(0) # internal unique sequence counter to generate unique IDs + + # Class methods and properties + # ---------------------------- + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns attributes as dictionary. + """ + dict = {"num_dests": cls._OP_NUM_DESTS, + "num_sources": cls._OP_NUM_SOURCES, + "default_throughput": cls._OP_DEFAULT_THROUGHPUT, + "default_latency": cls._OP_DEFAULT_LATENCY} + return dict + + @classmethod + def SetNumDests(cls, val): + cls._OP_NUM_DESTS = val + + @classmethod + def SetNumSources(cls, val): + cls._OP_NUM_SOURCES = val + + @classmethod + def SetDefaultThroughput(cls, val): + cls._OP_DEFAULT_THROUGHPUT = val + + @classmethod + def SetDefaultLatency(cls, val): + cls._OP_DEFAULT_LATENCY = val + + @classproperty + def name(cls) -> str: + """ + Name for the instruction. + """ + return cls._get_name() + + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. Defaults to the ASM-ISA name. + """ + return cls.OP_NAME_ASM + + @classproperty + def OP_NAME_PISA(cls) -> str: + """ + P-ISA name for the instruction. + """ + return cls._get_OP_NAME_PISA() + + @classmethod + def _get_OP_NAME_PISA(cls) -> str: + """ + Derived classes should implement this method and return correct + P-ISA name for the operation. Defaults to the ASM-ISA name. + """ + return cls.OP_NAME_ASM + + @classproperty + def OP_NAME_ASM(cls) -> str: + """ + ASM-ISA name for instruction. + + Will throw if no ASM-ISA name for instruction. + """ + return cls._get_OP_NAME_ASM() + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Derived classes should implement this method and return correct + ASM name for the operation. + """ + raise NotImplementedError('Abstract method not implemented.') + + # Constructor + # ----------- + + def __init__(self, + id: int, + throughput : int, + latency : int, + comment: str = ""): + """ + Initializes a new BaseInstruction object. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + throughput (int): Number of clock cycles that it takes after this instruction starts executing before the + execution engine can start executing a new instruction. Instructions are pipelined, so, + another instruction can be started in the clock cycle after this instruction's throughput + has elapsed, even if this instruction latency hasn't elapsed yet. + latency (int): Number of clock cycles it takes for the instruction to complete and its outputs to be ready. + Outputs are ready in the clock cycle after this instruction's latency has elapsed. Must be + greater than or equal to throughput. + comment (str): Optional comment for the instruction. + + Raises: + ValueError: If throughput is less than 1 or latency is less than throughput. + """ + # validate inputs + if throughput < 1: + raise ValueError(("`throughput`: must be a positive number, " + "but {} received.".format(throughput))) + if latency < throughput: + raise ValueError(("`latency`: cannot be less than throughput. " + "Expected, at least, {}, but {} received.".format(throughput, latency))) + + super().__init__((0, 0)) + + self.__id = (id, next(BaseInstruction.__id_count)) # Mix with unique sequence counter + self.__throughput = throughput # read_only throughput of the operation + self.__latency = latency # read_only latency of the operation + self._dests = [] + self._sources = [] + self.comment = " id: {}{}{}".format(self.__id, + "; " if comment.strip() else "", + comment) + self.__schedule_timing: ScheduleTiming = None # Tracks when was this instruction scheduled, or None if not scheduled yet + + self._frozen_pisa = "" # To contain frozen P-ISA format after scheduling + self._frozen_xisa = "" # To contain frozen XInst in ASM ISA format after scheduling + self._frozen_cisa = "" # To contain frozen CInst in ASM ISA format after scheduling + self._frozen_misa = "" # To contain frozen MInst in ASM ISA format after scheduling + + def __repr__(self): + """ + Returns a string representation of the BaseInstruction object. + """ + retval = ('<{}({}) object at {}>(id={}[0], ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.OP_NAME_PISA, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + def __eq__(self, other): + """ + Checks equality between two BaseInstruction objects. + """ + return self is other #other.id == self.id + + def __hash__(self): + """ + Returns the hash of the BaseInstruction object. + """ + return hash(self.id) + + def __str__(self): + """ + Returns a string representation of the BaseInstruction object. + """ + return f'{self.name} {self.id}' + + # Methods and properties + # ---------------------------- + + @property + def id(self) -> tuple: + """ + Gets the unique ID for the instruction. + + This is a combination of the client ID specified during construction and a unique nonce per instruction. + + Returns: + tuple: (client_id: int, nonce: int) where client_id is the id specified at construction. + """ + return self.__id + + @property + def schedule_timing(self) -> ScheduleTiming: + """ + Retrieves the 1-based index for this instruction in its schedule listing, + or less than 1 if not scheduled yet. + """ + return self.__schedule_timing + + def set_schedule_timing_index(self, value: int): + """ + Sets the schedule timing index. + + Parameters: + value (int): The index value to set. + + Raises: + ValueError: If the value is less than 0. + """ + if value < 0: + raise ValueError("`value`: expected a value of `0` or greater for `schedule_timing.index`.") + self.__schedule_timing = ScheduleTiming(cycle = self.__schedule_timing.cycle, + index=value) + + @property + def is_scheduled(self) -> bool: + """ + Checks if the instruction is scheduled. + + Returns: + bool: True if the instruction is scheduled, False otherwise. + """ + return True if self.schedule_timing else False + + @property + def throughput(self) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput. + """ + return self.__throughput + + @property + def latency(self) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency. + """ + return self.__latency + + @property + def dests(self) -> list: + """ + Gets the list of destination objects. + + Returns: + list: The list of destination objects. + """ + return self._dests + + @dests.setter + def dests(self, value): + """ + Sets the list of destination objects. + + Parameters: + value (list): The list of destination objects to set. + """ + self._set_dests(value) + + def _set_dests(self, value): + """ + Validates and sets the list of destination objects. + + Parameters: + value (list): The list of destination objects to set. + + Raises: + ValueError: If the value is not a list of CycleTracker objects. + """ + if not all(isinstance(x, CycleTracker) for x in value): + raise ValueError("`value`: Expected list of `CycleTracker` objects.") + self._dests = [ x for x in value ] + + @property + def sources(self) -> list: + """ + Gets the list of source objects. + + Returns: + list: The list of source objects. + """ + return self._sources + + @sources.setter + def sources(self, value): + """ + Sets the list of source objects. + + Parameters: + value (list): The list of source objects to set. + """ + self._set_sources(value) + + def _set_sources(self, value): + """ + Validates and sets the list of source objects. + + Parameters: + value (list): The list of source objects to set. + + Raises: + ValueError: If the value is not a list of CycleTracker objects. + """ + if not all(isinstance(x, CycleTracker) for x in value): + raise ValueError("`value`: Expected list of `CycleTracker` objects.") + self._sources = [ x for x in value ] + + def _get_cycle_ready(self): + """ + Returns the current value for ready cycle. + + This method is called by property cycle_ready getter to retrieve the value. + An instruction cycle ready value is the maximum among its own and all the + sources ready cycles, and destinations (special case). + + Cycles are measured as tuples: (bundle: int, clock_cycle: int) + + Overrides `CycleTracker._get_cycle_ready`. + + Returns: + CycleType: The current value for ready cycle. + """ + + # we have to be careful that `max` won't iterate on our CycleType tuples' inner values + retval = super()._get_cycle_ready() + if self.sources: + retval = max(retval, *(src.cycle_ready for src in self.sources)) + if self.dests: + # dests cycle ready is a special case: + # dests are ready to be read or writen to at their cycle_ready, but instructions can + # start the following cycle when their dests are ready minus the latency of + # the instruction because the dests will be writen to in the last cycle of + # the instruction: + # Cycle decode_phase write_phase dests_ready latency + # 1 INST1 5 + # 2 INST2 5 + # 3 INST3 5 + # 4 INST4 5 + # 5 INST6 INST1 5 + # 6 INST7 INST2 INST1 5 + # 7 INST8 INST3 INST2 5 + # INST1's dests are ready in cycle 6 and they are writen to in cycle 5. + # If INST2 uses any INST1 dest as its dest, INST2 can start the cycle + # following INST1, 2, because INST2 will write to the same dest in cycle 6. + retval = max(retval, *(dst.cycle_ready - self.latency + 1 for dst in self.dests)) + return retval + + def freeze(self): + """ + Called immediately after `_schedule()` to freeze the instruction after scheduling + to preserve the instruction string representation to output into the listing. + Changes made to the instruction and its components after freezing are ignored. + + Freezing is necessary because content of instruction sources and destinations + may change by further instructions as they get scheduled. + + Clients may call this method stand alone if they need to refresh the frozen + instruction. However, refreezing may result in incorrect string representation + depending on the instruction. + + This method ensures that the instruction can be frozen. + + Derived classes should override to correctly freeze the instruction. + When overriding, this base method must be called as part of the override. + + Raises: + RuntimeError: If the instruction has not been scheduled yet. + """ + if not self.is_scheduled: + raise RuntimeError(f"Instruction `{self.name}` (id = {self.id}) is not yet scheduled.") + + self._frozen_pisa = self._toPISAFormat() + self._frozen_xisa = self._toXASMISAFormat() + self._frozen_cisa = self._toCASMISAFormat() + self._frozen_misa = self._toMASMISAFormat() + + def _schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Ensures that this instruction is ready to be scheduled (dependencies and states + are ready). + + Derived classes can override to add their own simulation rules. When overriding, + this base method must be called, at some point, as part of the override. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_idx (int): 1-based index for this instruction in its schedule listing. + + Raises: + ValueError: If invalid arguments are provided. + RuntimeError: If the instruction is not ready to be scheduled yet or if the instruction is already scheduled. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + if self.is_scheduled: + raise RuntimeError(f"Instruction `{self.name}` (id = {self.id}) is already scheduled.") + if schedule_idx < 1: + raise ValueError("`schedule_idx`: expected a value of `1` or greater.") + if len(cycle_count) < 2: + raise ValueError("`cycle_count`: expected a pair/tuple with two components.") + if cycle_count < self.cycle_ready: + raise RuntimeError(("Instruction {}, id: {}, not ready to schedule. " + "Ready cycle is {}, but current cycle is {}.").format(self.name, + self.id, + self.cycle_ready, + cycle_count)) + self.__schedule_timing = ScheduleTiming(cycle_count, schedule_idx) + return self.throughput + + @final + def schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: + """ + Schedules and freezes the instruction, simulating timings of executing this instruction. + + Ensures that this instruction is ready to be scheduled (dependencies and states + are ready). + + Derived classes can override the protected methods `_schedule()` and `_freeze()` to add their + own simulation and freezing rules. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_idx (int): 1-based index for this instruction in its schedule listing. + + Raises: + ValueError: If invalid arguments are provided. + RuntimeError: If the instruction is not ready to be scheduled yet or if the instruction is already scheduled. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + retval = self._schedule(cycle_count, schedule_idx) + self.freeze() + return retval + + def toStringFormat(self, + preamble, + op_name: str, + *extra_args) -> str: + """ + Converts the instruction to a string format. + + Parameters: + preamble (iterable): List of arguments prefacing the instruction name `op_name`. Can be None if no preamble. + op_name (str): Name of the operation for the instruction. Cannot be empty. + extra_args: Variable number of arguments. Extra arguments to add at the end of the instruction. + + Returns: + str: A string representing the instruction. The string has the form: + [preamble0, preamble1, ..., preamble_p,] op [, extra0, extra1, ..., extra_e] [# comment] + """ + # op, dst0 (bank), dst1 (bank), ..., dst_d (bank), src0 (bank), src1 (bank), ..., src_s (bank) [, extra], res # comment + if not op_name: + raise ValueError("`op_name` cannot be empty.") + retval = op_name + if preamble: + retval = ('{}, '.format(', '.join(str(x) for x in preamble))) + retval + if extra_args: + retval += ', {}'.format(', '.join([str(extra) for extra in extra_args])) + if not GlobalConfig.suppressComments: + if self.comment: + retval += ' #{}'.format(self.comment) + return retval + + @final + def toPISAFormat(self) -> str: + """ + Converts the instruction to P-ISA kernel format. + + Returns: + str: String representation of the instruction in P-ISA kernel format. The string has the form: + `N, op, dst0 (bank), dst1 (bank), ..., dst_d (bank), src0 (bank), src1 (bank), ..., src_s (bank) [, extra0, extra1, ..., extra_e] [, res] [# comment]` + where `extra_e` are instruction specific extra arguments. + """ + return self._frozen_pisa if self._frozen_pisa else self._toPISAFormat() + + @final + def toXASMISAFormat(self) -> str: + """ + Converts the instruction to ASM-ISA format. + + If instruction is frozen, this returns the frozen result, otherwise, it attempts to + generate the string representation on the fly. + + Internally calls method `_toXASMISAFormat()`. + + Derived classes can override method `_toXASMISAFormat()` to provide their own conversion. + + Returns: + str: A string representation of the instruction in ASM-ISA format. The string has the form: + `id[0], N, op, dst_register0, dst_register1, ..., dst_register_d, src_register0, src_register1, ..., src_register_s [, extra0, extra1, ..., extra_e], res [# comment]` + where `extra_e` are instruction specific extra arguments. + Since the residual is mandatory in the format, it is set to `0` in the output if the + instruction does not support residual. + """ + return self._frozen_xisa if self._frozen_xisa else self._toXASMISAFormat() + + @final + def toCASMISAFormat(self) -> str: + """ + Converts the instruction to CInst ASM-ISA format. + + If instruction is frozen, this returns the frozen result, otherwise, it attempts to + generate the string representation on the fly. + + Internally calls method `_toCASMISAFormat()`. + + Derived classes can override method `_toCASMISAFormat()` to provide their own conversion. + + Returns: + str: A string representation of the instruction in ASM-ISA format. The string has the form: + `N, op, dst0, dst1, ..., dst_d, src0, src1, ..., src_s [, extra0, extra1, ..., extra_e], [# comment]` + where `extra_e` are instruction specific extra arguments. + Since the ring size is mandatory in the format, it is set to `0` in the output if the + instruction does not support it. + """ + return self._frozen_cisa if self._frozen_cisa else self._toCASMISAFormat() + + @final + def toMASMISAFormat(self) -> str: + """ + Converts the instruction to MInst ASM-ISA format. + + If instruction is frozen, this returns the frozen result, otherwise, it attempts to + generate the string representation on the fly. + + Internally calls method `_toMASMISAFormat()`. + + Derived classes can override method `_toMASMISAFormat()` to provide their own conversion. + + Returns: + str: A string representation of the instruction in ASM-ISA format. The string has the form: + `op, dst0, dst1, ..., dst_d, src0, src1, ..., src_s [, extra0, extra1, ..., extra_e], [# comment]` + where `extra_e` are instruction specific extra arguments. + """ + return self._frozen_misa if self._frozen_misa else self._toMASMISAFormat() + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to P-ISA kernel format. + + Derived classes should override with their functionality. Overrides do not need to call + this base method. + + Returns: + str: Empty string ("") to indicate that this instruction does not have a P-ISA equivalent. + """ + return "" + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to XInst ASM-ISA format. + + This base method returns an empty string. + + Derived classes should override with their functionality. Overrides do not need to call + this base method. + + Returns: + str: Empty string ("") to indicate that this instruction does not have an XInst equivalent. + """ + return "" + + def _toCASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to CInst ASM-ISA format. + + Derived classes should override with their functionality. Overrides do not need to call + this base method. + + Returns: + str: Empty string ("") to indicate that this instruction does not have a CInst equivalent. + """ + return "" + + def _toMASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to MInst ASM-ISA format. + + Derived classes should override with their functionality. Overrides do not need to call + this base method. + + Returns: + str: Empty string ("") to indicate that this instruction does not have an MInst equivalent. + """ + return "" diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py new file mode 100644 index 00000000..7b485ccb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/__init__.py @@ -0,0 +1,8 @@ + +from . import mload, mstore, msyncc + +# MInst aliases + +MLoad = mload.Instruction +MStore = mstore.Instruction +MSyncc = msyncc.Instruction diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py new file mode 100644 index 00000000..6a2afe4b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py @@ -0,0 +1,77 @@ +from assembler.common.cycle_tracking import CycleType +from ..instruction import BaseInstruction + +class MInstruction(BaseInstruction): + """ + Represents a memory-level instruction (MInstruction). + + This class is used to encapsulate the properties and behaviors of a memory-level instruction, + including its throughput, latency, and a unique counter value that increases with each + MInstruction created. + + Methods: + count: Returns the MInstruction counter value for this instruction. + """ + + __minst_count = 0 # Internal Minst counter + + def __init__(self, + id: int, + throughput: int, + latency: int, + comment: str = ""): + """ + Constructs a new MInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + throughput (int): The throughput of the instruction. + + latency (int): The latency of the instruction. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + super().__init__(id, throughput, latency, comment=comment) + self.__count = MInstruction.__minst_count + + @property + def count(self): + """ + Returns the MInstruction counter value for this instruction. + + This value monotonically increases per MInstruction created. + + Returns: + int: The counter value for this MInstruction. + """ + return self.__count + + def _get_cycle_ready(self): + """ + Returns a CycleType object indicating when the instruction is ready. + + Returns: + CycleType: A CycleType object with bundle and cycle set to 0. + """ + return CycleType(bundle=0, cycle=0) + + def _toMASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to MInst ASM-ISA format. + + See inherited for more information. + + Parameters: + extra_args: Additional arguments for formatting. + + Returns: + str: The instruction in MInst ASM-ISA format. + """ + # Instruction sources + extra_args = tuple(src.toMASMISAFormat() for src in self.sources) + extra_args + # Instruction destinations + extra_args = tuple(dst.toMASMISAFormat() for dst in self.dests) + extra_args + return self.toStringFormat(None, + self.OP_NAME_ASM, + *extra_args) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py new file mode 100644 index 00000000..2618403a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py @@ -0,0 +1,209 @@ + +from assembler.common.config import GlobalConfig +from assembler.common.cycle_tracking import CycleType +from .minstruction import MInstruction +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable + +class Instruction(MInstruction): + """ + Encapsulates an `mload` MInstruction. + + Instruction `mload` loads a word, corresponding to a single polynomial residue, + from HBM data region into the SPAD memory. + + CINST queue should use `csyncm` matching this instruction before using the address. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_mload.md + + Attributes: + dst_spad_addr (int): SPAD address where to load the source variable. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'mload'. + """ + return "mload" + + def __init__(self, + id: int, + src: list, + mem_model: MemoryModel, + dst_spad_addr: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `mload` MInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + src (list of Variable): A list containing a single Variable object indicating the source variable to load from + HBM into SPAD. + + mem_model (MemoryModel): The memory model containing the SPAD where to store the source variable. + + dst_spad_addr (int): SPAD address where to load the source variable. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + if not GlobalConfig.useHBMPlaceHolders: + for variable in src: + if comment: + comment += "; " + comment += f'variable "{variable.name}"' + + super().__init__(id, throughput, latency, comment=comment) + self.__mem_model = mem_model + self.dst_spad_addr = dst_spad_addr + self.__internal_set_dests(src) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, source, destination SPAD address, throughput, and latency. + """ + assert(len(self.dests) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'src={}, dst_spad_addr={}, mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.sources, + self.dst_spad_addr, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction only supports setting sources. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction only supports setting sources. + """ + raise RuntimeError(f"Instruction `{self.name}` only supports setting sources.") + + def __internal_set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Source Variable will be updated to reflect the load. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: Multiple SPAD allocation or source is not allocated to HBM. + See inherited method for other exceptions. + + ValueError: Invalid SPAD address. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) + assert(all(src == dst for src, dst in zip(self.sources, self.dests))) + + hbm = self.__mem_model.hbm + spad = self.__mem_model.spad + + variable: Variable = self.sources[0] + + if variable.spad_address >= 0: + raise RuntimeError("Source variable is already in SPAD. Cannot load a variable into SPAD more than once.") + if variable.hbm_address < 0: + raise RuntimeError("Null reference exception: source variable is not in HBM.") + + retval = super()._schedule(cycle_count, schedule_id) + # Perform the load + spad.allocateForce(self.dst_spad_addr, variable) + # Track SPAD access + spad.getAccessTracking(self.dst_spad_addr).last_mload = self + return retval + + def _toMASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to MInst ASM-ISA format. + + See inherited for more information. + + Parameters: + extra_args: Additional arguments for formatting. + + Returns: + str: The instruction in MInst ASM-ISA format. + """ + # Instruction sources + extra_args = tuple(src.toMASMISAFormat() for src in self.sources) + extra_args + # Instruction destinations + extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + return self.toStringFormat(None, + self.OP_NAME_ASM, + *extra_args) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py new file mode 100644 index 00000000..a5228193 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py @@ -0,0 +1,230 @@ +from assembler.common.cycle_tracking import CycleType +from .minstruction import MInstruction +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable + +class Instruction(MInstruction): + """ + Encapsulates an `mstore` MInstruction. + + Instruction `mstore` stores a word, corresponding to a single polynomial residue, + from SPAD memory into HBM data region. + + MINST queue should use `msyncc` before scheduling this instruction to ensure source + SPAD address is ready. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_mstore.md + + Attributes: + dst_hbm_addr (int): HBM address where to store the source variable. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'mstore'. + """ + return "mstore" + + def __init__(self, + id: int, + src: list, + mem_model: MemoryModel, + dst_hbm_addr: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `mstore` MInstruction. + + SPAD should use `csyncm` matching this instruction before using the address. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + src (list of Variable): A list containing a single Variable object indicating the source variable to store from + SPAD into HBM. + + mem_model (MemoryModel): The memory model containing the SPAD where to store the source variable. + + dst_hbm_addr (int): HBM address where to store the source variable. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If `dst_hbm_addr` is negative. + """ + if dst_hbm_addr < 0: + raise ValueError('`dst_hbm_addr`: cannot be null address (negative).') + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, throughput, latency, comment=comment) + self.__mem_model = mem_model + self.dst_hbm_addr = dst_hbm_addr + self.__internal_set_dests(src) + self._set_sources(src) + self.__source_spad_address = src[0].spad_address + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, source, destination HBM address, throughput, and latency. + """ + assert(len(self.dests) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'src={}, dst_hbm_addr={}, mem_model, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.sources, + self.dst_hbm_addr, + # repr(self.__mem_model), + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction only supports setting sources. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction only supports setting sources. + """ + raise RuntimeError(f"Instruction `{self.name}` only supports setting sources.") + + def __internal_set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Source Variable will be updated to reflect the load. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: Multiple SPAD allocation or source is not allocated to HBM. + See inherited method for other exceptions. + + ValueError: Invalid SPAD address. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) + assert(all(src == dst for src, dst in zip(self.sources, self.dests))) + + hbm = self.__mem_model.hbm + spad = self.__mem_model.spad + + variable: Variable = self.sources[0] + if self.__source_spad_address < 0: + self.__source_spad_address = self.sources[0].spad_address + + if variable.hbm_address >= 0: + if self.dst_hbm_addr != variable.hbm_address: + raise RuntimeError("Source variable is already in different HBM location. Cannot store a variable into HBM more than once.") + assert(hbm.buffer[variable.hbm_address] == variable) + if self.__source_spad_address < 0: + raise RuntimeError("Null reference exception: source variable is not in SPAD.") + + if self.comment: + self.comment += ';' + # self.comment += ' variable "{}": HBM({}) <- SPAD({})'.format(variable.name, + # self.dst_hbm_addr, + # variable.spad_address) + self.comment += ' variable "{}" <- SPAD({})'.format(variable.name, + variable.spad_address) + + retval = super()._schedule(cycle_count, schedule_id) + # Perform the store + if variable.hbm_address < 0: # Variable new to HBM + hbm.allocateForce(self.dst_hbm_addr, variable) + spad.deallocate(self.__source_spad_address) # Deallocate variable from SPAD + # Track SPAD access + spad_access_tracking = spad.getAccessTracking(self.__source_spad_address) + spad_access_tracking.last_mstore = self + # No need to track last CInst access after a `mstore` + spad_access_tracking.last_cload = None + spad_access_tracking.last_cstore = None + + return retval + + def _toMASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to MInst ASM-ISA format. + + See inherited for more information. + + Parameters: + extra_args: Additional arguments for formatting. + + Returns: + str: The instruction in MInst ASM-ISA format. + """ + # Instruction sources + extra_args = (self.__source_spad_address, ) + extra_args + # Instruction destinations + extra_args = tuple(dst.toMASMISAFormat() for dst in self.dests) + extra_args + return self.toStringFormat(None, + self.OP_NAME_ASM, + *extra_args) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py new file mode 100644 index 00000000..d8b3718d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py @@ -0,0 +1,145 @@ +from assembler.common.cycle_tracking import CycleType +from .minstruction import MInstruction + +class Instruction(MInstruction): + """ + Encapsulates an `msyncc` MInstruction. + + This instruction is used to synchronize with a specific instruction from the CINST queue. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_msyncc.md + + Attributes: + cinstr: The instruction from the CINST queue for which to wait. + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'msyncc'. + """ + return "msyncc" + + def __init__(self, + id: int, + cinstr, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `msyncc` CInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + cinstr: CInstruction + Instruction from the CINST queue for which to wait. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + super().__init__(id, throughput, latency, comment=comment) + self.cinstr = cinstr # Instruction number from the MINST queue for which to wait + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, cinstr, throughput, and latency. + """ + assert(len(self.dests) > 0) + retval=('<{}({}) object at {}>(id={}[0], ' + 'cinstr={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.OP_NAME_PISA, + hex(id(self)), + self.id, + repr(self.cinstr), + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as the instruction does not have destination parameters. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have destination parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as the instruction does not have source parameters. + + Parameters: + value: The value to set as source, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction does not have source parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: CInstruction to sync is invalid or has not been scheduled. + See inherited for more exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + if not self.cinstr: + raise RuntimeError("Invalid empty CInstruction.") + if not self.cinstr.is_scheduled: + raise RuntimeError("CInstruction to sync is not scheduled yet.") + + retval = super()._schedule(cycle_count, schedule_id) + return retval + + def _toMASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert(self.cinstr.is_scheduled) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # warnings.warn("`msyncc` instruction requires second pass to set correct instruction number.") + return super()._toMASMISAFormat(self.cinstr.schedule_timing.index) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py new file mode 100644 index 00000000..02f91124 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py @@ -0,0 +1,140 @@ +from assembler.memory_model import MemoryModel +from .xinstruction import XInstruction +from . import add, sub, mul, muli, mac, maci, ntt, intt, twntt, twintt, rshuffle, irshuffle, move, xstore, nop +from . import exit as exit_mod +from . import copy as copy_mod + +# XInst aliases + +# XInsts with P-ISA equivalent +Add = add.Instruction +Sub = sub.Instruction +Mul = mul.Instruction +Muli = muli.Instruction +Mac = mac.Instruction +Maci = maci.Instruction +NTT = ntt.Instruction +iNTT = intt.Instruction +twNTT = twntt.Instruction +twiNTT = twintt.Instruction +rShuffle = rshuffle.Instruction +Copy = copy_mod.Instruction +irShuffle = irshuffle.Instruction +# All other XInsts +Move = move.Instruction +XStore = xstore.Instruction +Exit = exit_mod.Instruction +Nop = nop.Instruction + +# Collection of XInstructions with P-ISA or intermediate P-ISA equivalents +__PISA_INSTRUCTIONS = ( Add, Sub, Mul, Muli, Mac, Maci, NTT, iNTT, twNTT, twiNTT, rShuffle, irShuffle, Copy ) + +# Collection of XInstructions with global cycle tracking +GLOBAL_CYCLE_TRACKING_INSTRUCTIONS = ( rShuffle, irShuffle, XStore ) + +def createFromParsedObj(mem_model: MemoryModel, + inst_type, + parsed_op, + new_id: int = 0) -> XInstruction: + """ + Creates an XInstruction object XInst from the specified namespace data. + + Variables are extracted from the memory model (or created if not existing) and + added as destinations and sources to the instruction. + + Parameters: + mem_model (MemoryModel): + The MemoryModel object, where all variables are kept. Variables parsed from the + input string will be automatically added to the memory model if they do not already + exist. The represented object may be modified if addition is needed. + inst_type (type): + Type of the instruction to create. Constructor must be compatible with namespace `parsed_op`. + This type must be a class derived from `XInstruction`. + parsed_op (Namespace): + A namespace that is compatible with the instruction of type `inst_type` to create. + new_id (int): + Optional ID number for the instruction. Defaults to 0. + + Returns: + XInstruction: A XInstruction derived object encapsulating the XInst. + + Raises: + ValueError: If `inst_type` is not a class derived from `XInstruction`. + """ + + if not issubclass(inst_type, XInstruction): + raise ValueError('`inst_type`: expected a class derived from `XInstruction`.') + + # Convert variable names into actual variable objects. + + # Find the variables for dst. + dsts = [] + for var_name, bank in parsed_op.dst: + # Retrieve variable from global list (or create new one if it doesn't exist). + var = mem_model.retrieveVarAdd(var_name, bank) + dsts.append(var) + + # Find the variables for src. + srcs = [] + for var_name, bank in parsed_op.src: + # Retrieve variable from global list (or create new one if it doesn't exist). + var = mem_model.retrieveVarAdd(var_name, bank) + srcs.append(var) + + # Prepare parsed object to add as arguments to instruction constructor. + parsed_op.dst = dsts + parsed_op.src = srcs + assert(parsed_op.op_name == inst_type.OP_NAME_PISA) + parsed_op = vars(parsed_op) + parsed_op.pop("op_name") # op name not needed: inst_type knows its name already + return inst_type(new_id, **parsed_op) + +def createFromPISALine(mem_model: MemoryModel, + line: str, + line_no: int = 0) -> XInstruction: + """ + Parses an XInst from the specified string (in P-ISA kernel input format) and returns a + XInstruction object encapsulating the resulting instruction. + + Note that this function will not decompose P-ISA instructions that require multiple + XInsts. This function will only match instructions that have a 1:1 equivalent between + P-ISA and XInst. + + Parameters: + mem_model (MemoryModel): + The MemoryModel object, where all variables are kept. Variables parsed from the + input string will be automatically added to the memory model if they do not already + exist. The represented object may be modified if addition is needed. + line (str): + Line of text containing the instruction to parse in P-ISA kernel input format. + line_no (int): + Optional line number for the line. This will be used as ID for the parsed instruction. + Defaults to 0. + + Returns: + XInstruction: A XInstruction derived object encapsulating the XInst equivalent to the parsed P-ISA + instruction or None if line could not be parsed. + + Raises: + Exception: If an error occurs during parsing, with the line number and content included in the message. + """ + + retval = None + + try: + + for inst_type in __PISA_INSTRUCTIONS: + parsed_op = inst_type.parseFromPISALine(line) + if parsed_op: + assert(inst_type.OP_NAME_PISA == parsed_op.op_name) + + # Convert parsed instruction into an actual instruction object. + retval = createFromParsedObj(mem_model, inst_type, parsed_op, line_no) + + # Line parsed by an instruction: stop searching. + break + + except Exception as ex: + raise Exception(f'line {line_no}: {line}.') from ex + + return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py new file mode 100644 index 00000000..fb06d226 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py @@ -0,0 +1,225 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Represents an `add` instruction in the assembler with specific properties and methods for parsing, + scheduling, and formatting. + + This instructions adds two polynomials stored in the register file and + store the result in a register. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_add.md + + Methods: + parseFromPISALine(line: str) -> list: + Parses an `add` instruction from a Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> list: + """ + Parses an `add` instruction from a Kernel instruction string. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, add, dst (bank), src0 (bank), src1 (bank), res # comment + Comment is optional. + + Example line: + "13, add , output_0_1_3 (2), c_0_1_3 (0), d_0_1_3 (1), 1" + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name ("add") + dst (list[(str, int)]): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for `add`. + src (list[(str, int)]): List of sources of the form (variable_name, suggested_bank). + This list has two elements for `add`. + res (int): Residual for the operation. + comment (str): String with the comment attached to the line (empty string if no comment). + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = { "comment": tokens[1] } + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["res"] = int(instr_tokens[params_end]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: ASM format operation. + """ + return "add" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Initializes an Instruction object with the given parameters. + + Parameters: + id (int): The unique identifier for the instruction. + N (int): The ring size, typically Log_2(PMD). + dst (list): List of destination variables. + src (list): List of source variables. + res (int): The residual for the operation. + throughput (int, optional): The throughput of the instruction. Defaults to None. + latency (int, optional): The latency of the instruction. Defaults to None. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of object. + """ + retval=('<{}({}) object at {}>(id={}[0], res={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): List of destination variables. + + Raises: + ValueError: If the list does not contain the expected number of `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): List of source variables. + + Raises: + ValueError: If the list does not contain the expected number of `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: Kernel format instruction. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toPISAFormat() + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: ASM format instruction. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py new file mode 100644 index 00000000..74022b32 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py @@ -0,0 +1,244 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Encapsulates a `move` instruction when used to copy + a variable into another variable through registers. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def _get_name(cls) -> str: + """ + Returns the operation name in PISA format. + + Returns: + str: PISA operation name. + """ + return cls.OP_NAME_PISA + + @classmethod + def _get_OP_NAME_PISA(cls) -> str: + """ + Returns the operation name in PISA format. + + Returns: + str: PISA operation name. + """ + return "copy" + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: ASM operation name. + """ + return "move" + + @classmethod + def parseFromPISALine(cls, line: str) -> list: + """ + Parses a `copy` instruction from a P-ISA Kernel instruction string. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, copy, dst (bank), src (bank), res=0 # comment + Comment is optional. + + Example line: + "13, copy, output_0_1_3 (2), c_0_1_3 (0), 0" + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name ("copy") + dst (list[(str, int)]): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for `copy`. + src (list[(str, int)]): List of sources of the form (variable_name, suggested_bank). + This list has two elements for `copy`. + res: Residual for the operation. Ignored for copy/move + comment (str): String with the comment attached to the line (empty string if no comment). + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = { "comment": tokens[1] } + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + if len(instr_tokens) < cls._OP_NUM_TOKENS: + # temporary warning to avoid syntax error during testing + # REMOVE WARNING AND TURN IT TO ERROR DURING PRODUCTION + #--------------------------- + warnings.warn(f'Not enough tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + pass + else: + # ignore "res", but make sure it exists (syntax) + assert(instr_tokens[params_end] is not None) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Initializes an Instruction object with the given parameters. + + Parameters: + id (int): The unique identifier for the instruction. + N (int): The ring size, typically Log_2(PMD). + dst (list): List of destination variables. + src (list): List of source variables. + throughput (int, optional): The throughput of the instruction. Defaults to None. + latency (int, optional): The latency of the instruction. Defaults to None. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + + Raises: + ValueError: If the source and destination are the same. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + N = 0 # does not require ring-size + super().__init__(id, N, throughput, latency, comment=comment) + if dst[0].name == src[0].name: + raise ValueError(f'`dst`: Source and destination cannot be the same for instruction "{self.name}".') + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of object. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): List of destination variables. + + Raises: + ValueError: If the list does not contain the expected number of `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): List of source variables. + + Raises: + ValueError: If the list does not contain the expected number of `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to P-ISA kernel format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: P-ISA kernel format instruction. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toPISAFormat() + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: ASM format instruction. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py new file mode 100644 index 00000000..7bce54c1 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py @@ -0,0 +1,116 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `bexit` instruction in the assembler with specific properties and methods for + scheduling and formatting. + + This instruction terminates execution of an instruction bundle. + + For more information, check the specificationn: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_exit.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name in ASM format. + """ + return "bexit" + + def __init__(self, + id: int, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Initializes an Instruction object with the given parameters. + + Parameters: + id (int): The unique identifier for the instruction. + throughput (int, optional): The throughput of the instruction. Defaults to None. + latency (int, optional): The latency of the instruction. Defaults to None. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + N = 0 + super().__init__(id, N, throughput, latency, comment=comment) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Raises an error as `bexit` does not have destination parameters. + + Parameters: + value: The value to set as destinations. + + Raises: + RuntimeError: Always raised as `bexit` does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.OP_NAME_PISA}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as `bexit` does not have source parameters. + + Parameters: + value: The value to set as sources. + + Raises: + RuntimeError: Always raised as `bexit` does not have parameters. + """ + raise RuntimeError(f"Instruction `{self.OP_NAME_PISA}` does not have parameters.") + + def _toPISAFormat(self, *extra_args) -> str: + """ + This instruction has no PISA equivalent. + + Parameters: + *extra_args: Additional arguments (not supported). + + Returns: + None: As this instruction has no PISA equivalent. + """ + return None + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py new file mode 100644 index 00000000..c5bae42a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py @@ -0,0 +1,249 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Represents an `intt` instruction in the assembler with specific properties and methods for parsing, + scheduling, and formatting. + + The Inverse Number Theoretic Transform (iNTT), converts NTT form to positional form. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_intt.md + + Attributes: + stage (int): The stage number of the current NTT instruction. + + Methods: + parseFromPISALine(line: str) -> object: + Parses an `intt` instruction from a pre-processed Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> object: + """ + Parses an `intt` instruction from a pre-processed Kernel instruction string. + + A preprocessed kernel contains the split implementation of original iNTT into + HERACLES equivalent irshuffle, intt, twintt. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, intt, dst_top (bank), dest_bot (bank), src_top (bank), src_bot (bank), src_tw (bank), stage, res # comment + Comment is optional. + + Example line: + "15, intt, outtmp_9_0 (2), outtmp_9_2 (3), output_9_0 (2), output_9_1 (3), w_gen_17_1 (1), 1, 9" + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name ("intt") + dst (list[(str, int)]): List of destinations of the form (variable_name, suggested_bank). + This list has two elements for `intt`. + src (list[(str, int)]): List of sources of the form (variable_name, suggested_bank). + This list has three elements for `intt`. + stage (int): Stage number of the current NTT instruction. + res (int): Residual for the operation. + comment (str): String with the comment attached to the line (empty string if no comment). + + Returns None if an `intt` could not be parsed from the input. + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = { "comment": tokens[1] } + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["stage"] = int(instr_tokens[params_end]) + retval["res"] = int(instr_tokens[params_end + 1]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name in ASM format. + """ + return "intt" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + stage: int, + res: int, + comment: str = "", + throughput : int = None, + latency : int = None): + """ + Initializes an Instruction object with the given parameters. + + Parameters: + id (int): The unique identifier for the instruction. + N (int): The ring size, typically Log_2(PMD). + dst (list): List of destination variables. + src (list): List of source variables. + stage (int): The stage number of the current NTT instruction. + res (int): The residual for the operation. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + throughput (int, optional): The throughput of the instruction. Defaults to None. + latency (int, optional): The latency of the instruction. Defaults to None. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self.__stage = stage # (read-only) stage + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object. + """ + retval=('<{}({}) object at {}>(id={}[0], res={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + @property + def stage(self): + """ + The stage number of the current NTT instruction. + + Returns: + int: The stage number. + """ + return self.__stage + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): List of destination variables. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): List of source variables. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in kernel format. + """ + if extra_args: + raise ValueError('`extra_args` not supported.') + + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + # N, intt, dst_top (bank), dest_bot (bank), src_top (bank), src_bot (bank), src_tw (bank), stage, res # comment + retval = super()._toPISAFormat(self.stage) + + return retval + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat(self.stage) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py new file mode 100644 index 00000000..7176a5cb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py @@ -0,0 +1,401 @@ +import warnings + +from argparse import Namespace + +from assembler.common.cycle_tracking import CycleType +from assembler.common.decorators import * +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable +from . import rshuffle + +class Instruction(XInstruction): + """ + Represents an instruction in the assembler with specific properties and methods for parsing, + scheduling, and formatting. This class is specifically designed to handle `irshuffle` + instruction within the assembler's instruction set architecture (ISA). + + Attributes: + SpecialLatency (int): Special latency indicating the first increment at which another rshuffle instruction + can be scheduled within `SpecialLatencyMax` latency. + SpecialLatencyMax (int): Special latency maximum, indicating that no other rshuffle instruction can be enqueued + within this latency unless it is in `SpecialLatencyIncrement`. + SpecialLatencyIncrement (int): Special latency increment, allowing enqueuing of other rshuffle instructions within + `SpecialLatencyMax` only in increments of this value. + RSHUFFLE_DATA_TYPE (str): Data type used for irshuffle operations, default is "intt". + + Methods: + parseFromPISALine(line: str) -> object: + Parses an `irshuffle` instruction from a pre-processed P-ISA Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS : int + _OP_IRMOVE_LATENCY : int + _OP_IRMOVE_LATENCY_MAX: int + _OP_IRMOVE_LATENCY_INC: int + + __irshuffle_global_cycle_ready = CycleType(0, 0) # private class attribute to track cycle ready among irshuffles + __rshuffle_global_cycle_ready = CycleType(0, 0) # private class attribute to track the cycle ready based on last rshuffle + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS, + "special_latency_max": cls._OP_IRMOVE_LATENCY_MAX, + "special_latency_increment": cls._OP_IRMOVE_LATENCY_INC}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def SetSpecialLatencyMax(cls, val): + cls._OP_IRMOVE_LATENCY_MAX = val + + @classmethod + def SetSpecialLatencyIncrement(cls, val): + cls._OP_IRMOVE_LATENCY_INC = val + cls._OP_IRMOVE_LATENCY = cls._OP_IRMOVE_LATENCY_INC + + @classproperty + def SpecialLatency(cls): + """ + Special latency (indicates the first increment at which another rshuffle instruction + can be scheduled within `SpecialLatencyMax` latency). + + Returns: + int: The special latency value. + """ + return cls._OP_IRMOVE_LATENCY + + @classproperty + def SpecialLatencyMax(cls): + """ + Special latency maximum (cannot enqueue any other rshuffle instruction within this latency + unless it is in `SpecialLatencyIncrement`). + + Returns: + int: The special latency maximum value. + """ + return cls._OP_IRMOVE_LATENCY_MAX + + @classproperty + def SpecialLatencyIncrement(cls): + """ + Special latency increment (can only enqueue any other rshuffle instruction # TCHECK for rshuffle + within `SpecialLatencyMax` only in increments of this value). + + Returns: + int: The special latency increment value. + """ + return cls._OP_IRMOVE_LATENCY_INC + + @classproperty + def RSHUFFLE_DATA_TYPE(cls): + """ + Data type used for rshuffle operations. + + Returns: + str: The data type, default is "intt". + """ + return "intt" + + @classmethod + def parseFromPISALine(cls, line: str) -> object: + """ + Parses an `irshuffle` instruction from a pre-processed P-ISA Kernel instruction string. + + A preprocessed kernel contains the split implementation of original iNTT into + HERACLES equivalent irshuffle, intt, twintt. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, irshuffle, dst0, dst1, src0, src1, res # comment + Comment is optional. + + Example line: + "13, irshuffle, outtmp_9_0 (2), outtmp_9_2 (3), outtmp_9_0 (2), outtmp_9_2 (3), 0" + + Returns: + Namespace: A namespace with the following attributes: + - N (int): Ring size = Log_2(PMD) + - op_name (str): Operation name ("irshuffle") + - dst (list[(str, int)]): List of destinations of the form (variable_name, suggested_bank). + This list has two elements for `irshuffle`. + - src (list[(str, int)]): List of sources of the form (variable_name, suggested_bank). + This list has two elements for `irshuffle`. + - comment (str): String with the comment attached to the line (empty string if no comment). + + Returns None if an `irshuffle` could not be parsed from the input. + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = { "comment": tokens[1] } + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + # ignore "res", but make sure it exists (syntax) + assert(instr_tokens[params_end] is not None) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Returns the operation name in PISA format. + + Returns: + str: The operation name in PISA format. + """ + return cls.OP_NAME_PISA + + @classmethod + def _get_OP_NAME_PISA(cls) -> str: + """ + Returns the operation name in PISA format. + + Returns: + str: The operation name in PISA format. + """ + return "irshuffle" + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name in ASM format. + """ + return "rshuffle" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + wait_cyc: int = 0, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Initializes an Instruction object with the given parameters. + + Parameters: + id (int): The unique identifier for the instruction. + N (int): The ring size, typically Log_2(PMD). + dst (list): List of destination variables. + src (list): List of source variables. + wait_cyc (int, optional): The wait cycle for the instruction. Defaults to 0. + throughput (int, optional): The throughput of the instruction. Defaults to None. + latency (int, optional): The latency of the instruction. Defaults to None. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + + Raises: + ValueError: If the latency is less than the special latency. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + if latency < Instruction._OP_IRMOVE_LATENCY: + raise ValueError((f'`latency`: expected a value greater than or equal to ' + '{Instruction._OP_IRMOVE_LATENCY}, but {latency} received.')) + + super().__init__(id, N, throughput, latency, comment=comment) + + self.wait_cyc = wait_cyc + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'dst={}, src={}, ' + 'wait_cyc={}, res={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.wait_cyc, + self.res) + return retval + + @classmethod + def __set_irshuffleGlobalCycleReady(cls, value: CycleType): + """ + Sets the global cycle ready for irshuffle instructions. + + Parameters: + value (CycleType): The cycle type value to set. + """ + if (value > cls.__irshuffle_global_cycle_ready): + cls.__irshuffle_global_cycle_ready = value + + @classmethod + def set_rshuffleGlobalCycleReady(cls, value: CycleType): + """ + Sets the global cycle ready for rshuffle instructions. + + Parameters: + value (CycleType): The cycle type value to set. + """ + if (value > cls.__rshuffle_global_cycle_ready): + cls.__rshuffle_global_cycle_ready = value + + @classmethod + def reset_GlobalCycleReady(cls, value = CycleType(0, 0)): + """ + Resets the global cycle ready for both irshuffle and rshuffle instructions. + + Parameters: + value (CycleType, optional): The cycle type value to reset to. Defaults to CycleType(0, 0). + """ + cls.__rshuffle_global_cycle_ready = value + cls.__irshuffle_global_cycle_ready = value + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): List of destination variables. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_DESTS} Variable objects, " + "but list with {len(value)} elements received.")) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): List of source variables. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_SOURCES} Variable objects, " + "but list with {len(value)} elements received.")) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _get_cycle_ready(self): + """ + Returns the current value for ready cycle. + + Overrides :func:`BaseInstruction._get_cycle_ready`. + + Returns: + CycleType: The current cycle ready value. + """ + # This will return the maximum cycle ready among this instruction + # sources and the global cycles-ready for other rshuffles and other irshuffles. + # An irshuffle cannot be within _OP_IRMOVE_LATENCY cycles from another irshuffle, + # nor within _OP_DEFAULT_LATENCY cycles from another rshuffle. + return max(super()._get_cycle_ready(), + Instruction.__irshuffle_global_cycle_ready, + Instruction.__rshuffle_global_cycle_ready) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + The ready cycle for all destinations is updated based on input `cycle_count` and + this instruction latency. The global `xrshuffle` ready cycles is also updated. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_id (int): The schedule identifier. + + Raises: + RuntimeError: If the instruction is not ready to execute yet. Based on current cycle, + the instruction is ready to execute if its cycle_ready value is less than or + equal to `cycle_count`. + + Returns: + int: The throughput for this instruction, i.e., the number of cycles by which to advance + the current cycle counter. + """ + original_throughput = super()._schedule(cycle_count, schedule_id) + retval = self.throughput + self.wait_cyc + assert(original_throughput <= retval) + Instruction.__set_irshuffleGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + Instruction._OP_IRMOVE_LATENCY)) + # Avoid rshuffles and irshuffles in the same bundle + rshuffle.Instruction.set_irshuffleGlobalCycleReady(CycleType(cycle_count.bundle + 1, 0)) + return retval + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in kernel format. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # N, irshuffle, dst0, dst1, src0, src1, res=0 # comment + return super()._toPISAFormat(0) + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + *extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # id[0], N, op, dst_register0, dst_register1, src_register0, src_register1, wait_cycle, data_type="intt", res=0 [# comment] + return super()._toXASMISAFormat(self.wait_cyc, self.RSHUFFLE_DATA_TYPE) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py new file mode 100644 index 00000000..1d6d817a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py @@ -0,0 +1,251 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Represents a `mac` (multiply-accumulate) instruction in an assembly language. + + This class is responsible for parsing, representing, and converting `mac` instructions + according to a specific instruction set architecture (ISA) specification. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mac.md + + Methods: + parseFromPISALine: Parses a `mac` instruction from a Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def SetNumSources(cls, val): + cls._OP_NUM_SOURCES = val + # In ASM ISA spec there are 3 sources, but src[0] == dst + cls._OP_NUM_PISA_SOURCES = cls._OP_NUM_SOURCES - 1 + + @classmethod + def parseFromPISALine(cls, line: str) -> list: + """ + Parses a 'mac' instruction from a Kernel instruction string. + + Parameters: + line (str): + String containing the instruction to parse. + Instruction format: N, mac, dst (bank), src0 (bank), src1 (bank), res # comment + Comment is optional. + + Example line: + "13, mac , c2_rlk_0_10_0 (3), coeff_0_0_0 (2), rlk_0_2_10_0 (0), 10" + + Returns: + list: + A list of tuples with a single element representing the parsed information, + or an empty list if a 'mac' could not be parsed from the input. + + Element `Instruction` is this class. + + Element `parsed_op` is a namespace with the following attributes: + - N (int): Ring size = Log_2(PMD) + - op_name (str): Operation name ("mac") + - dst (list of tuples): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for 'mac': dst[0] = dst[0] + src[0] * src[1] + - src (list of tuples): List of sources of the form (variable_name, suggested_bank). + This list has two elements for 'mac'. + - res (int): Residual for the operation. + - comment (str): String with the comment attached to the line (empty string if no comment). + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_PISA_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_PISA_SOURCES, + params_start) + retval.update(dst_src) + retval["res"] = int(instr_tokens[params_end]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name as a string. + """ + return "mac" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Initializes an Instruction object for a 'mac' operation. + + Parameters: + id (int): + The unique identifier for the instruction. + N (int): + The ring size. + dst (list): + List of destination variables. + src (list): + List of source variables. + res (int): + The residual for the operation. + throughput (int, optional): + The throughput of the instruction. Defaults to the class-level default if not provided. + latency (int, optional): + The latency of the instruction. Defaults to the class-level default if not provided. + comment (str, optional): + An optional comment for the instruction. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self._set_dests(dst) + self._set_sources([self.dests[0]] + src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the object. + """ + retval = ('<{}({}) object at {}>(id={}[0], res={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the destinations. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the sources. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in kernel format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + preamble = (self.N,) + extra_args = tuple(src.toPISAFormat() for src in self.sources[1:]) + extra_args + extra_args = tuple(dst.toPISAFormat() for dst in self.dests) + extra_args + if self.res is not None: + extra_args += (self.res,) + return self.toStringFormat(preamble, + self.OP_NAME_PISA, + *extra_args) + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py new file mode 100644 index 00000000..455387f2 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py @@ -0,0 +1,267 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Represents a `maci` (multiply-accumulate immediate) instruction in an assembly language. + + This class is responsible for parsing, representing, and converting 'maci' instructions + according to a specific instruction set architecture (ISA) specification. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_maci.md + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def SetNumSources(cls, val): + cls._OP_NUM_SOURCES = val + # In ASM ISA spec there are 2 sources, but src[0] == dst + cls._OP_NUM_PISA_SOURCES = cls._OP_NUM_SOURCES - 1 + + @classmethod + def parseFromPISALine(cls, line: str) -> list: + """ + Parses a 'maci' instruction from a Kernel instruction string. + + Parameters: + line (str): + String containing the instruction to parse. + Instruction format: N, maci, dst (bank), src (bank), imm, res # comment + Comment is optional. + + Example line: + "13, maci, coeff_0_1_3 (2), c2_4_3 (3), Qqr_extend_2_13_4, 13" + + Returns: + list: + A list of tuples with a single element representing the parsed information, + or an empty list if a 'maci' could not be parsed from the input. + + Element `Instruction` is this class. + + Element `parsed_op` is a namespace with the following attributes: + - N (int): Ring size = Log_2(PMD) + - op_name (str): Operation name ("maci") + - dst (list of tuples): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for 'maci': dst[0] = dst[0] + src[0] * imm + - src (list of tuples): List of sources of the form (variable_name, suggested_bank). + This list has a single element for 'maci'. + - imm (str): Name of immediate identifier. This will be replaced by a literal value during linkage. + - res (int): Residual for the operation. + - comment (str): String with the comment attached to the line (empty string if no comment). + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_PISA_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_PISA_SOURCES, + params_start) + retval.update(dst_src) + retval["imm"] = instr_tokens[params_end] + retval["res"] = int(instr_tokens[params_end + 1]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name as a string. + """ + return "maci" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + imm: str, + res: int, + comment: str = "", + throughput: int = None, + latency: int = None): + """ + Initializes an Instruction object for a 'maci' operation. + + Parameters: + id (int): + The unique identifier for the instruction. + N (int): + The ring size. + dst (list): + List of destination variables. + src (list): + List of source variables. + imm (str): + The immediate value identifier. + res (int): + The residual for the operation. + comment (str, optional): + An optional comment for the instruction. + throughput (int, optional): + The throughput of the instruction. Defaults to the class-level default if not provided. + latency (int, optional): + The latency of the instruction. Defaults to the class-level default if not provided. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self.__imm = imm # (Read-only) immediate + self._set_dests(dst) + self._set_sources([self.dests[0]] + src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the object. + """ + retval = ('<{}({}) object at {}>(id={}[0], res={}, imm={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.imm, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + @property + def imm(self): + """ + Returns the immediate value identifier. + + Returns: + str: The immediate value identifier. + """ + return self.__imm + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the destinations. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the sources. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in kernel format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # N, muli, dst (bank), src0 (bank), imm, res # comment + preamble = (self.N,) + extra_args = (self.imm,) + extra_args = tuple(src.toPISAFormat() for src in self.sources[1:]) + extra_args + extra_args = tuple(dst.toPISAFormat() for dst in self.dests) + extra_args + if self.res is not None: + extra_args += (self.res,) + return self.toStringFormat(preamble, + self.OP_NAME_PISA, + *extra_args) + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat(self.imm) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py new file mode 100644 index 00000000..80a47f3c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py @@ -0,0 +1,222 @@ +from assembler.common.cycle_tracking import CycleType +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable, DummyVariable +from assembler.memory_model.register_file import Register + +class Instruction(XInstruction): + """ + Encapsulates a `move` instruction used to copy data from one register to a different one. + + This class is responsible for managing the movement of variables between registers + in accordance with a specific instruction set architecture (ISA) specification. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_move.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name as a string. + """ + return "move" + + def __init__(self, + id: int, + dst: Register, + src: list, + dummy_var: DummyVariable = None, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `move` CInstruction. + + Parameters: + id (int): + User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + dst (Register): + The destination register where to load the variable in `src`. + src (list of Variable): + A list containing a single Variable object indicating the source variable to move from + its current register to `dst` register. + dummy_var (DummyVariable, optional): + A dummy variable used for marking registers as free. + throughput (int, optional): + The throughput of the instruction. Defaults to the class-level default if not provided. + latency (int, optional): + The latency of the instruction. Defaults to the class-level default if not provided. + comment (str, optional): + An optional comment for the instruction. + + Raises: + ValueError: If a dummy variable is used as a source or if the destination register is not empty. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + if any(isinstance(v, DummyVariable) or not v.name for v in src): + raise ValueError(f"{Instruction.OP_NAME_ASM} cannot have dummy variable as source.") + if dst.contained_variable \ + and not isinstance(dst.contained_variable, DummyVariable): + raise ValueError("{}: destination register must be empty, but variable {}.{} found.".format(Instruction.OP_NAME_ASM, + dst.contained_variable.name, + dst.contained_variable.tag)) + N = 0 # Does not require ring-size + super().__init__(id, N, throughput, latency, comment=comment) + self.__dummy_var = dummy_var + self._set_dests([dst]) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the object. + """ + retval = ('<{}({}) object at {}>(id={}[0], ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination register for the instruction. + + Parameters: + value (list): A list of Register objects representing the destination. + + Raises: + ValueError: If the list does not contain the expected number of Register objects. + TypeError: If the list contains non-Register objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Register) for x in value): + raise TypeError("`value`: Expected list of `Register` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variable for the instruction. + + Parameters: + value (list): A list of Variable objects representing the source. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Scheduling `move` XInst will cause the involved registers and variables to be + updated. The source register for the variable will be freed, and the variable + will be allocated to the destination register. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_id (int): The schedule identifier. + + Raises: + RuntimeError: If the instruction is not ready to execute yet or if the target register is not empty. + + Returns: + int: The throughput for this instruction, i.e., the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + + variable = self.sources[0] # Expected sources to contain a Variable + target_register = self.dests[0] + if isinstance(variable, Register): + # Source and target types are swapped after scheduling + # Instruction already scheduled: can only schedule once + assert(isinstance(target_register, Variable)) + raise RuntimeError(f'Instruction `{self.name}` (id = {self.id}) already scheduled.') + + if target_register.contained_variable \ + and not isinstance(target_register.contained_variable, DummyVariable): + raise RuntimeError(('Instruction `{}` (id = {}) ' + 'cannot be scheduled because target register `{}` is not empty: ' + 'contains variable "{}".').format(self.name, + self.id, + target_register.name, + target_register.contained_variable.name)) + + assert not target_register.contained_variable or self.__dummy_var == target_register.contained_variable + # Perform the move + register_dirty = variable.register_dirty + source_register = variable.register + target_register.allocateVariable(variable) + source_register.allocateVariable(self.__dummy_var) # Mark source register as free for next bundle + assert source_register.bank.bank_index == 0 + # Swap source and dest to keep the output format of the string instruction consistent + self.sources[0] = source_register + self.dests[0] = variable + + retval = super()._schedule(cycle_count, schedule_id) + # We only moved the variable, we didn't change its value + variable.register_dirty = register_dirty # Preserve register dirty state + + if self.comment: + self.comment += ';' + self.comment += ' variable "{}"'.format(variable.name) + + return retval + + def _toPISAFormat(self, *extra_args) -> str: + """ + This instruction has no PISA equivalent. + + Parameters: + extra_args: Additional arguments (not used). + + Returns: + None: As there is no PISA equivalent for this instruction. + """ + return None + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py new file mode 100644 index 00000000..a17e2673 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py @@ -0,0 +1,225 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Represents a `mul` (multiply) instruction in an assembly language. + + This class is responsible for parsing, representing, and converting `mul` instructions + according to a specific instruction set architecture (ISA) specification. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mul.md + + Methods: + parseFromPISALine: Parses a `mul` instruction from a Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> Namespace: + """ + Parses a 'mul' instruction from a Kernel instruction string. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, mul, dst (bank), src0 (bank), src1 (bank), res # comment + Comment is optional. + + Example line: + "13, mul , output_0_1_3 (2), c_0_1_3 (0), d_0_1_3 (1), 1" + + Returns: + Namespace: Namespace representing the parsed information, + or None if a 'mul' could not be parsed from the input. + + Element `parsed_op` is a namespace with the following attributes: + - N (int): Ring size = Log_2(PMD) + - op_name (str): Operation name ("mul") + - dst (list of tuples): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for 'mul'. + - src (list of tuples): List of sources of the form (variable_name, suggested_bank). + This list has two elements for 'mul'. + - res (int): Residual for the operation. + - comment (str): String with the comment attached to the line (empty string if no comment). + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["res"] = int(instr_tokens[params_end]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name as a string. + """ + return "mul" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Initializes an Instruction object for a 'mul' operation. + + Parameters: + id (int): The unique identifier for the instruction. + N (int): The ring size. + dst (list): List of destination variables. + src (list): List of source variables. + res (int): The residual for the operation. + throughput (int, optional): The throughput of the instruction. Defaults to the class-level default if not provided. + latency (int, optional): The latency of the instruction. Defaults to the class-level default if not provided. + comment (str, optional): An optional comment for the instruction. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the object. + """ + retval = ('<{}({}) object at {}>(id={}[0], res={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the destinations. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the sources. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in kernel format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toPISAFormat() + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py new file mode 100644 index 00000000..eb3ce83c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py @@ -0,0 +1,246 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Represents a `muli` (multiply immediate) instruction in an assembly language. + + This class is responsible for parsing, representing, and converting `muli` instructions + according to a specific instruction set architecture (ISA) specification. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_muli.md + + Methods: + parseFromPISALine: Parses a `muli` instruction from a Kernel instruction string. + imm: Property to get the immediate value identifier. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> list: + """ + Parses a 'muli' instruction from a Kernel instruction string. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, muli, dst (bank), src0 (bank), imm, res # comment + Comment is optional. + + Example line: + "13, muli , output_0_1_3 (2), c_0_1_3 (0), immediate, 1" + + Returns: + list: A list of tuples with a single element representing the parsed information, + or an empty list if a 'muli' could not be parsed from the input. + + Element `Instruction` is this class. + + Element `parsed_op` is a namespace with the following attributes: + - N (int): Ring size = Log_2(PMD) + - op_name (str): Operation name ("muli") + - dst (list of tuples): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for 'muli'. + - src (list of tuples): List of sources of the form (variable_name, suggested_bank). + This list has a single element for 'muli'. + - imm (str): Name of immediate identifier. This will be replaced by a literal value during linkage. + - res (int): Residual for the operation. + - comment (str): String with the comment attached to the line (empty string if no comment). + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["imm"] = instr_tokens[params_end] + retval["res"] = int(instr_tokens[params_end + 1]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name as a string. + """ + return "muli" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + imm: str, + res: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Initializes an Instruction object for a 'muli' operation. + + Parameters: + id (int): The unique identifier for the instruction. + N (int): The ring size. + dst (list): List of destination variables. + src (list): List of source variables. + imm (str): The immediate value identifier. + res (int): The residual for the operation. + throughput (int, optional): The throughput of the instruction. Defaults to the class-level default if not provided. + latency (int, optional): The latency of the instruction. Defaults to the class-level default if not provided. + comment (str, optional): An optional comment for the instruction. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self.__imm = imm # (Read-only) immediate + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the object. + """ + retval = ('<{}({}) object at {}>(id={}[0], res={}, imm={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.imm, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + @property + def imm(self): + """ + Returns the immediate value identifier. + + Returns: + str: The immediate value identifier. + """ + return self.__imm + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the destinations. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the sources. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in kernel format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # N, muli, dst (bank), src0 (bank), imm, res # comment + return super()._toPISAFormat(self.imm) + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat(self.imm) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py new file mode 100644 index 00000000..5c55c7f3 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py @@ -0,0 +1,112 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `nop` (no operation) instruction in an assembly language. + + This class handles the representation and conversion of `nop` instructions, + which are used to introduce idle cycles in the execution pipeline. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_nop.md + """ + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name as a string. + """ + return "nop" + + def __init__(self, + id: int, + idle_cycles: int, + comment: str = ""): + """ + Initializes an Instruction object for a 'nop' operation. + + Parameters: + id (int): The unique identifier for the instruction. + idle_cycles (int): The number of idle cycles for the 'nop' operation. + comment (str, optional): An optional comment for the instruction. + """ + N = 0 + # Throughput and latency for `nop` is the number of idle cycles + super().__init__(id, N, idle_cycles, idle_cycles, comment=comment) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the object. + """ + retval = ('<{}({}) object at {}>(id={}[0], ' + 'idle_cycles={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput) + return retval + + def _set_dests(self, value): + """ + Raises an error as 'nop' instruction does not have destination parameters. + + Parameters: + value: The value to set as destinations (not used). + + Raises: + RuntimeError: Always raised as 'nop' has no parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _set_sources(self, value): + """ + Raises an error as 'nop' instruction does not have source parameters. + + Parameters: + value: The value to set as sources (not used). + + Raises: + RuntimeError: Always raised as 'nop' has no parameters. + """ + raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") + + def _toPISAFormat(self, *extra_args) -> str: + """ + Indicates that this instruction has no PISA equivalent. + + Parameters: + extra_args: Additional arguments (not used). + + Returns: + None: As there is no PISA equivalent for 'nop'. + """ + return None + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # The idle cycles in the ASM ISA for `nop` must be one less because decoding/scheduling + # the instruction counts as a cycle. + return super()._toXASMISAFormat(self.throughput - 1) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py new file mode 100644 index 00000000..5b68c9d8 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py @@ -0,0 +1,243 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Represents an `ntt` (Number Theoretic Transform) instruction in an assembly language. + + This class is responsible for parsing, representing, and converting `ntt` instructions + according to a specific instruction set architecture (ISA) specification. + + For more information, check the `ntt` specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_ntt.md + + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> object: + """ + Parses an 'ntt' instruction from a pre-processed Kernel instruction string. + + A preprocessed kernel contains the split implementation of the original NTT into + HERACLES equivalent rshuffle, ntt, twntt. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, ntt, dst_top, dest_bot, src_top, src_bot, src_tw, stage, res # comment + Comment is optional. + + Example line: + "15, ntt, outtmp_9_0 (2), outtmp_9_2 (3), output_9_0 (2), output_9_1 (3), w_gen_17_1 (1), 1, 9" + + Returns: + Namespace: A namespace with the following attributes: + - N (int): Ring size = Log_2(PMD) + - op_name (str): Operation name ("ntt") + - dst (list of tuples): List of destinations of the form (variable_name, suggested_bank). + This list has two elements for 'ntt'. + - src (list of tuples): List of sources of the form (variable_name, suggested_bank). + This list has three elements for 'ntt'. + - stage (int): Stage number of the current NTT instruction. + - res (int): Residual for the operation. + - comment (str): String with the comment attached to the line (empty string if no comment). + + Returns None if an 'ntt' could not be parsed from the input. + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["stage"] = int(instr_tokens[params_end]) + retval["res"] = int(instr_tokens[params_end + 1]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: The operation name as a string. + """ + return "ntt" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + stage: int, + res: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Initializes an Instruction object. + + Parameters: + id (int): The unique identifier for the instruction. + N (int): The ring size. + dst (list): List of destination variables. + src (list): List of source variables. + stage (int): The stage number of the instruction. + res (int): The residual for the operation. + throughput (int, optional): The throughput of the instruction. Defaults to the class-level default if not provided. + latency (int, optional): The latency of the instruction. Defaults to the class-level default if not provided. + comment (str, optional): An optional comment for the instruction. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self.__stage = stage # (Read-only) stage + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the object. + """ + retval = ('<{}({}) object at {}>(id={}[0], res={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + @property + def stage(self): + """ + Returns the stage of the instruction. + + Returns: + int: The stage as an integer. + """ + return self.__stage + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the destinations. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of Variable objects representing the sources. + + Raises: + ValueError: If the list does not contain the expected number of Variable objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in kernel format as a string. + """ + if extra_args: + raise ValueError('`extra_args` not supported.') + + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + # N, ntt, dst_top (bank), dest_bot (bank), src_top (bank), src_bot (bank), src_tw (bank), stage, res # comment + retval = super()._toPISAFormat(self.stage) + + return retval + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments (not supported). + + Raises: + ValueError: If extra arguments are provided. + + Returns: + str: The instruction in ASM format as a string. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat(self.stage) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py new file mode 100644 index 00000000..89e53617 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py @@ -0,0 +1,267 @@ +import warnings + +from argparse import Namespace + +from assembler.common import constants +from assembler.instructions import xinst +from assembler.memory_model import MemoryModel + +__xntt_id = 0 + +def parseXNTTKernelLine(line: str, + op_name: str, + tw_separator: str) -> Namespace: + """ + Parses an `xntt` instruction from a P-ISA kernel instruction string. + + Parameters: + line (str): The line containing the instruction to parse. + Instruction format: N, op_name, dst0, dst1, src0, src1, twiddle, res # comment + Comment is optional. + + op_name (str): The operation name that should be contained in the line. + + tw_separator (str): The separator used in the twiddle information. + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name + dst (list of tuple): List of destinations of the form (variable_name, suggested_bank). + src (list of tuple): List of sources of the form (variable_name, suggested_bank). + res (int): Residual for the operation. + stage (int): Stage number of the current NTT instruction. + block (int): Index of current word in the 2-words (16KB) polynomial. + comment (str): String with the comment attached to the line (empty string if no comment). + + None: If an `xntt` could not be parsed from the input. + """ + + OP_NUM_DESTS = 2 + OP_NUM_SOURCES = 2 + OP_NUM_TOKENS = 8 + + retval = None + tokens = xinst.XInstruction.tokenizeFromPISALine(op_name, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + + if len(instr_tokens) > OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{op_name}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + OP_NUM_DESTS + OP_NUM_SOURCES + dst_src = xinst.XInstruction.parsePISASourceDestsFromTokens(instr_tokens, + OP_NUM_DESTS, + OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + twiddle = instr_tokens[params_end] + retval["res"] = int(instr_tokens[params_end + 1]) + + # Parse twiddle (w___, where "_" is the `tw_separator`) + twiddle_tokens = list(map(lambda s: s.strip(), twiddle.split(tw_separator))) + if len(twiddle_tokens) != 4: + raise ValueError(f'Error parsing twiddle information for "{op_name}" in line "{line}".') + if twiddle_tokens[0] != "w": + raise ValueError(f'Invalid twiddle detected for "{op_name}" in line "{line}".') + if int(twiddle_tokens[1]) != retval["res"]: + raise ValueError(f'Invalid "residual" component detected in twiddle information for "{op_name}" in line "{line}".') + retval["stage"] = int(twiddle_tokens[2]) + retval["block"] = int(twiddle_tokens[3]) + + retval = Namespace(**retval) + assert(retval.op_name == op_name) + return retval + +def __generateRMoveParsedOp(kntt_parsed_op: Namespace) -> (type, Namespace): + """ + Generates a namespace compatible with xrshuffle XInst constructor. + + Parameters: + kntt_parsed_op (Namespace): Parsed xntt object (Namespace). + + Returns: + tuple: A tuple containing the xrshuffle type and a Namespace with the parsed operation. + """ + xrshuffle_type = None + parsed_op = {} + parsed_op["N"] = kntt_parsed_op.N + parsed_op["op_name"] = "" + parsed_op["wait_cyc"] = 0 + parsed_op["dst"] = [] + parsed_op["src"] = [] + parsed_op["comment"] = "" + + if kntt_parsed_op.op_name == xinst.NTT.OP_NAME_PISA: + xrshuffle_type = xinst.rShuffle + parsed_op["dst"] = [d for d in kntt_parsed_op.dst] + elif kntt_parsed_op.op_name == xinst.iNTT.OP_NAME_PISA: + xrshuffle_type = xinst.irShuffle + parsed_op["dst"] = [s for s in kntt_parsed_op.src] + else: + raise ValueError('`kntt_parsed_op`: cannot process operation with name "{}".'.format(kntt_parsed_op.op_name)) + + assert(xrshuffle_type) + + parsed_op["src"] = parsed_op["dst"] + parsed_op["op_name"] = xrshuffle_type.OP_NAME_PISA + + # rshuffle goes above corresponding intt or below corresponding ntt + return xrshuffle_type, Namespace(**parsed_op) + +def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: + """ + Generates a namespace compatible with twxntt XInst constructor. + + Parameters: + xntt_parsed_op (Namespace): Parsed kernel xntt object (Namespace). + + Returns: + tuple: A tuple containing the twxntt type, a Namespace with the parsed operation, and a tuple with the twiddle variable name and suggested bank. + The twxntt type is None if a twxntt is not needed for the specified xntt. + """ + global __xntt_id # TODO: replace by unique ID once it gets integrated into the P-ISA kernel. + + retval = None + + parsed_op = {} + parsed_op["N"] = xntt_parsed_op.N + parsed_op["op_name"] = 'tw' + str(xntt_parsed_op.op_name) + parsed_op["res"] = xntt_parsed_op.res + parsed_op["stage"] = xntt_parsed_op.stage + parsed_op["block"] = xntt_parsed_op.block + parsed_op["dst"] = [] + parsed_op["src"] = [] + parsed_op["tw_meta"] = 0 + parsed_op["comment"] = "" + + # Find types depending on whether we are doing ntt or intt + twxntt_type = next((t for t in (xinst.twNTT, xinst.twiNTT) if t.OP_NAME_PISA == parsed_op["op_name"]), None) + assert(twxntt_type) + + # Adapted from legacy code add_tw_xntt + #------------------------------------- + + ringsize = int(parsed_op["N"]) + rminustwo = ringsize - 2 + rns_term = int(parsed_op["res"]) + stage = int(parsed_op["stage"]) + + # Generate meta data look-up + meta_rns_term = rns_term % constants.MemoryModel.MAX_RESIDUALS + mdata_word_sel = meta_rns_term >> 1 # 5bit word select + mdata_inword_res_sel = meta_rns_term & 1 + mdata_inword_stage_sel = rminustwo - stage + if twxntt_type == xinst.twiNTT: + mdata_inword_ntt_sel = 1 # Select intt field + else: # xinst.twNTT + mdata_inword_ntt_sel = 0 # Select ntt field + mdata_ptr = (mdata_word_sel << 6) + mdata_ptr |= (mdata_inword_res_sel << 5) + mdata_ptr |= (mdata_inword_ntt_sel << 4) + mdata_ptr |= mdata_inword_stage_sel + + block = int(parsed_op["block"]) + + if rns_term == 0 and stage == 0 and block == 0: + __xntt_id += 1 + + # Generate twiddle variable name + tw_var_name_bank = ("w_gen_{}_{}_{}_{}".format(mdata_inword_ntt_sel, __xntt_id, rns_term, block), 1) + + meta_data_comment = "{} {} ".format(mdata_word_sel, mdata_inword_res_sel) + meta_data_comment += "{} {} w_{}_{}_{}".format(mdata_inword_ntt_sel, mdata_inword_stage_sel, + # hop_list[6] + parsed_op["res"], parsed_op["stage"], parsed_op["block"]) + + parsed_op["dst"] = [tw_var_name_bank] + parsed_op["src"] = [tw_var_name_bank] + parsed_op["tw_meta"] = mdata_ptr + parsed_op["comment"] = meta_data_comment + + if twxntt_type == xinst.twNTT and mdata_ptr >= 0: + # NTT + retval = twxntt_type + elif twxntt_type == xinst.twiNTT and stage <= rminustwo: + # iNTT + # Only add twiddle inst in lower stages + retval = twxntt_type + # else None + + return retval, Namespace(**parsed_op), tw_var_name_bank + +def generateXNTT(mem_model: MemoryModel, + xntt_parsed_op: Namespace, + new_id: int = 0) -> list: + """ + Parses an `xntt` instruction from a P-ISA kernel instruction string. + + Parameters: + mem_model (MemoryModel): The MemoryModel object, where all variables are kept. Variables parsed from the + input string will be automatically added to the memory model if they do not already + exist. The represented object may be modified if addition is needed. + + xntt_parsed_op (Namespace): Namespace of parsed xntt from P-ISA. + + new_id (int, optional): A new ID for the instruction. Defaults to 0. + + Returns: + list: A list of `xinstruction.XInstruction` representing the instructions needed to compute the + parsed xntt. + """ + retval = [] + + # Find xntt type depending on whether we are doing ntt or intt + xntt_type = next((t for t in (xinst.NTT, xinst.iNTT) if t.OP_NAME_PISA == xntt_parsed_op.op_name), None) + if not xntt_type: + raise ValueError('`xntt_parsed_op`: cannot process parsed kernel operation with name "{}".'.format(xntt_parsed_op.op_name)) + + # Generate twiddle instruction + #----------------------------- + + twxntt_type, twxntt_parsed_op, last_twxinput_name = __generateTWNTTParsedOp(xntt_parsed_op) + # print(twxntt_parsed_op) + twxntt_inst = None + if twxntt_type: + twxntt_inst = xinst.createFromParsedObj(mem_model, twxntt_type, twxntt_parsed_op, new_id) + + # Generate corresponding rshuffle + #----------------------------- + + rshuffle_type, rshuffle_parsed_op = __generateRMoveParsedOp(xntt_parsed_op) + rshuffle_parsed_op.comment += (" " + twxntt_parsed_op.comment) if twxntt_parsed_op else "" + rshuffle_inst = xinst.createFromParsedObj(mem_model, rshuffle_type, rshuffle_parsed_op, new_id) + + # Generate xntt instruction + #-------------------------- + + # Prepare arguments for ASM ntt instruction object construction + if twxntt_parsed_op: + assert(twxntt_parsed_op.stage == xntt_parsed_op.stage) + delattr(xntt_parsed_op, "block") + xntt_parsed_op.src.append(last_twxinput_name) + xntt_parsed_op.comment += twxntt_parsed_op.comment if twxntt_parsed_op else "" + + # Create instruction + xntt_inst = xinst.createFromParsedObj(mem_model, xntt_type, xntt_parsed_op, new_id) + + # Add instructions to return list + #-------------------------------- + + retval = [xntt_inst] # xntt + + if xntt_type == xinst.iNTT: # rshuffle + # rshuffle goes above corresponding intt + retval = [rshuffle_inst] + retval + else: + # rshuffle goes below corresponding ntt + retval.append(rshuffle_inst) + + if twxntt_inst: # twiddle + retval.append(twxntt_inst) + + return retval \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py new file mode 100644 index 00000000..ff3b7090 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py @@ -0,0 +1,375 @@ +import warnings +from argparse import Namespace + +from assembler.common.cycle_tracking import CycleType +from assembler.common.decorators import * +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable +from . import irshuffle + +class Instruction(XInstruction): + """ + Encapsulates an `rshuffle` XInstruction. + + Methods: + SpecialLatency: Returns the special latency for rshuffle instructions. + SpecialLatencyMax: Returns the maximum special latency for rshuffle instructions. + SpecialLatencyIncrement: Returns the increment for special latency for rshuffle instructions. + RSHUFFLE_DATA_TYPE: Returns the data type for rshuffle instructions. + parseFromPISALine: Parses an `rshuffle` instruction from a pre-processed Kernel instruction string. + set_irshuffleGlobalCycleReady: Sets the global cycle ready based on the last irshuffle. + reset_GlobalCycleReady: Resets the global cycle ready for rshuffle and irshuffle instructions. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS : int + _OP_RMOVE_LATENCY : int + _OP_RMOVE_LATENCY_MAX: int + _OP_RMOVE_LATENCY_INC: int + + __rshuffle_global_cycle_ready = CycleType(0, 0) # Private class attribute to track cycle ready among rshuffles + __irshuffle_global_cycle_ready = CycleType(0, 0) # Private class attribute to track the cycle ready based on last irshuffle + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS, + "special_latency_max": cls._OP_RMOVE_LATENCY_MAX, + "special_latency_increment": cls._OP_RMOVE_LATENCY_INC}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def SetSpecialLatencyMax(cls, val): + cls._OP_RMOVE_LATENCY_MAX = val + + @classmethod + def SetSpecialLatencyIncrement(cls, val): + cls._OP_RMOVE_LATENCY_INC = val + cls._OP_RMOVE_LATENCY = cls._OP_RMOVE_LATENCY_INC + + @classproperty + def SpecialLatency(cls): + """ + Special latency (indicates the first increment at which another rshuffle instruction + can be scheduled within `SpecialLatencyMax` latency). + + Returns: + int: The special latency for rshuffle instructions. + """ + return cls._OP_RMOVE_LATENCY + + @classproperty + def SpecialLatencyMax(cls): + """ + Special latency maximum (cannot enqueue any other rshuffle instruction within this latency + unless it is in `SpecialLatencyIncrement`). + + Returns: + int: The maximum special latency for rshuffle instructions. + """ + return cls._OP_RMOVE_LATENCY_MAX + + @classproperty + def SpecialLatencyIncrement(cls): + """ + Special latency increment (can only enqueue any other rshuffle instruction + within `SpecialLatencyMax` only in increments of this value). + + Returns: + int: The increment for special latency for rshuffle instructions. + """ + return cls._OP_RMOVE_LATENCY_INC + + @classproperty + def RSHUFFLE_DATA_TYPE(cls): + """ + Returns the data type for rshuffle instructions. + + Returns: + str: The data type for rshuffle instructions, which is "ntt". + """ + return "ntt" + + @classmethod + def parseFromPISALine(cls, line: str) -> object: + """ + Parses an `rshuffle` instruction from a pre-processed Kernel instruction string. + + A preprocessed kernel contains the split implementation of original NTT into + HERACLES equivalent rshuffle, ntt, twntt. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, rshuffle, dst0, dst1, src0, src1, wait_cyc # comment + Comment is optional. + + Example line: + "13, rshuffle, outtmp_9_0 (2), outtmp_9_2 (3), outtmp_9_0 (2), outtmp_9_2 (3), 0" + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name ("rshuffle") + dst (list of tuple): List of destinations of the form (variable_name, suggested_bank). + This list has two elements for `rshuffle`. + src (list of tuple): List of sources of the form (variable_name, suggested_bank). + This list has two elements for `rshuffle`. + comment (str): String with the comment attached to the line (empty string if no comment). + + None: If an `rshuffle` could not be parsed from the input. + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + # Ignore "res", but make sure it exists (syntax) + assert(instr_tokens[params_end] is not None) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'rshuffle'. + """ + return "rshuffle" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + wait_cyc: int = 0, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `rshuffle` XInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + N (int): Ring size for the operation, Log_2(PMD). + dst (list of Variable): List of destination variables. + src (list of Variable): List of source variables. + wait_cyc (int, optional): The number of wait cycles. Defaults to 0. + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + comment (str, optional): A comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If `latency` is less than the special latency for rshuffle instructions. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + if latency < Instruction._OP_RMOVE_LATENCY: + raise ValueError((f'`latency`: expected a value greater than or equal to ' + '{Instruction._OP_RMOVE_LATENCY}, but {latency} received.')) + + super().__init__(id, N, throughput, latency, comment=comment) + + self.wait_cyc = wait_cyc + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, destinations, sources, and wait cycles. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'dst={}, src={}, ' + 'wait_cyc={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.wait_cyc) + return retval + + @classmethod + def __set_rshuffleGlobalCycleReady(cls, value: CycleType): + """ + Sets the global cycle ready for rshuffle instructions. + + Parameters: + value (CycleType): The cycle type value to set. + """ + if (value > cls.__rshuffle_global_cycle_ready): + cls.__rshuffle_global_cycle_ready = value + + @classmethod + def set_irshuffleGlobalCycleReady(cls, value: CycleType): + """ + Sets the global cycle ready based on the last irshuffle. + + Parameters: + value (CycleType): The cycle type value to set. + """ + if (value > cls.__irshuffle_global_cycle_ready): + cls.__irshuffle_global_cycle_ready = value + + @classmethod + def reset_GlobalCycleReady(cls, value=CycleType(0, 0)): + """ + Resets the global cycle ready for rshuffle and irshuffle instructions. + + Parameters: + value (CycleType, optional): The cycle type value to reset to. Defaults to CycleType(0, 0). + """ + cls.__rshuffle_global_cycle_ready = value + cls.__irshuffle_global_cycle_ready = value + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_DESTS} Variable objects, " + "but list with {len(value)} elements received.")) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_SOURCES} Variable objects, " + "but list with {len(value)} elements received.")) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _get_cycle_ready(self): + """ + Returns the current value for ready cycle. + + Overrides :func:`BaseInstruction._get_cycle_ready`. + + Returns: + CycleType: The maximum cycle ready among this instruction's sources and the global cycles-ready for other rshuffles and irshuffles. + """ + # This will return the maximum cycle ready among this instruction + # sources and the global cycles-ready for other rshuffles and other irshuffles. + # An rshuffle cannot be within _OP_RMOVE_LATENCY cycles from another rshuffle, + # nor within _OP_DEFAULT_LATENCY cycles from another irshuffle. + return max(super()._get_cycle_ready(), + Instruction.__irshuffle_global_cycle_ready, + Instruction.__rshuffle_global_cycle_ready) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + The ready cycle for all destinations is updated based on input `cycle_count` and + this instruction latency. The global `rshuffle` ready cycle is also updated. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: The instruction is not ready to execute yet. Based on current cycle, + the instruction is ready to execute if its cycle_ready value is less than or + equal to `cycle_count`. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + original_throughput = super()._schedule(cycle_count, schedule_id) + retval = self.throughput + self.wait_cyc + assert(original_throughput <= retval) + Instruction.__set_rshuffleGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + Instruction._OP_RMOVE_LATENCY)) + # Avoid rshuffles and irshuffles in the same bundle + irshuffle.Instruction.set_rshuffleGlobalCycleReady(CycleType(cycle_count.bundle + 1, 0)) + return retval + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in kernel format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # N, rshuffle, dst0, dst1, src0, src1, res=0 # comment + return super()._toPISAFormat(0) + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # id[0], N, op, dst_register0, dst_register1, src_register0, src_register1, wait_cycle, data_type="ntt", res=0 [# comment] + return super()._toXASMISAFormat(self.wait_cyc, self.RSHUFFLE_DATA_TYPE) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py new file mode 100644 index 00000000..3b7bcce3 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py @@ -0,0 +1,232 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Encapsulates a `sub` XInstruction. + + This instruction performs element-wise polynomial subtraction. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_sub.md + + Methods: + parseFromPISALine: Parses a `sub` instruction from a Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> list: + """ + Parses a `sub` instruction from a Kernel instruction string. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, sub, dst (bank), src0 (bank), src1 (bank), res # comment + Comment is optional. + + Example line: + "13, sub , output_0_1_3 (2), c_0_1_3 (0), d_0_1_3 (1), 1" + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name ("sub") + dst (list of tuple): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for `sub`. + src (list of tuple): List of sources of the form (variable_name, suggested_bank). + This list has two elements for `sub`. + res (int): Residual for the operation. + comment (str): String with the comment attached to the line (empty string if no comment). + + None: If a `sub` could not be parsed from the input. + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["res"] = int(instr_tokens[params_end]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'sub'. + """ + return "sub" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `sub` XInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + N (int): Ring size for the operation, Log_2(PMD). + + dst (list of Variable): List of destination variables. + + src (list of Variable): List of source variables. + + res (int): Residual for the operation. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, residual, destinations, sources, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], res={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in kernel format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toPISAFormat() + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat() \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py new file mode 100644 index 00000000..6b14a5fe --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py @@ -0,0 +1,288 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Encapsulates a `twintt` XInstruction. + + This instruction performs on-die generation of twiddle factors for the next stage of iNTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twintt.md + + Attributes: + tw_meta (int): Indexing information of the twiddle metadata. + stage (int): Stage number of the current NTT instruction. + block (int): Index of the current word in the 2-words (16KB) polynomial. + + Methods: + parseFromPISALine: Parses a `twintt` instruction from a pre-processed Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> object: + """ + Parses a `twintt` instruction from a pre-processed Kernel instruction string. + + A preprocessed kernel contains the split implementation of original NTT into + HERACLES equivalent irshuffle, intt, twintt. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, twintt, dst_tw, src_tw, tw_meta, stage, block, res # comment + Comment is optional. + + Example line: + "15, twintt, w_gen_17_1 (1), w_gen_17_1 (1), 9, 300, 1, 0" + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name ("twintt") + dst (list of tuple): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for `twintt`. + src (list of tuple): List of sources of the form (variable_name, suggested_bank). + This list has a single element for `twintt`. + tw_meta (int): Indexing information of the twiddle metadata. + stage (int): Stage number of the current NTT instruction. + block (int): Index of current word in the 2-words (16KB) polynomial. + res (int): Residual for the operation. + comment (str): String with the comment attached to the line (empty string if no comment). + + None: If a `twintt` could not be parsed from the input. + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["tw_meta"] = int(instr_tokens[params_end]) + retval["stage"] = int(instr_tokens[params_end + 1]) + retval["block"] = int(instr_tokens[params_end + 2]) + retval["res"] = int(instr_tokens[params_end + 3]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'twintt'. + """ + return "twintt" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + tw_meta: int, + stage: int, + block: int, + res: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `twintt` XInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + N (int): Ring size for the operation, Log_2(PMD). + dst (list of Variable): List of destination variables. + src (list of Variable): List of source variables. + tw_meta (int): Indexing information of the twiddle metadata. + stage (int): Stage number of the current NTT instruction. + block (int): Index of the current word in the 2-words (16KB) polynomial. + res (int): Residual for the operation. + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self.__tw_meta = tw_meta # (Read-only) tw_meta + self.__stage = stage # (Read-only) stage + self.__block = block # (Read-only) block + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, residual, tw_meta, stage, block, destinations, sources, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.tw_meta, + self.stage, + self.block, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + @property + def tw_meta(self): + """ + Returns the twiddle metadata index. + + Returns: + int: The twiddle metadata index. + """ + return self.__tw_meta + + @property + def stage(self): + """ + Returns the stage number of the current NTT instruction. + + Returns: + int: The stage number. + """ + return self.__stage + + @property + def block(self): + """ + Returns the index of the current word in the 2-words (16KB) polynomial. + + Returns: + int: The block index. + """ + return self.__block + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in kernel format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # N, twintt, dst_tw, src_tw, tw_meta, stage, block, res # comment + retval = super()._toPISAFormat(self.tw_meta, + self.stage, + self.block) + return retval + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat(self.tw_meta, + self.stage, + self.block, + self.N) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py new file mode 100644 index 00000000..3494e5b5 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py @@ -0,0 +1,298 @@ +import warnings + +from argparse import Namespace + +from .xinstruction import XInstruction +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Encapsulates a `twntt` XInstruction. + + This instruction performs on-die generation of twiddle factors for the next stage of NTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twntt.md + + Attributes: + tw_meta (int): Indexing information of the twiddle metadata. + stage (int): Stage number of the current NTT instruction. + block (int): Index of the current word in the 2-words (16KB) polynomial. + + Methods: + parseFromPISALine: Parses a `twntt` instruction from a pre-processed Kernel instruction string. + """ + + # To be initialized from ASM ISA spec + _OP_NUM_TOKENS: int + + @classmethod + def isa_spec_as_dict(cls) -> dict: + """ + Returns isa_spec attributes as dictionary. + """ + dict = super().isa_spec_as_dict() + dict.update({"num_tokens": cls._OP_NUM_TOKENS}) + return dict + + @classmethod + def SetNumTokens(cls, val): + cls._OP_NUM_TOKENS = val + + @classmethod + def parseFromPISALine(cls, line: str) -> object: + """ + Parses a `twntt` instruction from a pre-processed Kernel instruction string. + + A preprocessed kernel contains the split implementation of original NTT into + HERACLES equivalent rshuffle, ntt, twntt. + + Parameters: + line (str): String containing the instruction to parse. + Instruction format: N, twntt, dst_tw, src_tw, tw_meta, stage, block, res # comment + Comment is optional. + + Example line: + "15, twntt, w_gen_17_1 (1), w_gen_17_1 (1), 9, 300, 1, 0" + + Returns: + Namespace: A namespace with the following attributes: + N (int): Ring size = Log_2(PMD) + op_name (str): Operation name ("twntt") + dst (list of tuple): List of destinations of the form (variable_name, suggested_bank). + This list has a single element for `twntt`. + src (list of tuple): List of sources of the form (variable_name, suggested_bank). + This list has a single element for `twntt`. + tw_meta (int): Indexing information of the twiddle metadata. + stage (int): Stage number of the current NTT instruction. + block (int): Index of current word in the 2-words (16KB) polynomial. + res (int): Residual for the operation. + comment (str): String with the comment attached to the line (empty string if no comment). + + None: If a `twntt` could not be parsed from the input. + """ + retval = None + tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + if tokens: + retval = {"comment": tokens[1]} + instr_tokens = tokens[0] + if len(instr_tokens) > cls._OP_NUM_TOKENS: + warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + + retval["N"] = int(instr_tokens[0]) + retval["op_name"] = instr_tokens[1] + params_start = 2 + params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES + dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, + cls._OP_NUM_DESTS, + cls._OP_NUM_SOURCES, + params_start) + retval.update(dst_src) + retval["tw_meta"] = int(instr_tokens[params_end]) + retval["stage"] = int(instr_tokens[params_end + 1]) + retval["block"] = int(instr_tokens[params_end + 2]) + retval["res"] = int(instr_tokens[params_end + 3]) + + retval = Namespace(**retval) + assert(retval.op_name == cls.OP_NAME_PISA) + return retval + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'twntt'. + """ + return "twntt" + + def __init__(self, + id: int, + N: int, + dst: list, + src: list, + tw_meta: int, + stage: int, + block: int, + res: int, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `twntt` XInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + N (int): Ring size for the operation, Log_2(PMD). + + dst (list of Variable): List of destination variables. + + src (list of Variable): List of source variables. + + tw_meta (int): Indexing information of the twiddle metadata. + + stage (int): Stage number of the current NTT instruction. + + block (int): Index of the current word in the 2-words (16KB) polynomial. + + res (int): Residual for the operation. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + + super().__init__(id, N, throughput, latency, res=res, comment=comment) + + self.__tw_meta = tw_meta # (Read-only) tw_meta + self.__stage = stage # (Read-only) stage + self.__block = block # (Read-only) block + self._set_dests(dst) + self._set_sources(src) + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, residual, tw_meta, stage, block, destinations, sources, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, ' + 'dst={}, src={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.tw_meta, + self.stage, + self.block, + self.dests, + self.sources, + self.throughput, + self.latency) + return retval + + @property + def tw_meta(self): + """ + Returns the twiddle metadata index. + + Returns: + int: The twiddle metadata index. + """ + return self.__tw_meta + + @property + def stage(self): + """ + Returns the stage number of the current NTT instruction. + + Returns: + int: The stage number. + """ + return self.__stage + + @property + def block(self): + """ + Returns the index of the current word in the 2-words (16KB) polynomial. + + Returns: + int: The block index. + """ + return self.__block + + def _set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of Variable objects.") + super()._set_sources(value) + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to kernel format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in kernel format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + # N, twntt, dst_tw, src_tw, tw_meta, stage, block, res # comment + retval = super()._toPISAFormat(self.tw_meta, + self.stage, + self.block) + return retval + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM format. + + Parameters: + extra_args: Additional arguments, which are not supported. + + Returns: + str: The instruction in ASM format. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + return super()._toXASMISAFormat(self.tw_meta, + self.stage, + self.block, + self.N) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py new file mode 100644 index 00000000..a2bbce57 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py @@ -0,0 +1,255 @@ +from argparse import Namespace + +from assembler.common import constants +from assembler.common.cycle_tracking import CycleType +from assembler.common.decorators import * +from assembler.memory_model.variable import Variable +from assembler.memory_model.register_file import Register +from ..instruction import BaseInstruction +from .. import tokenizeFromLine + +class XInstruction(BaseInstruction): + """ + This class is used to encapsulate the properties and behaviors of an xinstruction, + including its throughput, latency, and optional residual. + + Static Methods: + tokenizeFromPISALine: Checks if the specified instruction can be parsed from the specified line and returns the tokenized line. + parsePISASourceDestsFromTokens: Parses the sources and destinations for an instruction from tokens in P-ISA format. + reset_GlobalCycleReady: Resets global cycle tracking for derived classes. + + Methods: + N: Returns the ring size for the operation. + res: Returns the residual for the operation. + """ + + @staticmethod + def tokenizeFromPISALine(op_name: str, line: str) -> list: + """ + Checks whether the specified instruction can be parsed from the specified + line and, if so, returns the tokenized line. + + Parameters: + op_name (str): Name of operation that should be contained in the line. + + line (str): Line to tokenize. + + Returns: + tuple: A tuple containing tokens (tuple of str) and comment (str), or None if the instruction cannot be parsed from the line. + """ + retval = None + tokens, comment = tokenizeFromLine(line) + if len(tokens) > 1 and tokens[1] == op_name: + retval = (tokens, comment) + return retval + + @staticmethod + def parsePISASourceDestsFromTokens(tokens: list, + num_dests: int, + num_sources: int, + offset: int = 0) -> dict: + """ + Parses the sources and destinations for an instruction, given sources and + destinations in tokens in P-ISA format. + + Parameters: + tokens (list of str): List of string tokens where each token corresponds to a destination or + a source for the instruction being parsed, in order. + + num_dests (int): Number of destinations for the instruction. + + num_sources (int): Number of sources for the instruction. + + offset (int, optional): Offset in the list of tokens where to start parsing. Defaults to 0. + + Returns: + dict: A dictionary with, at most, two keys: "src" and "dst", representing the parsed sources + and destinations for the instruction. The value for each key is a list of parsed + `Variable` tuples. + """ + retval = {} + dst_start = offset + dst_end = dst_start + num_dests + dst = [] + for dst_token in tokens[dst_start:dst_end]: + dst.append(Variable.parseFromPISAFormat(dst_token)) + src_start = dst_end + src_end = src_start + num_sources + src = [] + for src_token in tokens[src_start:src_end]: + src.append(Variable.parseFromPISAFormat(src_token)) + if dst: + retval["dst"] = dst + if src: + retval["src"] = src + return retval + + @classmethod + def reset_GlobalCycleReady(cls, value=CycleType(0, 0)): + """ + If derived classes have global cycle tracking, they should override this + method to reset their global cycle tracking when called. + + Parameters: + value (CycleType, optional): The cycle type value to reset to. Defaults to CycleType(0, 0). + """ + pass + + def __init__(self, + id: int, + N: int, + throughput: int, + latency: int, + res: int = None, + comment: str = ""): + """ + Constructs a new XInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + N (int): Ring size for the operation, Log2(PMD). Set to `0` if not known. + + throughput (int): The throughput of the instruction. + + latency (int): The latency of the instruction. + + res (int, optional): The residual for the operation. Defaults to None. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + """ + if res is not None and res >= constants.MemoryModel.MAX_RESIDUALS: + comment = f"res = {res}" + ("; " + comment if comment else "") + super().__init__(id, throughput, latency, comment=comment) + self.__n = N # Read-only ring size for the operation + self.__res = res # Read-only residual + + @property + def N(self) -> int: + """ + Ring size, Log2(PMD). This is just for tracking purposes. Set to `0` if not known. + + Returns: + int: The ring size for the operation. + """ + return self.__n + + @property + def res(self) -> int: + """ + Residual for the operation, or None if operation does not support residuals. + + Returns: + int: The residual for the operation. + """ + return self.__res + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + The ready cycle for all destinations is updated based on input `cycle_count` and + this instruction latency. + + All variables in the instruction sources and destinations are updated to reflect + the variable access. + + Derived classes can override to add their own simulation rules. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: The instruction is not ready to execute yet. Based on current cycle, + the instruction is ready to execute if its cycle_ready value is less than or + equal to `cycle_count`. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + retval = super()._schedule(cycle_count, schedule_id) + + # Update accessed cycle and access instruction of variables + vars = set(v for v in self.sources + self.dests if isinstance(v, Variable)) + for v in vars: + # Check that variable is in register file + if not v.register: + # All variables must be in register before scheduling instruction + raise RuntimeError('Instruction( {}, id={} ): Variable {} not in register file.'.format(self.name, + self.id, + v.name)) + # Update accessed cycle + v.last_x_access = cycle_count + # Remove this instruction from access list + accessed_idx = -1 + for idx, access_element in enumerate(v.accessed_by_xinsts): + if access_element.instruction_id == self.id: + accessed_idx = idx + break + assert(accessed_idx >= 0) + v.accessed_by_xinsts = v.accessed_by_xinsts[:accessed_idx] + v.accessed_by_xinsts[accessed_idx + 1:] + + # Update ready cycle and dirty state of dests + for dst in self.dests: + dst.cycle_ready = CycleType(cycle_count.bundle, cycle_count.cycle + self.latency) + dst.register_dirty = True + + return retval + + def _toPISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to P-ISA kernel format. + + See inherited for more information. + + Parameters: + extra_args: Additional arguments for formatting. + + Returns: + str: The instruction in P-ISA kernel format. + """ + preamble = (self.N,) + extra_args = tuple(src.toPISAFormat() for src in self.sources) + extra_args + extra_args = tuple(dst.toPISAFormat() for dst in self.dests) + extra_args + if self.res is not None: + extra_args += (self.res,) + return self.toStringFormat(preamble, + self.OP_NAME_PISA, + *extra_args) + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM-ISA format. + + See inherited for more information. + + Parameters: + extra_args: Additional arguments for formatting. + + Returns: + str: The instruction in ASM-ISA format. + """ + # preamble = (self.id[0], self.N) + preamble = (self.id[0],) + # Instruction sources + extra_args = tuple(src.toXASMISAFormat() for src in self.sources) + extra_args + # Instruction destinations + extra_args = tuple(dst.toXASMISAFormat() for dst in self.dests) + extra_args + if self.res is not None: + extra_args += (self.res % constants.MemoryModel.MAX_RESIDUALS,) + return self.toStringFormat(preamble, + self.OP_NAME_ASM, + *extra_args) + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the operation name in ASM format. + + Returns: + str: ASM format operation. + """ + return "default_op" # Provide a default operation name or a meaningful one if applicable \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py new file mode 100644 index 00000000..e01b1d31 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py @@ -0,0 +1,283 @@ +from assembler.common.cycle_tracking import CycleType +from .xinstruction import XInstruction +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable + +class Instruction(XInstruction): + """ + Encapsulates an `xstore` MInstruction. + + Instruction `xstore` transfers a word from a CE register into the intermediate data + buffer. The intermediate data buffer features a FIFO structure, which means that the + transferred data is pushed at the end of the queue. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_xstore.md + + Attributes: + dest_spad_address (int): The SPAD address where the source variable will be stored. + + Methods: + reset_GlobalCycleReady: Resets the global cycle ready for `xstore` instructions. + """ + + __xstore_global_cycle_ready = CycleType(0, 0) # private class attribute to track cycle ready among xstores + + @classmethod + def _get_OP_NAME_ASM(cls) -> str: + """ + Returns the ASM name of the operation. + + Returns: + str: The name of the operation in ASM format, which is 'xstore'. + """ + return "xstore" + + def __init__(self, + id: int, + src: list, + mem_model: MemoryModel, + dest_spad_addr: int = -1, + throughput: int = None, + latency: int = None, + comment: str = ""): + """ + Constructs a new `xstore` MInstruction. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + + src (list of Variable): A list containing a single Variable object indicating the source variable to store into SPAD. + Variable must be assigned to a register. + Variable `spad_address` must be negative (not assigned) or match the address of the corresponding + `cstore` instruction. + + mem_model (MemoryModel): The memory model used for storing the source variable. + + dest_spad_addr (int, optional): The SPAD address where the source variable will be stored. Defaults to -1. + + throughput (int, optional): The throughput of the instruction. Defaults to the class's default throughput. + + latency (int, optional): The latency of the instruction. Defaults to the class's default latency. + + comment (str, optional): A comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If `mem_model` is not an instance of `MemoryModel` or if `dest_spad_addr` is invalid. + """ + if not isinstance(mem_model, MemoryModel): + raise ValueError('`mem_model` must be an instance of `MemoryModel`.') + if not throughput: + throughput = Instruction._OP_DEFAULT_THROUGHPUT + if not latency: + latency = Instruction._OP_DEFAULT_LATENCY + N = 0 # Does not require ring-size + super().__init__(id, N, throughput, latency, comment=comment) + self.__mem_model = mem_model + self._set_sources(src) + self.__internal_set_dests(src) + + if dest_spad_addr < 0 and src[0].spad_address < 0: + raise ValueError('`dest_spad_addr` must be a valid SPAD address if source variable is not allocated in SPAD.') + if dest_spad_addr >= 0 and src[0].spad_address >= 0 and dest_spad_addr != src[0].spad_address: + raise ValueError('`dest_spad_addr` must be null SPAD address (negative) if source variable is allocated in SPAD.') + self.dest_spad_address = src[0].spad_address if dest_spad_addr < 0 else dest_spad_addr + + def __repr__(self): + """ + Returns a string representation of the Instruction object. + + Returns: + str: A string representation of the Instruction object, including + its type, name, memory address, ID, source, memory model, destination SPAD address, throughput, and latency. + """ + retval=('<{}({}) object at {}>(id={}[0], ' + 'src={}, mem_model, dest_spad_addr={}, ' + 'throughput={}, latency={})').format(type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.dest_spad_address, + self.throughput, + self.latency) + return retval + + @classmethod + def __set_xstoreGlobalCycleReady(cls, value: CycleType): + """ + Sets the global cycle ready for xstore instructions. + + Parameters: + value (CycleType): The cycle type value to set. + """ + if (value > cls.__xstore_global_cycle_ready): + cls.__xstore_global_cycle_ready = value + + @classmethod + def reset_GlobalCycleReady(cls, value=CycleType(0, 0)): + """ + Resets the global cycle ready for xstore instructions. + + Parameters: + value (CycleType, optional): The cycle type value to reset to. Defaults to CycleType(0, 0). + """ + cls.__xstore_global_cycle_ready = value + + def _set_dests(self, value): + """ + Raises an error as the instruction only supports setting sources. + + Parameters: + value: The value to set as destination, which is not applicable. + + Raises: + RuntimeError: Always raised as the instruction only supports setting sources. + """ + raise RuntimeError(f"Instruction `{self.name}` only supports setting sources.") + + def __internal_set_dests(self, value): + """ + Sets the destination variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as destinations. + + Raises: + ValueError: If the number of destinations is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_DESTS: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_dests(value) + + def _set_sources(self, value): + """ + Sets the source variables for the instruction. + + Parameters: + value (list): A list of `Variable` objects to set as sources. + + Raises: + ValueError: If the number of sources is incorrect. + TypeError: If the list does not contain `Variable` objects. + """ + if len(value) != Instruction._OP_NUM_SOURCES: + raise ValueError(("`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, + len(value)))) + if not all(isinstance(x, Variable) for x in value): + raise ValueError("`value`: Expected list of `Variable` objects.") + super()._set_sources(value) + self.__internal_set_dests(value) + + def _get_cycle_ready(self): + """ + Returns the current value for ready cycle. + + Overrides :func:`BaseInstruction._get_cycle_ready`. + + Returns: + CycleType: The maximum cycle ready among this instruction's sources and the global cycles-ready for other xstores. + """ + # This will return the maximum cycle ready among this instruction + # sources and the global cycles-ready for other xstores. + # An xstore cannot be within _OP_DEFAULT_LATENCY cycles from another xstore + # because they both use the SPAD-CE data channel. + return max(super()._get_cycle_ready(), + Instruction.__xstore_global_cycle_ready) + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating timings of executing this instruction. + + Scheduling `xstore` XInst will not cause the involved registers and variables to be + updated. Scheduling a `xstore` should be accompanied with the scheduling of a matching + CInst `cstore` to occur immediately after this `xstore`'s bundle is fetched by `ifetch`. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + + schedule_id (int): 1-based index for this instruction in its schedule listing. + + Raises: + RuntimeError: If the source is not a `Variable` or if the instruction is already scheduled. + See inherited method for more exceptions. + + Returns: + int: The throughput for this instruction. i.e. the number of cycles by which to advance + the current cycle counter. + """ + assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) + assert(all(src == dst for src, dst in zip(self.sources, self.dests))) + + if not isinstance(self.sources[0], Variable): + raise RuntimeError('XInstruction ({}, id = {}) already scheduled.'.format(self.name, self.id)) + + store_buffer_item = MemoryModel.StoreBufferValueType(variable=self.sources[0], + dest_spad_address=self.dest_spad_address) + register = self.sources[0].register + retval = super()._schedule(cycle_count, schedule_id) + # Perform xstore + register.register_dirty = False # Register has been flushed + register.allocateVariable(None) + self.sources[0] = register # Make the register the source for freezing, since variable is no longer in it + self.__mem_model.store_buffer[store_buffer_item.variable.name] = store_buffer_item + # Matching CInst cstore completes the xstore + + if self.comment: + self.comment += ';' + self.comment += ' variable "{}": SPAD({}) <- {}'.format(store_buffer_item.variable.name, + store_buffer_item.dest_spad_address, + register.name) + + # Set the global cycle ready for next xstore + Instruction.__set_xstoreGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + self.latency)) + return retval + + def _toPISAFormat(self, *extra_args) -> str: + """ + This instruction has no PISA equivalent. + + Returns: + None + """ + return None + + def _toXASMISAFormat(self, *extra_args) -> str: + """ + Converts the instruction to ASM-ISA format. + + Parameters: + extra_args: Variable number of arguments to add before the residual in the resulting string. + + Returns: + str: A string representation of the instruction in ASM-ISA format. The string has the form: + `id[0], N, xstore, dst_spad_addr, src_register, res=0 [# comment]` + Since the residual is mandatory in the format, it is set to `0` in the output if the + instruction does not support residual. + `dst_spad_addr` may be ignored as it is for bookkeeping purposes only. + + Raises: + ValueError: If extra arguments are provided. + """ + assert(len(self.dests) == Instruction._OP_NUM_DESTS) + assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + + if extra_args: + raise ValueError('`extra_args` not supported.') + + preamble = (self.id[0],) + # Instruction sources + extra_args = tuple(src.toXASMISAFormat() for src in self.sources) + extra_args + # Instruction destinations + # extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + # extra_args += (0,) # res = 0 + return self.toStringFormat(preamble, + self.OP_NAME_ASM, + *extra_args) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/isa_spec/__init__.py b/assembler_tools/hec-assembler-tools/assembler/isa_spec/__init__.py new file mode 100644 index 00000000..eff3f06a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/isa_spec/__init__.py @@ -0,0 +1,144 @@ +import os +import json +import assembler.instructions.cinst as cinst +import assembler.instructions.minst as minst +import assembler.instructions.xinst as xinst + +class SpecConfig: + __target_cops = { + "bload" : cinst.bload.Instruction, + "bones" : cinst.bones.Instruction, + "exit" : cinst.cexit.Instruction, + "cload" : cinst.cload.Instruction, + "nop" : cinst.cnop.Instruction, + "cstore" : cinst.cstore.Instruction, + "csyncm" : cinst.csyncm.Instruction, + "ifetch" : cinst.ifetch.Instruction, + "kgload" : cinst.kgload.Instruction, + "kgseed" : cinst.kgseed.Instruction, + "kgstart" : cinst.kgstart.Instruction, + "nload" : cinst.nload.Instruction, + "xinstfetch": cinst.xinstfetch.Instruction, + } + + __target_xops = { + "add" : xinst.add.Instruction, + "copy" : xinst.copy_mod.Instruction, + "exit" : xinst.exit_mod.Instruction, + "intt" : xinst.intt.Instruction, + "irshuffle": xinst.irshuffle.Instruction, + "mac" : xinst.mac.Instruction, + "maci" : xinst.maci.Instruction, + "move" : xinst.move.Instruction, + "mul" : xinst.mul.Instruction, + "muli" : xinst.muli.Instruction, + "nop" : xinst.nop.Instruction, + "ntt" : xinst.ntt.Instruction, + "rshuffle" : xinst.rshuffle.Instruction, + "sub" : xinst.sub.Instruction, + "twintt" : xinst.twintt.Instruction, + "twntt" : xinst.twntt.Instruction, + "xstore" : xinst.xstore.Instruction, + } + + __target_mops = { + "mload" : minst.mload.Instruction, + "mstore": minst.mstore.Instruction, + "msyncc": minst.msyncc.Instruction, + } + + _target_ops = { + "xinst": __target_xops, + "cinst": __target_cops, + "minst": __target_mops + } + + _target_attributes = { + "num_tokens" : "SetNumTokens", + "num_dests" : "SetNumDests", + "num_sources" : "SetNumSources", + "default_throughput" : "SetDefaultThroughput", + "default_latency" : "SetDefaultLatency", + "special_latency_max" : "SetSpecialLatencyMax", + "special_latency_increment": "SetSpecialLatencyIncrement", + } + + @classmethod + def dump_isa_spec_to_json(cls, filename): + """ + Dumps the attributes of all ops' classes as a JSON file under the "isa_spec" section. + + Args: + filename (str): The name of the JSON file to write to. + """ + isa_spec_dict = {} + + for inst_type, ops in cls._target_ops.items(): + isa_spec_dict[inst_type] = {} + + for op_name, op in ops.items(): + # Call the as_dict method to get attributes + class_dict = op.isa_spec_as_dict() + # Store the attributes in the dictionary + isa_spec_dict[inst_type][op_name] = class_dict + + # Wrap the isa_spec_dict in a top-level dictionary + output_dict = {"isa_spec": isa_spec_dict} + + # Write the dictionary to a JSON file + with open(filename, 'w') as json_file: + json.dump(output_dict, json_file, indent=4) + + @classmethod + def init_isa_spec_from_json(cls, filename): + """ + Updates ops' class attributes using methods specified in the target_attributes dictionary based on a JSON file. + This method checks wether values found on json file exists in target dictionaries. + + Args: + filename (str): The name of the JSON file to read from. + """ + with open(filename, 'r') as json_file: + data = json.load(json_file) + + # Check for the "isa_spec" section + if "isa_spec" not in data: + raise ValueError("The JSON file does not contain the 'isa_spec' section.") + + isa_spec = data["isa_spec"] + + for inst_type, ops in cls._target_ops.items(): + if inst_type not in isa_spec: + raise ValueError(f"Instruction type '{inst_type}' is not found in the JSON file.") + + for op_name, op in ops.items(): + if op_name not in isa_spec[inst_type]: + raise ValueError(f"Operation '{op_name}' is not found in the JSON file for instruction type '{inst_type}'.") + + attributes = isa_spec[inst_type][op_name] + + for attr_name, value in attributes.items(): + if attr_name in cls._target_attributes: + method_name = cls._target_attributes[attr_name] + setter = getattr(op, method_name) + setter(value) + else: + raise ValueError(f"Attribute '{attr_name}' is not recognized.") + + @classmethod + def initialize_isa_spec(cls, module_dir, isa_spec_file): + + if not isa_spec_file: + isa_spec_file = os.path.join(module_dir, "config/isa_spec.json") + isa_spec_file = os.path.abspath(isa_spec_file) + + if not os.path.exists(isa_spec_file): + raise FileNotFoundError( + f"Required ISA Spec file not found: {isa_spec_file}\n" + "Please provide a valid path using the `isa_spec` option, " + "or use a valid default file at: `/config/isa_spec.json`." + ) + + cls.init_isa_spec_from_json(isa_spec_file) + + return isa_spec_file diff --git a/assembler_tools/hec-assembler-tools/assembler/isa_spec/cinst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/isa_spec/cinst/__init__.py new file mode 100644 index 00000000..6e827e85 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/isa_spec/cinst/__init__.py @@ -0,0 +1,638 @@ +from .. import ISASpecInstruction + +class BLoad(ISASpecInstruction): + """ + Represents a `bload` instruction. + + This instruction loads metadata from scratchpad to register file. + + For more information, check the `bload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bload.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 5. + """ + return 5 + +class BOnes(ISASpecInstruction): + """ + Represents a `bones` instruction. + + The `bones` instruction loads metadata of identity (one) from the scratchpad to the register file. + + For more information, check the `bones` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bones.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 5. + """ + return 5 + +class Exit(ISASpecInstruction): + """ + Represents an `cexit` instruction. + + This instruction terminates execution of a HERACLES program. + + For more information, check the `cexit` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cexit.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class CLoad(ISASpecInstruction): + """ + Represents a `cload` instruction. + + This instruction loads a single polynomial residue from scratchpad into a register. + + For more information, check the `cload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cload.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 4. + """ + return 4 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 4. + """ + return 4 + +class Nop(ISASpecInstruction): + """ + Represents a `nop` instruction. + + This instruction adds desired amount of idle cycles in the Cfetch flow. + + For more information, check the `nop` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_nop.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class CStore(ISASpecInstruction): + """ + Represents a `cstore` instruction. + + This instruction fetchs a single polynomial residue from the intermediate data buffer and store back to SPAD. + + For more information, check the `cstore` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cstore.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 5. + """ + return 5 + +class CSyncM(ISASpecInstruction): + """ + Represents a `csyncm` instruction. + + Wait instruction similar to a barrier that stalls the execution of CINST + queue until the specified instruction from MINST queue has completed. + + For more information, check the `csyncm` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_csyncm.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class iFetch(ISASpecInstruction): + """ + Represents an `ifetch` instruction. + + This instruction fetchs a bundle of instructions from the XINST queue and send it to the CE for execution. + + For more information, check the `ifetch` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_ifetch.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 5. + """ + return 5 + +class KGLoad(ISASpecInstruction): + """ + Represents a `kgload` instruction. + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 4. + """ + return 4 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 40. + """ + return 40 + +class KGSeed(ISASpecInstruction): + """ + Represents a `kgseed` instruction. + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class KGStart(ISASpecInstruction): + """ + Represents a `kgstart` instruction. + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 40. + """ + return 40 + +class NLoad(ISASpecInstruction): + """ + Represents a `nload` instruction. + + This instruction loads metadata (for NTT/iNTT routing mapping) from + scratchpad into a special routing table register. + + For more information, check the `nload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_nload.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 4. + """ + return 4 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 4. + """ + return 4 + +class XInstFetch(ISASpecInstruction): + """ + Represents an `xinstfetch` instruction. + + Fetches instructions from the HBM and sends it to the XINST queue. + + For more information, check the `xinstfetch` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_xinstfetch.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/isa_spec/minst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/isa_spec/minst/__init__.py new file mode 100644 index 00000000..463df89b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/isa_spec/minst/__init__.py @@ -0,0 +1,152 @@ +from .. import ISASpecInstruction + +class MLoad(ISASpecInstruction): + """ + Represents an `mload` instruction, inheriting from ISASpecInstruction. + + This instruction loads a single polynomial residue from local memory to scratchpad. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_mload.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands for the instruction. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands for the instruction. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class MStore(ISASpecInstruction): + """ + Represents an `mstore` instruction, inheriting from ISASpecInstruction. + + This instruction stores a single polynomial residue from scratchpad to local memory. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_mstore.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands for the instruction. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands for the instruction. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class MSyncC(ISASpecInstruction): + """ + Represents an MSyncC instruction, inheriting from ISASpecInstruction. + + Wait instruction similar to a barrier that stalls the execution of MINST + queue until the specified instruction from CINST queue has completed. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_msyncc.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands for the instruction. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands for the instruction. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/isa_spec/xinst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/isa_spec/xinst/__init__.py new file mode 100644 index 00000000..0fd3e554 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/isa_spec/xinst/__init__.py @@ -0,0 +1,919 @@ +from assembler.common.decorators import * +from .. import ISASpecInstruction + +class Add(ISASpecInstruction): + """ + Represents an `add` instruction. + + This instructions adds two polynomials stored in the register file and + store the result in a register. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_add.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 2. + """ + return 2 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class Copy(ISASpecInstruction): + """ + Represents a Copy instruction. + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class Exit(ISASpecInstruction): + """ + Represents an `exit` instruction. + + This instruction terminates execution of an instruction bundle. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_exit.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class iNTT(ISASpecInstruction): + """ + Represents an `intt` instruction. + + The Inverse Number Theoretic Transform (iNTT), converts NTT form to positional form. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_intt.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 2. + """ + return 2 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 3. + """ + return 3 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class irShuffle(ISASpecInstruction): + """ + Represents an irShuffle instruction with special latency properties. + + Properties: + SpecialLatency: Indicates the first increment at which another irshuffle instruction + can be scheduled within `SpecialLatencyMax` latency. + SpecialLatencyMax: Cannot enqueue any other irshuffle instruction within this latency + unless it is in `SpecialLatencyIncrement`. + SpecialLatencyIncrement: Can only enqueue any other irshuffle instruction + within `SpecialLatencyMax` only in increments of this value. + """ + + @classproperty + def SpecialLatency(cls): + """ + Special latency (indicates the first increment at which another irshuffle instruction + can be scheduled within `SpecialLatencyMax` latency). + + Returns: + int: The special latency increment. + """ + return cls.SpecialLatencyIncrement + + @classproperty + def SpecialLatencyMax(cls): + """ + Special latency maximum (cannot enqueue any other irshuffle instruction within this latency + unless it is in `SpecialLatencyIncrement`). + + Returns: + int: The special latency maximum, which is 17. + """ + return 17 + + @classproperty + def SpecialLatencyIncrement(cls): + """ + Special latency increment (can only enqueue any other irshuffle instruction + within `SpecialLatencyMax` only in increments of this value). + + Returns: + int: The special latency increment, which is 5. + """ + return 5 + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 2. + """ + return 2 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 2. + """ + return 2 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 23. + """ + return 23 + +class Mac(ISASpecInstruction): + """ + Represents a `mac` instruction. + + Element-wise polynomial multiplication and accumulation. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mac.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 2. + """ + return 2 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class Maci(ISASpecInstruction): + """ + Represents a `maci` instruction. + + Element-wise polynomial scaling by an immediate value and accumulation. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_maci.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class Move(ISASpecInstruction): + """ + Represents a `move` instruction. + + This instruction copies data from one register to a different one. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_move.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class Mul(ISASpecInstruction): + """ + Represents a `mul` instruction. + + This instructions performs element-wise polynomial multiplication. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mul.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 2. + """ + return 2 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class Muli(ISASpecInstruction): + """ + Represents a Muli instruction. + + This instruction performs element-wise polynomial scaling by an immediate value. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_muli.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class Nop(ISASpecInstruction): + """ + Represents a `nop` instruction. + + This instruction adds a desired amount of idle cycles to the compute flow. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_nop.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 0. + """ + return 0 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 0. + """ + return 0 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 1. + """ + return 1 + +class NTT(ISASpecInstruction): + """ + Represents an `ntt` instruction (Number Theoretic Transform). + Converts positional form to NTT form. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_ntt.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 2. + """ + return 2 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 3. + """ + return 3 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class rShuffle(ISASpecInstruction): + """ + Represents an rShuffle instruction with special latency properties. + + Properties: + SpecialLatency: Indicates the first increment at which another rshuffle instruction + can be scheduled within `SpecialLatencyMax` latency. + SpecialLatencyMax: Cannot enqueue any other rshuffle instruction within this latency + unless it is in `SpecialLatencyIncrement`. + SpecialLatencyIncrement: Can only enqueue any other rshuffle instruction + within `SpecialLatencyMax` only in increments of this value. + """ + + @classproperty + def SpecialLatency(cls): + """ + Special latency (indicates the first increment at which another rshuffle instruction + can be scheduled within `SpecialLatencyMax` latency). + + Returns: + int: The special latency increment. + """ + return cls.SpecialLatencyIncrement + + @classproperty + def SpecialLatencyMax(cls): + """ + Special latency maximum (cannot enqueue any other rshuffle instruction within this latency + unless it is in `SpecialLatencyIncrement`). + + Returns: + int: The special latency maximum, which is 17. + """ + return 17 + + @classproperty + def SpecialLatencyIncrement(cls): + """ + Special latency increment (can only enqueue any other rshuffle instruction + within `SpecialLatencyMax` only in increments of this value). + + Returns: + int: The special latency increment, which is 5. + """ + return 5 + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 2. + """ + return 2 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 2. + """ + return 2 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 23. + """ + return 23 + +class Sub(ISASpecInstruction): + """ + Represents a `sub` instruction. + + Element-wise polynomial subtraction. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_sub.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 2. + """ + return 2 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class twiNTT(ISASpecInstruction): + """ + Represents a `twintt` instruction. + + This instruction performs on-die generation of twiddle factors for the next stage of iNTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twintt.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class twNTT(ISASpecInstruction): + """ + Represents a `twntt` instruction. + + This instruction performs on-die generation of twiddle factors for the next stage of NTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twntt.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 + +class xStore(ISASpecInstruction): + """ + Represents an `xstore` instruction. + + This instruction transfers data from a register into the intermediate data buffer for subsequent transfer into SPAD. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_xstore.md + """ + + @classmethod + def _get_numDests(cls) -> int: + """ + Gets the number of destination operands. + + Returns: + int: The number of destination operands, which is 1. + """ + return 1 + + @classmethod + def _get_numSources(cls) -> int: + """ + Gets the number of source operands. + + Returns: + int: The number of source operands, which is 1. + """ + return 1 + + @classmethod + def _get_throughput(cls) -> int: + """ + Gets the throughput of the instruction. + + Returns: + int: The throughput, which is 1. + """ + return 1 + + @classmethod + def _get_latency(cls) -> int: + """ + Gets the latency of the instruction. + + Returns: + int: The latency, which is 6. + """ + return 6 \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py new file mode 100644 index 00000000..2dfbb96b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/__init__.py @@ -0,0 +1,455 @@ +import os +import math +import pathlib +from typing import NamedTuple + +from assembler.common import constants +from assembler.common.decorators import * +from assembler.common.queue_dict import QueueDict +from . import hbm +from . import spad +from . import register_file +from .variable import Variable +from .variable import findVarByName +from pickle import NONE + +class MemoryModel: + """ + Represents a memory model with various components such as HBM, SPAD, and register banks. + + This class provides methods and properties to manage and interact with different parts + of the memory model, including metadata variables and output variables. + """ + class StoreBufferValueType(NamedTuple): + """ + Represents a value type for the store buffer. + + Attributes: + variable (Variable): The variable associated with the store buffer entry. + dest_spad_address (int): The destination SPAD address for the variable. + """ + variable: Variable + dest_spad_address: int + + __MAX_TWIDDLE_META_VARS_PER_SEGMENT = math.ceil(constants.MemoryModel.NUM_TWIDDLE_META_REGISTERS * \ + constants.MemoryModel.TWIDDLE_META_REGISTER_SIZE_BYTES / \ + constants.Constants.WORD_SIZE) + + @classproperty + def MAX_TWIDDLE_META_VARS_PER_SEGMENT(cls): + """ + Gets the number of variables needed to fill up the twiddle factor metadata registers. + + Returns: + int: The number of variables per segment. + """ + return cls.__MAX_TWIDDLE_META_VARS_PER_SEGMENT + + + # Constructor + # ----------- + + def __init__(self, + hbm_capacity_words: int, + spad_capacity_words: int, + num_register_banks: int = constants.MemoryModel.NUM_REGISTER_BANKS, + register_range: range = None): + """ + Initializes a new MemoryModel object. + + Args: + hbm_capacity_words (int): The capacity of the HBM in words. + spad_capacity_words (int): The capacity of the SPAD in words. + num_register_banks (int, optional): The number of register banks. Defaults to constants.MemoryModel.NUM_REGISTER_BANKS. + register_range (range, optional): A range for the indices of the registers contained in this register bank. + Defaults to `range(constants.MemoryModel.NUM_REGISTER_PER_BANKS)`. + + Raises: + ValueError: If the number of register banks is less than the required minimum. + """ + # check that constant is correct + assert self.MAX_TWIDDLE_META_VARS_PER_SEGMENT == 8 + + if num_register_banks < constants.MemoryModel.NUM_REGISTER_BANKS: + raise ValueError(('`num_register_banks`: there must be at least {} register banks, ' + 'but {} requested.').format(constants.MemoryModel.NUM_REGISTER_BANKS, + num_register_banks)) + self.__register_range = range(constants.MemoryModel.NUM_REGISTER_PER_BANKS) if not register_range else register_range + # initialize members + self.__store_buffer = QueueDict() # QueueDict(var_name: str, StoreBufferValueType) + self.__variables = {} # dict(var_name, Variable) + self.__meta_ones_vars = [] # list(QueueDict()) + self.meta_ntt_aux_table: str = "" # var name + self.meta_ntt_routing_table: str = "" # var name + self.meta_intt_aux_table: str = "" # var name + self.meta_intt_routing_table: str = "" # var name + self.__meta_twiddle_vars = [] # list(QueueDict()) + self.__meta_keygen_seed_vars = QueueDict() # QueueDict(var_name: str, None): set of variables that are seeds to this operation + self.__keygen_vars = dict() # dict(var_name: str, tuple(seed_idx: int, key_idx: int)): set of variables that are output to this operation + self.__output_vars = QueueDict() # QueueDict(var_name: str, None): set of variables that are output to this operation + self.__last_keygen_order = (0, -1) # tracks the generation order of last keygen variable; next must be 1 above this order. + self.__hbm = hbm.HBM(hbm_capacity_words) + self.__spad = spad.SPAD(spad_capacity_words) + self.__register_file = tuple([register_file.RegisterBank(idx, self.__register_range) \ + for idx in range(num_register_banks)]) + + # Special Methods + # --------------- + + def __repr__(self): + """ + Returns a string representation of the MemoryModel object. + + Returns: + str: The string representation. + """ + retval = ('<{} object at {}>(hbm_capacity_words={}, ' + 'spad_capacity_words={}, ' + 'num_register_banks={}, ' + 'register_range={})').format(type(self).__name__, + hex(id(self)), + self.spad.CAPACITY_WORDS, + self.hbm.CAPACITY_WORDS, + len(self.reister_banks), + self.__register_range) + return retval + + + # Methods and properties + # ---------------------- + + @property + def hbm(self) -> hbm.HBM: + """ + Gets the HBM component of the memory model. + + Returns: + hbm.HBM: The HBM component. + """ + return self.__hbm + + @property + def spad(self) -> spad.SPAD: + """ + Gets the SPAD component of the memory model. + + Returns: + spad.SPAD: The SPAD component. + """ + return self.__spad + + @property + def store_buffer(self) -> QueueDict: + """ + Gets the store buffer between SPAD and CE. + + Returns: + QueueDict: QueueDict(var_name: str, StoreBufferValueType) + """ + return self.__store_buffer + + @property + def register_banks(self) -> tuple: + """ + Gets the register banks in the memory model register file. + + Returns: + tuple: A tuple of `RegisterBank` objects. + """ + return self.__register_file + + @property + def variables(self) -> dict: + """ + Gets the dictionary of global variables, indexed by variable name. + + These are all the variables in the program. They may not be allocated in HBM. To + check if they are allocated check the class:`Variable.hbm_address` property. It + is allocated if greater than or equal to zero. + + Returns: + dict: A dictionary of variables. + """ + return self.__variables + + def add_meta_ones_var(self, var_name: str): + """ + Marks an existing variable as Metadata Ones Variable. + + Args: + var_name (str): The name of the variable to mark. + + Raises: + RuntimeError: If the variable is not in the memory model. + """ + if var_name not in self.variables: + raise RuntimeError(f'Variable "{var_name}" is not in memory model.') + self.__meta_ones_vars.append(QueueDict()) + self.__meta_ones_vars[-1].push(var_name, None) + + @property + def meta_ones_vars_segments(self) -> list: + """ + Retrieves the set of variable names that have been marked as Metadata Ones variables. + + A list of segments (list[QueueDict(str, None)]), where each segment is + the set of variable names that have been marked as Metadata Ones variables. + The size of each set is given by the number of variables needed to fill up + the ones metadata registers (see constants.MemoryModel.NUM_ONES_META_REGISTERS). + Clients should not change these values. Use add_meta_ones_var() to add new ones metadata. + + Returns: + list: A list of segments, each containing variable names. + + """ + return self.__meta_ones_vars + + def add_meta_twiddle_var(self, var_name: str): + """ + Marks an existing variable as a twiddle metadata variable. + + Args: + var_name (str): The name of the variable to mark. + + Raises: + RuntimeError: If the variable is not in the memory model. + """ + if var_name not in self.variables: + raise RuntimeError(f'Variable "{var_name}" is not in memory model.') + # Twiddle metadata variables are grouped in segments of 8 + if len(self.__meta_twiddle_vars) <= 0 \ + or len(self.__meta_twiddle_vars[-1]) >= self.MAX_TWIDDLE_META_VARS_PER_SEGMENT: + self.__meta_twiddle_vars.append(QueueDict()) + self.__meta_twiddle_vars[-1].push(var_name, None) + + @property + def meta_twiddle_vars_segments(self) -> list: + """ + Gets the variable names that have been marked as Metadata Twiddle variables. + + Clients should not change these values. Use meta_twiddle_vars_segments() to add + new twiddle metadata. + + A list of segments (list[QueueDict(str, None)]), where each segment is a set of + variable names that have been marked as Metadata Twiddle variables. The size + of each set is given by the number of variables needed to fill up the twiddle + factor metadata registers (see MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT). + + Returns: + list: A list of segments containing variable names. + """ + return self.__meta_twiddle_vars + + def isMetaVar(self, var_name: str) -> bool: + """ + Checks whether a variable name is one of the meta variables. + + Args: + var_name (str): The name of the variable to check. + + Returns: + bool: True if the variable is a meta variable, False otherwise. + """ + return bool(var_name) and \ + (var_name in self.meta_keygen_seed_vars \ + or any(var_name in meta_twiddle_vars for meta_twiddle_vars in self.meta_twiddle_vars_segments) \ + or any(var_name in meta_ones_vars for meta_ones_vars in self.meta_ones_vars_segments) \ + or var_name in set((self.meta_ntt_aux_table, self.meta_ntt_routing_table, + self.meta_intt_aux_table, self.meta_intt_routing_table))) + + @property + def output_variables(self) -> QueueDict: + """ + Gets the set of variable names that have been marked as output variables. + + Returns: + QueueDict: The set of output variable names. + """ + return self.__output_vars + + def add_meta_keygen_seed_var(self, var_name: str): + """ + Marks an existing variable as a keygen seed. + + Args: + var_name (str): The name of the variable to mark. + + Raises: + RuntimeError: If the variable is not in the memory model. + """ + if var_name not in self.variables: + raise RuntimeError(f'Variable "{var_name}" is not in memory model.') + self.meta_keygen_seed_vars.push(var_name, None) + + @property + def meta_keygen_seed_vars(self) -> QueueDict: + """ + Gets the variable names that have been marked as keygen seed variables. + + Clients should not change these values. Use add_meta_keygen_seed_var() to add + new keygen seeds metadata. + + Returns: + QueueDict: The set of keygen seed variable names. + """ + return self.__meta_keygen_seed_vars + + @property + def keygen_variables(self) -> dict: + """ + Gets the set of variable names that have been marked as key material variables. + + Clients should not modify this list. Use add_keygen_variable() to mark a variable + as key material. + + Returns: + dict: A dictionary mapping variable names to their generation ordering. + """ + return self.__keygen_vars + + def add_keygen_variable(self, var_name: str, seed_index: int, key_index: int): + """ + Marks an existing variable as a key material variable. + + Args: + var_name (str): The name of the variable to mark. + seed_index (int): The index of the keygen seed. + key_index (int): The index of the key. + + Raises: + RuntimeError: If the variable is not used by the associated kernel, is already marked as key material, + or is marked as output. + IndexError: If the key_index is invalid or the seed_index is out of range. + """ + if var_name not in self.variables: + raise RuntimeError(f'Variable "{var_name}" is not used by associated kernel.') + if var_name in self.keygen_variables: + raise RuntimeError(f'Variable "{var_name}" is marked already as key material.') + if var_name in self.output_variables: + raise RuntimeError(f'Variable "{var_name}" is marked as output and cannot be marked as key material.') + if key_index < 0: + raise IndexError('`key_index` must be a valid zero-based index.') + if seed_index < 0 or seed_index >= len(self.meta_keygen_seed_vars): + raise IndexError(('`seed_index` must be a valid index into the existing keygen seeds. ' + 'Expected value in range [0, {}), but {} received.').format(len(self.meta_keygen_seed_vars), + seed_index)) + + self.keygen_variables[var_name] = (seed_index, key_index) + + def isVarInMem(self, var_name: str) -> bool: + """ + Checks whether the specified variable is in memory. + + Args: + var_name (str): The name of the variable to check. + + Returns: + bool: True if the variable is loaded into the register file, SPAD, or HBM. False otherwise. + + Raises: + ValueError: If the variable is not in the memory model. + """ + + if var_name not in self.variables: + raise ValueError(f'`var_name`: "{var_name}" not in memory model.') + + variable: Variable = self.variables[var_name] + return variable.hbm_address >= 0 or variable.spad_address >= 0 or variable.register is not None + + def retrieveVarAdd(self, + var_name: str, + suggested_bank: int = -1) -> Variable: + """ + Retrieves a Variable object from the global list of variables or add a new variable if not found. + + Args: + var_name (str): The name of the variable to retrieve or add. + suggested_bank (int, optional): The suggested bank for the variable. Defaults to -1. + + Returns: + Variable: The Variable object with the given name. + + Raises: + ValueError: If the suggested bank does not match the existing variable's suggested bank. + """ + + retval = self.variables[var_name] if var_name in self.variables else None + if not retval: + retval = Variable(var_name, suggested_bank) + self.variables[retval.name] = retval + if retval.suggested_bank < 0: + retval.suggested_bank = suggested_bank + elif suggested_bank >= 0: + if retval.suggested_bank != suggested_bank: + raise ValueError(('`suggested_bank`: value {} does not match existing variable "{}" ' + 'suggested bank of {}.').format(suggested_bank, + var_name, + retval.suggested_bank)) + return retval + + def findUniqueVarName(self) -> str: + """ + Find a unique variable name that is not already in use. + + Returns: + str: A unique variable name. + """ + retval = "_0" + idx = 1 + while retval in self.variables: + retval = f"_{idx}" + idx += 1 + return retval + + def __dumpVariables(self, ostream): + """ + Dump the variables to the specified output stream. + + Args: + ostream: The output stream to write the variable information to. + """ + print("name, hbm, spad, spad dirty, suggested bank, register, register_dirty, last xinst use, pending xinst use", file=ostream) + for _, variable in self.variables.items(): + print('{}, {}, {}, {}, {}, {}, {}'.format(variable.name, + variable.hbm_address, + variable.spad_address, + variable.spad_dirty, + variable.suggested_bank, + variable.register, + variable.register_dirty, + repr(variable.last_x_access), + repr(variable.accessed_by_xinsts)), + file = ostream) + + def dump(self, + output_dir = ''): + """ + Dump the memory model information to files in the specified output directory. + + Args: + output_dir (str, optional): + The directory to write the dump files to. + Defaults to the current working directory. + """ + if not output_dir: + output_dir = os.path.join(pathlib.Path.cwd(), "tmp") + pathlib.Path(output_dir).mkdir(exist_ok = True, parents=True) + print('******************') + print(f'Dumping to: {output_dir}') + + vars_filename = os.path.join(output_dir, "variables.dump.csv") + hbm_filename = os.path.join(output_dir, "hbm.dump.csv") + spad_filename = os.path.join(output_dir, "spad.dump.csv") + + with open(vars_filename, 'w') as outnum: + self.__dumpVariables(outnum) + with open(hbm_filename, 'w') as outnum: + self.hbm.dump(outnum) + with open(spad_filename, 'w') as outnum: + self.spad.dump(outnum) + for idx, rb in enumerate(self.register_banks): + register_filename = os.path.join(output_dir, f"register_bank_{idx}.dump.csv") + with open(register_filename, 'w') as outnum: + rb.dump(outnum) + + print('******************') diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py new file mode 100644 index 00000000..105d790c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/hbm.py @@ -0,0 +1,160 @@ +from assembler.common.constants import MemoryModel as mmconstants +from assembler.common.decorators import * +from .memory_bank import MemoryBank +from .variable import Variable, findVarByName +from . import mem_utilities as utilities + +class HBM(MemoryBank): + """ + Encapsulates the high-bandwidth DRAM memory model, also known as HBM. + + This class provides methods for managing the allocation and deallocation of variables + within the HBM, as well as methods for finding available addresses and dumping the + current state of the HBM. + + Constructors: + HBM(data_capacity_words: int) -> HBM + Creates a new HBM object with a specified capacity in words. + + fromCapacityBytes(data_capacity_bytes: int) -> HBM + Creates an HBM object from a specified capacity in bytes. + + Methods: + allocateForce(hbm_addr: int, var: Variable) + Forces the allocation of an existing variable at a specific address. + + deallocate(hbm_addr: int) -> Variable + Frees up the slot at the specified memory address in the memory buffer. + + deallocateVariable(var: Variable) -> Variable + Deallocates the specified variable from HBM, freeing up its slot in the memory buffer. + + findAvailableAddress(live_var_names) -> int + Retrieves the next available HBM address. + + dump(ostream) + Dumps the current state of the HBM to the specified output stream. + """ + + def __init__(self, + hbm_data_capacity_words: int): + """ + Initializes a new HBM object. + + Args: + hbm_data_capacity_words (int): Capacity in words for the HBM data region. + + Raises: + ValueError: If the capacity exceeds the maximum allowed capacity. + """ + # validate input + if hbm_data_capacity_words > mmconstants.HBM.MAX_CAPACITY_WORDS: + raise ValueError(("`hbm_data_capacity_words` must be in the range (0, {}], " + "but {} received.".format(mmconstants.HBM.MAX_CAPACITY_WORDS, hbm_data_capacity_words))) + + # initialize base + super().__init__(hbm_data_capacity_words) + + def allocateForce(self, + hbm_addr: int, + var: Variable): + """ + Forces the allocation of an existing variable at a specific address. + + Args: + hbm_addr (int): Address in HBM where to allocate the variable. + var (Variable): Variable object to allocate. The variable's hbm_address must be clear (set to a negative value). + + Raises: + ValueError: If the variable is already allocated or if there is a conflicting allocation. + RuntimeError: If the HBM is out of capacity. + """ + # validate variable + if var.hbm_address >= 0: + # variable is already allocated (avoid dangling pointers) + raise ValueError(('`var`: Variable {} address is not cleared. ' + 'Expected negative address, but {} received.'.format(var, var.hbm_address))) + + # allocate in memory bank + super().allocateForce(hbm_addr, var) + var.hbm_address = hbm_addr + + def deallocate(self, hbm_addr: int) -> object: + """ + Frees up the slot at the specified memory address in the memory buffer. + + Args: + hbm_addr (int): Address of the memory slot to free. + + Raises: + ValueError: If the address is invalid or already freed. + + Returns: + Variable: The object that was contained in the deallocated slot. + """ + + # deallocate from memory bank + var = super().deallocate(hbm_addr) + var.hbm_address = -1 + + return var + + def deallocateVariable(self, var: Variable) -> Variable: + """ + Deallocates the specified variable from HBM, freeing up its slot in the memory buffer. + + Args: + var (Variable): Variable to free. + + Raises: + ValueError: If the variable is not allocated in HBM. + + Returns: + Variable: The object that was contained in the deallocated slot. + """ + retval = self.deallocate(var.hbm_address) + assert(retval.name == var.name) + return retval + + def findAvailableAddress(self, + live_var_names) -> int: + """ + Retrieves the next available HBM address. + + Args: + live_var_names (set or list): A collection of variable names that should not be removed from HBM. + + Returns: + int: The first empty address, or -1 if no suitable address is found. + """ + return utilities.findAvailableLocation(self.buffer, live_var_names) + + def dump(self, ostream): + """ + Dumps the current state of the HBM to the specified output stream. + + Args: + ostream: The output stream to write the HBM state to. + """ + print('HBM', file = ostream) + print(f'Max Capacity, {self.CAPACITY}, Bytes', file = ostream) + print(f'Max Capacity, {self.CAPACITY_WORDS}, Words', file = ostream) + print(f'Current Capacity, {self.currentCapacityWords}, Words', file = ostream) + print(f'Current Occupied, {self.CAPACITY_WORDS - self.currentCapacityWords}, Words', file = ostream) + print("", file = ostream) + print("address, variable, variable hbm", file = ostream) + last_addr = 0 + for addr, variable in enumerate(self.buffer): + if variable is not None: + for idx in range(last_addr, addr): + # empty addresses + print(f'{idx}, None', file = ostream) + if variable.name: + print('{}, {}'.format(addr, + variable.name, + variable.hbm_address), + file = ostream) + else: + print('f{addr}, Dummy_{variable.tag}', + file = ostream) + last_addr = addr + 1 diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py new file mode 100644 index 00000000..7250116b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -0,0 +1,670 @@ +from assembler.common import constants +from assembler.instructions import tokenizeFromLine +from assembler.memory_model.variable import Variable +from . import MemoryModel + +class MemInfoVariable: + """ + Represents a memory information variable with a name and an HBM address. + + This class encapsulates the details of a variable, including its name and the + address in high-bandwidth memory (HBM) where it is stored. + """ + def __init__(self, + var_name: str, + hbm_address: int): + """ + Initializes a new MemInfoVariable object with a specified name and HBM address. + + Args: + var_name (str): The name of the variable. Must be a valid identifier. + hbm_address (int): The HBM address where the variable is stored. + + Raises: + RuntimeError: If the variable name is invalid. + """ + if not Variable.validateName(var_name): + raise RuntimeError(f'Invalid variable name "{var_name}"') + self.var_name = var_name.strip() + self.hbm_address = hbm_address + + def __repr__(self): + """ + Returns a string representation of the MemInfoVariable object. + + Returns: + str: A string representation of the object as a dictionary. + """ + return repr(self.as_dict()) + + def as_dict(self) -> dict: + """ + Converts the MemInfoVariable object to a dictionary. + + Returns: + dict: A dictionary representation of the variable, including its name and HBM address. + """ + return { 'var_name': self.var_name, + 'hbm_address': self.hbm_address } + +class MemInfoKeygenVariable(MemInfoVariable): + """ + Represents a memory information key generation variable. + + This class extends MemInfoVariable to include additional attributes for key generation, + specifically the seed index and key index associated with the variable. + """ + def __init__(self, + var_name: str, + seed_index: int, + key_index: int): + """ + Initializes a new MemInfoKeygenVariable object with a specified name, seed index, and key index. + + Args: + var_name (str): The name of the variable. Must be a valid identifier. + seed_index (int): The index of the seed used for key generation. Must be a zero-based index. + key_index (int): The index of the key. Must be a zero-based index. + + Raises: + IndexError: If the seed index or key index is negative. + """ + super().__init__(var_name, -1) + if seed_index < 0: + raise IndexError('seed_index: must be a zero-based index.') + if key_index < 0: + raise IndexError('key_index: must be a zero-based index.') + self.seed_index = seed_index + self.key_index = key_index + + def as_dict(self) -> dict: + """ + Converts the MemInfoKeygenVariable object to a dictionary. + + Returns: + dict: A dictionary representation of the variable, including its name, seed index, and key index. + """ + return { 'var_name': self.var_name, + 'seed_index': self.seed_index, + 'key_index': self.key_index } + +class MemInfo: + """ + Represents memory information for a set of variables and metadata fields. + + This class encapsulates the parsing and management of memory information variables, + including key generation variables, input and output variables, and various metadata fields. + """ + + Const = constants.MemInfo + + class Metadata: + """ + Encapsulates metadata fields within the memory information. + + This class provides methods for parsing and accessing metadata variables such as + ones, NTT auxiliary tables, NTT routing tables, iNTT auxiliary tables, iNTT routing tables, + twiddle factors, and keygen seeds. + """ + + class Ones: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses a ones metadata variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed ones metadata variable. + """ + return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, + MemInfo.Const.Keyword.LOAD_ONES, + var_prefix=MemInfo.Const.Keyword.LOAD_ONES) + + class NTTAuxTable: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses an NTT auxiliary table metadata variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed NTT auxiliary table metadata variable. + """ + return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, + MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE) + + class NTTRoutingTable: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses an NTT routing table metadata variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed NTT routing table metadata variable. + """ + return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, + MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE) + + class iNTTAuxTable: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses an iNTT auxiliary table metadata variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed iNTT auxiliary table metadata variable. + """ + return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, + MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE) + + class iNTTRoutingTable: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses an iNTT routing table metadata variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed iNTT routing table metadata variable. + """ + return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, + MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE) + + class Twiddle: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses a twiddle metadata variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed twiddle metadata variable. + """ + return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, + MemInfo.Const.Keyword.LOAD_TWIDDLE, + var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE) + + class KeygenSeed: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses a keygen seed metadata variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed keygen seed metadata variable. + """ + return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, + MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED) + + @classmethod + def parseMetaFieldFromMemLine(cls, + tokens: list, + meta_field_name: str, + var_prefix: str = "meta", + var_extra: str = None) -> MemInfoVariable: + """ + Parses a metadata variable name from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line from which to parse. Expected format: `dload, , [, var_name]`. + meta_field_name (str): Name identifying the meta field to parse from the tokens. + var_prefix (str, optional): Prefix for the metadata variable. Ignored if a name is supplied in the tokens. + var_extra (str, optional): Extra postfix to add to the variable name. Ignored if a name is supplied in the tokens. + + Returns: + MemInfoVariable: The mem info for the parsed variable, or None if no variable could be parsed. + """ + retval = None + if len(tokens) >= 3: + if tokens[0] == MemInfo.Const.Keyword.LOAD \ + and tokens[1] == meta_field_name: + hbm_addr = int(tokens[2]) + if len(tokens) >= 4 and tokens[3]: + # name supplied in the tokenized line + var_name = tokens[3] + else: + if var_extra is None: + var_extra = f'_{hbm_addr}' + else: + var_extra = var_extra.strip() + var_name = f'{var_prefix}{var_extra}' + retval = MemInfoVariable(var_name = var_name, + hbm_address = hbm_addr) + return retval + + def __init__(self, **kwargs): + """ + Initializes a new Metadata object with specified metadata fields. + + Args: + kwargs (dict): A dictionary containing metadata fields and their corresponding MemInfoVariable objects. + """ + self.__meta_dict = {} + for meta_field in MemInfo.Const.FIELD_METADATA_SUBFIELDS: + self.__meta_dict[meta_field] = [ MemInfoVariable(**d) for d in kwargs.get(meta_field, []) ] + + def __getitem__(self, key): + """ + Retrieves the list of MemInfoVariable objects for the specified metadata field. + + Args: + key: The metadata field key. + + Returns: + list: A list of MemInfoVariable objects. + """ + return self.__meta_dict[key] + + + @property + def ones(self) -> list: + """ + Retrieves the list of ones metadata variables. + + Returns: + list: Ones metadata variables. + """ + return self.__meta_dict[MemInfo.Const.MetaFields.FIELD_ONES] + + @property + def ntt_auxiliary_table(self) -> list: + """ + Retrieves the list of NTT auxiliary table metadata variables. + + Returns: + list: Metadata variables. + """ + return self.__meta_dict[MemInfo.Const.MetaFields.FIELD_NTT_AUX_TABLE] + + @property + def ntt_routing_table(self) -> list: + """ + Retrieves the list of NTT routing table metadata variables. + + Returns: + list: Metadata variables. + """ + return self.__meta_dict[MemInfo.Const.MetaFields.FIELD_NTT_ROUTING_TABLE] + + @property + def intt_auxiliary_table(self) -> list: + """ + Retrieves the list of iNTT auxiliary table metadata variables. + + Returns: + list: Metadata variables. + """ + return self.__meta_dict[MemInfo.Const.MetaFields.FIELD_iNTT_AUX_TABLE] + + @property + def intt_routing_table(self) -> list: + """ + Retrieves the list of iNTT routing table metadata variables. + + Returns: + list: Metadata variables. + """ + return self.__meta_dict[MemInfo.Const.MetaFields.FIELD_iNTT_ROUTING_TABLE] + + @property + def twiddle(self) -> list: + """ + Retrieves the list of twiddle metadata variables. + + Returns: + list: Twiddle metadata variables. + """ + return self.__meta_dict[MemInfo.Const.MetaFields.FIELD_TWIDDLE] + + @property + def keygen_seeds(self) -> list: + """ + Retrieves the list of keygen seed metadata variables. + + Returns: + list: Keygen seed metadata variables. + """ + return self.__meta_dict[MemInfo.Const.MetaFields.FIELD_KEYGEN_SEED] + + class Keygen: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses a keygen variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoKeygenVariable: Mem Info describing a keygen variable. + """ + retval = None + if len(tokens) >= 4: + if tokens[0] == MemInfo.Const.Keyword.KEYGEN: + seed_idx = int(tokens[1]) + key_idx = int(tokens[2]) + var_name = tokens[3] + retval = MemInfoKeygenVariable(var_name = var_name, + seed_index = seed_idx, + key_index = key_idx) + return retval + + class Input: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses an input variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed input variable. + """ + retval = None + if len(tokens) >= 4: + if tokens[0] == MemInfo.Const.Keyword.LOAD \ + and tokens[1] == MemInfo.Const.Keyword.LOAD_INPUT: + hbm_addr = int(tokens[2]) + var_name = tokens[3] + if Variable.validateName(var_name): + retval = MemInfoVariable(var_name = var_name, + hbm_address = hbm_addr) + return retval + + class Output: + @classmethod + def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: + """ + Parses an output variable from a tokenized line. + + Args: + tokens (list[str]): Fully tokenized line to parse from. This must include all tokens, including the initial keyword. + + Returns: + MemInfoVariable: The parsed output variable. + """ + retval = None + if len(tokens) >= 3: + if tokens[0] == MemInfo.Const.Keyword.STORE: + hbm_addr = int(tokens[2]) + var_name = tokens[1] + if Variable.validateName(var_name): + retval = MemInfoVariable(var_name = var_name, + hbm_address = hbm_addr) + return retval + + def __init__(self, **kwargs): + """ + Initializes a new MemInfo object. + + Clients may call this method without parameters for default initialization. + Clients should use MemInfo.from_iter() constructor to parse the contents of a .mem file. + + Args: + kwargs (dict): A dictionary as generated by the method MemInfo.as_dict(). This is provided as + a shortcut to creating a MemInfo object from structured data such as the contents of a YAML file. + """ + self.__keygens = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) ] + self.__inputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, []) ] + self.__outputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_OUTPUTS, []) ] + self.__metadata = MemInfo.Metadata(**kwargs.get(MemInfo.Const.FIELD_METADATA, {})) + self.validate() + + @classmethod + def from_iter(cls, line_iter): + """ + Creates a new MemInfo object from an iterator of strings, where each string is a line of text to parse. + + This constructor is intended to parse a .mem file. + + Args: + line_iter (iter): Iterator of strings. Each string is considered a line of text to parse. + + Raises: + RuntimeError: If there is an error parsing the lines. + + Returns: + MemInfo: The constructed MemInfo object. + """ + + retval = cls() + + factory_dict = { MemInfo.Keygen: retval.keygens, + MemInfo.Input: retval.inputs, + MemInfo.Output: retval.outputs, + MemInfo.Metadata.KeygenSeed: retval.metadata.keygen_seeds, + MemInfo.Metadata.Ones: retval.metadata.ones, + MemInfo.Metadata.NTTAuxTable: retval.metadata.ntt_auxiliary_table, + MemInfo.Metadata.NTTRoutingTable: retval.metadata.ntt_routing_table, + MemInfo.Metadata.iNTTAuxTable: retval.metadata.intt_auxiliary_table, + MemInfo.Metadata.iNTTRoutingTable: retval.metadata.intt_routing_table, + MemInfo.Metadata.Twiddle: retval.metadata.twiddle } + for line_no, s_line in enumerate(line_iter, 1): + s_line = s_line.strip() + if s_line: # skip empty lines + tokens, _ = tokenizeFromLine(s_line) + if tokens and len(tokens) > 0: + b_parsed = False + for mem_info_type in factory_dict: + miv: MemInfoVariable = mem_info_type.parseFromMemLine(tokens) + if miv is not None: + factory_dict[mem_info_type].append(miv) + b_parsed = True + break # next line + if not b_parsed: + raise RuntimeError(f'Could not parse line {line_no}: "{s_line}"') + retval.validate() + return retval + + @property + def keygens(self) -> list: + """ + Retrieves the list of keygen variables. + + Returns: + list: Keygen variables. + """ + return self.__keygens + + @property + def inputs(self) -> list: + """ + Retrieves the list of input variables. + + Returns: + list: Input variables. + """ + return self.__inputs + + @property + def outputs(self) -> list: + """ + Retrieves the list of output variables. + + Returns: + list: Output variables. + """ + return self.__outputs + + @property + def metadata(self) -> Metadata: + """ + Retrieves the metadata associated with this MemInfo object. + + Returns: + Metadata: MemInfo's metadata. + """ + return self.__metadata + + def as_dict(self): + """ + Returns a dictionary representation of this MemInfo object. + + Returns: + dict: A dictionary representation of the MemInfo object. + """ + return { MemInfo.Const.FIELD_KEYGENS: [ x.as_dict() for x in self.keygens ], + MemInfo.Const.FIELD_INPUTS: [ x.as_dict() for x in self.inputs ], + MemInfo.Const.FIELD_OUTPUTS: [ x.as_dict() for x in self.outputs ], + MemInfo.Const.FIELD_METADATA: { meta_field: [ x.as_dict() for x in self.metadata[meta_field] ] \ + for meta_field in MemInfo.Const.FIELD_METADATA_SUBFIELDS if self.metadata[meta_field] } } + + def validate(self): + """ + Validates the MemInfo object to ensure consistency and correctness. + + Raises: + RuntimeError: If the validation fails due to inconsistent metadata or duplicate variable names. + """ + if len(self.metadata.ones) * MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT != len(self.metadata.twiddle): + raise RuntimeError(('Expected {} times as many twiddles as ones metadata values, ' + 'but received {} twiddles and {} ones.').format(MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT, + len(self.metadata.twiddle), + len(self.metadata.ones))) + # Avoid duplicate variable names with different HBM addresses. + mem_info_vars = {} + all_var_info = self.inputs + self.outputs \ + + self.metadata.intt_auxiliary_table + self.metadata.intt_routing_table \ + + self.metadata.ntt_auxiliary_table + self.metadata.ntt_routing_table \ + + self.metadata.ones + self.metadata.twiddle + for var_info in all_var_info: + if var_info.var_name not in mem_info_vars: + mem_info_vars[var_info.var_name] = var_info + elif mem_info_vars[var_info.var_name].hbm_address != var_info.hbm_address: + raise RuntimeError(('Variable "{}" already allocated in HBM address {}, ' + 'but new allocation requested into address {}.').format(var_info.var_name, + mem_info_vars[var_info.var_name].hbm_address, + var_info.hbm_address)) + +def __allocateMemInfoVariable(mem_model: MemoryModel, + v_info: MemInfoVariable): + """ + Allocates a memory information variable in the memory model. + + This function ensures that the specified variable is allocated in the high-bandwidth memory (HBM) + of the memory model. It checks if the variable is present in the memory model and allocates it + at the specified HBM address if it is not already allocated. + + Args: + mem_model (MemoryModel): The memory model in which to allocate the variable. + v_info (MemInfoVariable): The memory information variable to allocate. + + Raises: + RuntimeError: If the variable is not present in the memory model or if there is a conflicting + allocation request. + """ + assert v_info.hbm_address >= 0 + if v_info.var_name not in mem_model.variables: + raise RuntimeError(f'Variable {v_info.var_name} not in memory model. All variables used in mem info must be present in P-ISA kernel.') + if mem_model.variables[v_info.var_name].hbm_address < 0: + mem_model.hbm.allocateForce(v_info.hbm_address, mem_model.variables[v_info.var_name]) + elif v_info.hbm_address != mem_model.variables[v_info.var_name].hbm_address: + raise RuntimeError(('Variable {} already allocated in HBM address {}, ' + 'but new allocation requested into address {}.').format(v_info.var_name, + mem_model.variables[v_info.var_name].hbm_address, + v_info.hbm_address)) + +def updateMemoryModelWithMemInfo(mem_model: MemoryModel, + mem_info: MemInfo): + """ + Updates the memory model with memory information. + + This function updates the memory model by allocating variables and metadata fields + specified in the memory information. It processes inputs, outputs, metadata, and keygen + variables, ensuring they are correctly allocated and added to the memory model. + + Args: + mem_model (MemoryModel): The memory model to update. + mem_info (MemInfo): The memory information containing variables and metadata to allocate. + + Raises: + RuntimeError: If there are inconsistencies or errors during the allocation process. + """ + + # Inputs + for v_info in mem_info.inputs: + __allocateMemInfoVariable(mem_model, v_info) + + # Outputs + for v_info in mem_info.outputs: + __allocateMemInfoVariable(mem_model, v_info) + mem_model.output_variables.push(v_info.var_name, None) + + # Metadata + + # Ones + for v_info in mem_info.metadata.ones: + mem_model.retrieveVarAdd(v_info.var_name) + __allocateMemInfoVariable(mem_model, v_info) + mem_model.add_meta_ones_var(v_info.var_name) + + # Shuffle meta vars + if mem_info.metadata.ntt_auxiliary_table: + assert(len(mem_info.metadata.ntt_auxiliary_table) == 1) + v_info = mem_info.metadata.ntt_auxiliary_table[0] + mem_model.retrieveVarAdd(v_info.var_name) + __allocateMemInfoVariable(mem_model, v_info) + mem_model.meta_ntt_aux_table = v_info.var_name + + if mem_info.metadata.ntt_routing_table: + assert(len(mem_info.metadata.ntt_routing_table) == 1) + v_info = mem_info.metadata.ntt_routing_table[0] + mem_model.retrieveVarAdd(v_info.var_name) + __allocateMemInfoVariable(mem_model, v_info) + mem_model.meta_ntt_routing_table = v_info.var_name + + if mem_info.metadata.intt_auxiliary_table: + assert(len(mem_info.metadata.intt_auxiliary_table) == 1) + v_info = mem_info.metadata.intt_auxiliary_table[0] + mem_model.retrieveVarAdd(v_info.var_name) + __allocateMemInfoVariable(mem_model, v_info) + mem_model.meta_intt_aux_table = v_info.var_name + + if mem_info.metadata.intt_routing_table: + assert(len(mem_info.metadata.intt_routing_table) == 1) + v_info = mem_info.metadata.intt_routing_table[0] + mem_model.retrieveVarAdd(v_info.var_name) + __allocateMemInfoVariable(mem_model, v_info) + mem_model.meta_intt_routing_table = v_info.var_name + + # Twiddle + for v_info in mem_info.metadata.twiddle: + mem_model.retrieveVarAdd(v_info.var_name) + __allocateMemInfoVariable(mem_model, v_info) + mem_model.add_meta_twiddle_var(v_info.var_name) + + # Keygen seeds + for v_info in mem_info.metadata.keygen_seeds: + mem_model.retrieveVarAdd(v_info.var_name) + __allocateMemInfoVariable(mem_model, v_info) + mem_model.add_meta_keygen_seed_var(v_info.var_name) + + # End metadata + + # Keygen variables + for v_info in mem_info.keygens: + mem_model.add_keygen_variable(**v_info.as_dict()) diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py new file mode 100644 index 00000000..ed9b29dc --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_utilities.py @@ -0,0 +1,134 @@ + +from assembler.common.constants import Constants +from assembler.common.cycle_tracking import CycleType +from assembler.common.priority_queue import PriorityQueue + +def computePriority(variable, replacement_policy): + """ + Computes the priority for reusing the location of a specified variable. + + The priority is determined based on the replacement policy. The smaller the priority value, + the higher the priority for reuse. Tuples are used for staged comparisons. + + Args: + variable (Variable): The variable for which to compute the priority. + replacement_policy (str): The policy to use for determining priority. Must be one of + `Constants.REPLACEMENT_POLICIES`. + + Returns: + tuple: A tuple representing the priority for reusing the variable's location. + """ + retval = (float("-inf"), ) # Default: highest priority if no variable + if variable: + # Register in use + # last_x_access = variable.last_x_access.bundle * Constants.MAX_BUNDLE_SIZE + variable.last_x_access.cycle \ + last_x_access = variable.last_x_access if variable.last_x_access \ + else CycleType(0, 0) + if replacement_policy == Constants.REPLACEMENT_POLICY_FTBU: + if variable.accessed_by_xinsts: + # Priority by + retval = (-variable.accessed_by_xinsts[0].index, # Largest (furthest) accessing instruction + *last_x_access, # Oldest accessed cycle (oldest == smallest) + len(variable.accessed_by_xinsts)) # How many more uses this variable has + elif replacement_policy == Constants.REPLACEMENT_POLICY_LRU: + # Priority by oldest accessed cycle (oldest == smallest) + retval = (*last_x_access, ) + else: + raise ValueError(f'`replacement_policy`: invalid value "{replacement_policy}". Expected value in {REPLACEMENT_POLICIES}.') + return retval + +def flushRegisterBank(register_bank, + current_cycle: CycleType, + replacement_policy, + live_var_names = None, + pct: float = 0.5): + """ + Cleans up a register bank by removing variables assigned to registers. + + The function attempts to free up to pct * 100% of registers. Only non-dirty registers + that do not contain live variables are cleaned up. Dummy variables are considered live. + + Args: + register_bank (RegisterBank): + The register bank to clean up. + current_cycle (CycleType): + The current cycle to consider for readiness. + replacement_policy (str): + The policy to use for determining which variables to replace. + Must be one of `Constants.REPLACEMENT_POLICIES`. + live_var_names (set or list, optional): + A collection of variable names that are not available + for replacement. Defaults to None. + pct (float, optional): + The fraction of the register bank to clean up. Defaults to 0.5. + """ + local_heap = PriorityQueue() + occupied_count: int = 0 + for idx, reg in enumerate(register_bank): + # Traverse the registers in the bank and put occupied, non-dirty ones + # in a heap where priority is based on replacement_policy + v = reg.contained_variable + if v is not None: + occupied_count += 1 + if not reg.register_dirty \ + and (v.name and v.name not in live_var_names) \ + and current_cycle >= v.cycle_ready: + # Variable can be cleared from the register if needed + priority = computePriority(v, replacement_policy) + local_heap.push(priority, reg, (idx, )) + + # Clean up registers until we reach the specified pct occupancy or we have + # no registers left that can be cleaned up + while local_heap \ + and occupied_count / register_bank.register_count > pct: + _, reg = local_heap.pop() + reg.allocateVariable(None) + occupied_count -= 1 + +def findAvailableLocation(vars_lst, + live_var_names, + replacement_policy: str = None): + """ + Retrieves the index of the next available location in a collection of Variable objects. + + The function proposes a location to use if all are occupied, based on a replacement policy. + Locations with dummy variables (with empty names) are considered live and will not be selected. + + Args: + vars_lst (iterable): + An iterable collection of Variable objects. Can contain `None`s. + live_var_names (set or list): + A collection of variable names that are not available for replacement. + replacement_policy (str, optional): + The policy to use for determining which variables to replace. + Must be one of `Constants.REPLACEMENT_POLICIES`. Defaults to None. + + Raises: + ValueError: If the replacement policy is invalid. + + Returns: + int: The index of the first empty location found in `vars_lst`, or the index of the suggested + location to replace if `vars_lst` is full and a replacement policy was specified. Returns -1 + if no suitable location is found. + """ + if replacement_policy and replacement_policy not in Constants.REPLACEMENT_POLICIES: + raise ValueError(('`replacement_policy`: invalid value "{}". ' + 'Expected value in {} or None.').format(replacement_policy, + Constants.REPLACEMENT_POLICIES)) + + retval = -1 + priority = (float("inf"), float("inf"), float("inf")) + for idx, v in enumerate(vars_lst): + if not v: + retval = idx + break # Found an empty spot + elif replacement_policy \ + and (v.name and v.name not in live_var_names): # Avoids dummy variables + # Find priority for replacement of this location + v_priority = computePriority(v, replacement_policy) + if v_priority < priority: + retval = idx + priority = v_priority + # At this point, highest priority location has been found in `retval`, if any + + return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py new file mode 100644 index 00000000..3029ea0e --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/memory_bank.py @@ -0,0 +1,172 @@ +from assembler.common import constants + +class MemoryBank: + """ + Base class for memory banks. + + This class simulates a memory bank and its locations, where each address in the memory bank's buffer + represents a slot with space for a word. + + Constructors: + MemoryBank(data_capacity_words: int) + Creates a MemoryBank with a specified capacity in words. + + fromCapacityBytes(data_capacity_bytes: int) -> MemoryBank + Creates a MemoryBank with a specified capacity in bytes. + + Attributes: + _current_data_capacity_words (int): Protected attribute representing the current capacity in words. + This is typically modified when allocating or deallocating spaces in the memory bank buffer. + + Properties: + CAPACITY (int): Total capacity of the memory bank in bytes. + CAPACITY_WORDS (int): Total capacity of the memory bank in words. + buffer (list): The memory bank's buffer, allocated to hold up to CAPACITY_WORDS objects. + If `buffer[i]` is None, then index `i` (address) is considered an empty/available memory slot. + currentCapacityWords (int): Current available capacity for the memory bank in words. + + Methods: + allocateForce(addr: int, obj: object) + Forces allocation of an existing object at a specific address. + + deallocate(addr: int) -> object + Frees up the slot at the specified memory address in the memory buffer. + """ + + # Constructor wrappers + # -------------------- + + @classmethod + def fromCapacityBytes(cls, data_capacity_bytes: int): + """ + Creates a new MemoryBank object with a specified capacity in bytes. + + Args: + data_capacity_bytes (int): Maximum capacity in bytes for the memory bank. + + Returns: + MemoryBank: A new instance of MemoryBank with the specified capacity. + """ + return cls(constants.convertBytes2Words(data_capacity_bytes)) + + # Constructor + # ----------- + + def __init__(self, + data_capacity_words: int): + """ + Initializes a new MemoryBank object with a specified capacity in words. + + Args: + data_capacity_words (int): Maximum capacity in words for the memory bank. + + Raises: + ValueError: If the capacity is not a positive number. + """ + if data_capacity_words <= 0: + raise ValueError(("`data_capacity_words` must be a positive number, " + "but {} received.".format(data_capacity_words))) + self.__data_capacity_words = data_capacity_words # max capacity in words + self.__data_capacity = constants.convertWords2Bytes(data_capacity_words) + self.__buffer = [None for _ in range(self.__data_capacity_words)] + self._current_data_capacity_words = self.__data_capacity_words + + # Methods and properties + # ---------------------- + + @property + def CAPACITY(self): + """ + Gets the total capacity of the memory bank in bytes. + + Returns: + int: The total capacity in bytes. + """ + return self.__data_capacity + + @property + def CAPACITY_WORDS(self): + """ + Gets the total capacity of the memory bank in words. + + Returns: + int: The total capacity in words. + """ + return self.__data_capacity_words + + @property + def currentCapacityWords(self): + """ + Gets the current available capacity for the memory bank in words. + + Returns: + int: The current available capacity in words. + """ + return self._current_data_capacity_words + + @property + def buffer(self): + """ + Gets the memory bank's buffer. + + Returns: + list: The buffer allocated to hold up to CAPACITY_WORDS objects. + """ + return self.__buffer + + def allocateForce(self, + addr: int, + obj: object): + """ + Force the allocation of an existing object at a specific address. + + Each object is considered to occupy one word. The current capacity is decreased by one word. + This method returns immediately if the object is already allocated to the specified address. + + Args: + addr (int): Address in the memory bank where to allocate the object. Must not be already occupied by a different object. + obj (object): Object to allocate. It will be assigned to `buffer[addr]`. + + Raises: + ValueError: If the address is out of range or already occupied by a different object. + RuntimeError: If the memory bank is out of capacity. + """ + if self.currentCapacityWords <= 0: + raise RuntimeError("Critical error: Out of memory.") + if addr < 0 or addr >= len(self.buffer): + raise ValueError(("`addr` out of range. Must be in range [0, {})," + "but {} received.".format(len(self.buffer), addr))) + if not self.buffer[addr]: + # track the obj our buffer + self.buffer[addr] = obj + # update capacity + self._current_data_capacity_words -= 1 + else: + if self.buffer[addr] != obj: + raise ValueError("`addr` {} already occupied.".format(addr)) + + def deallocate(self, addr) -> object: + """ + Free up the slot at the specified memory address in the memory buffer. + + Args: + addr (int): Address of the memory slot to free. + + Raises: + ValueError: If the address is out of range or already free. + + Returns: + object: The object that was contained in the deallocated slot. + """ + if addr < 0 or addr >= len(self.buffer): + raise ValueError(("`addr` out of range. Must be in range [0, {})," + "but {} received.".format(len(self.buffer), addr))) + + obj = self.buffer[addr] + if not obj: + raise ValueError('`addr`: Adress "{}" is already free.'.format(addr)) + + self.buffer[addr] = None + self._current_data_capacity_words += 1 + + return obj diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py new file mode 100644 index 00000000..9925be9f --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py @@ -0,0 +1,393 @@ + +from assembler.common import constants +from assembler.common.cycle_tracking import CycleTracker +from .variable import Variable +from . import mem_utilities as utilities + +class RegisterBank: + """ + Encapsulates a register bank. + + This class provides an iterable over the registers contained in the register bank and + offers methods to retrieve and manage registers. + + Properties: + bank_index (int): Index for the bank as specified during construction. + register_count (int): Number of registers contained in the bank. + + Methods: + getRegister(idx: int) -> Register: + Retrieves the register associated with the specified index. + + findAvailableRegister(live_var_names, replacement_policy: str) -> Register: + Retrieves the next available register or proposes a register to use if all are + occupied, based on a replacement policy. + """ + + class __RBIterator: + """ + Allows iteration over the registers in a register bank. + """ + def __init__(self, obj): + assert obj is not None and obj.register_count > 0 + self.__obj = obj + self.__i = 0 + + def __next__(self): + if self.__i >= self.__obj.register_count: + raise StopIteration() + retval = self.__obj.getRegister(self.__i) + self.__i += 1 + return retval + + # Constructor + # ----------- + + def __init__(self, + bank_index: int, + register_range: range = None): + """ + Constructs a new RegisterBank object. + + Args: + bank_index (int): Zero-based index for the bank to create. The memory model typically has 4 banks, + but this is flexible to create more banks if needed. + register_range (range, optional): A range for the indices of the registers contained in this register bank. + Defaults to `range(constants.MemoryModel.NUM_REGISTER_PER_BANKS)`. + + Raises: + ValueError: If the bank index is negative or if the register range is invalid. + """ + if bank_index < 0: + raise ValueError((f'`bank_index`: expected non-negative a index for bank, ' + f'but {bank_index} received.')) + if not register_range: + register_range = range(constants.MemoryModel.NUM_REGISTER_PER_BANKS) + elif len(register_range) < 1: + raise ValueError((f'`register_range`: expected a range within [0, {constants.MemoryModel.NUM_REGISTER_PER_BANKS}) with, ' + f'at least, 1 element, but {register_range} received.')) + elif abs(register_range.step) != 1: + raise ValueError((f'`register_range`: expected a range within step of 1 or -1, ' + f'but {register_range} received.')) + self.__bank_index = bank_index + # list of registers in this bank + self.__registers = [ Register(self, register_i) for register_i in register_range ] + + # Special methods + # --------------- + + def __iter__(self): + """ + Returns an iterator over the registers in the register bank. + + Returns: + __RBIterator: An iterator over the registers. + """ + return RegisterBank.__RBIterator(self) + + def __repr__(self): + """ + Returns a string representation of the RegisterBank object. + + Returns: + str: A string representation of the RegisterBank. + """ + return '<{} object at {}>(bank_index = {})'.format(type(self).__name__, + hex(id(self)), + self.bank_index) + + # Methods and properties + # ---------------------- + + @property + def bank_index(self) -> int: + """ + Gets the index of the bank. + + Returns: + int: The index of the bank. + """ + return self.__bank_index + + @property + def register_count(self) -> int: + """ + Gets the number of registers in this bank. + + Returns: + int: The number of registers. + """ + return len(self.__registers) + + def getRegister(self, idx: int): + """ + Retrieves the register associated with the specified index. + + Args: + idx (int): Index for the register to retrieve. This can be a negative value. + + Returns: + Register: The register associated with the specified index. + + Raises: + ValueError: If the index is out of range. + """ + if idx < -self.register_count or idx >= self.register_count: + raise ValueError((f'`idx`: expected an index for register in the range [-{self.register_count}, {self.register_count}), ' + f'but {idx} received.')) + return self.__registers[idx] + + def findAvailableRegister(self, + live_var_names, + replacement_policy: str = None): + """ + Retrieve the next available register or propose a register to use if all are occupied. + + Args: + live_var_names (set or list): + A set of variable names containing the variables that are not available for replacement + i.e. live variables. This is used to avoid replacing variables that were just allocated + as dependencies for an upcoming instruction. + + replacement_policy (str, optional): + If specified, it must be a value from `Constants.REPLACEMENT_POLICIES`. Otherwise, + this method will not find a location to replace if all registers are occupied. + Values: + - `Constants.REPLACEMENT_POLICY_FTBU`: suggests replacement of variable that is furthest accessed + (using LRU and number of usages left as tie breakers). + - `Constants.REPLACEMENT_POLICY_LRU`: suggests replacement of the least recently accessed variable. + + Returns: + Register: The first empty register, or the register to replace if all are occupied. Returns None if no suitable register is found. + """ + retval_idx = utilities.findAvailableLocation((register.contained_variable for register in self.__registers), + live_var_names, + replacement_policy) + return self.getRegister(retval_idx) if retval_idx >= 0 else None + + def dump(self, ostream): + """ + Dump the current state of the register bank to the specified output stream. + + Args: + ostream: The output stream to write the register bank state to. + """ + print(f'Register bank, {self.bank_index}', file = ostream) + print(f'Number of registers, {self.register_count}', file = ostream) + print("", file = ostream) + print("register, variable, variable register, dirty", file = ostream) + for idx in range(self.register_count): + register = self.getRegister(idx) + if not register: + print('ERROR: None Register') + else: + var_data = 'None' + variable = register.contained_variable + if variable is not None: + if variable.name: + var_data = '{}, {}'.format(variable.name, + variable.register, + variable.register_dirty) + else: + var_data = f'Dummy_{variable.tag}' + print('{}, {}'.format(register.name, + var_data), + file = ostream) + +class Register(CycleTracker): + """ + Represents a register in the register file. + + Inherits from CycleTracker to manage the cycle when the register is ready to be used. + This class tracks the register name, the variable contained within the register as a form + of inverse look-up, and whether the register contents are "dirty". + + A register is identified by its bank and index inside the bank. The name of the + register is formatted as `rb`. For example, register 5 in bank 1 has the name `r5b1`. + + Properties: + bank (RegisterBank): The bank where this register resides. + name (str): The name of this register, built from the bank and register indices. + register_index (int): The index for this register inside its bank. + register_dirty (bool): Specifies whether the register is "dirty". A register is dirty if it has + been written to but has not been saved into SPAD. + contained_variable (Variable): The variable contained in this register, or None if no variable is currently + contained in this register. This is used as a form of inverse look-up. + """ + + # Constructor + # ----------- + + def __init__(self, + bank: RegisterBank, + register_index: int): + """ + Initializes a new Register object. + + Args: + bank (RegisterBank): The bank to which this register belongs. + register_index (int): The index of the register inside the bank. + + Raises: + ValueError: If the register index is out of the valid range. + """ + if register_index < 0 or register_index >= constants.MemoryModel.NUM_REGISTER_PER_BANKS: + raise ValueError((f'`register_index`: expected an index for register in the range [0, {constants.MemoryModel.NUM_REGISTER_PER_BANKS}), ' + f'but {register_index} received.')) + super().__init__((0, 0)) + self.register_dirty = False + self.__bank = bank + self.__register_index = register_index + self.__contained_var = None + + # Special methods + # --------------- + + def __eq__(self, other): + """ + Checks equality with another Register object. + + Args: + other (Register): The other Register to compare with. + + Returns: + bool: True if the other Register is the same as this one, False otherwise. + """ + return other is self \ + or (isinstance(other, Register) and other.name == self.name) + + def __hash__(self): + """ + Returns the hash of the register's name. + + Returns: + int: The hash of the register's name. + """ + return hash(self.name) + + def __str__(self): + """ + Returns the name of the register as its string representation. + + Returns: + str: The name of the register. + """ + return self.name + + def __repr__(self): + """ + Returns a string representation of the Register object. + + Returns: + str: A string representation of the Register. + """ + var_section = "" + if self.contained_variable: + var_section = "Variable='{}'".format(self.contained_variable.name) + return '<{}({}) object at {}>({})'.format(type(self).__name__, + self.name, + hex(id(self)), + var_section) + + # Methods and properties + # ---------------------- + + @property + def name(self) -> str: + """ + Gets the name of the register. + + Returns: + str: The name of the register. + """ + return f"r{self.register_index}b{self.bank.bank_index}" + + @property + def bank(self) -> RegisterBank: + """ + Gets the bank where this register resides. + + Returns: + RegisterBank: The bank of the register. + """ + return self.__bank + + @property + def register_index(self) -> int: + """ + Gets the index of the register inside its bank. + + Returns: + int: The index of the register. + """ + return self.__register_index + + @property + def contained_variable(self) -> Variable: + """ + Gets or sets the variable contained in this register. + + Returns: + Variable: The variable contained in this register, or None if no variable is contained. + """ + return self.__contained_var + + def _set_contained_variable(self, value): + """ + Sets the variable contained in this register. + + Args: + value (Variable): The variable to set, or None to clear the register. + + Raises: + ValueError: If the value is not a Variable. + """ + if value: + if not isinstance(value, Variable): + raise ValueError('`value`: expected a `Variable`.') + self.__contained_var = value + # register no longer dirty because we are overwriting it with new variable (or None to clear) + self.register_dirty = False + + def allocateVariable(self, variable: Variable = None): + """ + Allocates the specified variable into this register, or frees this register if + the specified variable is None. + + The register and the newly allocated variable are no longer dirty after this allocation. + + Args: + variable (Variable, optional): The variable to allocate, or None to free the register. + """ + old_var: Variable = self.contained_variable + if old_var: + # make old variable aware that it is no longer in this register + assert(not old_var.register_dirty) # we should not be deallocating dirty variables + old_var.register = None + if variable: + # make variable aware of new register + old_reg = variable.register + if old_reg: + # free old register, if any + old_reg._set_contained_variable(None) + variable.register = self + + self._set_contained_variable(variable) + + def toCASMISAFormat(self) -> str: + """ + Converts the register to CInst ASM-ISA format. + + Returns: + str: The CInst ASM-ISA format of the register. + """ + return self.name + + def toXASMISAFormat(self) -> str: + """ + Converts the register to XInst ASM-ISA format. + + Returns: + str: The XInst ASM-ISA format of the register. + """ + return self.name diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py new file mode 100644 index 00000000..dc32fc32 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/spad.py @@ -0,0 +1,322 @@ +import itertools + +from assembler.common.constants import MemoryModel as mmconstants +from assembler.common.counter import Counter +from assembler.common.decorators import * +from .memory_bank import MemoryBank +from .variable import Variable +from . import mem_utilities as utilities + +class SPAD(MemoryBank): + """ + Encapsulates the SRAM cache, also known as SPAD, within the memory model. + + This class provides methods for managing the allocation and deallocation of variables + within the SPAD, as well as methods for tracking access and finding available addresses. + + Constructors: + SPAD(data_capacity_words: int) -> SPAD + Creates a new SPAD object with a specified capacity in words. + + fromCapacityBytes(data_capacity_bytes: int) -> SPAD + Creates a SPAD object from a specified capacity in bytes. + + Properties: + buffer (list): Inherited property that returns the SPAD's buffer, allocated to hold up to + CAPACITY_WORDS `Variable` objects. + + Methods: + allocateForce(addr: int, variable: Variable) + Forces the allocation of an existing `Variable` object at a specific address. + + deallocate(addr: int) -> Variable + Frees up the slot at the specified memory address in the memory buffer. + + findAvailableAddress(live_var_names: set or list, replacement_policy: str = None) -> int + Retrieves the next available SPAD address or proposes an address to use if all are + occupied, based on a replacement policy. + """ + + class AccessTracker: + """ + Tracks access to SPAD addresses by various instructions. + + This class maintains a count and the last access instruction for each type of access, + allowing clients to determine the order of accesses. + """ + + __idx_counter = Counter.count(0) # internal unique sequence counter to generate monotonous indices + + def __init__(self, + last_mload = None, + last_mstore = None, + last_cload = None, + last_cstore = None): + self.__last_mload = (next(SPAD.AccessTracker.__idx_counter), last_mload) + self.__last_mstore = (next(SPAD.AccessTracker.__idx_counter), last_mstore) + self.__last_cload = (next(SPAD.AccessTracker.__idx_counter), last_cload) + self.__last_cstore = (next(SPAD.AccessTracker.__idx_counter), last_cstore) + + @property + def last_mload(self) -> tuple: + """ + Retrieves the last `mload` access. + + Retrieved tuple contains: + count - a count number that can be used to compare with other accesses in this object. + This value monotonically increases with each access regardless of access type. It + can be used to identify which access occurred first when two accesses are needed. + minstr - the last `mload` instruction to access, or None, if no last access. + + Returns: + tuple: A tuple containing a count and the last `mload` instruction. + """ + return self.__last_mload + + @last_mload.setter + def last_mload(self, value: object): + self.__last_mload = (next(SPAD.AccessTracker.__idx_counter), value) + + @property + def last_mstore(self) -> tuple: + """ + Gets the last `mstore` access. + + Returns: + tuple: A tuple containing a count and the last `mstore` instruction. + """ + return self.__last_mstore + + @last_mstore.setter + def last_mstore(self, value: object): + self.__last_mstore = (next(SPAD.AccessTracker.__idx_counter), value) + + @property + def last_cload(self) -> tuple: + """ + Gets the last `cload` access. + + Returns: + tuple: A tuple containing a count and the last `cload` instruction. + """ + return self.__last_cload + + @last_cload.setter + def last_cload(self, value: object): + self.__last_cload = (next(SPAD.AccessTracker.__idx_counter), value) + + @property + def last_cstore(self) -> tuple: + """ + Gets the last `cstore` access. + + Returns: + tuple: A tuple containing a count and the last `cstore` instruction. + """ + return self.__last_cstore + + @last_cstore.setter + def last_cstore(self, value: object): + self.__last_cstore = (next(SPAD.AccessTracker.__idx_counter), value) + + # Constructor + # ----------- + + def __init__(self, + data_capacity_words: int): + """ + Initializes a new SPAD object representing the SRAM cache or scratchpad. + + Args: + data_capacity_words (int): Capacity in words for the SPAD. + + Raises: + ValueError: If the capacity exceeds the maximum allowed capacity. + """ + # validate input + if data_capacity_words > mmconstants.SPAD.MAX_CAPACITY_WORDS: + raise ValueError(("`data_capacity_words` must be in the range (0, {}], " + "but {} received.").format(mmconstants.SPAD.MAX_CAPACITY_WORDS, data_capacity_words)) + + # initialize base + super().__init__(data_capacity_words) + self.__var_lookup = {} # dict(var_name: str, variable: Variable) - reverse look-up on variable name + self.__access_tracker = [ SPAD.AccessTracker() for _ in range(len(self.buffer)) ] + + # Special methods + # --------------- + + def __contains__(self, var_name): + """ + Checks if a variable name is contained within the SPAD. + + Args: + var_name (str or Variable): The variable name or Variable object to check. + + Returns: + bool: True if the variable is contained within the SPAD, False otherwise. + """ + return self._contains(var_name.name) if isinstance(var_name, Variable) else self._contains(var_name) + + def __getitem__(self, key): + """ + Retrieves a contained Variable object by name or index. + + Args: + key (str or int): The variable name or index to retrieve. + + Returns: + Variable: The contained Variable object, or None if not found. + """ + return self.findContainedVariable(key) if isinstance(key, str) else self.buffer[key] + + # Methods and properties + # ---------------------- + + def getAccessTracking(self, spad_address: int) -> AccessTracker: + """ + Gets the access tracker object for the specified SPAD address. + + This is used to track last access to specified SPAD address by CInstructions + and MInstructions. See `AccessTracker` for tracking information. + + Clients can either use the returned object to query for last access or + to specify a new last access. + + Args: + spad_address (int): SPAD address for which access tracking is requested. + + Returns: + AccessTracker: A mutable AccessTracker object containing the last access instructions. + + Raises: + IndexError: If the SPAD address is out of range. + """ + if spad_address < 0 or spad_address >= len(self.__access_tracker): + raise IndexError("`spad_address` out of range.") + return self.__access_tracker[spad_address] + + def _contains(self, var_name) -> bool: + """ + Checks if a variable name is contained within the SPAD. + + Args: + var_name (str): The variable name to check. + + Returns: + bool: True if the variable is contained within the SPAD, False otherwise. + """ + return var_name in self.__var_lookup + + def findContainedVariable(self, var_name: str) -> Variable: + """ + Retrieves a contained Variable object by name. + + Args: + var_name (str): The name of the variable to retrieve. + + Returns: + Variable: The contained Variable object, or None if not found. + """ + return self.__var_lookup[var_name] if var_name in self.__var_lookup else None + + def allocateForce(self, + addr: int, + variable: Variable): + """ + Forces the allocation of an existing `Variable` object at a specific address. + + Args: + addr (int): Address in SPAD where to allocate the `Variable` object. + variable (Variable): Variable object to allocate. The variable's spad_address must be clear. + + Raises: + ValueError: If the variable is already allocated or if there is a conflicting allocation. + RuntimeError: If the SPAD is out of capacity. + """ + if variable.spad_address < 0: + assert(variable.name not in self.__var_lookup) + # Allocate variable in SPAD + super().allocateForce(addr, variable) + variable.spad_address = addr + if variable.name: # avoid dummy vars + self.__var_lookup[variable.name] = variable + elif addr >= 0 and variable.spad_address != addr: + # Multiple allocations not allowed + raise ValueError(('`variable` already allocated in address "{}", ' + 'but new allocation requested in address "{}".'.format(variable.spad_address, + addr))) + + def deallocate(self, addr) -> object: + """ + Frees up the slot at the specified memory address in the memory buffer. + + Args: + addr (int): Address of the memory slot to free. + + Raises: + ValueError: If the address is invalid or already free. + + Returns: + Variable: The Variable object that was contained in the deallocated slot. + """ + retval = super().deallocate(addr) + retval.spad_address = -1 # deallocate variable + if retval.name: # avoid dummy vars + self.__var_lookup.pop(retval.name) + return retval + + def findAvailableAddress(self, + live_var_names, + replacement_policy: str = None) -> int: + """ + Retrieves the next available SPAD address or propose an address to use if all are occupied. + + Args: + live_var_names (set or list): A collection of variable names that are not available for replacement. + replacement_policy (str, optional): The policy to use for determining which variables to replace. + + Returns: + int: The first empty address, or the address to replace if all are occupied. Returns -1 if no suitable address is found. + """ + return utilities.findAvailableLocation(self.buffer, + live_var_names, + replacement_policy) + + def dump(self, ostream): + """ + Dumps the current state of the SPAD to the specified output stream. + + Args: + ostream: The output stream to write the SPAD state to. + """ + print('SPAD', file = ostream) + print(f'Max Capacity, {self.CAPACITY}, Bytes', file = ostream) + print(f'Max Capacity, {self.CAPACITY_WORDS}, Words', file = ostream) + print(f'Current Capacity, {self.currentCapacityWords}, Words', file = ostream) + print(f'Current Occupied, {self.CAPACITY_WORDS - self.currentCapacityWords}, Words', file = ostream) + print("", file = ostream) + print("address, variable, variable spad, dirty, last mload, last mstore, last cload, last cstore", file = ostream) + last_addr = 0 + for addr, variable in enumerate(self.buffer): + if variable is not None: + for idx in range(last_addr, addr): + # empty addresses + print(f'{idx}, None', file = ostream) + if variable.name: + spad_access_tracker = self.getAccessTracking(addr) + print('{}, {}, {}, {}, {}, {}, {}'.format(addr, + variable.name, + variable.spad_address, + variable.spad_dirty, + repr(spad_access_tracker.last_mload), + repr(spad_access_tracker.last_mstore), + repr(spad_access_tracker.last_cload), + repr(spad_access_tracker.last_cstore)), + + file = ostream) + else: + print('f{addr}, Dummy_{variable.tag}', + file = ostream) + + last_addr = addr + 1 diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py new file mode 100644 index 00000000..1bf179fc --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py @@ -0,0 +1,431 @@ +import re +from typing import NamedTuple + +from assembler.common import constants +from assembler.common.config import GlobalConfig +from assembler.common.cycle_tracking import CycleTracker, CycleType + +class Variable(CycleTracker): + """ + Class to represent a variable within a memory model. + + Inherits from CycleTracker to manage the cycle when the variable is ready to be used. + This class tracks the variable's name, its location across the memory model, and its readiness cycle. + + Attributes: + hbm_address (int): + HBM data region address (zero-based word index) where this variable is stored. + Set to -1 if not stored. + accessed_by_xinsts (list[AccessElement]): + List of XInstruction IDs that will access this variable. + The elements of the list also contain an ordering index estimating the index of the instruction + in the instructions listing. + last_x_access (XInstruction): + Last XInstruction that accessed this variable (either read or write). + + Properties: + name (str): + Name of the variable. + suggested_bank (int): + The suggested bank for the variable in the range [0, NUM_REGISTER_BANKS) + or a negative number if no bank is suggested. Setting it to a negative number is ignored. + register (Register): + Specifies the register for this variable. `None` if not allocated in a register. + register_dirty (bool): + Specifies whether the register for this variable is "dirty". + spad_address (int): + Specifies the SPAD address for this variable. -1 if not stored in SPAD. + spad_dirty (bool): + Specifies whether the SPAD location for this variable is "dirty". + """ + + class AccessElement(NamedTuple): + """ + Structured tuple to contain an instruction ID and its index in the ordered instruction listing. + + Attributes: + index (int): The index of the instruction in the listing. + instruction_id (tuple): The ID of the instruction. + """ + index: int + instruction_id: tuple + + # Static methods + # -------------- + + @classmethod + def parseFromPISAFormat(cls, s_pisa: str): + """ + Parses a `Variable` from P-ISA format and return a tuple that can be used to construct a `Variable` object. + + Args: + s_pisa (str): String containing the P-ISA variable format. It has the form: + ` ()` where is optional (parentheses are ignored). + + Raises: + ValueError: If the input is in an invalid P-ISA format. + + Returns: + tuple: A tuple representing the parsed information that can be used to construct a `Variable` object. + """ + tokens = list(map(lambda s: s.strip(), s_pisa.split())) + if len(tokens) > 2 or len(tokens) < 1: + raise ValueError(f'Invalid format for P-ISA variable: {s_pisa}.') + if len(tokens) < 2: + # default to suggested bank -1 + tokens.append(-1) + else: + tokens[1] = int(tokens[1].strip("()")) + return tuple(tokens) + + @classmethod + def validateName(cls, name: str) -> bool: + """ + Validates whether a name is an appropriate identifier for a variable. + + Args: + name (str): Variable name to validate. + + Returns: + bool: True if the name is a valid variable identifier, False otherwise. + """ + retval = True + if name: + name = name.strip() + if not name: + retval = False + if retval and not re.search('^[A-Za-z_][A-Za-z0-9_]*', name): + retval = False + return retval + + + # Constructor + # ----------- + + def __init__(self, + var_name: str, + suggested_bank: int = -1): + """ + Constructs a new Variable object with a specified name and suggested bank number. + + Args: + var_name (str): Name of the variable. Must be an identifier. + suggested_bank (int, optional): Suggested bank for the variable in the range [0, NUM_REGISTER_BANKS) + or a negative number if no bank is suggested. Defaults to -1. + + Raises: + ValueError: If the variable name is invalid or the suggested bank is out of range. + """ + + # validate the variable name to be an identifier + if not self.validateName(var_name): + raise ValueError((f'`var_name`: Invalid variable name "{var_name}".')) + self.__var_name = var_name.strip() + # validate bank number + if suggested_bank >= constants.MemoryModel.NUM_REGISTER_BANKS: + raise ValueError(("`suggested_bank`: Expected negative to indicate no " + "suggestion or a bank index less than {}, but {} received.").format( + constants.MemoryModel.NUM_REGISTER_BANKS, suggested_bank)) + + super().__init__(CycleType(0, 0)) # cycle ready in the form (bundle, clock_cycle) + + self.__suggested_bank = suggested_bank + # HBM data region address (zero-based word index) where this variable is stored. + # Set to -1 if not stored. + self.hbm_address = -1 + self.__spad_address = -1 + self.__spad_dirty = False + self.__register = None # Register + self.__register_dirty = False + self.accessed_by_xinsts = [] # list of AccessElements containing instruction IDs that access this variable + self.last_x_access = None # last xinstruction that accessed this variable + + # Special methods + # --------------- + + def __repr__(self): + """ + Returns a string representation of the Variable object. + + Returns: + str: A string representation. + """ + retval = '<{} object at {}>(var_name="{}", suggested_bank={})'.format(type(self).__name__, + hex(id(self)), + self.name, + self.suggested_bank) + return retval + + def __str__(self): + """ + Returns the name of the variable as its string representation. + + Returns: + str: The name of the variable. + """ + return self.name + + def __eq__(self, other): + """ + Checks equality with another Variable object. + + Args: + other (Variable): The other Variable to compare with. + + Returns: + bool: True if the other Variable is the same as this one, False otherwise. + """ + return other is self + + def __hash__(self): + """ + Returns the hash of the variable's name. + + Returns: + int: The hash. + """ + return hash(self.name) + + # Methods and properties + # ---------------------- + + def _get_var_name(self): + """ + Gets the name of the variable. + + Returns: + str: The name of the variable. + """ + return self.__var_name + + @property + def name(self): + """ + Gets the name of the variable. + + Returns: + str: The name of the variable. + """ + return self._get_var_name() + + @property + def suggested_bank(self): + """ + Gets or sets the suggested bank for the variable. + + Returns: + int: The suggested bank for the variable. + """ + return self.__suggested_bank + + @suggested_bank.setter + def suggested_bank(self, value: int): + if value >= constants.MemoryModel.NUM_REGISTER_BANKS: + raise ValueError('`value`: must be in range [0, {}), but {} received.'.format(constants.MemoryModel.NUM_REGISTER_BANKS, + str(value))) + if value >= 0: # ignore negative values + self.__suggested_bank = value + + @property + def register(self): + """ + Gets or sets the register for this variable. + + Returns: + Register: The register for this variable, or `None` if not allocated in a register. + """ + return self.__register + + @register.setter + def register(self, value): + self._set_register(value) + + def _set_register(self, value): + from .register_file import Register + if value: + if not isinstance(value, Register): + raise ValueError(('`value`: expected a `Register`, but received a `{}`.'.format(type(value).__name__))) + self.__register = value + else: + self.__register = None + self.register_dirty = False + self.last_x_access = None # new Register, so, no XInst access yet + + @property + def register_dirty(self) -> bool: + """ + Gets or sets whether the register for this variable is "dirty". + + Returns: + bool: True if the register is dirty, False otherwise. + """ + return self.register.register_dirty if self.register else False + + @register_dirty.setter + def register_dirty(self, value: bool): + if self.register: + self.register.register_dirty = value + + @property + def spad_address(self) -> int: + """ + Gets or sets the SPAD address for this variable. + + Returns: + int: The SPAD address, or -1 if not stored in SPAD. + """ + return self.__spad_address + + @spad_address.setter + def spad_address(self, value: int): + self._set_spad_address(value) + + def _set_spad_address(self, value: int): + self.spad_dirty = False # SPAD is no longer dirty because we are overwriting it + if value < 0: + self.__spad_address = -1 + else: + self.__spad_address = value + + @property + def spad_dirty(self) -> bool: + """ + Gets or sets whether the SPAD location for this variable is "dirty". + + Returns: + bool: True if the SPAD location is dirty, False otherwise. + """ + return self.spad_address >= 0 and self.__spad_dirty + + @spad_dirty.setter + def spad_dirty(self, value: bool): + self.__spad_dirty = value + + def _get_cycle_ready(self) -> CycleType: + """ + Returns the current value for the ready cycle. + + Ready cycle for a variable is the maximum among its internal ready cycle and + the ready cycle of any of its locations (currently, only registers have a ready cycle). + + Returns: + CycleType: The current value for the ready cycle. + """ + retval = super()._get_cycle_ready() + if self.register and self.register.cycle_ready > retval: + retval = self.register.cycle_ready + + return retval + + def toPISAFormat(self) -> str: + """ + Converts the variable to P-ISA kernel format. + + Returns: + str: The P-ISA format of the variable. + """ + retval = f'{self.name}' + if self.suggested_bank >= 0: + retval += f' ({self.suggested_bank})' + return retval + + def toXASMISAFormat(self) -> str: + """ + Converts the variable to XInst ASM-ISA format. + + Returns: + str: The XInst ASM-ISA format of the variable. + + Raises: + RuntimeError: If the variable is not allocated to a register. + """ + if not self.register: + raise RuntimeError("`Variable` object not allocated to register. Cannot convert to XInst ASM-ISA format.") + return self.register.toXASMISAFormat() + + def toCASMISAFormat(self) -> str: + """ + Converts the variable to CInst ASM-ISA format. + + Returns: + str: The CInst ASM-ISA format of the variable. + + Raises: + RuntimeError: If the variable is not stored in SPAD. + """ + if self.spad_address < 0: + raise RuntimeError("`Variable` object not allocated in SPAD. Cannot convert to CInst ASM-ISA format.") + return self.spad_address if GlobalConfig.hasHBM else self.name + + def toMASMISAFormat(self) -> str: + """ + Converts the variable to MInst ASM-ISA format. + + Returns: + str: The MInst ASM-ISA format of the variable. + + Raises: + RuntimeError: If the variable is not stored in HBM. + """ + if self.hbm_address < 0: + raise RuntimeError("`Variable` object not allocated in HBM. Cannot convert to MInst ASM-ISA format.") + return self.name if GlobalConfig.useHBMPlaceHolders else self.hbm_address + +def findVarByName(vars_lst, var_name: str) -> Variable: + """ + Finds the first variable in an iterable of Variable objects that matches the specified name. + + Args: + vars_lst (iterable[Variable]): An iterable collection of Variable objects. + var_name (str): The name of the variable to find in `vars_lst`. + + Returns: + Variable: The first Variable object in `vars_lst` with a name matching `var_name`, or None if no match is found. + """ + return next((var for var in vars_lst if var.name == var_name), None) + +class DummyVariable(Variable): + """ + Represents a dummy variable used as a placeholder. + + A dummy variable serves as a placeholder to indicate registers that will be available in the next bundle, + but not in the current one, such as after a `move` operation. It can be identified by its empty name. + """ + + # Constructor + # ----------- + + def __init__(self, tag = None): + """ + Initializes a new DummyVariable object. + + Args: + tag (optional): An optional tag to associate with the dummy variable. Defaults to 0 if not provided. + """ + super().__init__("dummy") + self.tag = 0 if tag is None else tag + + def _get_var_name(self): + """ + Get the name of the dummy variable. + + Returns: + str: An empty string, indicating the variable is a dummy. + """ + return "" + + def _set_register(self, value): + """ + Overrides the method to set the register for the dummy variable. + + This method does nothing for a dummy variable. + """ + pass + + def _set_spad_address(self, value: int): + """ + Overrides the method to set the SPAD address for the dummy variable. + + This method does nothing for a dummy variable. + """ + pass diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/__init__.py b/assembler_tools/hec-assembler-tools/assembler/stages/__init__.py new file mode 100644 index 00000000..98e940f8 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/stages/__init__.py @@ -0,0 +1,29 @@ +import networkx as nx +from assembler.memory_model.variable import Variable + +def buildVarAccessListFromTopoSort(dependency_graph: nx.DiGraph): + """ + Given the dependency directed acyclic graph of XInsts, builds the list of + estimated usage order for the variables. + + This is used when deciding which variable to evict from register files or SPAD when + a memory location is needed and all are occupied (furthest used: FTBU). Least recently + used (LRU) is used as tie breaker. + + Usage order is estimated because order of instructions may change based on their + dependencies and timings during scheduling. + + Returns: + list(instruction_id: tuple): + The topological sort of the instructions. Since the topological sort is required + for this function, it is returned to caller to be reused if needed. + """ + + topo_sort = list(nx.topological_sort(dependency_graph)) + for idx, node in enumerate(topo_sort): + instr = dependency_graph.nodes[node]['instruction'] + vars = set(instr.sources + instr.dests) + for v in vars: + v.accessed_by_xinsts.append(Variable.AccessElement(idx, instr.id)) + + return topo_sort diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py b/assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py new file mode 100644 index 00000000..29e38374 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/stages/asm_scheduler.py @@ -0,0 +1,2794 @@ +import warnings +from typing import NamedTuple +import networkx as nx + +from . import buildVarAccessListFromTopoSort +from assembler.common.config import GlobalConfig +from assembler.common import constants +from assembler.common.cycle_tracking import CycleType +from assembler.common.priority_queue import PriorityQueue +from assembler.common.queue_dict import QueueDict +from assembler.instructions import xinst, cinst, minst +from assembler.memory_model import mem_utilities +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable, DummyVariable +from assembler.memory_model.register_file import Register, RegisterBank + +Constants = constants.Constants + +# TODO: +# - Keep instruction being processed into the next bundle if bundle needed to be flushed in mid preparation. +# - Refactor class `Simulation` +# + +# TODO: +# Add auto_allocate as configurable? +# -------------------- + +# FUTURE: +# - Analyze about adding instruction window to dependecy graph creation. +# - Analize about adding terms to priority that will prioritize P-ISA instructions over all others as tie-breaker +# in simulation priority queue. +# Maybe add a way to track preparation stage of instructions as part of the priority. +# - Separate variable xinst usage by inputs and outputs to avoid xstoring vars where next usage is a write. +# May require reorganization and book keeping, look ahead, etc. + +auto_allocate = True + +class XStoreAssign(xinst.XStore): + """ + Encapsulates a compound operation of an `xstore` instruction and a + register assignment. + + This is used for variable eviction from the register file, + when the register being flushed is needed for a new variable. + """ + def __init__(self, + id: int, + src: list, + mem_model: MemoryModel, + var_target: Variable, + dest_spad_addr: int = -1, + throughput : int = None, + latency : int = None, + comment: str = ""): + """ + Constructs a new `XStoreAssign` object. + + Parameters: + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + src (list): A list containing a single Variable object indicating the source variable to store into SPAD. + Variable must be assigned to a register. + Variable `spad_address` must be negative (not assigned) or match the address of the corresponding + `cstore` instruction. + mem_model (MemoryModel): The memory model containing the variables. + var_target (Variable): Variable object that will be allocated in the freed-up register after scheduling the corresponding + `xstore` instruction. + dest_spad_addr (int, optional): The destination SPAD address. Defaults to -1. + throughput (int, optional): The throughput of the instruction. Defaults to None. + latency (int, optional): The latency of the instruction. Defaults to None. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + + Raises: + ValueError: If `var_target` is an invalid empty or dummy `Variable` object. + """ + if not var_target or isinstance(var_target, DummyVariable): + raise ValueError('`var_target`: Invalid empty or dummy `Variable` object.') + super().__init__(id, src, mem_model, dest_spad_addr, throughput, latency, comment) + self.__var_target = var_target + + def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: + """ + Schedules the instruction, simulating the timing of executing this instruction. + + The ready cycle for all destinations is updated based on input `cycle_count` and + this instruction latency. The register is then allocated to the target variable. + + Parameters: + cycle_count (CycleType): Current cycle of execution. + schedule_id (int): The schedule identifier. + + Returns: + int: The throughput for this instruction, i.e., the number of cycles by which to advance + the current cycle counter. + """ + register = self.sources[0].register + retval = super()._schedule(cycle_count, schedule_id) + # Perform assignment of register + register.allocateVariable(self.__var_target) + return retval + +class BundleData(NamedTuple): + """ + Structure for a completed bundle of instructions. + + Attributes: + xinsts (list): List of XInstruction objects in the bundle. + latency (int): Total latency of the bundle. + latency_from_xstore (int): Bundle latency from the last XStore instruction. This is less than or equal to `latency`. + It is used to track latency before scheduling the next bundle with ifetch more easily, + and to attempt to avoid too many idle cycles. + """ + xinsts: list + latency: int + latency_from_xstore: int + +class XWriteCycleTrack(NamedTuple): + """ + Tracks the cycle where a write occurs to the register file by an XInstruction + and which banks are being written to. + + Attributes: + cycle (CycleType): The cycle in which the write happens. + banks (set): A set of indices of banks being written to in this cycle. + """ + cycle: CycleType + banks: set + +class CurrentrShuffleTable(NamedTuple): + """ + Tracks the current rShuffle routing table. + + Attributes: + r_type (type): The type of rShuffle currently loaded. It can be one of {rShuffle, irShuffle, None}. + bundle (int): The bundle where the specified r_type was set. + """ + r_type: type + bundle: int + +class Simulation: + """ + Simulates the scheduling of instructions in a dependency graph. + + Attributes: + INSTRUCTION_WINDOW_SIZE (int): The size of the instruction window. + MIN_INSTRUCTIONS_IN_TOPO_SORT (int): The minimum number of instructions in the topological sort. + BUNDLE_INSTRUCTION_MIN_LIMIT (int): The minimum number of instructions for a bundle to be considered short. + + Methods: + addXInstrBackIntoPipeline(xinstr: object): + Adds an instruction back into the pipeline. + addXInstrToTopoSort(xinstr_id: tuple): + Adds an instruction to the topological sort. + addDependency(new_dependency_instr, original_instr): + Adds a new dependency instruction to the instruction listing. + addLiveVar(var_name: str, instr): + Adds a live variable to the current bundle. + addUsedVar(var_name: str, instr): + Removes a used variable from the current bundle. + appendXInstToBundle(xinstr): + Appends an XInstruction to the current bundle. + cleanupPendingWriteCycles(): + Cleans up pending write cycles that have passed. + canSchedulerShuffle(xinstr) -> CycleType: + Checks whether the specified xrshuffle XInst can be scheduled now. + canSchedulerShuffleType(xinstr) -> bool: + Checks whether the specified rshuffle XInst can be scheduled now. + canScheduleArithmeticXInstr(xinstr: xinst.XInstruction) -> bool: + Checks whether the specified XInst can be scheduled now based on the currently loaded metadata. + findNextInstructionToSchedule() -> object: + Finds the next instruction to schedule. + flushBundle(): + Flushes the current bundle. + flushOutputVariableFromRegister(variable, xinstr = None) -> bool: + Flushes an output variable from the register. + generateKeyMaterial(instr_id: int, variable: Variable, register: Register, dep_id = None) -> int: + Generates key material for the specified variable. + loadrShuffleRoutingTable(rshuffle_data_type_name: str): + Queues CInstructions needed to load the `rshuffle` routing table into CE. + loadBOnesMetadata(spad_addr_offset: int, ones_metadata_segment: int) -> int: + Queues MInstructions and CInstructions needed to load the Ones metadata. + loadTwiddleMetadata(spad_addr_offset: int, twid_metadata_segment: int): + Queues MInstructions and CInstructions needed to load the Twiddle factor generation metadata. + loadKeygenSeedMetadata(spad_addr_offset: int, kgseed_idx: int) -> int: + Queues MInstructions and CInstructions needed to load a new keygen seed. + loadMetadata(): + Loads initial metadata at the start of the program. + prepareShuffleMetadata(spad_addr_offset: int) -> int: + Queues MInstructions needed to load the `rshuffle` metadata into SPAD. + priority_queue_push(xinstr, tie_breaker = None): + Adds a new instruction to the priority queue. + priority_queue_remove(xinstr): + Removes an instruction from the priority queue. + queueCSyncmLoad(instr_id: int, source_spad_addr: int): + Checks if needed, and, if so, queues a CSyncm CInstruction to sync to SPAD access from HBM. + queueMLoad(instr_id: int, target_spad_addr: int, variable, comment = ""): + Generates instructions to copy from HBM into SPAD. + queueMSynccLoad(instr_id: int, target_spad_addr: int): + Checks if needed, and, if so, queues an MSyncc MInstruction to sync to SPAD access. + updateQueuesSyncsPass2(): + Updates the msyncc and csyncm to correct instruction index after the scheduling completes. + updateSchedule(instr) -> bool: + Updates the simulation pending schedule after `instr` has been scheduled. + """ + + INSTRUCTION_WINDOW_SIZE = 100 + MIN_INSTRUCTIONS_IN_TOPO_SORT = 10 + # Amount of instructions for a bundle to be considered short + BUNDLE_INSTRUCTION_MIN_LIMIT = Constants.MAX_BUNDLE_SIZE // 4 # 10 + + def __init__(self, + dependency_graph: nx.DiGraph, + max_bundle_size: int, # Max number of instructions in a bundle + mem_model: MemoryModel, + replacement_policy: str, + progress_verbose: bool): + """ + Initializes the simulation of schedule. + + Parameters: + dependency_graph (nx.DiGraph): The dependency graph of instructions. + max_bundle_size (int): The maximum number of instructions in a bundle. + mem_model (MemoryModel): The memory model containing the variables. + replacement_policy (str): The replacement policy to use. + progress_verbose (bool): If True, enables verbose progress output. + """ + assert max_bundle_size == Constants.MAX_BUNDLE_SIZE + + self.__mem_model = mem_model + self.__replacement_policy = replacement_policy + + self.minsts = [] + self.cinsts = [] + self.xinsts = [] # List of bundles + + # Scheduling vars + + self.current_cycle = CycleType(bundle = len(self.xinsts), cycle = 1) + self.full_topo_sort = buildVarAccessListFromTopoSort(dependency_graph) + self.topo_start_idx = 0 # Starting index of the instruction window in full topo_sort + self.topo_sort = [] # Current slice of topo sort being scheduled + self.b_topo_sort_changed = True # All changed-tracking flags start as true because scheduling has changed (brought into existence) + self.dependency_graph = nx.DiGraph(dependency_graph) # Make a copy of the incoming graph to avoid modifying input + self.b_dependency_graph_changed = True + # Contains instructions without parent dependencies: sorted list by priority: ready cycle + # (never edit directly unless absolutely necessary; use priority_queue_remove/push instead) + self.priority_queue = PriorityQueue() + self.xstore_pq = PriorityQueue() # Sorted list by priority: ready cycle + self.b_priority_queue_changed = True # Tracks when there are changes in the priority queue + self.total_idle_cycles = 0 + # Tracks instructions that are in priority queue or have been removed from graph to avoid encountering + # if duplicated in the topo sort (instructions are only added to this when extracting them from the topo sort) + # (instructions that get pushed back into the topo sort for any reason, are removed from this set) + self.set_extracted_xinstrs = set() + + # Tracks the last xrshuffle scheduled + # self.last_rshuffle_cycle = CycleType(bundle = -1, cycle = 0) + self.last_xrshuffle = None + + # Bundle vars + + self.__max_bundle_size = max_bundle_size + self.b_empty_bundle: bool = False # Tracks if last bundle was empty + # Tracks if last bundle was flushed with very few instructions + self.num_short_bundles: int = 0 + # Local dummy variable to be updated per bundle: used to indicate that a register in bank 0 is live + # for current_cycle.bundle and should not be used by CInsts until next bundle + self.bundle_dummy_var = DummyVariable(self.current_cycle.bundle) + # Tracks instructions in current bundle getting constructed + # (never add to this manually, use appendXInstToBundle() method instead) + self.xinsts_bundle = [] + self.current_bundle_latency = 0 # Tracks current bundle latency + self.pre_bundle_csync_minstr = (0, None) + self.post_bundle_cinsts = [] + # Initial value for live vars (these will always be live) + self.live_vars_0 = dict() + # Add meta variables as always live vars: + # rshuffle routing tables + if self.mem_model.meta_ntt_aux_table: + self.live_vars_0[self.mem_model.meta_ntt_aux_table] = None + if self.mem_model.meta_ntt_routing_table: + self.live_vars_0[self.mem_model.meta_ntt_routing_table] = None + if self.mem_model.meta_intt_aux_table: + self.live_vars_0[self.mem_model.meta_intt_aux_table] = None + if self.mem_model.meta_intt_routing_table: + self.live_vars_0[self.mem_model.meta_intt_routing_table] = None + # Meta ones + for meta_ones_vars_segment in self.mem_model.meta_ones_vars_segments: + for meta_ones_var_name in meta_ones_vars_segment: + self.live_vars_0[meta_ones_var_name] = None + # Meta twids + for meta_twid_vars_segment in self.mem_model.meta_twiddle_vars_segments: + for meta_twid_var_name in meta_twid_vars_segment: + self.live_vars_0[meta_twid_var_name] = None + # Tracks live in variable names for current bundle (variables to be used by current bundle) + self.live_vars: dict = self.live_vars_0 # dict(var_name:str, pending_uses: set(XInstruction)) + self.live_outs = set() # Contains variables being stored in this bundle to avoid reusing them + # Ordered list of XWriteCycleTrack to track the cycle in which rshuffles are writing. + # This is used to avoid scheduling instructions that write to these banks on the same cycle as + # rshuffles. + self.pending_write_cycles = [] + + # Metadata tracking + + # Book-keeping to track keygen metadata + + # Starting SPAD address for keygen seed metadata: + # this will be overwritten by new keygen seed metadata whenever a swap is needed. + self.metadata_spad_addr_start_kgseed = -1 + self.bundle_current_kgseed = -1 # Tracks current index of keygen seed metadata loaded + self.bundle_used_kg_seed = -1 # Tracks the last bundle that used current keygen seed + self.last_keygen_index = -1 # Tracks the last key material generation index with current seed + + # Book-keeping to track residual metadata + + # Starting SPAD address for ones metadata: + # this will be overwritten by new ones metadata whenever a swap is needed. + self.metadata_spad_addr_start_ones = -1 + # Metadata for ones segment `i` supports computation of arithmetic operations + # with rns in range `[i * 64, (i + 1) * 64)` + # i == -1 means uninitialized + self.bundle_current_ones_segment = -1 # Tracks current ones segment metadata loaded + self.bundle_needed_ones_segment = -1 # Signals the ones segment metadata needed + + # Starting SPAD address for twid metadata: + # this will be overwritten by new twid metadata whenever a swap is needed. + self.metadata_spad_addr_start_twid = -1 + # Metadata for twiddles segment `i` supports computation of twiddle factors + # with rns in range `[i * 64, (i + 1) * 64)` + # i == -1 means uninitialized + self.bundle_current_twid_segment = -1 # Tracks current twid segment metadata loaded + self.bundle_needed_twid_segment = -1 # Signals the twid segment metadata needed + + # Book-keeping to track that rShuffle and irShuffle don't mix in the same bundle + + # Tracks the current type of rshuffle supported (rShuffle, irShuffle, None), + # and what bundle was it last set + self.bundle_current_rshuffle_type = (None, 0) # (type: {rShuffle, irShuffle, None}, bundle: int) + self.bundle_needed_rshuffle_type = None # Type of last rshuffle {rShuffle, irShuffle, None} scheduled in current bundle + + # xinstfetch vars + + self.xinstfetch_hbm_addr = 0 + self.xinstfetch_xq_addr = 0 + self.__max_bundles_per_xinstfetch = Constants.WORD_SIZE / (self.max_bundle_size * Constants.XINSTRUCTION_SIZE_BYTES) + self.xinstfetch_cinsts_buffer = [] # Used to group all xinstfetch per capacity of XInst queue + self.xinstfetch_location_idx_in_cinsts = 0 # Location in cinst where to insert xinstfetch's when a group is completed + + # Progress report vars + + # Tracks the number of instruction (in original dependency graph) that have been scheduled + self.scheduled_xinsts_count = 0 + self.verbose = progress_verbose + + @property + def last_xinstr(self) -> object: + """ + Provides the last XInstruction in the current bundle or the last bundle with instructions. + + Returns: + object: The last XInstruction or None if no instructions are found. + """ + retval = None + if len(self.xinsts_bundle) > 0: + retval = self.xinsts_bundle[-1] + elif len(self.xinsts) > 0: + # Find the last bundle with instructions + # (this should be the last bundle in the list of bundles) + for bundle in reversed(self.xinsts): + if len(bundle) > 0: + # Return last instruction in bundle + retval = bundle[-1] + break + return retval + + @property + def max_bundle_size(self) -> int: + """ + Provides the maximum bundle size. + + Returns: + int: The maximum bundle size. + """ + return self.__max_bundle_size + + @property + def max_bundles_per_xinstfetch(self) -> int: + """ + Provides the maximum number of bundles per xinstfetch. + + Returns: + int: The maximum number of bundles per xinstfetch. + """ + return self.__max_bundles_per_xinstfetch + + @property + def mem_model(self) -> str: + """ + Provides the memory model. + + Returns: + str: The memory model. + """ + return self.__mem_model + + @property + def progress_pct(self) -> float: + """ + Provides the progress percentage. + + Returns: + float: The progress percentage. + """ + return self.scheduled_xinsts_count * 100.0 / self.total_instructions + + @property + def replacement_policy(self) -> str: + """ + Provides the replacement policy. + + Returns: + str: The replacement policy. + """ + return self.__replacement_policy + + @property + def total_instructions(self) -> int: + """ + Provides the total number of instructions. + + Returns: + int: The total number of instructions. + """ + return self.dependency_graph.number_of_nodes() + self.scheduled_xinsts_count + + def addXInstrBackIntoPipeline(self, xinstr: object): + """ + Adds an instruction back into the pipeline. + + Parameters: + xinstr (object): The instruction to add back into the pipeline. + + Raises: + ValueError: If `xinstr` is a `Move` instruction or is already scheduled. + """ + if isinstance(xinstr, xinst.Move): + raise ValueError('`xinstr` is a `Move` instruction. `Move` instructions cannot be inserted into the pipeline.') + if xinstr.is_scheduled: + raise ValueError('`xinstr` already scheduled.') + assert xinstr.id in self.dependency_graph + if self.dependency_graph.in_degree(xinstr.id) > 0: + if xinstr in self.priority_queue: + # Remove from priority queue because it now has a dependency + self.priority_queue_remove(xinstr) + # Add back to topo sort because we have new dependencies + self.addXInstrToTopoSort(xinstr.id) + else: + # Original instruction has no dependencies, so, put it back in priority queue + self.priority_queue_push(xinstr) + + # Remove instruction vars from live list since it's being demoted. + # Pending xstore variables must be kept alive to avoid attempts to flush them again. + if not isinstance(xinstr, xinst.XStore): + for v in xinstr.sources + xinstr.dests: + if isinstance(v, Variable) \ + and v.name in self.live_vars \ + and xinstr in self.live_vars[v.name]: + self.addUsedVar(v.name, xinstr) + + def addXInstrToTopoSort(self, xinstr_id: tuple): + """ + Adds an instruction to the topological sort. + + Parameters: + xinstr_id (tuple): The ID of the instruction to add. + + Raises: + ValueError: If `xinstr_id` is not part of the dependency graph or is in the priority queue. + """ + if xinstr_id not in self.dependency_graph: + raise ValueError("`xinstr_id`: cannot add an instruction to topo sort that is not part of the dependency graph.") + if xinstr_id in self.priority_queue: + # Adding back to topo sort, xinstr cannot be in priority queue + raise ValueError("`xinstr_id`: cannot be in priority queue.") + # Find position in topo sort + target_idx = len(self.topo_sort) + match_idxs = [] # Locations where the same xinstr was found in topo sort + for idx, topo_instr_id in enumerate(self.topo_sort): + if topo_instr_id == xinstr_id: + match_idxs.append(idx) + elif topo_instr_id in self.dependency_graph \ + and self.dependency_graph.in_degree(topo_instr_id) >= self.dependency_graph.in_degree(xinstr_id): + target_idx = idx + break + self.topo_sort = self.topo_sort[:target_idx] + [ xinstr_id ] + self.topo_sort[target_idx:] + # Remove the previous instances found of xinstr from topo sort as it has incorrect order now + for idx, match_idx in enumerate(match_idxs): + del self.topo_sort[match_idx - idx] + self.b_topo_sort_changed = True + self.set_extracted_xinstrs.discard(xinstr_id) + + def addDependency(self, + new_dependency_instr, + original_instr): + """ + Adds `new_dependency_instr` to the instruction listing as a new dependency of + `original_instr`. + + Dependency graph and topo sort are updated as appropriate. `new_dependency_instr` is NOT + added to the topo_sort. + + Variables in sources and dests for `new_dependency_instr` are added to `live_vars`. + + Parameters: + new_dependency_instr: The new dependency instruction to add. + original_instr: The original instruction to which the new dependency is added. + """ + # Add new instruction to instructions listing (in dependency graph) + self.dependency_graph.add_node(new_dependency_instr.id, instruction=new_dependency_instr) + self.b_dependency_graph_changed = True + if original_instr: + assert original_instr.id in self.dependency_graph + self.dependency_graph.add_edge(new_dependency_instr.id, original_instr.id) # Link as dependency to input instruction + self.addXInstrBackIntoPipeline(original_instr) + + all_vars = set(v for v in new_dependency_instr.sources + new_dependency_instr.dests \ + if isinstance(v, Variable) and not isinstance(v, DummyVariable)) + for v in all_vars: + # Add dependencies to all other instructions + deps_added = 0 + for idx, next_instr_id in v.accessed_by_xinsts: + if idx > self.topo_start_idx + 2 * Simulation.INSTRUCTION_WINDOW_SIZE: + # Only add dependencies within the instruction window and next 2 instruction windows + if deps_added > 0 or len(v.accessed_by_xinsts) <= 0: + # Add, at least, one dependency if needed + break + if next_instr_id != new_dependency_instr.id: + assert next_instr_id in self.dependency_graph + self.dependency_graph.add_edge(new_dependency_instr.id, next_instr_id) # Link as dependency to input instruction + if self.dependency_graph.in_degree(next_instr_id) == 1: + # We need to add next instruction back to topo sort because it will have a dependency + next_instr = self.dependency_graph.nodes[next_instr_id]['instruction'] + self.addXInstrBackIntoPipeline(next_instr) + deps_added += 1 + self.addLiveVar(v.name, new_dependency_instr) # Source and dests variables are now a live-in for new_dependency_instr + + def addLiveVar(self, + var_name: str, + instr): + """ + Adds a live variable to the current bundle. + + Parameters: + var_name (str): The name of the variable to add. + instr: The instruction associated with the variable. + """ + if var_name not in self.live_vars: + self.live_vars[var_name] = set() + self.live_vars[var_name].add(instr) + + def addUsedVar(self, + var_name: str, + instr): + """ + Removes a used variable from the current bundle. + + Parameters: + var_name (str): The name of the variable to remove. + instr: The instruction associated with the variable. + """ + self.live_vars[var_name].remove(instr) + if len(self.live_vars[var_name]) <= 0: + self.live_vars.pop(var_name) + + def appendXInstToBundle(self, xinstr): + """ + Appends an XInstruction to the current bundle. + + Parameters: + xinstr: The XInstruction to append. + + Raises: + ValueError: If `xinstr` is None. + AssertionError: If the bundle is already full. + """ + if not xinstr: + raise ValueError('`xinstr` cannot be `None`.') + assert len(self.xinsts_bundle) < self.max_bundle_size, 'Cannot append XInstruction to full bundle.' + self.xinsts_bundle.append(xinstr) + if self.current_bundle_latency < self.current_cycle.cycle + xinstr.latency: + self.current_bundle_latency = self.current_cycle.cycle + xinstr.latency + + def cleanupPendingWriteCycles(self): + """ + Cleans up pending write cycles that have passed. + """ + # Remove any write cycles that passed + front_write_cycle_idx = -1 # len(self.pending_write_cycles) + for idx, write_cycle in enumerate(self.pending_write_cycles): + if write_cycle.cycle < self.current_cycle: # Not <= because no instruction writes on its decoding (first) cycle + # Found first write cycle in the list that occurs after current cycle + front_write_cycle_idx = idx + break + self.pending_write_cycles = self.pending_write_cycles[front_write_cycle_idx + 1:] + + def canSchedulerShuffle(self, xinstr) -> CycleType: + """ + Checks whether the specified xrshuffle XInst can be scheduled now, + based on the special latency timing. + + Returns: + CycleType: The cycle where the specified instruction can be scheduled: + - Returns cycle <= current_cycle if instruction can be scheduled now. + - Returns cycle > current_cycle the cycle where the instruction can be scheduled. + If instruction is an xrshuffle, this takes into account the slotting rule. + """ + # This is used to check whether an rshuffle is in the slotted latency from + # the previous rshuffle, or outside of the special latency. + + instr_ready_cycle = max(xinstr.cycle_ready, self.current_cycle) + + retval = instr_ready_cycle + + if xinstr.cycle_ready.bundle <= self.current_cycle.bundle \ + and self.last_xrshuffle is not None \ + and isinstance(xinstr, (xinst.rShuffle, xinst.irShuffle)): + # Attempting to schedule an rshuffle after a previous one already got + # scheduled in the same bundle + + last_rshuffle_cycle = self.last_xrshuffle.schedule_timing.cycle + assert self.current_cycle.bundle >= last_rshuffle_cycle.bundle, \ + "Last scheduled rshuffle cannot be in the future." + + if self.current_cycle.bundle == last_rshuffle_cycle.bundle: + # Last scheduled rshuffle was in this bundle + cycle_delta = abs(instr_ready_cycle.cycle - last_rshuffle_cycle.cycle) + if (isinstance(xinstr, xinst.rShuffle) and isinstance(self.last_xrshuffle, xinst.rShuffle)) \ + or (isinstance(xinstr, xinst.irShuffle) and isinstance(self.last_xrshuffle, xinst.irShuffle)): + # New rshuffle and previous are of the same kind + if cycle_delta < self.last_xrshuffle.SpecialLatencyMax: + # Trying to schedule within max special latency: attempt to slot + r = cycle_delta % self.last_xrshuffle.SpecialLatencyIncrement + cycle_delta += ((0 if r == 0 else self.last_xrshuffle.SpecialLatencyIncrement) - r) + if cycle_delta >= self.last_xrshuffle.SpecialLatencyMax: + # Slot found is greater than max latency, so, we can schedule at max latency + cycle_delta = self.last_xrshuffle.SpecialLatencyMax + retval = CycleType(bundle = self.current_cycle.bundle, + cycle = last_rshuffle_cycle.cycle + cycle_delta) + else: + # New rshuffle and previous are inverse: only schedule outside + # of the full latency + retval = CycleType(bundle = self.current_cycle.bundle, + cycle = max(self.current_cycle.cycle, last_rshuffle_cycle.cycle + self.last_xrshuffle.latency)) + + if retval < instr_ready_cycle: + retval = instr_ready_cycle + + return retval + + def canSchedulerShuffleType(self, xinstr) -> bool: + """ + Checks whether the specified rshuffle XInst can be scheduled now, + or if there are inverse rshuffles in queue that can be scheduled because + the currently loaded routing table matches them instead. + + Returns: + bool: True if the specified instruction is an xrshuffle that can be scheduled in + this bundle, or specified instruction is not an xrshuffle. False otherwise. + + This is used to avoid switching tables back and forth while there + are still xrshuffles of correct table type pending. + """ + retval = True + + if isinstance(xinstr, (xinst.rShuffle, xinst.irShuffle)): + # Check if a routing table change is needed for specified rshuffle + + # Can schedule if not on this bundle, or is instance of previously scheduled + # rshuffles in this bundle + retval = xinstr.cycle_ready.bundle > self.current_cycle.bundle \ + or self.bundle_needed_rshuffle_type is None \ + or isinstance(xinstr, self.bundle_needed_rshuffle_type) + + if self.bundle_current_rshuffle_type[0] is not None \ + and not isinstance(xinstr, self.bundle_current_rshuffle_type[0]) \ + and retval: + # Routing table change will be needed if we want to schedule specified xrshuffle + + # Search priority queue to see if there are any rshuffles matching + # current routing tables that can be queued instead + # NOTE: Traversing a priority queue is not good practice because we should not be + # messing with its contents, but it is needed for the single type + # of rshuffle per bundle restriction. + + retval = next((False for _, inv_rshuffle in self.priority_queue \ + if isinstance(inv_rshuffle, self.bundle_current_rshuffle_type[0]) \ + and inv_rshuffle.cycle_ready.bundle <= self.current_cycle.bundle), + retval) + + assert not retval \ + or xinstr.cycle_ready.bundle > self.current_cycle.bundle \ + or self.bundle_needed_rshuffle_type is None or isinstance(xinstr, self.bundle_needed_rshuffle_type), \ + f'Found rshuffle of type {type(xinstr)}, but type {self.bundle_needed_rshuffle_type} already scheduled in bundle.' + + return retval + + def canScheduleArithmeticXInstr(self, xinstr: xinst.XInstruction) -> bool: + """ + Checks whether the specified XInst can be scheduled now based on + the currently loaded metadata. + + Returns: + bool: True if the specified XInst can be scheduled in this bundle (may require + change of metadata). False otherwise. + + This is used to avoid switching metadata back and forth while there + are still XInsts of current metadata pending. + """ + retval = True + + if xinstr.res is not None: + # Instruction has residual + + assert self.bundle_current_ones_segment == self.bundle_current_twid_segment, \ + 'Current Ones and Twiddle metadata segments are not synchronized.' + assert self.bundle_needed_ones_segment == self.bundle_needed_twid_segment, \ + 'Needed Ones and Twiddle metadata segments are not synchronized.' + + xinstr_required_segment = xinstr.res // constants.MemoryModel.MAX_RESIDUALS + # Can schedule if not on this bundle, or required residual segment + # is same as previously scheduled in this bundle + retval = xinstr.cycle_ready.bundle > self.current_cycle.bundle \ + or self.bundle_needed_ones_segment == -1 \ + or self.bundle_needed_ones_segment == xinstr_required_segment + + # Check if a metadata change is needed for specified XInst + if self.bundle_current_ones_segment != -1 \ + and self.bundle_current_ones_segment != xinstr_required_segment \ + and retval: + # Metadata change will be needed if we want to schedule specified XInst + + # Search priority queue to see if there are any arithmetic instructions matching + # current metadata that can be queued instead + # NOTE: Traversing a priority queue is not good practice because we should + # not be messing with its contents, but it is needed for the single + # metadata segment per bundle restriction. + retval = next((False for _, other_xinstr in self.priority_queue \ + if other_xinstr.res is not None \ + and other_xinstr.res // constants.MemoryModel.MAX_RESIDUALS == self.bundle_current_ones_segment \ + and other_xinstr.cycle_ready.bundle <= self.current_cycle.bundle), + retval) + + assert not retval \ + or xinstr.cycle_ready.bundle > self.current_cycle.bundle \ + or self.bundle_needed_ones_segment == -1 or xinstr_required_segment == self.bundle_needed_ones_segment, \ + f'Found XInst of residual segment {xinstr_required_segment}, but segment {self.bundle_needed_ones_segment} already scheduled in bundle.' + + return retval + + def findNextInstructionToSchedule(self) -> object: + """ + Finds the next instruction to schedule. + + Returns: + object: The next instruction to schedule or None if priority_queue is empty. + + Returned instruction may be an injected xexit if no instructions are left to be + scheduled for current bundle. + """ + retval = None + + if self.priority_queue: + while retval is None \ + and self.priority_queue.peek()[1].cycle_ready.bundle <= self.current_cycle.bundle: + # Check if there is any immediate instruction we can schedule + # in this cycle + immediate_instr = self.priority_queue.find(self.current_cycle) + while retval is None and immediate_instr is not None: + # Check found instruction has correct priority + if immediate_instr.cycle_ready == self.current_cycle: + # Check for write cycle conflicts + if hasBankWriteConflict(immediate_instr, self): + # Write cycle conflict found, so, update found instruction cycle ready + new_cycle_ready = CycleType(bundle = self.current_cycle.bundle, + cycle = max(immediate_instr.cycle_ready.cycle, self.current_cycle.cycle) + 1) + immediate_instr.cycle_ready = new_cycle_ready + else: + new_cycle_ready = self.canSchedulerShuffle(immediate_instr) + if immediate_instr.cycle_ready != new_cycle_ready: + # Only xrshuffles should have a changed cycle ready if slotted + # and got picked outside of a slot to schedule. + assert immediate_instr.cycle_ready < new_cycle_ready, \ + "Computed new cycle ready cannot be earlier than instruction's cycle ready." + immediate_instr.cycle_ready = new_cycle_ready # Update instruction's cycle ready + else: + # Found immediate instruction + self.priority_queue_remove(immediate_instr) + retval = immediate_instr + if not retval: + # Found instruction that has incorrect priority, so, correct it + self.priority_queue_push(immediate_instr) + # See if there is any other immediate instruction we can schedule + immediate_instr = self.priority_queue.find(self.current_cycle) + + # If no immediate instruction found: + # Find the first we can schedule + while retval is None: + priority, p_inst = self.priority_queue.peek() + if p_inst.cycle_ready.bundle < self.current_cycle.bundle: + # Correct instruction ready cycle to this bundle + p_inst.cycle_ready = CycleType(bundle = self.current_cycle.bundle, + cycle = 0) + # Check found instruction has correct priority + if p_inst.cycle_ready == priority: + # Check for write cycle conflicts + if hasBankWriteConflict(p_inst, self): + # Write cycle conflict found, so, update found instruction cycle ready + new_cycle_ready = CycleType(bundle = self.current_cycle.bundle, + cycle = max(p_inst.cycle_ready.cycle, self.current_cycle.cycle) + 1) + p_inst.cycle_ready = new_cycle_ready + else: + new_cycle_ready = self.canSchedulerShuffle(p_inst) + if p_inst.cycle_ready != new_cycle_ready: + # Only xrshuffles should have a changed cycle ready if slotted + # and got picked outside of a slot to schedule. + assert p_inst.cycle_ready < new_cycle_ready, \ + "Computed new cycle ready cannot be earlier than instruction's cycle ready." + p_inst.cycle_ready = new_cycle_ready # Update instruction's cycle ready + else: + # Found instruction to schedule at the head of queue + priority, retval = self.priority_queue.pop() + assert(retval.id == p_inst.id and priority == retval.cycle_ready) + if not retval: + # Found instruction that has incorrect priority, so, correct it + # (this may change its order in the priority queue) + self.priority_queue_push(p_inst) + + assert(retval) + + if not self.canSchedulerShuffleType(retval): + # Found rshuffle that requires routing table change, but other + # rshuffles with current routing table still available: + # Move rshuffle to next bundle + retval.cycle_ready = CycleType(bundle = self.current_cycle.bundle + 1, + cycle = 0) + # Put back in priority queue + self.priority_queue_push(retval) + retval = None # Continue looping to find another suitable instruction + + if retval and not self.canScheduleArithmeticXInstr(retval): + # Found XInst that requires metadata change, but other + # XInst with current metadata still available: + # Move XInst to next bundle + retval.cycle_ready = CycleType(bundle = self.current_cycle.bundle + 1, + cycle = 0) + # Put back in priority queue + self.priority_queue_push(retval) + retval = None # Continue looping to find another suitable instruction + + return retval + + def flushBundle(self): + """ + Flushes the current bundle. + + Raises: + RuntimeError: If the previous bundle was short and the current bundle is empty. + """ + if self.b_empty_bundle and len(self.xinsts_bundle) <= 0: + # Previous bundle was short + raise RuntimeError('Cannot flush an empty bundle.') + + self.b_empty_bundle = len(self.xinsts_bundle) <= 0 # Flag whether this is an empty bundle + # Flag if this is a short bundle + if len(self.xinsts_bundle) <= self.BUNDLE_INSTRUCTION_MIN_LIMIT: + self.num_short_bundles += 1 + else: + self.num_short_bundles = 0 + + # Complete the bundle + + instr = None + if len(self.xinsts_bundle) < self.max_bundle_size: + # Bundle not full: + # Schedule an exit bundle + tmp_comment = f" terminating bundle {self.current_cycle.bundle}" + if self.num_short_bundles > 0: + tmp_comment += ": short bundle" + instr = xinst.Exit(len(self.xinsts), comment=tmp_comment) + self.current_cycle += instr.schedule(self.current_cycle, len(self.xinsts_bundle) + 1) + self.appendXInstToBundle(instr) + + # Find bundle latency measurements before padding bundle + assert(not isinstance(self.xinsts_bundle[-1], xinst.Nop)) # Last instruction in bundle is not a nop + bundle_latency = self.current_bundle_latency + # Find last xstore in bundle + bundle_last_xstore = next((self.xinsts_bundle[idx] for idx in reversed(range(len(self.xinsts_bundle))) \ + if isinstance(self.xinsts_bundle[idx], xinst.XStore)), + None) + # Latency from last xstore is the total bundle latency minus the cycle where the xstore was scheduled: + # Measured from the cycle where last xstore was scheduled to the total latency + bundle_latency_from_last_xstore = (bundle_latency - bundle_last_xstore.schedule_timing.cycle.cycle) \ + if bundle_last_xstore \ + else bundle_latency + if bundle_latency_from_last_xstore < 0: + bundle_latency_from_last_xstore = 0 + + if not instr: + _, instr = self.priority_queue.peek() if len(self.priority_queue) > 0 else (0, self.xinsts_bundle[-1]) + for _ in range(self.max_bundle_size - len(self.xinsts_bundle)): + # Pad incomplete bundle with nops: + # Incomplete bundles are finished by an xexit, but need to be padded to max_bundle_size + b_scheduled = scheduleXNOP(instr, + 1, # Idle cycles + self, + force_nop=True) # We want nop to be added regardless of last in bundle + assert(b_scheduled) + + assert(len(self.xinsts_bundle) == self.max_bundle_size) + + # See if we need to sync to MInstQ before fetching bundle + if self.pre_bundle_csync_minstr[1]: + minstr = self.pre_bundle_csync_minstr[1] + assert(minstr.is_scheduled) + csyncm = cinst.CSyncm(minstr.id[0], minstr) + csyncm.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(csyncm) + self.pre_bundle_csync_minstr = (0, None) # Clear sync because we may not need in next bundle + + # Schedule the bundle fetch + ifetch = cinst.IFetch(self.xinsts_bundle[0].id[1], # ID of first instruction in bundle, just for book-keeping + self.current_cycle.bundle) + + # See if we need idle CInstQ cycles from previous bundle before ifetch this bundle + if len(self.xinsts) > 0: + # Find latency for the CInstQ since last cstore (or ifetch if not cstore) + idx = len(self.cinsts) - 1 + cq_throughput = 0 + while idx >= 0 \ + and not isinstance(self.cinsts[idx], (cinst.IFetch, cinst.CStore)): + cq_throughput += self.cinsts[idx].throughput + idx -= 1 + + # Added ifetch latency to avoid timing errors when bundles are short or empty + idle_c_cycles = self.xinsts[-1].latency_from_xstore - cq_throughput \ + + ifetch.latency + if idle_c_cycles > 0: + cnop = cinst.CNop(self.current_cycle.bundle, idle_c_cycles) + cnop.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(cnop) + + # See if we need to load a new rshuffle routing table + # (not counted in the nops before next bundle because we don't want to + # switch routing tables in mid rshuffle if it is still in flight) + if self.bundle_needed_rshuffle_type is not None \ + and self.bundle_current_rshuffle_type[0] != self.bundle_needed_rshuffle_type: + self.loadrShuffleRoutingTable(self.bundle_needed_rshuffle_type.RSHUFFLE_DATA_TYPE) + self.bundle_current_rshuffle_type = (self.bundle_needed_rshuffle_type, self.current_cycle.bundle) + + # See if we need to load new twid metadata + # (not counted in the nops before next bundle because we don't want to + # switch twid metadata in mid bundle if it is still in flight) + if self.bundle_needed_twid_segment >= 0 \ + and self.bundle_current_twid_segment != self.bundle_needed_twid_segment: + self.loadTwiddleMetadata(self.metadata_spad_addr_start_twid, self.bundle_needed_twid_segment) + self.bundle_current_twid_segment = self.bundle_needed_twid_segment + + # See if we need to load new ones metadata + # (not counted in the nops before next bundle because we don't want to + # switch ones metadata in mid bundle if it is still in flight) + if self.bundle_needed_ones_segment >= 0 \ + and self.bundle_current_ones_segment != self.bundle_needed_ones_segment: + self.loadBOnesMetadata(self.metadata_spad_addr_start_ones, self.bundle_needed_ones_segment) + self.bundle_current_ones_segment = self.bundle_needed_ones_segment + + ifetch.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(ifetch) + + # Add bundle to list of bundles + self.xinsts.append(BundleData(xinsts=self.xinsts_bundle, + latency=bundle_latency, + latency_from_xstore=bundle_latency_from_last_xstore)) + + # Schedule all the pending CInsts + for idx, cstore_instr in enumerate(self.post_bundle_cinsts): + var_name, (variable, dst_spad_addr) = self.mem_model.store_buffer.peek() + cstore_instr.schedule(self.current_cycle, len(self.cinsts) + idx + 1) + + # Check if this is an output variable which is done + if variable.name in self.mem_model.output_variables \ + and not variable.accessed_by_xinsts: + # Variable is output and it is not used anymore + # Sync to last CInst access to avoid storing before access completes + assert(self.mem_model.spad.getAccessTracking(dst_spad_addr).last_cstore[1] == cstore_instr) + msyncc = minst.MSyncc(cstore_instr.id[0], + cstore_instr) + msyncc.schedule(self.current_cycle, len(self.minsts) + 1) + self.minsts.append(msyncc) + dest_hbm_addr = variable.hbm_address + if dest_hbm_addr < 0: + if not auto_allocate: + raise RuntimeError(f"Variable {variable.name} not found in HBM.") + dest_hbm_addr = self.mem_model.hbm.findAvailableAddress(self.mem_model.output_variables) + if dest_hbm_addr < 0: + raise RuntimeError("Out of HBM space.") + mstore = minst.MStore(cstore_instr.id[0], + [ variable ], + self.mem_model, + dest_hbm_addr, + comment=(' id: {} - flushing').format(cstore_instr.id)) + mstore.schedule(self.current_cycle, len(self.minsts) + 1) + self.minsts.append(mstore) + + self.cinsts += self.post_bundle_cinsts + + # Clean up for next bundle + + self.current_bundle_latency = 0 + self.xinsts_bundle = [] + self.post_bundle_cinsts = [] + self.pending_write_cycles = [] + self.live_outs = set() + + self.bundle_needed_rshuffle_type = None + self.bundle_needed_ones_segment = -1 + self.bundle_needed_twid_segment = -1 + + # Reset all global cycle trackings + for xinstr_type in xinst.GLOBAL_CYCLE_TRACKING_INSTRUCTIONS: + xinstr_type.reset_GlobalCycleReady() + + # Free up bank 0 registers with stale dummy variables + # (dummies left as placeholders in bank 0 by previous bundles) + for idx in range(len(self.mem_model.register_banks)): + bank = self.mem_model.register_banks[idx] + for reg in bank: + if isinstance(reg.contained_variable, DummyVariable) \ + and reg.contained_variable.tag < self.current_cycle.bundle: + # Register was used more than a bundle ago and can be re-used + reg.allocateVariable(None) + + self.b_dependency_graph_changed = True + # Next bundle starts + assert(len(self.xinsts) == self.current_cycle.bundle + 1) + self.current_cycle = CycleType(bundle = len(self.xinsts), cycle = 1) + self.bundle_dummy_var = DummyVariable(self.current_cycle.bundle) # Dummy variable for new bundle + + def flushOutputVariableFromRegister(self, + variable, + xinstr = None) -> bool: + """ + Flushes an output variable from the register. + + Parameters: + variable: The variable to flush. + xinstr (optional): The instruction associated with the variable. Defaults to None. + + Returns: + bool: True if the variable was successfully flushed, or it didn't need flushing. + + Raises: + ValueError: If `xinstr` is None when there are no other XInstructions available in the listing. + """ + retval = True + + if not xinstr: + xinstr = self.last_xinstr + if not xinstr: + raise ValueError('`xinstr`: cannot be None when there are no other XInstructions available in the listing.') + if variable.register_dirty: + # Variable is in a dirty register: + # Flush the register + + # Find a location in SPAD + dest_spad_addr = variable.spad_address + if dest_spad_addr < 0: + dest_spad_addr = findSPADAddress(xinstr, self) + if dest_spad_addr < 0: + retval = False # No SPAD available, flush later + + if retval: + xstore = _createXStore(xinstr.id[0], + dest_spad_addr, + variable, + None, + ' flushing output', + self) + self.addDependency(xstore, None) + # Add to topo_sort + self.addXInstrToTopoSort(xstore.id) + return retval + + def generateKeyMaterial(self, + instr_id: int, + variable: Variable, + register: Register, + dep_id = None) -> int: + """ + Generates key material for the specified variable. + + Parameters: + instr_id (int): The ID of the instruction. + variable (Variable): The variable for which to generate key material. + register (Register): The register associated with the variable. + dep_id (optional): The dependency ID. Defaults to None. + + Returns: + int: 1 if good to go, 2 if leave for next bundle (could not generate key material). + + Raises: + ValueError: If the variable is not keygen. + RuntimeError: If the keygen variable has already been generated or if the keygen variable generation is out of order. + """ + # Key material cannot be generated if it requires a seed change, but current seed + # was already used in this bundle. + + retval = 1 + + if variable.name not in self.mem_model.keygen_variables: + raise ValueError('Variable "{}" is not keygen.'.format(variable.name)) + if self.mem_model.isVarInMem(variable.name): + raise RuntimeError('Keygen variable "{}" has already been generated.'.format(variable.name)) + + seed_idx, key_idx = self.mem_model.keygen_variables[variable.name] + + # Check the seed + if seed_idx != self.bundle_current_kgseed: + if self.bundle_used_kg_seed >= self.current_cycle.bundle: + # Current seed already used in this bundle: cannot change seeds + retval = 2 + else: + # Change seed to needed + self.loadKeygenSeedMetadata(self.metadata_spad_addr_start_kgseed, seed_idx) + + if retval == 1: + # Seed ready to be used to generate new key material + + if key_idx != self.last_keygen_index + 1: + raise RuntimeError(('Keygen variable "{}" generation out of order. ' + 'Expected key index {}, but received {} for seed {}.').format(variable.name, + self.last_keygen_index + 1, + key_idx, + self.bundle_current_kgseed)) + + comment = "" if dep_id is None else 'dep id: {}'.format(dep_id) + kg_load = cinst.KGLoad(instr_id, register, [ variable ], comment=comment) + # Nop required because kg_load/kg_start instructions have a resource dependency among them + cnop = cinst.CNop(instr_id, + kg_load.latency, + comment='kg_load {} wait period'.format(kg_load.id)) + cnop.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(cnop) + + kg_load.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(kg_load) + + # Seed used this bundle + self.bundle_used_kg_seed = self.current_cycle.bundle + self.last_keygen_index = key_idx # Advance the last generated index tracker + + return retval + + def loadrShuffleRoutingTable(self, + rshuffle_data_type_name: str): + """ + Queues CInstructions needed to load the `rshuffle` routing table into CE. + + Parameters: + rshuffle_data_type_name (str): One of { 'ntt', 'intt' }. + + Raises: + ValueError: If `rshuffle_data_type_name` is invalid. + RuntimeError: If the required routing table for the specified type is not present in metadata. + """ + # Select the correct targets based on rshuffle or irshuffle + RegisterTargets = constants.MemInfo.MetaTargets + aux_table_name = "" + aux_table_target = -1 + routing_table_name = "" + routing_table_target = -1 + if rshuffle_data_type_name == xinst.rShuffle.RSHUFFLE_DATA_TYPE: + aux_table_name = self.mem_model.meta_ntt_aux_table + routing_table_name = self.mem_model.meta_ntt_routing_table + elif rshuffle_data_type_name == xinst.irShuffle.RSHUFFLE_DATA_TYPE: + aux_table_name = self.mem_model.meta_intt_aux_table + routing_table_name = self.mem_model.meta_intt_routing_table + else: + raise ValueError(('`rshuffle_data_type_name`: invalid value "{}". Expected one of {}.').format(rshuffle_data_type_name, + { xinst.rShuffle.RSHUFFLE_DATA_TYPE, + xinst.irShuffle.RSHUFFLE_DATA_TYPE })) + # Only NTT targets are supported for both NTT and iNTT in RTL 0.9 + aux_table_target = RegisterTargets.TARGET_NTT_AUX_TABLE + routing_table_target = RegisterTargets.TARGET_NTT_ROUTING_TABLE + if aux_table_name and routing_table_name: + spad_map = QueueDict() # dict(var_name, (Variable, target_register)) + spad_map[aux_table_name] = (self.mem_model.variables[aux_table_name], + aux_table_target) + spad_map[routing_table_name] = (self.mem_model.variables[routing_table_name], + routing_table_target) + # Load meta SPAD -> special CE rshuffle registers + for shuffle_meta_table_name in spad_map: + variable, target_idx = spad_map[shuffle_meta_table_name] + assert variable.spad_address >= 0, f'Metadata variable {variable.name} must be in SPAD' + self.queueCSyncmLoad(0, variable.spad_address) + nload = cinst.NLoad(0, target_idx, variable, self.mem_model) + nload.comment = f' loading routing table for `{rshuffle_data_type_name}`' + nload.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(nload) + else: + raise RuntimeError(f'`rshuffle`: required routing table for `{rshuffle_data_type_name}` not present in metadata.') + + def loadBOnesMetadata(self, + spad_addr_offset: int, + ones_metadata_segment: int) -> int: + """ + Queues MInstructions and CInstructions needed to load the Ones metadata. + + Parameters: + spad_addr_offset (int): SPAD address offset where to store the metadata variables. + ones_metadata_segment (int): Segment of Metadata Ones variables to load. + The number of each segment is computed as + `rns // constants.MemoryModel.MAX_RESIDUALS (64)`. + Each segment contains the name of the variable containing identity metadata required + to perform arithmetic computations for the corresponding set of residuals. + + Returns: + int: Offset inside SPAD following the last location used to store the metadata variables. + + Raises: + IndexError: If the requested segment index is out of range. + RuntimeError: If the required number of twiddle metadata variables per segment is not met. + """ + # Assert constants + assert constants.MemoryModel.NUM_ONES_META_REGISTERS == 1 + + if ones_metadata_segment < 0 or ones_metadata_segment >= len(self.mem_model.meta_ones_vars_segments): + raise IndexError(('`twid_metadata_segment`: requested segment index {}, but there are only {} ' + 'segments of ones metadata available for up to {} residuals.').format(ones_metadata_segment, + len(self.mem_model.meta_ones_vars_segments), + len(self.mem_model.meta_ones_vars_segments) * constants.MemoryModel.MAX_RESIDUALS)) + + RegisterTargets = constants.MemInfo.MetaTargets + spad_addr = 0 + spad_map = QueueDict() # dict(var_name, (Variable, target_register)) + meta_ones_vars = self.mem_model.meta_ones_vars_segments[ones_metadata_segment] + + if meta_ones_vars \ + and len(meta_ones_vars) != constants.MemoryModel.NUM_ONES_META_REGISTERS: + raise RuntimeError("Required {} twiddle metadata variables per segment, but {} received.".format(constants.MemoryModel.NUM_ONES_META_REGISTERS, + len(meta_ones_vars))) + + # Load HBM -> SPAD + for meta_ones_var_name in meta_ones_vars: + target_spad_addr = spad_addr_offset + spad_addr + # Clean up SPAD location (will cause undefined behavior if XInstQ is still executing) + if self.mem_model.spad.buffer[target_spad_addr]: + self.mem_model.spad.deallocate(target_spad_addr) + # Load variable into SPAD + variable = self.mem_model.variables[meta_ones_var_name] + self.queueMLoad(0, target_spad_addr, variable, + comment='loading ones metadata for residuals [{}, {})'.format(ones_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (ones_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + spad_map[constants.MemInfo.MetaFields.FIELD_ONES] = (variable, RegisterTargets.TARGET_ONES) + spad_addr += 1 + + # Load meta SPAD -> special CE ones register + for ones_meta_name in spad_map: + variable, target_idx = spad_map[ones_meta_name] + self.queueCSyncmLoad(0, variable.spad_address) + bones = cinst.BOnes(0, target_idx, variable, self.mem_model, + comment='loading ones metadata for residuals [{}, {})'.format(ones_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (ones_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + bones.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(bones) + + # Update the currently loaded segment + self.bundle_current_ones_segment = ones_metadata_segment + + return spad_addr_offset + spad_addr + + def loadTwiddleMetadata(self, + spad_addr_offset: int, + twid_metadata_segment: int): + """ + Queues MInstructions and CInstructions needed to load the Twiddle factor generation metadata. + + Must not be called while XInstQ is executing. + + Parameters: + spad_addr_offset (int): SPAD address offset where to store the metadata variables. + twid_metadata_segment (int): Segment of Metadata Twiddle variables to load. Each segment is a list that + contains self.mem_model.MAX_TWIDDLE_META_VARS_PER_SEGMENT (8) variable names. + The number of each segment is computed as + `rns // constants.MemoryModel.MAX_RESIDUALS (64)`. + Each segment contains the name of the metadata variables required to compute + the twiddle factors for the corresponding set of residuals. + + Returns: + int: Offset inside SPAD following the last location used to store the metadata variables. + + Raises: + IndexError: If the requested segment index is out of range. + RuntimeError: If the required number of twiddle metadata variables per segment is not met. + """ + spad_addr = 0 + + if twid_metadata_segment < 0 or twid_metadata_segment >= len(self.mem_model.meta_twiddle_vars_segments): + raise IndexError(('`twid_metadata_segment`: requested segment index {}, but there are only {} ' + 'segments of twiddle metadata available for up to {} residuals.').format(twid_metadata_segment, + len(self.mem_model.meta_twiddle_vars_segments), + len(self.mem_model.meta_twiddle_vars_segments) * constants.MemoryModel.MAX_RESIDUALS)) + + meta_twiddle_vars = self.mem_model.meta_twiddle_vars_segments[twid_metadata_segment] + + if meta_twiddle_vars \ + and len(meta_twiddle_vars) != self.mem_model.MAX_TWIDDLE_META_VARS_PER_SEGMENT: + raise RuntimeError("Required {} twiddle metadata variables per segment, but {} received.".format(self.mem_model.MAX_TWIDDLE_META_VARS_PER_SEGMENT, + len(meta_twiddle_vars))) + + # Load HBM -> SPAD + for meta_twiddle_var_name in meta_twiddle_vars: + target_spad_addr = spad_addr_offset + spad_addr + # Clean up SPAD location (will cause undefined behavior if XInstQ is still executing) + if self.mem_model.spad.buffer[target_spad_addr]: + self.mem_model.spad.deallocate(target_spad_addr) + # Load variable into SPAD + variable = self.mem_model.variables[meta_twiddle_var_name] + self.queueMLoad(0, target_spad_addr, variable, + comment='loading twid metadata for residuals [{}, {})'.format(twid_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (twid_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + spad_addr += 1 + + # Load meta SPAD -> special CE twiddle registers + target_bload_register = 0 + for meta_twiddle_var_name in meta_twiddle_vars: + variable = self.mem_model.variables[meta_twiddle_var_name] + for col_num in range(constants.MemoryModel.NUM_BLOCKS_PER_TWID_META_WORD): # Block + self.queueCSyncmLoad(0, variable.spad_address) + bload = cinst.BLoad(0, + col_num, + target_bload_register, + variable, + self.mem_model, + comment='loading twid metadata for residuals [{}, {})'.format(twid_metadata_segment * constants.MemoryModel.MAX_RESIDUALS, + (twid_metadata_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + bload.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(bload) + target_bload_register += 1 + + # Update the currently loaded segment + self.bundle_current_twid_segment = twid_metadata_segment + + return spad_addr_offset + spad_addr + + def loadKeygenSeedMetadata(self, + spad_addr_offset: int, + kgseed_idx: int) -> int: + """ + Queues MInstructions and CInstructions needed to load a new keygen seed. + + Keygen does not affect the XInstQ, and can be called to switch seeds when + needed. + + Parameters: + spad_addr_offset (int): SPAD address offset where to store the seed variable. + kgseed_idx (int): Index of the seed to load. There are 4 seeds in a word. This index will + be properly mapped into (word, block) to load the proper seed. + + Returns: + int: Offset inside SPAD following the last location used to store the metadata variable. + + Raises: + IndexError: If the seed index is out of range. + """ + if kgseed_idx < 0 \ + or kgseed_idx >= len(self.mem_model.meta_keygen_seed_vars) * constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD: + raise IndexError('`kgseed_idx` must index in the range [0, {}), but {} received'.format(len(self.mem_model.meta_keygen_seed_vars) * constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD, + kgseed_idx)) + # Only switch seeds if different from current + if kgseed_idx != self.bundle_current_kgseed: + + spad_addr = 0 + # One word contains 4 seeds: find the right seed + seed_word_block = kgseed_idx % constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD + seed_word_idx = kgseed_idx // constants.MemoryModel.NUM_BLOCKS_PER_KGSEED_META_WORD + seed_variable = None + # Unfortunately, kg seeds are not contained in a list, + # so we have to loop through the container to find the var name based on index + for idx, var_name in enumerate(self.mem_model.meta_keygen_seed_vars): + if idx == seed_word_idx: + seed_variable = self.mem_model.variables[var_name] + break + + # Load HBM -> SPAD + target_spad_addr = spad_addr_offset + spad_addr + # Clean up SPAD location (will cause undefined behavior if XInstQ is still executing) + if self.mem_model.spad.buffer[target_spad_addr]: + self.mem_model.spad.deallocate(target_spad_addr) + # Load variable into SPAD + self.queueMLoad(0, target_spad_addr, seed_variable, + comment='loading keygen seed ({}, block = {})'.format(seed_word_idx, + seed_word_block)) + spad_addr += 1 + + # Load seed SPAD -> key material generation subsystem + + self.queueCSyncmLoad(len(self.cinsts), seed_variable.spad_address) + + kg_seed = cinst.KGSeed(len(self.cinsts), + seed_word_block, + seed_variable, + self.mem_model) + kg_start = cinst.KGStart(len(self.cinsts) + 1, comment=f'seed {kgseed_idx}') + + kg_seed.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(kg_seed) + kg_start.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(kg_start) + + # Update the currently loaded seed + self.bundle_current_kgseed = kgseed_idx + self.last_keygen_index = -1 # Restart the keygen index + + return spad_addr_offset + spad_addr + + def loadMetadata(self): + """ + Loads initial metadata at the start of the program. + """ + spad_addr_offset = 0 + spad_addr_offset = self.prepareShuffleMetadata(spad_addr_offset) + self.metadata_spad_addr_start_twid = spad_addr_offset + spad_addr_offset = self.loadTwiddleMetadata(spad_addr_offset, 0) + self.metadata_spad_addr_start_ones = spad_addr_offset + spad_addr_offset = self.loadBOnesMetadata(spad_addr_offset, 0) + if len(self.mem_model.meta_keygen_seed_vars) > 0: + # Keygen used in this program + self.metadata_spad_addr_start_kgseed = spad_addr_offset + spad_addr_offset = self.loadKeygenSeedMetadata(spad_addr_offset, 0) + + def prepareShuffleMetadata(self, + spad_addr_offset: int) -> int: + """ + Queues MInstructions needed to load the `rshuffle` metadata into SPAD. + + Parameters: + spad_addr_offset (int): SPAD address offset where to store the metadata variables. + + Returns: + int: Offset inside SPAD following the last location used to store the metadata variables. + + Raises: + RuntimeError: If both NTT Auxiliary table and Routing table or both iNTT Auxiliary table and Routing table do not exist in memory model. + """ + spad_addr = 0 + + # Load HBM -> SPAD + if self.mem_model.meta_ntt_aux_table \ + and self.mem_model.meta_ntt_routing_table: + variable = self.mem_model.variables[self.mem_model.meta_ntt_aux_table] + self.queueMLoad(0, spad_addr_offset + spad_addr, variable) + spad_addr += 1 + + variable = self.mem_model.variables[self.mem_model.meta_ntt_routing_table] + self.queueMLoad(0, spad_addr_offset + spad_addr, variable) + spad_addr += 1 + else: + # If one of NTT aux table or routing table is specified, so must be the other + raise RuntimeError('Both, NTT Auxiliary table and Routing table must exist in memory model.') + + if self.mem_model.meta_intt_aux_table \ + and self.mem_model.meta_intt_routing_table: + variable = self.mem_model.variables[self.mem_model.meta_intt_aux_table] + self.queueMLoad(0, spad_addr_offset + spad_addr, variable) + spad_addr += 1 + + variable = self.mem_model.variables[self.mem_model.meta_intt_routing_table] + self.queueMLoad(0, spad_addr_offset + spad_addr, variable) + spad_addr += 1 + else: + # If one of iNTT aux table or routing table is specified, so must be the other + raise RuntimeError('Both, iNTT Auxiliary table and Routing table must exist in memory model.') + + return spad_addr_offset + spad_addr + + def priority_queue_push(self, xinstr, tie_breaker = None): + """ + Adds a new instruction to the priority queue. + + Instructions added will be correctly handled by all priority queues. + + Parameters: + xinstr: The instruction to add to the priority queue. + tie_breaker (optional): The tie breaker value. Defaults to None. + + Raises: + AssertionError: If the instruction is not in the dependency graph. + """ + assert xinstr.id in self.dependency_graph, f'{xinstr.id} NOT in simulation.dependency_graph' + if isinstance(xinstr, xinst.XStore): + if tie_breaker is None: + tie_breaker = (-1, ) + self.xstore_pq.push(xinstr.cycle_ready, xinstr, tie_breaker) + if isinstance(xinstr, xinst.Move): + if tie_breaker is None: + tie_breaker = (-2, ) + self.priority_queue.push(xinstr.cycle_ready, xinstr, tie_breaker) + self.set_extracted_xinstrs.add(xinstr.id) + + def priority_queue_remove(self, xinstr): + """ + Removes an instruction from the priority queue. + + Instructions removed will be correctly handled by all priority queues. + + Parameters: + xinstr: The instruction to remove from the priority queue. + """ + self.priority_queue.remove(xinstr) + if xinstr in self.xstore_pq: + assert isinstance(xinstr, xinst.XStore) + self.xstore_pq.remove(xinstr) + + def queueCSyncmLoad(self, + instr_id: int, + source_spad_addr: int): + """ + Checks if needed, and, if so, queues a CSyncm CInstruction to sync to + SPAD access from HBM in order to write from SPAD into CE. + + Parameters: + instr_id (int): ID for the MSyncc instruction. + source_spad_addr (int): SPAD address to sync to for writing. + """ + last_mload_access = self.mem_model.spad.getAccessTracking(source_spad_addr).last_mload[1] + if last_mload_access: + # Need to sync to MInst + csyncm = cinst.CSyncm(instr_id, last_mload_access) + csyncm.schedule(self.current_cycle, len(self.cinsts) + 1) + self.cinsts.append(csyncm) + + def queueMLoad(self, + instr_id: int, + target_spad_addr: int, + variable, + comment = ""): + """ + Generates instructions to copy from HBM into SPAD. + + Parameters: + instr_id (int): The ID of the instruction. + target_spad_addr (int): The target SPAD address. + variable: The variable to load. + comment (optional): A comment associated with the instruction. Defaults to an empty string. + + Raises: + ValueError: If the target SPAD address is negative. + RuntimeError: If the variable is not found in HBM or if out of HBM space. + """ + # Generate instructions to copy from HBM into SPAD + if target_spad_addr < 0: + raise ValueError('Argument Null Exception: Target SPAD address cannot be null (negative address).') + + self.queueMSynccLoad(instr_id, target_spad_addr) + if variable.hbm_address < 0: + if not auto_allocate: + raise RuntimeError(f"Variable {variable.name} not found in HBM.") + hbm_addr = self.mem_model.hbm.findAvailableAddress(self.mem_model.output_variables) + if hbm_addr < 0: + raise RuntimeError("Out of HBM space.") + self.mem_model.hbm.allocateForce(hbm_addr, variable) + mload = minst.MLoad(instr_id, [ variable ], self.mem_model, target_spad_addr, comment=comment) + mload.schedule(self.current_cycle, len(self.minsts) + 1) + self.minsts.append(mload) + + def queueMSynccLoad(self, + instr_id: int, + target_spad_addr: int): + """ + Checks if needed, and, if so, queues an MSyncc MInstruction to sync to + SPAD access to write from HBM into specified SPAD address. + + Parameters: + instr_id (int): ID for the MSyncc instruction. + target_spad_addr (int): SPAD address to sync to for writing. + + Raises: + ValueError: If the target SPAD address is negative. + """ + if target_spad_addr < 0: + raise ValueError('Argument Null Exception: Target SPAD address cannot be null (negative address).') + + # mload depends on the last c access (cload or cstore) + last_access = self.mem_model.spad.getAccessTracking(target_spad_addr) + last_c_access = last_access.last_cstore + if not last_access.last_cstore[1] \ + or (last_access.last_cload[1] \ + and last_access.last_cload[0] > last_access.last_cstore[0]): + # No last cstore or cload happened after cstore + last_c_access = last_access.last_cload + last_c_access = last_c_access[1] + if last_c_access: + # Need to sync to CInst + assert(last_c_access.is_scheduled) + msyncc = minst.MSyncc(instr_id, last_c_access) + msyncc.schedule(self.current_cycle, len(self.minsts) + 1) + self.minsts.append(msyncc) + + def updateQueuesSyncsPass2(self): + """ + Updates the msyncc and csyncm to correct instruction index + after the scheduling completes. + + This is the second pass. + """ + # Create reverse look-up maps for all CInsts and MInsts + + map_cinsts = dict((cinstr.id, idx) for idx, cinstr in enumerate(self.cinsts)) + map_minsts = dict((minstr.id, idx) for idx, minstr in enumerate(self.minsts)) + + # Traverse MInstQ and update msyncc targets + for minstr in self.minsts: + if isinstance(minstr, minst.MSyncc): + target_cinstr = minstr.cinstr + if isinstance(target_cinstr, cinst.CExit): + # Rule, msyncc pointing to cexit, has to point to the next instruction + target_cinstr.set_schedule_timing_index(map_cinsts[target_cinstr.id] + 1) + else: + target_cinstr.set_schedule_timing_index(map_cinsts[target_cinstr.id]) + minstr.freeze() # Re-freeze with new value + + # Traverse CInstQ and update csyncm targets + for cinstr in self.cinsts: + if isinstance(cinstr, cinst.CSyncm): + target_minstr = cinstr.minstr + target_minstr.set_schedule_timing_index(map_minsts[target_minstr.id]) + cinstr.freeze() # Re-freeze with new value + + def updateSchedule(self, instr) -> bool: + """ + Updates the simulation pending schedule after `instr` has been scheduled. + + Parameters: + instr: An instruction in the dependency graph. + + Returns: + bool: True if bundle is full after scheduling the instruction, False otherwise. + + Raises: + ValueError: If `instr` is None or not in the dependency graph. + RuntimeError: If the bundle is already full or if an attempt is made to schedule an instruction in a bundle that only allows specific types or residuals. + """ + if not instr: + raise ValueError('`instr` cannot be `None`.') + if instr.id not in self.dependency_graph: + raise ValueError(f'`instr`: invalid instruction "{instr}" not in dependency graph.') + if len(self.xinsts_bundle) >= self.max_bundle_size: + raise RuntimeError("Bundle already full.") + + dependents = list(self.dependency_graph.successors(instr.id)) # Find instructions that depend on this instruction + self.dependency_graph.remove_node(instr.id) # Remove from graph to update the in_degree of dependent instrs + self.b_dependency_graph_changed = True + # "move" dependent instrs that have no other dependencies to the top of the topo sort + if isinstance(instr, xinst.XStore): + for instr_id in dependents: + if self.dependency_graph.in_degree(instr_id) <= 0: + if instr_id not in self.set_extracted_xinstrs: + self.priority_queue_push(self.dependency_graph.nodes[instr_id]['instruction']) + else: + self.topo_sort = [ instr_id for instr_id in dependents if self.dependency_graph.in_degree(instr_id) <= 0 ] + self.topo_sort + self.b_topo_sort_changed = True + + if instr in self.priority_queue: + self.priority_queue_remove(instr) + + # Do not search the topo sort to actually remove the duplicated instrs because it is O(N) costly: + # set_extracted_xinstrs will take care of skipping them once encountered. + + self.scheduled_xinsts_count += 1 + + if isinstance(instr, xinst.XStore): + # Add corresponding cstore + cstore = cinst.CStore(instr.id[0], + self.mem_model, + comment=instr.comment) + self.post_bundle_cinsts.append(cstore) + # Make sure bundle syncs to last mstore before fetching because + # it does cstores that overwrite SPAD addresses that may still be in process + # of storing to HBM: + last_mstore = self.mem_model.spad.getAccessTracking(instr.dest_spad_address).last_mstore + if self.pre_bundle_csync_minstr[0] <= last_mstore[0] \ + and last_mstore[1] is not None: + self.pre_bundle_csync_minstr = last_mstore + + if isinstance(instr, (xinst.rShuffle, xinst.irShuffle)): + # Rule: no more than one write to same bank in the same cycle. + # rshuffles have different latency than other XInsts, so, we must ensure that their + # write cycle is respected. + + # Add rshuffle to list of pending writes + scheduled_cycle = instr.schedule_timing.cycle + write_cycle = XWriteCycleTrack(cycle = CycleType(bundle = scheduled_cycle.bundle, + cycle = scheduled_cycle.cycle + instr.latency - 1), + banks = set(v.suggested_bank for v in instr.dests)) + self.pending_write_cycles.append(write_cycle) + + # Track the scheduled xrshuffle to try to schedule others in slotted intervals + self.last_xrshuffle = instr + + # Rule: cannot mix rShuffle and irShuffle in same bundle. + + if self.bundle_needed_rshuffle_type is None: + self.bundle_needed_rshuffle_type = type(instr) + elif not isinstance(instr, self.bundle_needed_rshuffle_type): + raise RuntimeError('Attempted to schedule {} in bundle that only allows {}.'.format(instr, + self.bundle_needed_rshuffle_type)) + + # Rule: cannot mix XInsts of different residual segments in same bundle. + if instr.res is not None: + instr_needed_segment = instr.res // constants.MemoryModel.MAX_RESIDUALS + assert self.bundle_needed_ones_segment == self.bundle_needed_twid_segment, \ + 'Needed Ones and Twiddle metadata segments are not synchronized.' + if self.bundle_needed_ones_segment == -1: + self.bundle_needed_ones_segment = instr_needed_segment + elif self.bundle_needed_ones_segment != instr_needed_segment: + raise RuntimeError(('Attempted to schedule XInstruction "{}", residual = {}, ' + 'in bundle {} that only allows residuals in range [{}, {}).').format(str(instr), + instr.res, + self.current_cycle.bundle, + self.bundle_needed_ones_segment * constants.MemoryModel.MAX_RESIDUALS, + (self.bundle_needed_ones_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + if self.bundle_needed_twid_segment == -1: + self.bundle_needed_twid_segment = instr_needed_segment + elif self.bundle_needed_twid_segment != instr_needed_segment: + raise RuntimeError(('Attempted to schedule XInstruction {}, residual = {}, ' + 'in bundle {} that only allows residuals in range [{}, {}).').format(str(instr), + instr.res, + self.current_cycle.bundle, + self.bundle_needed_twid_segment * constants.MemoryModel.MAX_RESIDUALS, + (self.bundle_needed_twid_segment + 1) * constants.MemoryModel.MAX_RESIDUALS)) + + self.appendXInstToBundle(instr) # add instruction to bundle + + # True <=> bundle needs to be flushed (because of exit or full) + return isinstance(instr, xinst.Exit) \ + or (len(self.xinsts_bundle) >= self.max_bundle_size) + +def __canScheduleInBundle(instr, simulation: Simulation, padding: int = 1) -> bool: + """ + Determines if an instruction can be scheduled in the current bundle. + + Parameters: + instr: The instruction to be scheduled. + simulation (Simulation): The current simulation context. + padding (int): Extra instruction padding (like number of other instructions before this one, such as a nop). + + Returns: + bool: True if the instruction can be scheduled in the current bundle, False otherwise. + """ + # TODO: + # Look into this function to see if we can bring back skip scheduling of rshuffles at the end of bundles. + # Right now, this featuer is disabled because ifetch does not have the same latency as XInstrs, so + # the simulation keeps track of the whole bundle latency and just adds nops to the CInstQ as needed. + #----------------- + return len(simulation.xinsts_bundle) < simulation.max_bundle_size and True + +def __flushVariableFromSPAD(instr, dest_hbm_addr: int, variable: Variable, simulation: Simulation) -> bool: + """ + Flushes a variable from the SPAD to HBM. + + Parameters: + instr: The instruction triggering the flush. + dest_hbm_addr (int): The destination address in HBM. + variable (Variable): The variable to be flushed. + simulation (Simulation): The current simulation context. + + Returns: + bool: True if the flush was scheduled successfully, False otherwise. + + Raises: + AssertionError: If the destination HBM address is invalid. + """ + assert(dest_hbm_addr >= 0) + + spad = simulation.mem_model.spad + + comment = (' id: {} - flushing').format(instr.id) + last_cstore = spad.getAccessTracking(variable.spad_address).last_cstore[1] + if last_cstore: + # mstore needs to happen after last cstore + assert(last_cstore.is_scheduled) + # Sync to last CInst access to avoid storing before access completes + msyncc = minst.MSyncc(instr.id[0], last_cstore, comment=comment) + msyncc.schedule(simulation.current_cycle, len(simulation.minsts) + 1) + simulation.minsts.append(msyncc) + + mstore = minst.MStore(instr.id[0], [variable], simulation.mem_model, dest_hbm_addr, comment=comment) + mstore.schedule(simulation.current_cycle, len(simulation.minsts) + 1) + simulation.minsts.append(mstore) + + return True + +def _createXStore(instr_id: int, dest_spad_addr: int, evict_variable: Variable, new_variable: Variable, comment: str, simulation: Simulation) -> object: + """ + Creates an XStore instruction to move a variable into SPAD. + + Parameters: + instr_id (int): The instruction ID. + dest_spad_addr (int): The destination SPAD address. + evict_variable (Variable): The variable to evict. + new_variable (Variable): The variable to be allocated in register after eviction, or None to keep register free. + comment (str): A comment for the instruction. + simulation (Simulation): The current simulation context. + + Returns: + object: The created XStore instruction. + + Raises: + AssertionError: If the evict variable's register is None or if the SPAD address is invalid. + """ + assert(evict_variable.register is not None) + assert(evict_variable.spad_address < 0 or evict_variable.spad_address == dest_spad_addr) + + spad = simulation.mem_model.spad + + # Block SPAD address to avoid it being found by another findSPADAddress + if spad[dest_spad_addr]: + assert(not isinstance(spad[dest_spad_addr], DummyVariable)) + spad.deallocate(dest_spad_addr) + spad.allocateForce(dest_spad_addr, DummyVariable()) + # Generate the xstore instruction to move variable into SPAD + xstore = XStoreAssign(instr_id, [evict_variable], simulation.mem_model, new_variable, dest_spad_addr=dest_spad_addr, comment=comment) \ + if new_variable else \ + xinst.XStore(instr_id, [evict_variable], simulation.mem_model, dest_spad_addr=dest_spad_addr, comment=comment) + evict_variable.accessed_by_xinsts = [Variable.AccessElement(0, xstore.id)] + evict_variable.accessed_by_xinsts + + return xstore + +def __flushVariableFromRegisterFile(instr, dest_spad_addr: int, evict_variable: Variable, new_variable: Variable, simulation: Simulation) -> object: + """ + Flushes a variable from the register file to SPAD. + + Parameters: + instr: The instruction triggering the flush. + dest_spad_addr (int): The destination SPAD address. + evict_variable (Variable): The variable to evict. + new_variable (Variable): The variable to be allocated in register after eviction, or None to keep register free. + simulation (Simulation): The current simulation context. + + Returns: + object: The created XStore instruction. + """ + comment = (' dep id: {} - flushing'.format(instr.id)) + xstore = _createXStore(instr.id[0], dest_spad_addr, evict_variable, new_variable, comment, simulation) + simulation.addDependency(xstore, instr) + + return xstore + +def scheduleXNOP(instr, idle_cycles: int, simulation: Simulation, force_nop: bool = False) -> bool: + """ + Schedules a NOP instruction if necessary. + + Parameters: + instr: The instruction that requires the NOP. + idle_cycles (int): The number of idle cycles to schedule. + simulation (Simulation): The current simulation context. + force_nop (bool): Whether to force the scheduling of a NOP. + + Returns: + bool: True if the NOP was scheduled, False otherwise. + + Raises: + ValueError: If idle_cycles is not greater than 0. + """ + if idle_cycles <= 0: + raise ValueError(f'`idle_cycles`: expected greater than `0`, but {idle_cycles} received.') + + retval = True + + comment = "" + if not isinstance(instr, xinst.Exit): + comment = f" nop for not ready instr {instr.id}" + #prev_xinst = simulation.xinsts_bundle[-1] if len(simulation.xinsts_bundle) > 0 else None + prev_xinst = None # rshuffle wait cycle no longer works + if not force_nop and isinstance(prev_xinst, (xinst.rShuffle, xinst.irShuffle)): + # Add idle cycles using previous rshuffle + prev_xinst.wait_cyc = idle_cycles + if comment: + prev_xinst.comment += "{} {}".format(";" if len(prev_xinst.comment) > 0 else "", comment) + prev_xinst.freeze() # Refreeze rshuffle to reflect the new wait_cyc + simulation.current_cycle += idle_cycles # Advance current cycle + else: + retval = force_nop or len(simulation.xinsts_bundle) < simulation.max_bundle_size - 1 + if retval: + assert len(simulation.xinsts_bundle) < simulation.max_bundle_size, 'Cannot queue NOP into full bundle.' + xnop = xinst.Nop(instr.id[0], idle_cycles, comment=comment) + simulation.current_cycle += xnop.schedule(simulation.current_cycle, len(simulation.xinsts_bundle) + 1) + simulation.appendXInstToBundle(xnop) + + return retval + +def findSPADAddress(instr, simulation: Simulation) -> int: + """ + Finds an available SPAD address for an instruction. + + Parameters: + instr: The instruction needing SPAD. + simulation (Simulation): The current simulation context. + + Returns: + int: The SPAD address, or -1 if no address is available. + + Raises: + RuntimeError: If no SPAD address is available or if HBM is full. + """ + # Logic: + # if found empty spad_address: + # return spad_address + # else if no empty spad_address: + # Eviction (removal) of variable from SPAD needed + # find spad_address to evict using replacement policy and avoid live variables + # if found spad_address to evict: + # Eviction: + # if variable to evict is in register: + # no need to flush SPAD, just evict variable since it is in active use in register file (mark as dirty in register). + # else if variable is dirty in spad: + # flush (copy) variable to HBM, and evict from SPAD + # return spad_address + # if no spad_address found: + # return null spad_address (-1) + # + # returns retval_addr: int + # retval_addr < 0 if no address is available in SPAD for this bundle + + # Find an address in SPAD + spad = simulation.mem_model.spad + # Make live_vars be all variables in SPAD not in registers + live_vars = set(var_name for var_name in simulation.live_vars if var_name in spad and spad[var_name].register is None) + retval_addr: int = spad.findAvailableAddress(live_vars, simulation.replacement_policy) + if retval_addr < 0: + # Drastic measure to avoid running our of SPAD + # retval_addr: int = spad.findAvailableAddress(set(), simulation.replacement_policy) + + # Not implemented: throws if spad is full of live variables + raise RuntimeError(f"No SPAD address available. Bundle {simulation.current_cycle.bundle}") + if retval_addr >= 0: + # SPAD address found + variable: Variable = spad.buffer[retval_addr] + if variable: # Contains a variable + assert(variable.spad_address == retval_addr) + # Address needs to be evicted + if variable.spad_dirty: + # Check usage + if len(variable.accessed_by_xinsts) > 0 or variable.name in simulation.mem_model.output_variables: + # Check if SPAD flush is necessary + if not variable.register: + # SPAD flush necessary + + if variable.hbm_address < 0: + # Need a new location in HBM to store the variable + new_hbm_addr = simulation.mem_model.hbm.findAvailableAddress(simulation.mem_model.output_variables) + else: + # Variable already has a location in SPAD + new_hbm_addr = variable.hbm_address + + if new_hbm_addr < 0: + # HBM full + raise RuntimeError("Out of HBM space.") + + # Found HBM address + + # Queue operations to flush SPAD + # (this will deallocate variable from SPAD) + evict_scheduled = __flushVariableFromSPAD(instr, new_hbm_addr, variable, simulation) + if not evict_scheduled: + retval_addr = -1 # Could not schedule the eviction in this cycle + + else: # Variable resides in register + # Mark register as dirty to make sure we flush to cache when done in the register file + variable.register_dirty = True + # Now, just clear cache and keep it in register + else: + # Variable no longer used by remaining XInstructions, + # so, just get rid of it + variable.spad_dirty = False + + if retval_addr >= 0: + if variable.spad_address >= 0: + assert(variable.spad_address == retval_addr) + # Variable still in SPAD + # SPAD address now clean, just free the address + spad.deallocate(retval_addr) + + return retval_addr + +def findRegister(instr, bank_idx: int, simulation: Simulation, override_replacement_policy: str = None, dest_var: Variable = None) -> object: + """ + Finds an available register for an instruction. + + Parameters: + instr: The instruction needing a register. + bank_idx (int): The index of the register bank. + simulation (Simulation): The current simulation context. + override_replacement_policy (str): The replacement policy to override, if any. + dest_var (Variable): The variable to be allocated in register after eviction, or None to keep register free. + + Returns: + tuple: A tuple containing the ready value (int) and the register or XInstruction. + + Raises: + RuntimeError: If no SPAD address is available or if HBM is full. + """ + # Logic: + # find empty retval_register in register file + # if retval_register found: + # return retval_register + # else if no empty register: + # Eviction (removal) of variable from register file needed + # find retval_register to evict using replacement policy and avoid live variables + # if found retval_register to evict: + # Eviction: + # if register is clean: + # no need to flush register, just evict variable since it has not been writen to. + # else, register is dirty: + # need to flush variable to SPAD cache: + # flush logic: + # find appropriate SPAD address to flush to. + # if no SPAD address found: + # return null retval_register (None) + # else, SPAD address found: + # copy variable from register to SPAD + # evict variable from register + # return retval_register + # if no retval_register found: + # return null retval_register (None) + # + # returns ready_value: int, retval_register: Register if ready_value == 1 else XInstruction + # retval_register is None if no register is available for this bundle + + def inner_computeLiveVars(register_bank): + # Returns an iterable over all variable names in the register bank that are live variables + for r_i in range(register_bank.register_count): + v: Variable = register_bank.getRegister(r_i).contained_variable + if v and v.name and ((v.name in simulation.live_vars) or (v.cycle_ready > simulation.current_cycle)): + yield v.name + + retval = 1 + if override_replacement_policy is None: + override_replacement_policy = simulation.replacement_policy + + # Find a register from specified bank + register_bank = simulation.mem_model.register_banks[bank_idx] + # Compute live variables + live_vars = set(inner_computeLiveVars(register_bank)) + + retval_register: Register = register_bank.findAvailableRegister(live_vars, override_replacement_policy) + if retval_register: + # Register found + if retval_register.contained_variable: + # Register needs to evict contained variable + variable = retval_register.contained_variable + assert(not isinstance(variable, DummyVariable)) + if variable.register_dirty: + # Check usage + if len(variable.accessed_by_xinsts) > 0 or variable.name in simulation.mem_model.output_variables: + # Flush necessary + if variable.spad_address < 0: + # Need a new location in SPAD to store the variable + new_spad_addr = findSPADAddress(instr, simulation) + else: + # Variable already has a location in SPAD + new_spad_addr = variable.spad_address + + if new_spad_addr < 0: + # No SPAD address available this bundle + retval_register = None + retval = 0 + else: + # Found SPAD address + + # Evict variable + retval_register = __flushVariableFromRegisterFile(instr, new_spad_addr, variable, dest_var, simulation) + retval = 2 + else: + # Variable no longer used by remaining XInstructions, + # so, just get rid of it + variable.register_dirty = False + + if retval == 1: # Register clean + # No eviction is necessary, just free the register for destination variable (or none if no destination) + retval_register.allocateVariable(dest_var) + + if not retval_register: + retval = 0 + + return retval, retval_register + +def loadVariableHBMToSPAD(instr, variable: Variable, simulation: Simulation) -> bool: + """ + Loads a variable from HBM to SPAD. + + Parameters: + instr: The instruction needing the variable. + variable (Variable): The variable to be loaded. + simulation (Simulation): The current simulation context. + + Returns: + bool: True if the variable was loaded successfully, False otherwise. + + Raises: + RuntimeError: If the variable is not found in HBM or if HBM is full. + """ + # Schedules a list of instructions needed to load the specified variable from HBM into SPAD. + spad = simulation.mem_model.spad + + target_spad_addr = -1 # This will be used as our flag to track valid state (-1 = not valid) + + if variable.name not in simulation.mem_model.store_buffer: # Check variable is not in transit from CE + if variable.spad_address >= 0: + # Variable already in SPAD + target_spad_addr = variable.spad_address + else: + # Bring variable from HBM into SPAD + + # Need a new location in SPAD to store the variable + target_spad_addr = findSPADAddress(instr, simulation) + if target_spad_addr >= 0: + # We are still in valid state + + # Generate instructions to copy from HBM into SPAD + + # Mload depends on the last c access (cload or cstore) + last_access = spad.getAccessTracking(target_spad_addr) + last_c_access = last_access.last_cstore + if not last_access.last_cstore[1] or (last_access.last_cload[1] and last_access.last_cload[0] > last_access.last_cstore[0]): + # No last cstore or cload happened after cstore + last_c_access = last_access.last_cload + last_c_access = last_c_access[1] + if last_c_access: + # Need to sync to CInst + assert(last_c_access.is_scheduled) + msyncc = minst.MSyncc(instr.id[0], last_c_access) + msyncc.schedule(simulation.current_cycle, len(simulation.minsts)) + simulation.minsts.append(msyncc) + if variable.hbm_address < 0: + hbm_addr = simulation.mem_model.hbm.findAvailableAddress(simulation.mem_model.output_variables) + if hbm_addr < 0: + raise RuntimeError("Out of HBM space.") + simulation.mem_model.hbm.allocateForce(hbm_addr, variable) + mload = minst.MLoad(instr.id[0], [variable], simulation.mem_model, target_spad_addr, comment="dep id: {}".format(instr.id)) + mload.schedule(simulation.current_cycle, len(simulation.minsts) + 1) + simulation.minsts.append(mload) + + return target_spad_addr >= 0 + +def hasBankWriteConflictGeneral(ready_cycle: CycleType, latency: int, banks, simulation: Simulation) -> bool: + """ + Checks for bank write conflicts in general. + + Parameters: + ready_cycle (CycleType): The cycle when the instruction is ready. + latency (int): The latency of the instruction. + banks: An iterable of bank indices. + simulation (Simulation): The current simulation context. + + Returns: + bool: True if there is a bank write conflict, False otherwise. + """ + retval = False + if ready_cycle.bundle <= simulation.current_cycle.bundle: # Instruction has no conflicts if it is on a later bundle + instr_write_cycle = XWriteCycleTrack(cycle=CycleType(bundle=simulation.current_cycle.bundle, cycle=max(ready_cycle.cycle, simulation.current_cycle.cycle) + latency - 1), banks=set(banks)) + if len(instr_write_cycle.banks) > 0: + for rshuffle_write_cycle in simulation.pending_write_cycles: + if instr_write_cycle.cycle < rshuffle_write_cycle.cycle: + # Instruction write cycle happens before conflicting with examined write cycle + # and thus will not conflict with any other write cycles in the list because it + # is ordered by write cycle + break + # Check if we conflict + if instr_write_cycle.cycle == rshuffle_write_cycle.cycle and len(instr_write_cycle.banks & rshuffle_write_cycle.banks) > 0: + # Instruction bank writes conflict with a previous write cycle + retval = True + break + + return retval + +def hasBankWriteConflict(instr, simulation: Simulation) -> bool: + """ + Checks for bank write conflicts for a specific instruction. + + Parameters: + instr: The instruction to check. + simulation (Simulation): The current simulation context. + + Returns: + bool: True if there is a bank write conflict, False otherwise. + """ + ready_cycle = instr.cycle_ready + if ready_cycle.bundle < simulation.current_cycle.bundle: + ready_cycle = CycleType(bundle=simulation.current_cycle.bundle, cycle=0) + + if isinstance(instr, xinst.XStore): + banks = set() # Xstore does not write to register file + else: + banks = set(v.suggested_bank for v in instr.dests if isinstance(v, Variable)) + banks |= set(r.bank.bank_index for r in instr.dests if isinstance(r, Register)) + + return hasBankWriteConflictGeneral(ready_cycle, instr.latency, banks, simulation) + +def prepareInstruction(original_xinstr, simulation: Simulation) -> int: + """ + Prepares an instruction for scheduling. + + Parameters: + original_xinstr: The original instruction to prepare. + simulation (Simulation): The current simulation context. + + Returns: + tuple: A tuple containing the ready value (int) and the instruction or None. + + Raises: + RuntimeError: If a variable is not in the suggested bank. + """ + # Schedules the specified instruction into the current bundle of xinsts. + retval = 1 # Tracks whether we are valid for scheduling in current bundle + retval_instr = original_xinstr + + # Check sources + expanded_dests = original_xinstr.dests[:] # Create a copy + if retval == 1: + for idx, src_var in enumerate(original_xinstr.sources): + if idx == 2 and isinstance(original_xinstr, (xinst.NTT, xinst.iNTT)): + # Special case for xntt: twiddles for stage 0 are ignored + if original_xinstr.stage == 0: + expanded_dests.append(src_var) + continue # Next source, but this should end the for-loop + + if retval != 1: + break + + b_generated_keygen_var = False # Flag to track whether this source variable is a generated keygen this time + # Make sure destination variables are on a register + if isinstance(src_var, Variable): + if src_var.name in simulation.live_outs: + # Stop preparing and move instruction to next bundle: one of its variables is marked for eviction + retval = 0 + else: + simulation.addLiveVar(src_var.name, original_xinstr) # Add variable as live + if not src_var.register: + # Needs to start at bank 0 + + b_generated_keygen_var = not simulation.mem_model.isVarInMem(src_var.name) and src_var.name in simulation.mem_model.keygen_variables + + if not b_generated_keygen_var: + # Variable is not keygen or it has already been generated + + # Load into SPAD + if src_var.spad_address < 0: + assert src_var.name not in simulation.mem_model.store_buffer, f'Attempting to load from HBM: "{src_var.name}"; already in transit in SPAD store buffer.' + if not loadVariableHBMToSPAD(original_xinstr, src_var, simulation): + # Could not find location in SPAD, move to next bundle + retval = 0 + + if retval != 0: + retval, new_instr_or_reg = findRegister(original_xinstr, 0, simulation, override_replacement_policy="") # No replacement policy for bank 0 + # retval == 1 => register good to go + # retval == 2 => xstore needed for eviction + if retval == 1: + # Register ready, load from SPAD + assert(new_instr_or_reg.bank.bank_index == 0) + + if b_generated_keygen_var: + # This is a keygen variable that has not been generated + keygen_retval = simulation.generateKeyMaterial(original_xinstr.id[0], src_var, new_instr_or_reg) + if keygen_retval == 2: + # Could not generate key material this bundle + retval = 3 + else: + # Generate instructions to load variable from SPAD into bank 0 + last_mload_access = simulation.mem_model.spad.getAccessTracking(src_var.spad_address).last_mload[1] + if last_mload_access: + # Need to sync to MInst + csyncm = cinst.CSyncm(original_xinstr.id[0], last_mload_access) + csyncm.schedule(simulation.current_cycle, len(simulation.cinsts) + 1) + simulation.cinsts.append(csyncm) + cload = cinst.CLoad(original_xinstr.id[0], new_instr_or_reg, [src_var], simulation.mem_model, comment="dep id: {}".format(original_xinstr.id)) + cload.schedule(simulation.current_cycle, len(simulation.cinsts) + 1) + simulation.cinsts.append(cload) + if retval == 2: + # Register needs eviction + assert isinstance(new_instr_or_reg, xinst.XStore) + retval_instr = new_instr_or_reg + elif retval == 3: + # Could not generate key material this bundle + assert b_generated_keygen_var, f"Variable {src_var.name} is not keygen" + retval = 2 + + if retval == 1: + if src_var.register.bank.bank_index == 0: + # Already in bank 0, so, bring to correct bank + retval, new_instr_or_reg = findRegister(original_xinstr, src_var.suggested_bank, simulation) + # retval == 1 => register good to go + # retval == 2 => xstore needed for eviction + if retval == 1: + # Generate instruction to move variable from bank 0 to its suggested bank + new_instr_or_reg.allocateVariable(simulation.bundle_dummy_var) + xmove = xinst.Move(original_xinstr.id[0], new_instr_or_reg, [src_var], dummy_var=simulation.bundle_dummy_var) + if xmove.cycle_ready.bundle < simulation.current_cycle.bundle: + # Correct cycle ready's bundle + xmove.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle, cycle=0) + if hasBankWriteConflict(xmove, simulation): + xmove.cycle_ready = CycleType(bundle=xmove.cycle_ready.bundle, cycle=xmove.cycle_ready.cycle + 1) + if not scheduleXNOP(xmove, 1, simulation): + retval = 0 + if retval != 0: + src_var.accessed_by_xinsts = [Variable.AccessElement(0, xmove.id)] + src_var.accessed_by_xinsts + simulation.addDependency(xmove, original_xinstr) + new_instr_or_reg = xmove + retval = 2 # Need xmove + + if retval == 2: + # XInsts needed to prepare variable + + # Moves should always be able to schedule at this point + assert isinstance(new_instr_or_reg, (xinst.Move, xinst.XStore)) + retval_instr = new_instr_or_reg + + if retval == 1: + if src_var.register.bank.bank_index != src_var.suggested_bank: + raise RuntimeError('Variable `{}` is in register `{}`, which is not in suggested bank {}.'.format(src_var.name, src_var.register.name, src_var.suggested_bank)) + if b_generated_keygen_var: + # Mark register as dirty since this variable is keygen and + # does not exist elsewhere: we want to preserve this value + src_var.register_dirty = True + + # Check destinations + if retval == 1: + # Expanded_dests is original_xinstr.dests + any extra variables that need to be here + for dst_var in expanded_dests: + if retval != 1: + break + # Make sure destination variables are on a register + if isinstance(dst_var, Variable): + if dst_var.name in simulation.live_outs: + # Stop preparing and move instruction to next bundle: one of its variables is marked for eviction + retval = 0 + else: + simulation.addLiveVar(dst_var.name, original_xinstr) # Add variable as live + if not dst_var.register: + # Find register for variable: + # This will schedule all the C and M instructions needed to secure that register + retval, new_instr_or_reg = findRegister(original_xinstr, dst_var.suggested_bank, simulation, dest_var=dst_var) + + # retval == 1 => register good to go + # retval == 2 => xstore needed for eviction + if retval == 2: + assert isinstance(new_instr_or_reg, xinst.XStore) + retval_instr = new_instr_or_reg + + if retval == 1: + if dst_var.register.bank.bank_index != dst_var.suggested_bank: + raise RuntimeError('Variable `{}` is in register `{}`, which is not in suggested bank {}.'.format(dst_var.name, dst_var.register.name, dst_var.suggested_bank)) + + assert retval == 0 or (retval_instr is not None and __canScheduleInBundle(retval_instr, simulation)) # We should always be able to schedule preparation instructions + + if retval == 0: + retval_instr = None + elif retval_instr: + assert retval_instr.id in simulation.dependency_graph + if hasBankWriteConflict(retval_instr, simulation): + assert not isinstance(retval_instr, xinst.Move) # Moves must be scheduled immediately + # Write cycle conflict found, so, update found instruction cycle ready + new_cycle_ready = CycleType(bundle=simulation.current_cycle.bundle, cycle=max(retval_instr.cycle_ready.cycle, simulation.current_cycle.cycle) + 1) + retval_instr.cycle_ready = new_cycle_ready + + return retval, retval_instr + +def scheduleASMISAInstructions(dependency_graph: nx.DiGraph, + max_bundle_size: int, + mem_model: MemoryModel, + replacement_policy, + progress_verbose: bool = False) -> (list, list, list, int): + """ + Schedules ASM-ISA instructions based on a dependency graph of XInsts to minimize idle cycles. + + Parameters: + dependency_graph (nx.DiGraph): The dependency directed acyclic graph of XInsts. + max_bundle_size (int): Maximum number of instructions in a bundle. + mem_model (MemoryModel): The memory model used in the simulation. + replacement_policy: The policy used for memory replacement. + progress_verbose (bool): Whether to print progress information. + + Returns: + tuple: A tuple containing lists of xinst, cinst, minst, and the total idle cycles. + """ + simulation = Simulation(dependency_graph, + max_bundle_size, # Max number of instructions in a bundle + mem_model, + replacement_policy, + progress_verbose) + # DEBUG + iter_counter = 0 + pisa_instr_counter = 0 + # ENDDEBUG + + if progress_verbose: + print('Dependency Graph') + print(f' Initial number of dependencies: {simulation.dependency_graph.size()}') + print('Scheduling metadata preparation.') + + simulation.loadMetadata() + + if progress_verbose: + print('Scheduling XInstructions...') + + try: + b_flush_bundle = False + fixed_last_short_bundle = -1 # Tracks last bundle considered short that got fixed + new_bundle = True + while simulation.dependency_graph: # Iterates per instruction to be scheduled + # DEBUG + iter_counter += 1 + if GlobalConfig.debugVerbose: + if iter_counter % int(GlobalConfig.debugVerbose) == 0: + print(iter_counter) + # ENDDEBUG + + # Check if current bundle needs to be flushed + if b_flush_bundle: + simulation.flushBundle() + b_flush_bundle = False # Bundle flushed + new_bundle = True + + # Check if we need to fetch more bundles + if new_bundle and len(simulation.xinsts) % simulation.max_bundles_per_xinstfetch == 0: + if progress_verbose: + pct = int(simulation.scheduled_xinsts_count * 100 / simulation.total_instructions) + print("{}% - {}/{}".format(pct, + simulation.scheduled_xinsts_count, + simulation.total_instructions)) + # Handle xinstfetch + xinstfetch = cinst.XInstFetch(len(simulation.xinstfetch_cinsts_buffer), + simulation.xinstfetch_xq_addr, + simulation.xinstfetch_hbm_addr) + xinstfetch.schedule(simulation.current_cycle, len(simulation.xinstfetch_cinsts_buffer) + 1) + simulation.xinstfetch_cinsts_buffer.append(xinstfetch) + + simulation.xinstfetch_xq_addr = (simulation.xinstfetch_xq_addr + 1) % constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS + simulation.xinstfetch_hbm_addr += 1 + # Check if we reached end of XInst queue + if simulation.xinstfetch_xq_addr <= 0: + if GlobalConfig.useXInstFetch: + if progress_verbose: + print("XInst queue filled: wrapping around...") + # Flush buffered xinstfetches to cinst + simulation.cinsts = simulation.cinsts[:simulation.xinstfetch_location_idx_in_cinsts] \ + + simulation.xinstfetch_cinsts_buffer \ + + simulation.cinsts[simulation.xinstfetch_location_idx_in_cinsts:] + # Point to next location to insert xinstfetches + simulation.xinstfetch_location_idx_in_cinsts = len(simulation.cinsts) + simulation.xinstfetch_cinsts_buffer = [] # Buffer flushed, start new + + new_bundle = False + + # Remove any write cycles that have passed + simulation.cleanupPendingWriteCycles() + + while True: # do/while + if simulation.topo_start_idx < len(simulation.full_topo_sort) \ + and len(simulation.topo_sort) < Simulation.MIN_INSTRUCTIONS_IN_TOPO_SORT: + if len(simulation.priority_queue) < Simulation.MIN_INSTRUCTIONS_IN_TOPO_SORT: + simulation.topo_sort += simulation.full_topo_sort[simulation.topo_start_idx:simulation.topo_start_idx + Simulation.INSTRUCTION_WINDOW_SIZE] + simulation.topo_start_idx += Simulation.INSTRUCTION_WINDOW_SIZE + simulation.b_topo_sort_changed = True # Added to topo window + + assert len(simulation.priority_queue) > 0 or len(simulation.topo_sort) > 0, 'Possible infinite loop detected.' + + # Try to exhaust the priority queue first: + # These may introduce some inefficiency to the schedule, but avoids + # memory thrashing when new instructions become available from the topo sort + xinstr = simulation.findNextInstructionToSchedule() + fill_pq = not xinstr or xinstr.cycle_ready > simulation.current_cycle + if fill_pq: + if simulation.b_topo_sort_changed or simulation.b_dependency_graph_changed: + # Extract all the instructions that can be executed without dependencies + # and merge to current instructions that can be executed without dependencies + last_idx = -1 + for idx, instr_id in enumerate(simulation.topo_sort): + if instr_id in simulation.set_extracted_xinstrs: + last_idx = idx # We want to remove repeated instructions + else: + assert instr_id in simulation.dependency_graph + if simulation.dependency_graph.in_degree(instr_id) > 0: + # Found first instruction with dependencies + last_idx = idx - 1 + break + instr = simulation.dependency_graph.nodes[instr_id]['instruction'] + simulation.priority_queue_push(instr) + simulation.b_priority_queue_changed = True + + # Remove all instructions that got queued for scheduling + if last_idx >= 0: + simulation.topo_sort = simulation.topo_sort[last_idx + 1:] + if xinstr: + # Next instruction to schedule may have changed after pulling from topo sort + simulation.priority_queue_push(xinstr) + xinstr = None + + # Graph and topo sort have been updated + simulation.b_topo_sort_changed = False + simulation.b_dependency_graph_changed = False + + if xinstr or len(simulation.priority_queue) > 0: # End do/while loop + # There must be, at least one instruction to schedule at this point + # (if condition was true), + # else, attempt to refill topo sort (restart the top of do/while loop) + break # There is, at least, one instruction to schedule + + assert(len(simulation.xinsts_bundle) < simulation.max_bundle_size) # We should have space in current bundle for an xinstruction + + # Find next xinstruction to schedule + if not xinstr: + assert(simulation.priority_queue) + xinstr = simulation.findNextInstructionToSchedule() + if not xinstr: + # No instruction left to schedule this bundle + + # If this bundle is too short, attempt to bring instructions from + # later bundles to schedule now, if possible + if len(simulation.xinsts_bundle) <= simulation.BUNDLE_INSTRUCTION_MIN_LIMIT: + # Do not fix two short bundles in a row + if fixed_last_short_bundle + 1 < simulation.current_cycle.bundle: + # Do not fix if bank 0 is full + b_bundle_needs_fix = any(reg.contained_variable is None for reg in simulation.mem_model.register_banks[0]) + + if b_bundle_needs_fix: + # DEBUG + if GlobalConfig.debugVerbose: + print(f'---- Fixing short bundle {simulation.current_cycle.bundle}') + # ENDDEBUG + + # Flush register banks and attempt to schedule again + for bank_idx in range(1, len(simulation.mem_model.register_banks)): + mem_utilities.flushRegisterBank(simulation.mem_model.register_banks[bank_idx], + simulation.current_cycle, + simulation.replacement_policy, + simulation.live_vars, + pct=0.5) + # Attempt to schedule instructions slated for next bundle in this bundle + tmp_set = set() + for _, xinstr in simulation.priority_queue: + if xinstr.cycle_ready.bundle == simulation.current_cycle.bundle + 1: + if xinstr.cycle_ready.cycle <= 1: + xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle, + cycle=xinstr.cycle_ready.cycle) + tmp_set.add(xinstr) + for xinstr in tmp_set: + simulation.priority_queue_push(xinstr) + xinstr = simulation.findNextInstructionToSchedule() + fixed_last_short_bundle = simulation.current_cycle.bundle + + # Flush bundle anyway if no instruction was found after fixing + b_flush_bundle = xinstr is None + + if not b_flush_bundle: + assert(xinstr is not None) # Only None if priority queue is empty + + # Attempt to schedule xinstruction + + # Scheduling logic: + # - Block cstore locations in SPAD with dummy vars: + # * Add corresponding xstores to priority_queue and dependency graph with + # current xinstruction dependent on them (first xstore should replace current xinstruction to schedule). + # * If xmoves(target_bank, bank0) are required, they must be scheduled immediately. + # * Remove dependent xinstruction from priority_queue (make sure it is next in the topo_sort). + # - All other cinsts and minsts before the bundle should be scheduled. + # - Schedule all cstores after the bundle ifetch is scheduled (SPAD locations should be available because + # we blocked them in first step) (check that cstores can correctly allocate in SPAD with dummy var). + # - Add all input and output variables to live_ins + # - If xinstruction comes back from topo_sort it should not have pending dependencies, then schedule it. + # * Add all input and output variables to live_ins_used. + ############################################### + + prep_counter = 0 + original_xinstr = xinstr + while xinstr is not None: + # All xinstr at this point should be ready for current bundle + + if GlobalConfig.debugVerbose: + if iter_counter % int(GlobalConfig.debugVerbose) == 0: + print('prep_counter', prep_counter) + xinstr_prepped, xinstr = prepareInstruction(original_xinstr, simulation) + + if xinstr_prepped == 0: + assert xinstr is None + # Failed to prepare instruction in this bundle, leave it for next bundle + original_xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle + 1, + cycle=0) + # Add back to priority queue + simulation.priority_queue_push(original_xinstr) + elif xinstr != original_xinstr: + # This is a preparation instruction + prep_counter += 1 + + if xinstr: + assert xinstr.id in simulation.dependency_graph + if simulation.dependency_graph.in_degree(xinstr.id) > 0: + # Instruction to schedule has new dependencies: + # This occurs if, while preparing the variables for the instruction, + # new dependencies were added. + + assert xinstr == original_xinstr + + xinstr = None + + # Ready to schedule xinstruction + # Check if xinstruction is cycle ready for scheduling + elif xinstr.cycle_ready > simulation.current_cycle: + + if prep_counter > 0: # Instructions were added to prep original + if original_xinstr == xinstr: + assert (xinstr_prepped == 1) + # Original instruction prepped in this group, but not ready to schedule yet: + # Put it back in the priority queue during schedule update phase + else: + # Xinstr is not the original, but one needed to prepare the original + + assert not isinstance(xinstr, xinst.Move), f'xinstr = {repr(xinstr)} \ncycle = {simulation.current_cycle}; iter = {iter_counter}' + + # Cycle for xinstr is not ready yet, so, + # put it back in the correct place in the simulation pipeline + assert xinstr.id in simulation.dependency_graph \ + and simulation.dependency_graph.in_degree(xinstr.id) <= 0 + simulation.addXInstrBackIntoPipeline(xinstr) + + # This will cause the schedule update phase below to put original instruction + # back in the correct place in the simulation pipeline (pq or topo sort) + xinstr = None + + if xinstr: + assert prep_counter == 0 + assert original_xinstr == xinstr + + # Nop required + idle_cycles_required = xinstr.cycle_ready.cycle - simulation.current_cycle.cycle + if scheduleXNOP(xinstr, + idle_cycles_required, + simulation): + simulation.total_idle_cycles += idle_cycles_required + else: + # Could not schedule required NOP in this bundle: + # Leave xinstruction for next bundle + xinstr.cycle_ready = CycleType(bundle=simulation.current_cycle.bundle + 1, + cycle=1) + # Add back to pipeline during schedule update phase + xinstr = None + + if xinstr: + # We are still valid for scheduling + + # At this point, xinstruction should be in ready cycle + assert(__canScheduleInBundle(xinstr, simulation, padding=0)) + assert(simulation.current_cycle >= xinstr.cycle_ready) + # Simulate schedule of xinstruction + simulation.current_cycle += xinstr.schedule(simulation.current_cycle, len(simulation.xinsts_bundle) + 1) + + # Mark the used lives + xinstr_var_names = set(v.name for v in xinstr.sources + xinstr.dests \ + if isinstance(v, Variable) and not isinstance(v, DummyVariable)) + if isinstance(xinstr, xinst.XStore): + simulation.live_outs.update(xinstr_var_names) + for var_name in xinstr_var_names: + simulation.addUsedVar(var_name, xinstr) + + # Schedule update phase + if xinstr: + # XInstruction scheduled: update remaining schedule + simulation.set_extracted_xinstrs.add(xinstr.id) + b_flush_bundle = simulation.updateSchedule(xinstr) + if original_xinstr == xinstr: + pisa_instr_counter += 1 + if GlobalConfig.debugVerbose: + if iter_counter % int(GlobalConfig.debugVerbose) == 0: + print(f'P-ISA scheduled: {pisa_instr_counter}') + + # Check for completed outputs to flush + for variable in original_xinstr.dests: + # This assertion may be broken if move instructions end up back in the topo sort + assert(variable.name not in simulation.mem_model.store_buffer \ + or isinstance(original_xinstr, xinst.XStore)) + if variable.name in simulation.mem_model.output_variables \ + and not variable.accessed_by_xinsts \ + and variable.name not in simulation.mem_model.store_buffer: + # Variable is an output variable + # and it is no longer needed + # and it is not in-flight to be stored already + if not simulation.flushOutputVariableFromRegister(variable): + break # Continue next bundle + + # Terminate loop + xinstr = None + elif b_flush_bundle: + # Add back to priority queue if we haven't scheduled original yet + # and bundle needs to be flushed + simulation.addXInstrBackIntoPipeline(original_xinstr) + # Terminate loop + xinstr = None + elif simulation.priority_queue.find(simulation.current_cycle): + # Immediate instruction ready: stop preparing current + # Add back to pipeline if we haven't scheduled original yet + simulation.addXInstrBackIntoPipeline(original_xinstr) + # Terminate loop + xinstr = None + + else: # Xinstr was None + # Put original instruction back in the correct place of the simulation pipeline + simulation.addXInstrBackIntoPipeline(original_xinstr) + + if not simulation.dependency_graph: + # Completed schedule: store output variables still in registers + last_xinstr = simulation.last_xinstr + if not last_xinstr: + last_xinstr = original_xinstr + for output_var_name in simulation.mem_model.output_variables: + variable = simulation.mem_model.variables[output_var_name] + assert(not variable.accessed_by_xinsts) # Variable should not be accessed any more + if not simulation.flushOutputVariableFromRegister(variable): + break # Continue next bundle + + # Next cycle starts + + # Completed scheduling - first pass + + # Flush last bundle + if len(simulation.xinsts_bundle) > 0: + simulation.flushBundle() + + # Flush buffered xinstfetches to cinst + if GlobalConfig.useXInstFetch: + if len(simulation.xinstfetch_cinsts_buffer) > 0: + simulation.cinsts = simulation.cinsts[:simulation.xinstfetch_location_idx_in_cinsts] \ + + simulation.xinstfetch_cinsts_buffer \ + + simulation.cinsts[simulation.xinstfetch_location_idx_in_cinsts:] + + # TODO: + ################################# + warnings.warn("Rework xinstfetch logic to stream as XInsts are consumed instead of blindly placing them.") + + # End the CInst queue + + # Wait for last instruction in MInstQ to complete + if len(simulation.minsts) > 0: + last_csyncm = cinst.CSyncm(simulation.minsts[-1].id[0], simulation.minsts[-1]) + last_csyncm.schedule(simulation.current_cycle, len(simulation.cinsts) + 1) + simulation.cinsts.append(last_csyncm) + + cexit = cinst.CExit(len(simulation.cinsts)) + cexit.schedule(simulation.current_cycle, len(simulation.cinsts) + 1) + simulation.cinsts.append(cexit) + + # Rule: last instruction in MInstQ must be a sync pointing to cexit + 1 + last_msyncc = minst.MSyncc(cexit.id[0], cexit, comment='terminating MInstQ') + last_msyncc.schedule(simulation.current_cycle, len(simulation.minsts) + 1) + simulation.minsts.append(last_msyncc) + + # Completed scheduling - second pass + + # Corrects the sync instructions to point to correct instruction number + simulation.updateQueuesSyncsPass2() + + if progress_verbose: + print("100% - {0}/{0}".format(simulation.total_instructions)) + + except KeyboardInterrupt as ex: + if GlobalConfig.debugVerbose: + cnt = 0 + while cnt < 10 and simulation.priority_queue: + _, xinstr = simulation.priority_queue.pop() + print('Cycle ready', xinstr.cycle_ready) + print(repr(xinstr)) + cnt += 1 + if len(simulation.priority_queue) > 10: + print('...') + print('priority_queue', len(simulation.priority_queue)) + print('topo_sort', len(simulation.topo_sort)) + print('current cycle', simulation.current_cycle) + simulation.mem_model.dump() + + import traceback + traceback.print_exc() + print(ex) + else: + raise + + return simulation.minsts, simulation.cinsts, simulation.xinsts, simulation.total_idle_cycles \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py b/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py new file mode 100644 index 00000000..b1cf68b8 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py @@ -0,0 +1,290 @@ +import networkx as nx + +from assembler.common.constants import Constants +from assembler.instructions import xinst +from assembler.instructions.xinst.xinstruction import XInstruction +from assembler.instructions.xinst import parse_xntt +from assembler.memory_model import MemoryModel +from assembler.memory_model import variable + +def __dependencyGraphForVars(insts_list: list) -> (nx.Graph, set, set): + """ + Given the listing of instructions, this method returns the dependency graph + for the variables in the listing and the sets of destination and source variables. + + Parameters: + insts_list (list): List of corresponding pre-processed P-ISA instructions containing the variables + to process. + + Returns: + tuple: A tuple containing: + - nx.Graph: Dependency graph for the variables in the listing. + Nodes: variable names. + Edges: dependencies among variables. + Dependencies: + - All variables being read from at the same time in any one instruction + depend on each other because it is forbidden to read more than once from + the same register bank in the same instruction. + - All variables being written to at the same time in any one instruction + depend on each other because it is forbidden to write more than once to + the same register bank in one instruction. + - set: Set of all variables (name) that are destinations in the input `insts_list`. + - set: Set of all variables (name) that are sources in the input `insts_list`. + """ + retval = nx.Graph() + all_dests_vars = set() + all_sources_vars = set() + + for inst in insts_list: + extra_sources = [] + for idx, v in enumerate(inst.dests): + all_dests_vars.add(v.name) + if v.name not in retval: + retval.add_node(v.name) + for v_i in range(idx + 1, len(inst.dests)): + v_next = inst.dests[v_i] + if v.name == v_next.name: + raise RuntimeError(f"Cannot write to the same variable in the same instruction more than once: {inst.toPISAFormat()}") + if not retval.has_edge(v.name, v_next.name): + retval.add_edge(v.name, v_next.name) + # Mac deps already handled in the Mac instructions themselves + # if isinstance(inst, (xinst.Mac, xinst.Maci)): + # extra_sources.append(v) + + inst_all_sources = extra_sources + inst.sources + for idx, v in enumerate(inst_all_sources): + all_sources_vars.add(v.name) + if v.name not in retval: + retval.add_node(v.name) + for v_i in range(idx + 1, len(inst_all_sources)): + v_next = inst_all_sources[v_i] + if v.name == v_next.name: + raise RuntimeError(f"Cannot read from the same variable in the same instruction more than once: {inst.toPISAFormat()}") + if not retval.has_edge(v.name, v_next.name): + retval.add_edge(v.name, v_next.name) + + return retval, all_dests_vars, all_sources_vars + +def injectVariableCopy(mem_model: MemoryModel, + insts_list: list, + instruction_idx: int, + var_name: str) -> int: + """ + Injects a copy of a variable into the instruction list at the specified index. + + Parameters: + mem_model (MemoryModel): The memory model containing the variables. + insts_list (list): The list of instructions. + instruction_idx (int): The index at which to inject the copy. + var_name (str): The name of the variable to copy. + + Returns: + int: Index for the instruction in the list after injection. + + Raises: + IndexError: If the instruction index is out of range. + """ + if instruction_idx < 0 or instruction_idx >= len(insts_list): + raise IndexError(f'instruction_idx: Expected index in range [0, {len(insts_list)}), but received {instruction_idx}.') + last_instruction: XInstruction = insts_list[instruction_idx] + last_instruction_sources = last_instruction.sources[:] + for idx, variable in enumerate(last_instruction_sources): + if variable.name == var_name: + # Find next available temp var name + temp_name = mem_model.findUniqueVarName() + temp_var = mem_model.retrieveVarAdd(temp_name, -1) + # Copy source var into temp + copy_xinst = xinst.Copy(id = last_instruction.id[1], + N = 0, + dst = [ temp_var ], + src = [ variable ], + comment='Injected copy for bank reduction.') + insts_list.insert(instruction_idx, copy_xinst) + # Replace src by temp + last_instruction.sources[idx] = temp_var + instruction_idx += 1 + + return instruction_idx + +def reduceVarDepsByVar(mem_model: MemoryModel, + insts_list: list, + var_name: str): + """ + Reduces variable dependencies by injecting copies of the specified variable. + + Parameters: + mem_model (MemoryModel): The memory model containing the variables. + insts_list (list): The list of instructions. + var_name (str): The name of the variable to reduce dependencies for. + """ + last_pos = 0 + last_instruction = None + # Find all instructions* with specified variable and make it a copy + # * care with mac instructions + while last_pos < len(insts_list): + if var_name in (v.name for v in insts_list[last_pos].sources): + last_instruction: XInstruction = insts_list[last_pos] + if isinstance(last_instruction, (xinst.Mac, xinst.Maci)): + # Check if the conflicting variable is the accumulator + if last_instruction.sources[0].name == var_name: + # Turn all other variables into copies + for variable in last_instruction.sources[1:]: + last_pos = injectVariableCopy(mem_model, insts_list, last_pos, variable.name) + assert last_instruction == insts_list[last_pos] + last_instruction = None # avoid further processing of instruction + last_pos += 1 + continue + # If conflict variable was not the accumulator, proceed to change the other variables + # Skip copy, twxntt and xrshuffle + if not isinstance(last_instruction, (xinst.twiNTT, + xinst.twiNTT, + xinst.irShuffle, + xinst.rShuffle, + xinst.Copy)): + # Break up indicated variable in sources into a temp copy + last_pos = injectVariableCopy(mem_model, insts_list, last_pos, var_name) + assert last_instruction == insts_list[last_pos] + + last_pos += 1 + +def assignRegisterBanksToVars(mem_model: MemoryModel, + insts_list: list, + use_bank0: bool, + verbose = False) -> str: + """ + Assigns register banks to variables using vertex coloring graph algorithm. + + The variables contained in the MemoryModel object will be modified to reflect + their suggested bank. + + Parameters: + mem_model (MemoryModel): The MemoryModel object, where all variables are kept. Variables detected that are + not already in the MemoryModel collection of variables will be added automatically. + The variables contained in the MemoryModel object will be modified to reflect + their suggested bank. + insts_list (list): List of corresponding pre-processed P-ISA instructions containing the variables + to process. + use_bank0 (bool): All variables are written into registers in bank 0 from SPAD, while no XInst + should be writing its results in bank 0 to avoid write-write conflicts. + If `True`, variables that can remain in bank 0 will be kept there (variables + that are never written to). + If `False`, bank 0 will not be assigned to any variable. Resulting ASM instructions + should add corresponding `move` instructions to move variables from bank 0 to + correct bank. + verbose (bool, optional): If True, prints verbose output. Defaults to False. + + Raises: + ValueError: Thrown for these cases: + - Invalid input values for parameters. + - Variables in listing cannot be successfully assigned to banks (the number + of banks is insufficient to accommodate the listing given the rules). + This should not happen, as long as rules are respected, and all instructions + have, at most, 3 inputs, and at most, 3 outputs. + + Returns: + str: The unique dummy variable name that was not used by the collection of variables in the + instruction listing. + """ + reduced_vars = set() + needs_reduction = True + pass_counter = 0 + while needs_reduction: + pass_counter += 1 + if verbose: + print(f"Pass {pass_counter}") + # Extract the dependency graph for variables + dep_graph_vars, dest_names, source_names = __dependencyGraphForVars(insts_list) + only_sources = source_names - dest_names # Find which variables are ever only used as sources + color_dict = nx.greedy_color(dep_graph_vars) # Do coloring + + needs_reduction = False + for var_name, bank in color_dict.items(): + if bank > 2: + if var_name in reduced_vars: + raise RuntimeError(('Found invalid bank {} > 2 for variable {} already reduced.').format(bank, + var_name)) + # DEBUG print + if verbose: + print('Variable {} ({}) requires reduction.'.format(var_name, bank)) + reduceVarDepsByVar(mem_model, insts_list, var_name) + reduced_vars.add(var_name) # Track reduced variable + needs_reduction = True + + # Assign banks based on coloring algo results + for v in mem_model.variables.values(): + if not mem_model.isMetaVar(v.name): # Skip meta variables + assert(v.name in color_dict) + bank = color_dict[v.name] + assert bank < 3, f'{v.name}, {bank}' + # If requested, keep vars used only as sources in bank 0 + v.suggested_bank = bank + (0 if use_bank0 and (v.name in only_sources) else 1) + + retval: str = mem_model.findUniqueVarName() + + return retval + +def preprocessPISAKernelListing(mem_model: MemoryModel, + line_iter, + progress_verbose: bool = False) -> list: + """ + Parses a P-ISA kernel listing, given as an iterator for strings, where each is + a line representing a P-ISA instruction. + + Generates twiddle factors and bit shuffling for original P-ISA xntt instructions. + + Variables in `mem_model` associated with the output will have assigned banks automatically. + + Parameters: + mem_model (MemoryModel): The MemoryModel object, where all variables are kept. Variables parsed from the + input string will be automatically added to the memory model if they do not already + exist. The represented object may be modified if addition is needed. + line_iter (iterator): Iterator of strings where each is a line of the P-ISA kernel instruction listing. + progress_verbose (bool, optional): Specifies whether to output progress every hundred lines processed to stdout. + Defaults to False. + + Returns: + list: A list of `BaseInstruction`s where each object represents + the a parsed instruction. Some P-ISA instructions (such as ntts) are converted into + a group of XInst that implement the original P-ISA instruction. + + Variables in `mem_model` collection of variables will be modified to reflect + assigned bank in `suggested_bank` attribute. + """ + NTT_KERNEL_GRAMMAR = lambda line: parse_xntt.parseXNTTKernelLine(line, xinst.NTT.OP_NAME_PISA, Constants.TW_GRAMMAR_SEPARATOR) + iNTT_KERNEL_GRAMMAR = lambda line: parse_xntt.parseXNTTKernelLine(line, xinst.iNTT.OP_NAME_PISA, Constants.TW_GRAMMAR_SEPARATOR) + + retval = [] + + if progress_verbose: + print("0") + num_input_insts = 0 + for line_no, s_line in enumerate(line_iter, 1): + num_input_insts = line_no + if progress_verbose and line_no % 100 == 0: + print(f"{num_input_insts}") + + parsed_insts = None + if not parsed_insts: + parsed_op = NTT_KERNEL_GRAMMAR(s_line) + if not parsed_op: + parsed_op = iNTT_KERNEL_GRAMMAR(s_line) + if parsed_op: + # Instruction is a P-ISA xntt + parsed_insts = parse_xntt.generateXNTT(mem_model, + parsed_op, + new_id = line_no) + if not parsed_insts: + # Instruction is one that is represented by single XInst + inst = xinst.createFromPISALine(mem_model, s_line, line_no) + if inst: + parsed_insts = [ inst ] + + if not parsed_insts: + raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) + + retval += parsed_insts + + if progress_verbose: + print(f"{num_input_insts}") + + return retval \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py b/assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py new file mode 100644 index 00000000..4ee8f65c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/stages/scheduler.py @@ -0,0 +1,440 @@ +import collections +import heapq +import networkx as nx +from typing import NamedTuple + +from . import buildVarAccessListFromTopoSort +from assembler.common.cycle_tracking import PrioritizedPlaceholder, CycleType +from assembler.instructions import xinst, cinst, minst +from assembler.memory_model import MemoryModel +from assembler.memory_model.variable import Variable +from pickle import TRUE + +def __orderKeygenVars(mem_model: MemoryModel) -> list: + """ + Returns the name of the keygen variables in the order they have to be generated. + + Parameters: + mem_model (MemoryModel): Completed memory model corresponding to the specified dependency graph and + input mem info. Used to extract the keygen variable information. + + Raises: + RuntimeError: Detected missing keygen variable in ordering. + + Returns: + list: A list of lists. Each element of the outer list represents a seed. Each element + in the inner list is the name of the keygen variable corresponding to that index + ordering for the corresponding seed. + """ + retval = list([] for _ in range(len(mem_model.meta_keygen_seed_vars))) + for var_name, (seed_idx, key_idx) in mem_model.keygen_variables.items(): + assert seed_idx < len(retval) + if key_idx >= len(retval[seed_idx]): + retval[seed_idx] += ((key_idx - len(retval[seed_idx]) + 1) * [None]) + retval[seed_idx][key_idx] = var_name + # Validate that no key material was skipped + for seed_idx, l in enumerate(retval): + for key_idx, var_name in enumerate(l): + if var_name is None: + raise RuntimeError(f'Detected key material {key_idx} generation skipped for seed {seed_idx}.') + + return retval + +def __findVarInPrevDeps(deps_graph: nx.DiGraph, + instr_id: tuple, + var_name: str, + b_only_sources: bool = False) -> tuple: + """ + Returns the ID for an instruction that uses the specified variable, and is + a dependency for input instruction. + + Parameters: + deps_graph (nx.DiGraph): Completed graph of dependencies among the instructions in the input listing. + instr_id (tuple): ID of instruction for which to find dependency. + var_name (str): Name of the variable that must be present in dependency. + b_only_sources (bool, optional): If True, only source variables are scanned. Otherwise, all variables are + checked when determining if `var_name` is in dependency instruction. + + Returns: + tuple or None: ID of first instruction found which is direct or indirect + dependency of `instr_id` in the dependency graph. The returned instruction + must have `var_name` as one of its variables. If no instruction is found, + returns None. + """ + retval = None + + if instr_id in deps_graph: + checked_instructions = set() # avoids checking same instruction multiple times + dep_instructions = collections.deque() + last_instr = deps_graph.nodes[instr_id]["instruction"] + # Repeat while we have instructions to process and we haven't found what we need + while last_instr is not None and retval is None: + # Add predecessors of last instruction to stack of predecessors + preds = (deps_graph.nodes[i_id]["instruction"] for i_id in deps_graph.predecessors(last_instr.id)) + for x in preds: + if x.id not in checked_instructions: + dep_instructions.append(x) + # Work on next instruction + last_instr = dep_instructions.pop() if len(dep_instructions) > 0 else None + if last_instr is not None: + checked_instructions.add(last_instr.id) + # Check if var_name is present in instruction + sources = set(src_var.name for src_var in last_instr.sources if isinstance(src_var, Variable)) + dests = set(dst_var.name for dst_var in last_instr.sources if not b_only_sources and isinstance(dst_var, Variable)) + if var_name in sources | dests: + # var_name found: return the instruction + retval = last_instr.id + + return retval + +def enforceKeygenOrdering(deps_graph: nx.DiGraph, + mem_model: MemoryModel, + verbose_ostream = None): + """ + Given the dependency graph for instructions and a complete memory model, injects + instructions and dependencies to enforce ordering required for the keygen subsystem. + + For all keygen variables of the same seed, `copy` instructions are injected and + all instructions using the same keygen variable as the `copy` instruction become + dependent of said `copy`. This ensures that the variable is generated before it is + used. Furthermore, `copy` instructions depend on each other based on the ordering + of the keygen variables. This ensures correct ordering of key material generation. + + Parameters: + deps_graph (nx.DiGraph): Completed graph of dependencies among the instructions in the input listing. + Will be changed with the injected instructions and dependencies. + mem_model (MemoryModel): Completed memory model corresponding to the specified dependency graph and + input mem info. Will be changed if new variables are needed during injection. + verbose_ostream: Stream where to print verbose output (object with a `write` method). + If None, no verbose output occurs. + """ + + # this function enforces the following dependency ordering: + # + # copy kg_var_0 -> op kg_var_0 + # -> op kg_var_0 + # -> op kg_var_0 + # ... + # -> copy kg_var_1 -> op kg_var_1 + # -> op kg_var_1 + # -> op kg_var_1 + # ... + # -> copy kg_var_2 -> ... + # + # This ordering ensures that kg_var_X is generated the first time its copy + # instruction is met and all the other uses can then occur with the generated + # value. Then, kg_var_X+1 can also be generated after kg_var_X has been generated + # without needing to wait for all kg_var_X uses to occur (only the first copy). + + ordered_kg_vars = __orderKeygenVars(mem_model) + + if ordered_kg_vars and verbose_ostream: + print("Enforcing keygen ordering", file = verbose_ostream) + + for seed_idx, kg_seed_list in enumerate(ordered_kg_vars): + if verbose_ostream: + print(f"Seed {seed_idx} / {len(ordered_kg_vars)}", file = verbose_ostream) + last_copy_id = None + b_copy_deps_found = False # tracks whether we have correctly added dependencies for the new copy + for key_idx, kg_var_name in enumerate(kg_seed_list): + # Create a copy instruction and make all instructions using this kg var depend on it + src = mem_model.variables[kg_var_name] + # Create temp target variable + dst = mem_model.retrieveVarAdd(mem_model.findUniqueVarName(), src.suggested_bank) + copy_instr = xinst.Copy(0, # id + 0, # N + [ dst ], + [ src ], + comment=f'injected copy to generate keygen var {kg_var_name} (seed = {seed_idx}, key = {key_idx})') + deps_graph.add_node(copy_instr.id, instruction=copy_instr) + # Enforce ordering of copies based on ordering of keygen + if last_copy_id is not None: + # Last copy -> current copy + deps_graph.add_edge(last_copy_id, copy_instr.id) + last_copy_id = copy_instr.id + + for instr_id in deps_graph: + if instr_id != copy_instr.id \ + and kg_var_name in set(src.name for src in deps_graph.nodes[instr_id]['instruction'].sources): + # Found instruction that uses the kg var: + + if not b_copy_deps_found: + # Find out if this instruction does not depend on another + # instruction that uses the same kg var + if __findVarInPrevDeps(deps_graph, instr_id, kg_var_name, b_only_sources=True) is None: + # instr_id does not depend on this kg variable: + # make its dependencies same as the injected copy in order to avoid + # copy being executed before it is needed + for dependency_id in deps_graph.predecessors(instr_id): + # dependency -> copy_instr + deps_graph.add_edge(dependency_id, copy_instr.id) + + b_copy_deps_found = True # found artificial dependencies for copy + + # Make instruction depend on this injected copy + # copy_instr -> instr + deps_graph.add_edge(copy_instr.id, instr_id) + + if ordered_kg_vars and verbose_ostream: + print(f"Seed {len(ordered_kg_vars)} / {len(ordered_kg_vars)}", file = verbose_ostream) + # We should not have introduced any cycles with these modifications + assert nx.is_directed_acyclic_graph(deps_graph) + +def generateInstrDependencyGraph(insts_listing: list, + verbose_ostream = None) -> nx.DiGraph: + """ + Given a pre-processed P-ISA instructions listing, generates a dependency graph + for the instructions based on their inputs and outputs, and any shared HW resources + among instructions. + + Parameters: + insts_listing (list): List of pre-processed P-ISA instructions. + verbose_ostream: Stream where to print verbose output (object with a `write` method). + If None, no verbose output occurs. + + Raises: + nx.NetworkXUnfeasible: Input listing results in a dependency graph that is not a Directed Acyclic Graph. + + Returns: + nx.DiGraph: A Directed Acyclic Graph representing the dependencies among the + instructions in the input listing. + """ + # Uses dynamic programming to track dependencies + + class VarTracking(NamedTuple): + # Used for clarity + last_write: object # last instruction that wrote to this variable + reads_after_last_write: list # all insts that read from this variable after last write + + retval = nx.DiGraph() + + verbose_report_every_x_insts = 1 + if verbose_ostream: + verbose_report_every_x_insts = len(insts_listing) // 10 + if verbose_report_every_x_insts < 1: + verbose_report_every_x_insts = 1 + + # Look up table for already seen variables + vars2insts = {} # dict(var_name, VarTracking ) + for idx, inst in enumerate(insts_listing): + + if verbose_ostream: + if idx % verbose_report_every_x_insts == 0: + print("{}% - {}/{}".format(idx * 100 // len(insts_listing), + idx, + len(insts_listing)), file = verbose_ostream) + + # Add new node + # All instructions are nodes + retval.add_node(inst.id, instruction=inst) + + # Find dependencies: + # prev_inst(x, dst) -> inst(dst, src) + # prev_inst(dst, x) -> inst(dst, src) + # prev_inst(src, x) -> inst(dst, src) + + for variable in inst.dests: + # Add dependencies + if variable.name in vars2insts: + # Check if last read + if vars2insts[variable.name].reads_after_last_write: + # Add deps to all reads after last write + for inst_dep in vars2insts[variable.name].reads_after_last_write: + if inst_dep.id != inst.id: + retval.add_edge(inst_dep.id, inst.id) + else: # Add dep to last write + inst_dep = vars2insts[variable.name].last_write # last instruction that wrote to this variable + if inst_dep and inst_dep.id != inst.id: + retval.add_edge(inst_dep.id, inst.id) + # Record write + vars2insts[variable.name] = VarTracking( inst, [] ) # (last inst that wrote to this, all insts that read from it after last write) + + for variable in inst.sources: + if variable.name in vars2insts: + # Add dependency to last write + inst_dep = vars2insts[variable.name].last_write # last instruction that wrote to this variable + if inst_dep and inst_dep.id != inst.id: + retval.add_edge(inst_dep.id, inst.id) + else: + # First time seeing this var + vars2insts[variable.name] = VarTracking( None, [] ) + # Record read + vars2insts[variable.name].reads_after_last_write.append(inst) + + # Different variants to enforce ordering + + #print('##### DEBUG #####') + ### sequential instructions (no reordering) + #print('***** Sequential *****') + #for idx in range(len(insts_listing) - 1): + # retval.add_edge(insts_listing[idx].id, insts_listing[idx + 1].id) + + ## tw before rshuffle + #print('***** twid before rshuffle *****') + #for idx in range(len(insts_listing) - 1): + # if isinstance(insts_listing[idx], xinst.rShuffle): + # if isinstance(insts_listing[idx + 1], xinst.twNTT): + # print(insts_listing[idx].id) + # retval.add_edge(insts_listing[idx + 1].id, insts_listing[idx].id) + + # rshuffle before tw + #print('***** rshuffle before twid *****') + #for idx in range(len(insts_listing) - 1): + # if isinstance(insts_listing[idx], xinst.rShuffle): + # if isinstance(insts_listing[idx + 1], xinst.twNTT): + # print(insts_listing[idx].id) + # retval.add_edge(insts_listing[idx].id, insts_listing[idx + 1].id) + + # rshuffles ordered + #print('***** Ordered rshuffles *****') + #for idx in range(len(insts_listing) - 1): + # if isinstance(insts_listing[idx], xinst.rShuffle): + # for j in range(len(insts_listing) - idx): + # jdx = j + idx + 1 + # if isinstance(insts_listing[jdx], xinst.rShuffle): + # print(insts_listing[idx].id) + # retval.add_edge(insts_listing[idx].id, insts_listing[jdx].id) + # break + + # twid ordered + #print('***** Ordered twntt *****') + #for idx in range(len(insts_listing) - 1): + # if isinstance(insts_listing[idx], xinst.twNTT): + # for jdx in range(idx + 1, len(insts_listing)): + # if isinstance(insts_listing[jdx], xinst.twNTT): + # print(insts_listing[idx].id) + # retval.add_edge(insts_listing[idx].id, insts_listing[jdx].id) + # break + + # Detect cycles in result + if not nx.is_directed_acyclic_graph(retval): + raise nx.NetworkXUnfeasible('Instruction listing must form a Directed Acyclic Graph dependency.') + + if verbose_ostream: + print("100% - {0}/{0}".format(len(insts_listing)), file = verbose_ostream) + + # retval contains the dependency graph + return retval + +def schedulePISAInstructions(dependency_graph: nx.DiGraph, + progress_verbose: bool = False) -> (list, int, int): + """ + Given the dependency directed acyclic graph of XInsts, returns a schedule + for the corresponding P-ISA instructions, that minimizes idle cycles. + + Parameters: + dependency_graph (nx.DiGraph): The dependency graph of XInsts. + progress_verbose (bool, optional): If True, prints progress information. Defaults to False. + + Returns: + tuple: A tuple containing: + - list: The scheduled instructions. + - int: The total number of idle cycles. + - int: The number of NOPs inserted. + """ + class PrioritizedInstruction(PrioritizedPlaceholder): + def __init__(self, + instruction, + priority_delta = (0, 0)): + super().__init__(priority_delta=priority_delta) + self.__instruction = instruction + + def __repr__(self): + return '<{} (priority = {})>(instruction={}, priority_delta={})'.format(type(self).__name__, + self.priority, + repr(self.instruction), + self.priority_delta) + + @property + def instruction(self): + return self.__instruction + + def _get_priority(self): + return self.instruction.cycle_ready + + retval = [] + topo_sort = buildVarAccessListFromTopoSort(dependency_graph) + dependency_graph = nx.DiGraph(dependency_graph) # make a copy of the incoming graph to avoid modifying input + total_idle_cycles = 0 + num_nops = 0 + set_processed_instrs = set() # track instructions that have been process to avoid encountering them after scheduling + current_cycle = CycleType(bundle = 0, cycle = 1) + p_queue = [] # Sorted list by priority: ready cycle + b_changed = True # Track when there are changes in the priority queue or dependency graph + total_insts = dependency_graph.number_of_nodes() + prev_report_pct = -1 + while dependency_graph: + + if progress_verbose: + pct = int(len(retval) * 100 / total_insts) + if pct > prev_report_pct and pct % 10 == 0: + prev_report_pct = pct + print(f"{pct}% - {len(retval)}/{total_insts}") + if b_changed: # If priority queue or dependency graph have changed since last iteration + + # Extract all the instructions that can be executed without dependencies + # and merge to current instructions that can be executed without dependencies + last_idx = -1 + for idx, instr_id in enumerate(topo_sort): + if instr_id not in set_processed_instrs: + if dependency_graph.in_degree(instr_id) > 0: + # Found first instruction with dependencies + break + instr = dependency_graph.nodes[instr_id]['instruction'] + p_queue.append(PrioritizedInstruction(instr)) + set_processed_instrs.add(instr.id) + last_idx = idx + # Remove all instructions that got queued for scheduling + if last_idx >= 0: + topo_sort = topo_sort[last_idx + 1:] + + # Reorder priority queue since the items' priorities may change after scheduling an instruction + assert(p_queue) + heapq.heapify(p_queue) + + # Schedule next instruction + + # See if there is an immediate instruction we can queue + element_idx = 0 + for idx, p_inst in enumerate(p_queue): + if p_inst.instruction.cycle_ready == current_cycle: + element_idx = idx + break + + # Instruction can be immediate in mid queue or + # the head of the queue if no immediate was found. + instr = p_queue[element_idx].instruction + + if instr.cycle_ready > current_cycle: + # We need nops because next instruction is not ready + num_idle_cycles = instr.cycle_ready.cycle - current_cycle.cycle + total_idle_cycles += num_idle_cycles + # Make new instruction to execute a nop + instr = xinst.Nop(instr.id[0], num_idle_cycles) + num_nops += 1 + b_changed = False # No changes in the queue or graph + + # Do not pop actual instruction from graph or queue since we had to add nops before its scheduling + else: + # Instruction ready: pop instruction from queue and update dependency graph + # (this breaks the heap invariant for p_queue, but we heapify + # on every iteration due to priorities changing based on latency) + p_queue = p_queue[:element_idx] + p_queue[element_idx + 1:] + dependents = list(dependency_graph.neighbors(instr.id)) # find instructions that depend on this instruction + dependency_graph.remove_node(instr.id) # remove from graph to update the in_degree of dependendent instrs + # "move" dependent instrs that have no other dependencies to the top of the topo sort + topo_sort = [ instr_id for instr_id in dependents if dependency_graph.in_degree(instr_id) <= 0 ] + topo_sort + # Do not search the topo sort to actually remove the duplicated instrs because it is O(N) costly: + # set_processed_instrs will take care of skipping them once encountered. + b_changed = True # queue and/or graph changed + + cycle_throughput = instr.schedule(current_cycle, len(retval) + 1) # simulate execution to update cycle ready of dependents + retval.append(instr) + + # Next cycle starts + current_cycle += cycle_throughput + + if progress_verbose: + print(f"100% - {total_insts}/{total_insts}") + + return retval, total_idle_cycles, num_nops \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/atomic_tester.py b/assembler_tools/hec-assembler-tools/atomic_tester.py new file mode 100644 index 00000000..c3ae6929 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/atomic_tester.py @@ -0,0 +1,261 @@ +import argparse +import io +import os +import pathlib +import subprocess +import yaml + +from assembler.common.constants import Constants +from assembler.common.run_config import RunConfig +import he_prep as preproc +import he_as as asm + +# module constants +DEFAULT_OPERATIONS = Constants.OPERATIONS[:6] + +class GenRunConfig(RunConfig): + """ + Maintains the configuration data for the run. + """ + + SCHEMES = [ 'bgv', 'ckks' ] + DEFAULT_SCHEME = SCHEMES[0] + + __initialized = False # specifies whether static members have been initialized + # contains the dictionary of all configuration items supported and their + # default value (or None if no default) + __default_config = {} + + def __init__(self, **kwargs): + """ + Constructs a new GenRunConfig Object from input parameters. + + See base class constructor for more arguments. + + Parameters + ---------- + N: int + Ring dimension: PMD = 2^N. + + min_nrns: int + Minimum number of residuals. + + max_nrns: int + Maximum number of residuals. + + key_nrns: int + Optional number of residuals for relinearization keys. Must be greater than `max_nrns`. + If missing, the `key_nrns` for each P-ISA kernel generated will be set to the kernel + `nrns` (number of residuals) + 1. + + scheme: str + FHE Scheme to use. Must be one of the schemes in `GenRunConfig.SCHEMES`. + Defaults to `GenRunConfig.DEFAULT_SCHEME`. + + op_list: list[str] + Optional list of name of operations to generate. If provided, it must be a non-empty + subset of `Constants.OPERATIONS`. + Defaults to `DEFAULT_OPERATIONS`. + + output_dir: str + Optional directory where to store all intermediate files and final output. + This will be created if it doesn't exists. + Defaults to /lib. + + Raises + ------ + TypeError + A mandatory configuration value was missing. + + ValueError + At least, one of the arguments passed is invalid. + """ + + self.__init_statics() + + super().__init__(**kwargs) + + for config_name, default_value in self.__default_config.items(): + assert(not hasattr(self, config_name)) + setattr(self, config_name, kwargs.get(config_name, default_value)) + if getattr(self, config_name) is None: + raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') + + if self.scheme not in self.SCHEMES: + raise ValueError('Invalid acheme "{}". Expected one of {}'.format(self.scheme, self.SCHEMES)) + + for op in self.op_list: + if op not in Constants.OPERATIONS: + raise ValueError('Invalid operation in input list of ops "{}". Expected one of {}'.format(op, Constants.OPERATIONS)) + + if self.key_nrns > 0: + if self.key_nrns < self.max_nrns: + raise ValueError(('`key_nrns` must be greater than `max_nrns` when present. ' + 'Received {}, but expected greater than {}.').format(self.key_nrns, + self.max_nrns)) + + @classmethod + def __init_statics(cls): + if not cls.__initialized: + cls.__default_config["N"] = None + cls.__default_config["min_nrns"] = None + cls.__default_config["max_nrns"] = None + cls.__default_config["key_nrns"] = 0 + cls.__default_config["scheme"] = cls.DEFAULT_SCHEME + cls.__default_config["output_dir"] = os.path.join(pathlib.Path.cwd(), "lib") + cls.__default_config["op_list"] = DEFAULT_OPERATIONS + + cls.__initialized = True + + def __str__(self): + """ + Returns a string representation of the configuration. + """ + self_dict = self.as_dict() + with io.StringIO() as retval_f: + for key, value in self_dict.items(): + print("{}: {}".format(key, value), file=retval_f) + retval = retval_f.getvalue() + return retval + + def as_dict(self) -> dict: + retval = super().as_dict() + tmp_self_dict = vars(self) + retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) + return retval + +def main(config: GenRunConfig, + b_verbose: bool = False): + + lib_dir = config.output_dir + + # create output directory to store outputs (if it doesn't already exist) + pathlib.Path(lib_dir).mkdir(exist_ok = True, parents=True) + + # point to the HERACLES-SEAL-isa-mapping repo + home_dir = pathlib.Path.home() + mapping_dir = os.getenv("HERACLES_MAPPING_PATH", os.path.join(home_dir, "HERACLES/HERACLES-SEAL-isa-mapping")) + # command to run the mapping script to generate operations kernels for our input + #generate_cmd = 'python3 "{}"'.format(os.path.join(mapping_dir, "kernels/run_he_op.py")) + generate_cmd = ['python3', '{}'.format(os.path.join(mapping_dir, "kernels/run_he_op.py"))] + + assert config.N < 1024 + assert config.min_nrns > 1 + assert (config.key_nrns == 0 or config.key_nrns > config.max_nrns) + assert(all(op in Constants.OPERATIONS for op in config.op_list)) + + pdegree = 2 ** config.N + regenerate_string = "" + for op in config.op_list: + for rn_el in range(config.min_nrns, config.max_nrns + 1): + key_nrns = config.key_nrns if config.key_nrns > 0 else rn_el + 1 + regenerate_string = "" + print("{} {} {} {} {}".format(config.scheme, op, config.N, rn_el, key_nrns)) + output_prefix = "t.{}.{}.{}".format(rn_el, op, config.N) + basef = os.path.join(lib_dir, output_prefix) + generate_cmdln = generate_cmd + [ str(x) for x in (config.scheme, op, pdegree, rn_el, key_nrns) ] + + csvfile = basef + ".csv" + memfile = basef + ".tw.mem" + + # call the external script to generate the kernel for this op + print(' '.join(generate_cmdln)) + run_result = subprocess.run(generate_cmdln, stdout=subprocess.PIPE) + if run_result.returncode != 0: + raise RuntimeError('Exit code: {}. Failure to complete kernel generation successfully.'.format(run_result.returncode)) + + # interpret output into correct kernel and mem files + merged_output = run_result.stdout.decode().splitlines() + with open(csvfile, 'w') as fout_csv: + with open(memfile, 'w') as fout_mem: + for s_line in merged_output: + if s_line: + if s_line.startswith('dload') \ + or s_line.startswith('dstore'): + print(s_line, file=fout_mem) + else: + print(s_line, file=fout_csv) + + # pre-process kernel + + # generate twiddle factors for this kernel + basef = basef + ".tw" #use the newly generated twiddle file + print() + print("Preprocessing") + preproc.main(basef + ".csv", + csvfile, + b_verbose=b_verbose) + + # prepare config for assembler + asm_config = asm.AssemblerRunConfig(input_file=basef + ".csv", + input_mem_file=memfile, + output_prefix=output_prefix, + **config.as_dict()) # convert config to a dictionary and expand it as arguments + print() + print("Assembling") + # run the assembler for this file + asm.main(asm_config, verbose=b_verbose) + + print(f'Completed "{output_prefix}"') + print() + +def parse_args(): + parser = argparse.ArgumentParser(description=("Generates a collection of HE operations based on input configuration."), + epilog=("To use, users should dump a default configuration file. Edit the file to " + "match the needs for the run, then execute the program with the modified " + "configuration. Note that dumping on top of an existing file will overwrite " + "its contents.")) + parser.add_argument("config_file", help=("YAML configuration file.")) + parser.add_argument("--dump", action="store_true", + help=("A default configuration will be writen into the file specified by `config_file`. " + "If the file already exists, it will be overwriten.")) + parser.add_argument("-v", "--verbose", dest="verbose", action="store_true", + help="If enabled, extra information and progress reports are printed to stdout.") + args = parser.parse_args() + + return args + +def readYAMLConfig(input_filename: str): + """ + Reads in a YAML file and returns a GenRunConfig object parsed from it. + """ + retval_dict = {} + with open(input_filename, "r") as infile: + retval_dict = yaml.safe_load(infile) + + return GenRunConfig(**retval_dict) + +def writeYAMLConfig(output_filename: str, config: GenRunConfig): + """ + Outputs the specified configuration to a YAML file. + """ + with open(output_filename, "w") as outfile: + yaml.dump(vars(config), outfile, sort_keys=False) + +if __name__ == "__main__": + module_name = os.path.basename(__file__) + print(module_name) + print() + + args = parse_args() + + if args.dump: + print("Writing default configuration to") + print(" ", args.config_file) + default_config = GenRunConfig(N=15, min_nrns=2, max_nrns=18) + writeYAMLConfig(args.config_file, default_config) + else: + print("Loading configuration file:") + print(" ", args.config_file) + config = readYAMLConfig(args.config_file) + print() + print("Gen Run Configuration") + print("=====================") + print(config) + print("=====================") + print() + main(config, + b_verbose=args.verbose) + + print() + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/config/isa_spec.json b/assembler_tools/hec-assembler-tools/config/isa_spec.json new file mode 100644 index 00000000..3ea0b212 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/config/isa_spec.json @@ -0,0 +1,225 @@ +{ + "isa_spec": { + "xinst": { + "add": { + "num_dests": 1, + "num_sources": 2, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 6 + }, + "copy": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 5 + }, + "exit": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 1 + }, + "intt": { + "num_dests": 2, + "num_sources": 3, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 9 + }, + "irshuffle": { + "num_dests": 2, + "num_sources": 2, + "default_throughput": 1, + "default_latency": 23, + "num_tokens": 7, + "special_latency_max": 17, + "special_latency_increment": 5 + }, + "mac": { + "num_dests": 1, + "num_sources": 3, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 6 + }, + "maci": { + "num_dests": 1, + "num_sources": 2, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 6 + }, + "move": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 6 + }, + "mul": { + "num_dests": 1, + "num_sources": 2, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 6 + }, + "muli": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 6 + }, + "nop": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 1 + }, + "ntt": { + "num_dests": 2, + "num_sources": 3, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 9 + }, + "rshuffle": { + "num_dests": 2, + "num_sources": 2, + "default_throughput": 1, + "default_latency": 23, + "num_tokens": 7, + "special_latency_max": 17, + "special_latency_increment": 5 + }, + "sub": { + "num_dests": 1, + "num_sources": 2, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 6 + }, + "twintt": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 8 + }, + "twntt": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 6, + "num_tokens": 8 + }, + "xstore": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 6 + } + }, + "cinst": { + "bload": { + "num_dests": 0, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 5 + }, + "bones": { + "num_dests": 0, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 5 + }, + "exit": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 1 + }, + "cload": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 4, + "default_latency": 4 + }, + "nop": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 1 + }, + "cstore": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 5 + }, + "csyncm": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 1 + }, + "ifetch": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 5 + }, + "kgload": { + "num_dests": 1, + "num_sources": 0, + "default_throughput": 4, + "default_latency": 40 + }, + "kgseed": { + "num_dests": 0, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 1 + }, + "kgstart": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 40 + }, + "nload": { + "num_dests": 0, + "num_sources": 1, + "default_throughput": 4, + "default_latency": 4 + }, + "xinstfetch": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 1 + } + }, + "minst": { + "mload": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 1 + }, + "mstore": { + "num_dests": 1, + "num_sources": 1, + "default_throughput": 1, + "default_latency": 1 + }, + "msyncc": { + "num_dests": 0, + "num_sources": 0, + "default_throughput": 1, + "default_latency": 1 + } + } + } +} \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/README.md b/assembler_tools/hec-assembler-tools/debug_tools/README.md new file mode 100644 index 00000000..87a01d38 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/README.md @@ -0,0 +1,108 @@ +# Debug Tools + +This folder contains a collection of scripts designed to assist with debugging and testing various aspects of the assembler and instruction scheduling processes. + +## Dependencies + +These tools are Python based. Dependencies for these scripts are the same as [those](../README.md#dependencies) for the parent project. + +> **Note**: Ensure the `assembler` folder is included in your Python `PATH`. +> ```bash +> export PYTHONPATH="${PYTHONPATH}:$(pwd)/assembler" +> ``` + +Below is a detailed description and usage example for each tool. + +--- + +## Tools Overview + +### `main.py` + + This script serves as the main entry point for running ASM-ISA assembly and P-ISA scheduling processes. It handles the preprocessing, assembly, and scheduling of instructions, transforming them from high-level representations to executable formats. + + The script is used to convert P-ISA kernels into ASM-ISA instructions, manage memory models, and ensure that instructions are scheduled correctly according to dependencies and resource constraints. + +- **Usage**: + ```bash + python main.py --mem_file --prefix --isa_spec -v + ``` + - `--mem_file`: Specifies the input memory file. + - `--prefix`: One or more input prefixes to process, representing different instructions or kernels. + - `--isa_spec`: Input ISA specification (.json) file that defines the parameters of the instruction set architecture. + - `-v`: Enables verbose mode. + +--- + +### `isolation_test.py` + + This script isolates specific variables in P-ISA by replacing instructions that do not affect the specified variables with NOPs (no operation instructions). The isolation test is used to focus on specific variables within a P-ISA kernel, allowing developers to analyze the impact of these variables. + +- **Usage**: + ```bash + python isolation_test.py --pisa_file --xinst_file --out_file --track -v + ``` + - `--pisa_file`: Input P-ISA prep (.csv) file containing instructions. + - `--xinst_file`: Input XInst instruction file. + - `--out_file`: Output file name where the modified instructions will be saved. + - `--track`: Set of variables to track. + - `-v`: Enables verbose mode. + +--- + +### `deadlock_test.py` + + This script checks for deadlocks in the CInstQ and MInstQ caused by sync instructions. It raises an exception if a deadlock is found, indicating a potential issue in instruction scheduling. + +- **Usage**: + ```bash + python deadlock_test.py [input_prefix] + ``` + - ``: Directory containing instruction files, typically organized by prefixes. + - `[input_prefix]`: Optional prefix for instruction files, used to specify particular sets of instructions. + +--- + +### `order_test.py` + This script tests all registers in an XInstQ to determine if any register is used out of order based on the P-ISA instruction order. It is specifically designed for kernels that do not involve evictions. + + The script helps ensure that registers are accessed in the correct sequence, which is crucial for maintaining the integrity of instruction execution in systems where register order matters. This is particularly important for debugging and optimizing instruction scheduling. + +- **Usage**: + ```bash + python order_test.py --input_file -v + ``` + - `--input_file`: Specifies the input (.xinst) file containing the XInstQ instructions. + - `-v`: Enables verbose mode for detailed output, providing insights into the processing steps and results. + +### `xinst_timing_check/inject_bundles.py` + + This script injects dummy bundles into instruction files after the first bundle, simulating additional instruction loads for testing purposes. The injection of dummy bundles is used to test the system's handling of instruction loads and synchronization points. + +- **Usage**: + ```bash + python inject_bundles.py [input_prefix] [output_prefix] --isa_spec -b -ne + ``` + - ``: Directory containing input files to be processed. + - ``: Directory to save output files with injected bundles. + - `[input_prefix]`: Optional prefix for input files, specifying the target instruction set. + - `[output_prefix]`: Optional prefix for output files, defining the naming convention for saved files. + - `--isa_spec`: Input ISA specification (.json) file, providing architectural details. + - `-b`: Number of dummy bundles to insert, simulating additional instruction loads. + - `-ne`: Skip exit in dummy bundles, altering the behavior of injected instructions. + +--- + +### `xinst_timing_check/xtiming_check.py` + + This script checks timing for register access, ensuring registers are not read before their write completes, and checks for bank write conflicts. + +- **Usage**: + ```bash + python xtiming_check.py [input_prefix] --isa_spec + ``` + - ``: Directory containing input files for timing analysis. + - `[input_prefix]`: Optional prefix for input files, specifying the target instruction set. + - `--isa_spec`: Input ISA specification (.json) file, providing architectural details for timing validation. + +--- diff --git a/assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py b/assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py new file mode 100644 index 00000000..b0009351 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/deadlock_test.py @@ -0,0 +1,160 @@ +import argparse +import os + +# Searches the CInstQ and MInstQ to find deadlocks caused by sync instructions. +# Raises exception on first deadlock found, otherwise, completes successfully. + +def makeUniquePath(path: str): + """ + Normalizes and expand a given file path. + + Parameters: + path (str): The file path to normalize and expand. + + Returns: + str: The normalized and expanded file path. + """ + return os.path.normcase(os.path.realpath(os.path.expanduser(path))) + +def loadInstructions(istream) -> list: + """ + Loads instructions from an input iterator. + + Parameters: + istream: An iterator where each item is a string line considered to contain an instruction. + + Returns: + list: A list of tuples. Each tuple contains a list of tokens from the comma-separated instruction and the comment. + """ + retval = [] + for line in istream: + line = line.strip() + if line: + # Separate comment + s_instr = "" + s_comment = "" + comment_start_idx = line.find('#') + if comment_start_idx < 0: + s_instr = line + else: + s_instr = line[:comment_start_idx] + s_comment = line[comment_start_idx + 1:] + + # Tokenize instruction + s_instr = map(lambda s: s.strip(), s_instr.split(",")) + + # Add instruction to collection + retval.append((list(s_instr), s_comment)) + + return retval + +def findDeadlock(minsts: list, cinsts: list) -> tuple: + """ + Searches the CInstQ and MInstQ to find the first deadlock. + + Parameters: + minsts (list): List of MInst instructions. + cinsts (list): List of CInst instructions. + + Returns: + tuple: A tuple of indices where a deadlock was found, or None if no deadlock was found. + """ + retval = None + queue_order_watcher = 0 + deadlock_watcher = 0 # Tracks whenever a queue doesn't move: if both queues don't move back to back, a deadlock has occurred + q = minsts[:] + q1 = cinsts[:] + while retval is None and (q and q1): + # Remove all non-syncs from q + sync_idx = len(q) + for idx, instr in enumerate(q): + if 'sync' in instr[1]: + # Sync found + sync_idx = idx + break + q = q[sync_idx:] + if q: + assert 'sync' in q[0][1], 'Next instruction in queue is not a sync!' + + if sync_idx != 0: + # Queue moved: restart the deadlock watcher + deadlock_watcher = 0 + + if deadlock_watcher > 1: + # Deadlock detected: neither queue moved + if queue_order_watcher > 0: + # Swap queue to original input order + q, q1 = q1, q + assert len(q) > 0 and len(q1) > 0 + # Report indices where deadlock occurred + retval = (int(q[0][0]), int(q1[0][0])) + else: + # Check if syncing to an instruction already executed + sync_to_q1_idx = int(q[0][2]) + if q1 and int(q1[0][0]) < sync_to_q1_idx: + # q1 is NOT past synced instruction + + if sync_idx == 0: + # Queue didn't move + deadlock_watcher += 1 + + # Switch to execute q1 + q, q1 = q1, q + queue_order_watcher = (queue_order_watcher + 1) % 2 + else: + # q1 is past synced instruction + q = q[1:] + + return retval + +def main(input_dir: str, input_prefix: str = None): + """ + Main function to check for deadlocks in instruction queues. + + Parameters: + input_dir (str): The directory containing instruction files. + input_prefix (str): The prefix for instruction files. + """ + input_dir = makeUniquePath(input_dir) + if not input_prefix: + input_prefix = os.path.basename(input_dir) + + print('Deadlock test.') + print() + print('Input dir:', input_dir) + print('Input prefix:', input_prefix) + + xinst_file = os.path.join(input_dir, input_prefix + ".xinst") + cinst_file = os.path.join(input_dir, input_prefix + ".cinst") + minst_file = os.path.join(input_dir, input_prefix + ".minst") + + with open(xinst_file, 'r') as f_xin: + xinsts = loadInstructions(f_xin) + xinsts = [x for (x, _) in xinsts] + with open(cinst_file, 'r') as f_cin: + cinsts = loadInstructions(f_cin) + cinsts = [x for (x, _) in cinsts] + with open(minst_file, 'r') as f_min: + minsts = loadInstructions(f_min) + minsts = [x for (x, _) in minsts] + + deadlock_indices = findDeadlock(minsts, cinsts) + if deadlock_indices is not None: + raise RuntimeError('Deadlock detected: MinstQ: {}, CInstQ: {}'.format(deadlock_indices[0], deadlock_indices[1])) + + print('No deadlock detected between CInstQ and MInstQ.') + +if __name__ == "__main__": + module_name = os.path.basename(__file__) + print(module_name) + print() + + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("input_prefix", nargs="?") + args = parser.parse_args() + + main(args.input_dir, args.input_prefix) + + print() + print(module_name, "- Complete") \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py b/assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py new file mode 100644 index 00000000..21ac5af0 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/isolation_test.py @@ -0,0 +1,136 @@ +import argparse +import os +import re + +# Given a variable in P-ISA, this script will replace all instructions that do not +# affect the variable with appropriate NOPs. + + +def parse_args(): + """ + Parses command-line arguments for the preprocessing script. + + This function sets up the argument parser and defines the expected arguments for the script. + It returns a Namespace object containing the parsed arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser( + description=("Isolation Test.\n" + "Given a set of variables in P-ISA, this script will replace all instructions that do not" + " affect the variable with appropriate NOPs.")) + parser.add_argument("--pisa_file", required= True, help="Input P-ISA prep (.csv) file.") + parser.add_argument("--xinst_file", required=True, help="Input (xinst) instruction file.") + parser.add_argument("--out_file", default="", help="Output file name.") + parser.add_argument("--track", default="", dest="variables_set", nargs='+', help="Set of variables to track.") + parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, + help=("If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + args = parser.parse_args() + + return args + +if __name__ == "__main__": + module_name = os.path.basename(__file__) + + args = parse_args() + + # File paths for input and output + pisa_prep_file = args.pisa_file + xinst_file = args.xinst_file + output_file = "" + if (args.out_file): + output_file = args.out_file + else: + # Create the new file name + name, ext = os.path.splitext(xinst_file) + output_file = f"{name}.out{ext}" + + # Set of variables to track + variables_set = args.variables_set + + if args.verbose > 0: + print(module_name) + print() + print("P-ISA: {0}".format(pisa_prep_file)) + print("Xinst File: {0}".format(xinst_file)) + print("Output Name: {0}".format(output_file)) + print("Tracking: {0}".format(variables_set)) + + # Find all related variables + pisa_instrs = [] + pisa_file_contents = [] + with open(pisa_prep_file, 'r') as f_in_pisa: + pisa_file_contents = [line for line in f_in_pisa if line] + + l = [] + set_updated = True + while set_updated: + set_updated = False + for line_idx, line in enumerate(pisa_file_contents): + # Remove comment + s_split = line.split("#") + line = s_split[0] + # Split into components + tmp_split = map(lambda s: s.strip(), line.split(",")) + s_split = [] + for component in tmp_split: + s_split.append(component.split('(')[0].strip()) + pisa_instrs.append(s_split[1:]) + if any(x in s_split for x in variables_set): + # Add all other variables as dependents + if s_split[1] == 'muli' or s_split[1] == 'maci': + s_split = s_split[2:-2] + else: + s_split = s_split[2:-1] + new_vars = set(v for v in s_split if re.search('^[A-Za-z_][A-Za-z0-9_]*', v)) + if 'iN' in new_vars: + print('iN') + if new_vars - variables_set: + l += [x for x in new_vars if x not in variables_set] + variables_set |= new_vars + set_updated = True + + print(variables_set) + + pisa_instr_num_set = set() + for idx, s_split in enumerate(pisa_instrs): + if any(x in s_split for x in variables_set): + # Variable found in instruction: keep it + pisa_instr_num_set.add(idx + 1) + + # Keep only xinsts that are used for the kept p-isa instr + with open(xinst_file, 'r') as f_in: + with open(output_file, 'w') as f_out: + for line in f_in: + # Remove comment + s_split = line.split("#") + s_line = s_split[0].strip() + # Split into components + s_split = list(map(lambda s: s.strip(), line.split(","))) + out_line = '' + if int(s_split[1]) in pisa_instr_num_set: + # Xinstruction is needed to complete p-isa instr + if s_split[2] not in ('move', 'xstore', 'nop'): + out_line = s_line + " # " + str(pisa_instrs[int(s_split[1]) - 1]) + else: + out_line = line.strip() + elif 'xstore' in s_line: + # All xstores are required because they are sync points with CInstQ + out_line = s_line.strip() + elif 'exit' in s_line: + # Keep all exits + out_line = s_line.strip() + elif 'rshuffle' in s_line: + # Other rshuffles are converted to nops for timing + out_line = '{}, {}, nop, {} # rshuffle'.format(s_split[0], s_split[1], s_split[7]) + elif 'nop' in s_line: + # Keep nops timing + out_line = s_line.strip() + if not out_line: + # Any other instructions are converted to single cycle nop + out_line = '{}, {}, nop, 0'.format(s_split[0], s_split[1]) + print(out_line, file=f_out) + + print("Done") diff --git a/assembler_tools/hec-assembler-tools/debug_tools/main.py b/assembler_tools/hec-assembler-tools/debug_tools/main.py new file mode 100644 index 00000000..7ac6efed --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/main.py @@ -0,0 +1,426 @@ +import os +import sys +import time +import argparse + +from assembler.common import constants +from assembler.common.config import GlobalConfig +from assembler.instructions import xinst +from assembler.stages import preprocessor +from assembler.stages import scheduler +from assembler.stages.scheduler import schedulePISAInstructions +from assembler.stages.asm_scheduler import scheduleASMISAInstructions +from assembler.memory_model import MemoryModel +from assembler.memory_model import mem_info +from assembler.isa_spec import SpecConfig + +def parse_args(): + """ + Parses command-line arguments for the preprocessing script. + + This function sets up the argument parser and defines the expected arguments for the script. + It returns a Namespace object containing the parsed arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser( + description=("Main Test.\n")) + parser.add_argument("--mem_file", default="", help="Input memory file.") + parser.add_argument("--prefix", default="", dest="base_names", nargs='+', help="One or more input prefix to process.") + parser.add_argument("--isa_spec", default="", dest="isa_spec_file", + help=("Input ISA specification (.json) file.")) + parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, + help=("If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + args = parser.parse_args() + + return args + +def main_readmem(args): + """ + Reads and processes memory information from a file. + """ + import sys + import yaml + import io + + if args.mem_file: + mem_filename = args.mem_file + else: + raise argparse.ArgumentError(None, "Please provide input memory file using `--mem_file` option.") + + mem_meta_info = None + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + + if mem_meta_info: + with io.StringIO() as retval_f: + yaml.dump(mem_meta_info.as_dict(), retval_f, sort_keys=False) + yaml_str = retval_f.getvalue() + + print(yaml_str) + print("--------------") + new_meminfo = yaml.safe_load(yaml_str) + print(new_meminfo) + print("--------------") + mem_meta_info = mem_info.MemInfo(**new_meminfo) + yaml.dump(mem_meta_info.as_dict(), sys.stdout, sort_keys=False) + else: + print("None") + +def asmisa_preprocessing(input_filename: str, + output_filename: str, + b_use_bank_0: bool, + b_verbose=True) -> int: + """ + Preprocess P-ISA kernel and save the intermediate result. + + Parameters: + input_filename (str): The input file containing the P-ISA kernel. + output_filename (str): The output file to save the intermediate result. + b_use_bank_0 (bool): Whether to use bank 0. + b_verbose (bool): Whether to print verbose output. + + Returns: + int: The time taken for preprocessing in seconds. + """ + if b_verbose: + print('Preprocessing P-ISA kernel...') + + hec_mem_model = MemoryModel(constants.MemoryModel.HBM.MAX_CAPACITY_WORDS, + constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS) + + start_time = time.time() + + with open(input_filename, 'r') as insts: + insts_listing = preprocessor.preprocessPISAKernelListing(hec_mem_model, + insts, + progress_verbose=b_verbose) + + if b_verbose: + print("Assigning register banks to variables...") + preprocessor.assignRegisterBanksToVars(hec_mem_model, insts_listing, use_bank0=b_use_bank_0) + + retval_timing = time.time() - start_time + + if b_verbose: + print("Saving intermediate...") + with open(output_filename, 'w') as outnum: + for inst in insts_listing: + inst_line = inst.toPISAFormat() # + f" # {inst.id}" + if inst_line: + print(inst_line, file=outnum) + + return retval_timing + +def asmisa_assembly(output_xinst_filename: str, + output_cinst_filename: str, + output_minst_filename: str, + output_mem_filename: str, + input_filename: str, + mem_filename: str, + max_bundle_size: int, + hbm_capcity_words: int, + spad_capacity_words: int, + num_register_banks: int = constants.MemoryModel.NUM_REGISTER_BANKS, + register_range: range = None, + b_verbose=True) -> tuple: + """ + Assembles ASM-ISA instructions from preprocessed P-ISA kernel. + + Parameters: + output_xinst_filename (str): The output file for XInst instructions. + output_cinst_filename (str): The output file for CInst instructions. + output_minst_filename (str): The output file for MInst instructions. + output_mem_filename (str): The output file for memory information. + input_filename (str): The input file containing the preprocessed P-ISA kernel. + mem_filename (str): The file containing memory information. + max_bundle_size (int): Maximum number of instructions in a bundle. + hbm_capcity_words (int): Capacity of HBM in words. + spad_capacity_words (int): Capacity of SPAD in words. + num_register_banks (int): Number of register banks. + register_range (range): Range of registers. + b_verbose (bool): Whether to print verbose output. + + Returns: + tuple: A tuple containing the number of XInsts, NOPs, idle cycles, dependency timing, and scheduling timing. + """ + if b_verbose: + print("Assembling!") + print("Reloading kernel from intermediate...") + + hec_mem_model = MemoryModel(hbm_capcity_words, spad_capacity_words, num_register_banks, register_range) + + insts_listing = [] + with open(input_filename, 'r') as insts: + for line_no, s_line in enumerate(insts, 1): + parsed_insts = None + if GlobalConfig.debugVerbose: + if line_no % 100 == 0: + print(f"{line_no}") + # Instruction is one that is represented by single XInst + inst = xinst.createFromPISALine(hec_mem_model, s_line, line_no) + if inst: + parsed_insts = [inst] + + if not parsed_insts: + raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) + + insts_listing += parsed_insts + + if b_verbose: + print("Interpreting variable meta information...") + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) + + if b_verbose: + print("Generating dependency graph...") + start_time = time.time() + dep_graph = scheduler.generateInstrDependencyGraph(insts_listing) + deps_end = time.time() - start_time + + if b_verbose: + print("Scheduling ASM-ISA instructions...") + start_time = time.time() + minsts, cinsts, xinsts, num_idle_cycles = scheduleASMISAInstructions(dep_graph, + max_bundle_size, + hec_mem_model, + constants.Constants.REPLACEMENT_POLICY_FTBU, + b_verbose) + sched_end = time.time() - start_time + num_nops = 0 + num_xinsts = 0 + for bundle_xinsts, *_ in xinsts: + for xinstr in bundle_xinsts: + num_xinsts += 1 + if isinstance(xinstr, xinst.Exit): + break # Stop counting instructions after bundle exit + if isinstance(xinstr, xinst.Nop): + num_nops += 1 + + if b_verbose: + print("Saving minst...") + with open(output_minst_filename, 'w') as outnum: + for idx, inst in enumerate(minsts): + inst_line = inst.toMASMISAFormat() + if inst_line: + print(f"{idx}, {inst_line}", file=outnum) + + if b_verbose: + print("Saving cinst...") + with open(output_cinst_filename, 'w') as outnum: + for idx, inst in enumerate(cinsts): + inst_line = inst.toCASMISAFormat() + if inst_line: + print(f"{idx}, {inst_line}", file=outnum) + + if b_verbose: + print("Saving xinst...") + with open(output_xinst_filename, 'w') as outnum: + for bundle_i, bundle_data in enumerate(xinsts): + for inst in bundle_data[0]: + inst_line = inst.toXASMISAFormat() + if inst_line: + print(f"F{bundle_i}, {inst_line}", file=outnum) + + if output_mem_filename: + if b_verbose: + print("Saving mem...") + with open(output_mem_filename, 'w') as outnum: + mem_meta_info.exportLegacyMem(outnum) + + return num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end + +def main_asmisa(args): + """ + Main function to run ASM-ISA assembly process. + """ + b_use_bank_0: bool = False + b_use_old_mem_file = False + b_verbose = True if args.verbose > 0 else False + GlobalConfig.debugVerbose = 0 + GlobalConfig.suppressComments = False + GlobalConfig.useHBMPlaceHolders = True + GlobalConfig.useXInstFetch = False + + max_bundle_size = constants.Constants.MAX_BUNDLE_SIZE + hbm_capcity_words = constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2 + spad_capacity_words = constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS + num_register_banks = constants.MemoryModel.NUM_REGISTER_BANKS + register_range = None + + # All base names for processing + if len(args.base_names) > 0: + all_base_names = args.base_names + else: + raise argparse.ArgumentError(f"Please provide one or more input file prefixes using `--prefix` option.") + + for base_name in all_base_names: + in_kernel = f'{base_name}.csv' + mem_kernel = f'{base_name}.tw.mem' + mid_kernel = f'{base_name}.tw.csv' + out_xinst = f'{base_name}.xinst' + out_cinst = f'{base_name}.cinst' + out_minst = f'{base_name}.minst' + out_mem = f'{base_name}.mem' if b_use_old_mem_file else None + + if b_verbose: + print("Verbose mode: ON") + + print('Input:', in_kernel) + + # Preprocessing + insts_end = asmisa_preprocessing(in_kernel, mid_kernel, b_use_bank_0, b_verbose) + + if b_verbose: + print() + + num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end = \ + asmisa_assembly(out_xinst, + out_cinst, + out_minst, + out_mem, + mid_kernel, + mem_kernel, + max_bundle_size, + hbm_capcity_words, + spad_capacity_words, + num_register_banks, + register_range, + b_verbose=b_verbose) + + if b_verbose: + print(f"Input: {in_kernel}") + print(f"Intermediate: {mid_kernel}") + print(f"--- Preprocessing time: {insts_end} seconds ---") + print(f"--- Total XInstructions: {num_xinsts} ---") + print(f"--- Deps time: {deps_end} seconds ---") + print(f"--- Scheduling time: {sched_end} seconds ---") + print(f"--- Minimum idle cycles: {num_idle_cycles} ---") + print(f"--- Minimum nops required: {num_nops} ---") + print() + + print("Complete") + +def main_pisa(args): + """ + Main function to run P-ISA scheduling process. + """ + b_use_bank_0: bool = False + b_verbose = True if args.verbose > 0 else False + + max_bundle_size = 8 + hec_mem_model = MemoryModel(constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2, + 16, + 4, + range(8)) + + if len(args.base_names) == 1: + base_name = args.base_names[0] + else: + raise argparse.ArgumentError(None, f"Please provide an input file prefix using `--prefix` option.") + + print("HBM") + print(hec_mem_model.hbm.CAPACITY / constants.Constants.GIGABYTE, "GB") + print(hec_mem_model.hbm.CAPACITY_WORDS, "words") + print() + + + in_kernel = f'{base_name}.csv' + mid_kernel = f'{base_name}.tw.csv' + out_kernel = f'{base_name}.tw.new.csv' + out_xinst = f'{base_name}.xinst' + out_cinst = f'{base_name}.cinst' + out_minst = f'{base_name}.minst' + + insts_listing = [] + start_time = time.time() + # Read input kernel and pre-process P-ISA: + # Resulting instructions will be correctly transformed and ready to be converted into ASM-ISA instructions; + # Variables used in the kernel will be automatically assigned to banks. + with open(in_kernel, 'r') as insts: + insts_listing = preprocessor.preprocessPISAKernelListing(hec_mem_model, + insts, + progress_verbose=b_verbose) + + print("Assigning register banks to variables...") + preprocessor.assignRegisterBanksToVars(hec_mem_model, insts_listing, use_bank0=b_use_bank_0) + + hec_mem_model.output_variables.update(v_name for v_name in hec_mem_model.variables if 'output' in v_name) + + insts_end = time.time() - start_time + + print("Saving intermediate...") + with open(mid_kernel, 'w') as outnum: + for inst in insts_listing: + inst_line = inst.toPISAFormat() + f" # {inst.id}" + if inst_line: + print(inst_line, file=outnum) + + #print("Reloading kernel from intermediate...") + #insts_listing = [] + #with open(mid_kernel, 'r') as insts: + # for line_no, s_line in enumerate(insts, 1): + # parsed_insts = None + # if line_no % 100 == 0: + # print(f"{line_no}") + # # instruction is one that is represented by single XInst + # inst = xinst.createFromPISALine(hec_mem_model, s_line, line_no) + # if inst: + # parsed_insts = [ inst ] + + # if not parsed_insts: + # raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) + + # insts_listing += parsed_insts + + print("Generating dependency graph...") + start_time = time.time() + dep_graph = preprocessor.generateInstrDependencyGraph(insts_listing) + deps_end = time.time() - start_time + + # Assign artificial register to allow scheduling of P-ISA + for v in hec_mem_model.variables.values(): + v.register = hec_mem_model.register_banks[0].getRegister(0) + + print("Scheduling P-ISA instructions...") + start_time = time.time() + pisa_insts_schedule, num_idle_cycles, num_nops = schedulePISAInstructions(dep_graph, progress_verbose=b_verbose) + sched_end = time.time() - start_time + + print("Saving...") + with open(out_kernel, 'w') as outnum: + for idx, inst in enumerate(pisa_insts_schedule): + inst_line = inst.toPISAFormat() + if inst_line: + print(inst_line, file=outnum) + + print(f"Input: {in_kernel}") + print(f"Intermediate: {mid_kernel}") + print(f"Output: {out_kernel}") + print(f"Instructions generated: {len(insts_listing)}") + print(f"--- Generation time: {insts_end} seconds ---") + print(f"--- Number of instructions: {len(insts_listing)} ---") + print(f"--- Deps time: {deps_end} seconds ---") + print(f"--- Scheduling time: {sched_end} seconds ---") + print(f"--- Minimum idle cycles: {num_idle_cycles} ---") + print(f"--- Minimum nops required: {num_nops} ---") + + print("Complete") + +if __name__ == "__main__": + module_dir = os.path.dirname(__file__) + module_name = os.path.basename(__file__) + + sys.path.append(os.path.join(module_dir,'xinst_timing_check')) + print(module_dir,'xinst_timing_check') + args = parse_args() + + repo_dir = os.path.join(module_dir,"..") + args.isa_spec_file = SpecConfig.initialize_isa_spec(repo_dir, args.isa_spec_file) + if args.verbose > 0: + print(f"ISA Spec: {args.isa_spec_file}") + + main_asmisa(args) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/order_test.py b/assembler_tools/hec-assembler-tools/debug_tools/order_test.py new file mode 100644 index 00000000..615ff879 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/order_test.py @@ -0,0 +1,100 @@ +import argparse +import re +import os + +# Tests all registers in an XInstQ for whether a register is used out of order based on P-ISA instruction order. +# This only works for kernels without evictions. +def parse_args(): + """ + Parses command-line arguments for the preprocessing script. + + This function sets up the argument parser and defines the expected arguments for the script. + It returns a Namespace object containing the parsed arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser( + description=("Order Test.\n" + "Tests all registers in an XInstQ for whether a register is used out of order based on P-ISA instruction order.\n" + "This only works for kernels without evictions.")) + parser.add_argument("--input_file", required= True, help="Input (.xinst) file.") + parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, + help=("If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + args = parser.parse_args() + + return args + +def convertRegNameToTuple(reg_name) -> tuple: + """ + Converts a register name to a tuple representation. + + Parameters: + reg_name (str): The register name in the format 'rb'. + + Returns: + tuple: A tuple containing the bank and register as integers. + """ + tmp_s = reg_name.split("r")[1] + tmp_s = tmp_s.split("b") + return (int(tmp_s[1]), int(tmp_s[0])) + +if __name__ == "__main__": + module_name = os.path.basename(__file__) + + args = parse_args() + input_file = args.input_file + + if args.verbose > 0: + print(module_name) + print() + print("Xinst File: {0}".format(input_file)) + print() + print("Starting") + + register_map = {} + + my_rx = "r[0-9]+b[0-3]" + prev_pisa_inst = 0 + instr_counter = 0 + with open(input_file, 'r') as f_in: + for line_idx, s_line in enumerate(f_in): + instr_regs = set() + s_split = s_line.split("#") + s_split = s_split[0].split(",") + pisa_instr_num = int(s_split[1]) + for s in s_split: + match = re.search(my_rx, s) + if match: + reg_name = s[match.start():match.end()] + if reg_name not in instr_regs: + instr_regs.add(reg_name) + reg = convertRegNameToTuple(reg_name) + if reg not in register_map: + register_map[reg] = [] + register_map[reg].append(pisa_instr_num) + + sorted_keys = [x for x in register_map] + sorted_keys.sort() + error_map = set() + + for reg in sorted_keys: + reg_name = f'r{reg[1]}b{reg[0]}' + print(reg_name, register_map[reg]) + reg_lst = register_map[reg] + inverted_map = {} + prev_in = 0 + for idx in range(len(reg_lst)): + if reg_lst[idx] >= prev_in: + prev_in = reg_lst[idx] + else: + inverted_map[idx] = (prev_in, reg_lst[idx]) + if inverted_map: + print('*** Ahead:', inverted_map) + error_map.add(reg_name) + + if error_map: + raise RuntimeError(f'Registers used out of order: {error_map}') + + print("Done") \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py new file mode 100644 index 00000000..4a2dae7a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py @@ -0,0 +1,274 @@ +import argparse +import os + +from xinst import xinstruction +from spec_config import XTC_SpecConfig + +# Injects dummy bundles after bundle 1 + +NUM_BUNDLE_INSTRUCTIONS = 64 + +def makeUniquePath(path: str): + """ + Normalizes and expand a given file path. + + Parameters: + path (str): The file path to normalize and expand. + + Returns: + str: The normalized and expanded file path. + """ + return os.path.normcase(os.path.realpath(os.path.expanduser(path))) + +def transferNextBundle(xinst_in_stream, xinst_out_stream, bundle_number): + """ + Transfers the next bundle of instructions from input to output stream. + + Parameters: + xinst_in_stream: The input stream for XInst instructions. + xinst_out_stream: The output stream for XInst instructions. + bundle_number: The current bundle number. + """ + for _ in range(NUM_BUNDLE_INSTRUCTIONS): + s_line = xinst_in_stream.readline().strip() + # Must have an instruction + assert s_line + + # Split line into tokens + tokens, comment = xinstruction.tokenizeFromLine(s_line) + tokens = list(tokens) + tokens[0] = f"F{bundle_number}" + + s_line = ', '.join(tokens) + if comment: + s_line += f" # {comment}" + + print(s_line, file=xinst_out_stream) + +def main(nbundles: int, + input_dir: str, + output_dir: str, + input_prefix: str = None, + output_prefix: str = None, + b_use_exit: bool = True): + """ + Main function to inject dummy bundles into instruction files. + + Parameters: + nbundles (int): Number of dummy bundles to insert. + input_dir (str): Directory containing input files. + output_dir (str): Directory to save output files. + input_prefix (str): Prefix for input files. + output_prefix (str): Prefix for output files. + b_use_exit (bool): Whether to use 'bexit' in dummy bundles. + """ + print("Starting") + + input_dir = makeUniquePath(input_dir) + if not input_prefix: + input_prefix = os.path.basename(input_dir) + output_dir = makeUniquePath(output_dir) + if not output_prefix: + output_prefix = os.path.basename(output_dir) + + print('Input dir:', input_dir) + print('Input prefix:', input_prefix) + print('Output dir:', output_dir) + print('Output prefix:', output_prefix) + print('Dummy bundles to insert:', nbundles) + print('Use bexit:', b_use_exit) + + xinst_file_i = os.path.join(input_dir, input_prefix + ".xinst") + cinst_file_i = os.path.join(input_dir, input_prefix + ".cinst") + minst_file_i = os.path.join(input_dir, input_prefix + ".minst") + + xinst_file_o = os.path.join(output_dir, output_prefix + ".xinst") + cinst_file_o = os.path.join(output_dir, output_prefix + ".cinst") + minst_file_o = os.path.join(output_dir, output_prefix + ".minst") + + with open(xinst_file_i, 'r') as f_xinst_file_i, \ + open(cinst_file_i, 'r') as f_cinst_file_i, \ + open(minst_file_i, 'r') as f_minst_file_i: + with open(xinst_file_o, 'w') as f_xinst_file_o, \ + open(cinst_file_o, 'w') as f_cinst_file_o, \ + open(minst_file_o, 'w') as f_minst_file_o: + + current_bundle = 0 + + # Read xinst until first bundle is over + num_xstores = 0 + for _ in range(NUM_BUNDLE_INSTRUCTIONS): + line = f_xinst_file_i.readline().strip() + assert line # Cannot be EOF + + # Write line to output as is + print(line, file=f_xinst_file_o) + + # Split line into tokens + tokens, _ = xinstruction.tokenizeFromLine(line) + + # Must be bundle 0 + assert int(tokens[0][1:]) == current_bundle + + if tokens[2] == 'xstore': + # Encountered xstore + num_xstores += 1 + + cinst_line_no = 0 + cinst_insertion_line_start = 0 # Track which line we started inserting dummy bundles into CInstQ + cinst_insertion_line_count = 0 # Track how many lines of dummy bundles were inserted into CInstQ + + # Read cinst until first bundle is over + while True: # do-while + line = f_cinst_file_i.readline().strip() + # Cannot be EOF + assert line + + # Write line to output as is + print(line, file=f_cinst_file_o) + + # Split line into tokens + tokens, _ = xinstruction.tokenizeFromLine(line) + + cinst_line_no += 1 + + if tokens[1] == 'ifetch': + # Encountered first ifetch + assert int(tokens[2]) == current_bundle + break + + # Need to check if there are any xstores that have matching cstores + for _ in range(num_xstores): + line = f_cinst_file_i.readline().strip() + # Cannot be EOF + assert line + + # Write line to output as is + print(line, file=f_cinst_file_o) + + # Split line into tokens + tokens, _ = xinstruction.tokenizeFromLine(line) + # Must be a matching cstore + assert tokens[1] == 'cstore' + + cinst_line_no += 1 + + current_bundle += 1 # Next bundle + cinst_insertion_line_start = cinst_line_no # Start inserting dummy bundles + + # Start inserting dummy bundles + print("Inserting", nbundles, "dummy bundles...") + if nbundles > 0: + # Wait for last bundle to complete (use max possible bundle size) + print(f"{cinst_line_no}, cnop, 2000", file=f_cinst_file_o) + cinst_line_no += 1 + for idx in range(nbundles): + if idx % 5000 == 0: + print("{}% - {}/{}".format(idx * 100 // nbundles, idx, nbundles)) + # Cinst + print(f"{cinst_line_no}, ifetch, {current_bundle} # dummy bundle {idx + 1}", file=f_cinst_file_o) + print(f"{cinst_line_no + 1}, cnop, 70", file=f_cinst_file_o) + cinst_line_no += 2 + + # Xinst + if b_use_exit: + print(f"F{current_bundle}, 0, bexit # dummy bundle", file=f_xinst_file_o) + else: + print(f"F{current_bundle}, 0, nop, 0", file=f_xinst_file_o) + for _ in range(NUM_BUNDLE_INSTRUCTIONS - 1): + print(f"F{current_bundle}, 0, nop, 0", file=f_xinst_file_o) + + current_bundle += 1 + + print("100% - {0}/{0}".format(nbundles)) + + # Number of lines inserted in CInstQ + cinst_insertion_line_count = cinst_line_no - cinst_insertion_line_start + + # Complete CInstQ and XInstQ + print() + print('Transferring remaining CInstQ and XInstQ...') + print(cinst_line_no) + while True: # do-while + if cinst_line_no % 50000 == 0: + print(cinst_line_no) + + line = f_cinst_file_i.readline().strip() + if not line: # EOF + break + + # Split line into tokens + tokens, comment = xinstruction.tokenizeFromLine(line) + tokens = list(tokens) + + tokens[0] = str(cinst_line_no) + # Output line with correct line and bundle number + if tokens[1] == 'ifetch': + # Ensure fetching correct bundle + tokens[2] = str(current_bundle) + + # Output xinst bundle + transferNextBundle(f_xinst_file_i, f_xinst_file_o, current_bundle) + current_bundle += 1 + + line = ', '.join(tokens) + if comment: + line += f" # {comment}" + + print(line, file=f_cinst_file_o) + cinst_line_no += 1 + + print(cinst_line_no) + + # Fix sync points in MInstQ + print() + print('Fixing MInstQ sync points...') + for idx, line in enumerate(f_minst_file_i): + if idx % 5000 == 0: + print(idx) + + tokens, comment = xinstruction.tokenizeFromLine(line) + assert int(tokens[0]) == idx, 'Unexpected line number mismatch in MInstQ.' + + tokens = list(tokens) + # Process sync instruction + if tokens[1] == 'msyncc': + ctarget_line_no = int(tokens[2]) + if ctarget_line_no >= cinst_insertion_line_start: + ctarget_line_no += cinst_insertion_line_count + tokens[2] = str(ctarget_line_no) + + # Transfer minst line to output file + line = ', '.join(tokens) + if comment: + line += f" # {comment}" + + print(line, file=f_minst_file_o) + + print(idx) + +if __name__ == "__main__": + module_dir = os.path.dirname(__file__) + module_name = os.path.basename(__file__) + print(module_name) + + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("output_dir") + parser.add_argument("input_prefix", nargs="?") + parser.add_argument("output_prefix", nargs="?") + parser.add_argument("--isa_spec", default="", dest="isa_spec_file", + help=("Input ISA specification (.json) file.")) + parser.add_argument("-b", "--dummy_bundles", dest='nbundles', type=int, default=0) + parser.add_argument("-ne", "--skip_exit", dest='b_use_exit', action='store_false') + args = parser.parse_args() + args.isa_spec_file = XTC_SpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + + print(f"ISA Spec: {args.isa_spec_file}") + print() + + main(args.nbundles, args.input_dir, args.output_dir, + args.input_prefix, args.output_prefix, args.b_use_exit) + + print() + print(module_name, "- Complete") \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py new file mode 100644 index 00000000..04a7d98d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py @@ -0,0 +1,62 @@ +import os +import xinst +from assembler.isa_spec import SpecConfig + +class XTC_SpecConfig (SpecConfig): + + __target_xops = { + "add" : xinst.add.Instruction, + "exit" : xinst.exit_mod.Instruction, + "intt" : xinst.intt.Instruction, + "mac" : xinst.mac.Instruction, + "maci" : xinst.maci.Instruction, + "move" : xinst.move.Instruction, + "mul" : xinst.mul.Instruction, + "muli" : xinst.muli.Instruction, + "nop" : xinst.nop.Instruction, + "ntt" : xinst.ntt.Instruction, + "rshuffle" : xinst.rshuffle.Instruction, + "sub" : xinst.sub.Instruction, + "twintt" : xinst.twintt.Instruction, + "twntt" : xinst.twntt.Instruction, + "xstore" : xinst.xstore.Instruction, + } + + _target_ops = { + "xinst": __target_xops + } + + _target_attributes = { + "num_tokens" : "SetNumTokens", + "num_dests" : "SetNumDests", + "num_sources" : "SetNumSources", + "default_throughput" : "SetDefaultThroughput", + "default_latency" : "SetDefaultLatency", + "special_latency_max" : "SetSpecialLatencyMax", + "special_latency_increment": "SetSpecialLatencyIncrement", + } + + @classmethod + def dump_isa_spec_to_json(cls, filename): + """ + Uninmplemented for this child class. + """ + print("WARNING: 'dump_isa_spec_to_json' unimplemented for xinst_timing_check") + + @classmethod + def initialize_isa_spec(cls, module_dir, isa_spec_file): + + if not isa_spec_file: + isa_spec_file = os.path.join(module_dir, "../../config/isa_spec.json") + isa_spec_file = os.path.abspath(isa_spec_file) + + if not os.path.exists(isa_spec_file): + raise FileNotFoundError( + f"Required ISA Spec file not found: {isa_spec_file}\n" + "Please provide a valid path using the `isa_spec` option, " + "or use a valid default file at: `/config/isa_spec.json`." + ) + + cls.init_isa_spec_from_json(isa_spec_file) + + return isa_spec_file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py new file mode 100644 index 00000000..5d49e768 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/__init__.py @@ -0,0 +1,24 @@ +from .xinstruction import XInstruction +from . import add, mul, muli, mac, maci, ntt, intt, twntt, twintt, rshuffle, sub, move, xstore, nop +from . import exit as exit_mod + +# XInst aliases + +Add = add.Instruction +Mul = mul.Instruction +Muli = muli.Instruction +Mac = mac.Instruction +Maci = maci.Instruction +NTT = ntt.Instruction +iNTT = intt.Instruction +twNTT = twntt.Instruction +twiNTT = twintt.Instruction +rShuffle = rshuffle.Instruction +Sub = sub.Instruction +Move = move.Instruction +XStore = xstore.Instruction +Exit = exit_mod.Instruction +Nop = nop.Instruction + +# collection of XInstructions with P-ISA or intermediate P-ISA equivalents +ASMISA_INSTRUCTIONS = ( Add, Mul, Muli, Mac, Maci, NTT, iNTT, twNTT, twiNTT, rShuffle, Sub, Move, XStore, Exit, Nop ) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py new file mode 100644 index 00000000..648d0909 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/add.py @@ -0,0 +1,77 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents an `add` instruction, inheriting from XInstruction. + + This instructions adds two polynomials stored in the register file and + store the result in a register. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_add.md + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parse an ASM ISA line to create an Instruction object. + + Parameters: + line (str): The line of text to parse. + + Returns: + list: An Instruction object if parsing is successful, None otherwise. + + Raises: + ValueError: If the line cannot be parsed into an Instruction. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # PISA instruction number + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "add". + """ + return "add" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Constructs a new Instruction object. + + Parameters: + bundle (int): The bundle number. + pisa_instr (int): The PISA instruction number. + dsts (list): List of destination registers. + srcs (list): List of source registers. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list): Additional parameters for the instruction. + comment (str): Optional comment for the instruction. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py new file mode 100644 index 00000000..e6747286 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/exit.py @@ -0,0 +1,75 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents an `bexit` instruction, inheriting from XInstruction. + + This instruction terminates execution of an instruction bundle. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_exit.md + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction object. + + Parameters: + line (str): The line of text to parse. + + Returns: + list: An Instruction object if parsing is successful, None otherwise. + + Raises: + ValueError: If the line cannot be parsed into an Instruction. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # PISA instruction number + [], + [], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "bexit". + """ + return "bexit" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Constructs a new Instruction object. + + Parameters: + bundle (int): The bundle number. + pisa_instr (int): The PISA instruction number. + dsts (list): List of destination registers. + srcs (list): List of source registers. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list): Additional parameters for the instruction. + comment (str): Optional comment for the instruction. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py new file mode 100644 index 00000000..18c4e7d6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/intt.py @@ -0,0 +1,76 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents an `intt` instruction, inheriting from XInstruction. + + The Inverse Number Theoretic Transform (iNTT), converts NTT form to positional form. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_intt.md + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction object. + + Parameters: + line (str): The line of text to parse. + + Returns: + list: An Instruction object if parsing is successful, None otherwise. + + Raises: + ValueError: If the line cannot be parsed into an Instruction. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # PISA instruction number + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "intt". + """ + return "intt" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Constructs a new Instruction object. + + Parameters: + bundle (int): The bundle number. + pisa_instr (int): The PISA instruction number. + dsts (list): List of destination registers. + srcs (list): List of source registers. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list): Additional parameters for the instruction. + comment (str): Optional comment for the instruction. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py new file mode 100644 index 00000000..c5ce986c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mac.py @@ -0,0 +1,77 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `mac` Instruction for element-wise polynomial multiplication and accumulation. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mac.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "mac". + """ + return "mac" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py new file mode 100644 index 00000000..3525f06d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/maci.py @@ -0,0 +1,79 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `maci` Instruction. + + Element-wise polynomial scaling by an immediate value and accumulation. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_maci.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Psisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "maci". + """ + return "maci" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py new file mode 100644 index 00000000..4d9205ae --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/move.py @@ -0,0 +1,79 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `move` Instruction. + + This instruction copies data from one register to a different one. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_move.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "move". + """ + return "move" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py new file mode 100644 index 00000000..d243b0d5 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/mul.py @@ -0,0 +1,79 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `mul` Instruction. + + This instructions performs element-wise polynomial multiplication. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mul.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "mul". + """ + return "mul" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py new file mode 100644 index 00000000..e117e7cd --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/muli.py @@ -0,0 +1,79 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `muli` Instruction. + + This instruction performs element-wise polynomial scaling by an immediate value. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_muli.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "muli". + """ + return "muli" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py new file mode 100644 index 00000000..02355b75 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/nop.py @@ -0,0 +1,81 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `nop` Instruction. + + This instruction adds a desired amount of idle cycles to the compute flow. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_nop.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 4 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + idle_cycles = int(tokens[3]) + 1 + retval = cls( + int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + [], + [], + idle_cycles, + idle_cycles, + tokens[3:], + comment + ) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "nop". + """ + return "nop" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py new file mode 100644 index 00000000..22eac064 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/ntt.py @@ -0,0 +1,80 @@ +from argparse import Namespace + +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents an `ntt` instruction (Number Theoretic Transform). + + Converts positional form to NTT form. + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_ntt.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # bundle + int(tokens[1]), # pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "ntt". + """ + return "ntt" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py new file mode 100644 index 00000000..94ad23bc --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/rshuffle.py @@ -0,0 +1,140 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents an Instruction with specific operational parameters and special latency properties. + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + _get_name: Gets the name of the instruction. + + Properties: + data_type: Gets the data type from the 'other' parameters. + wait_cycles: Gets the wait cycles from the 'other' parameters. + special_latency_max: Gets the special latency maximum. + special_latency_increment: Gets the special latency increment. + """ + + # To be initialized from ASM ISA spec + _OP_RMOVE_LATENCY : int + _OP_RMOVE_LATENCY_MAX: int + _OP_RMOVE_LATENCY_INC: int + + @classmethod + def SetSpecialLatencyMax(cls, val): + cls._OP_RMOVE_LATENCY_MAX = val + cls._OP_RMOVE_LATENCY = cls._OP_RMOVE_LATENCY_MAX + + @classmethod + def SetSpecialLatencyIncrement(cls, val): + cls._OP_RMOVE_LATENCY_INC = val + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 9 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # bundle + int(tokens[1]), # pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "rshuffle". + """ + return "rshuffle" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + + Raises: + ValueError: If the 'other' list does not contain at least two parameters. + """ + if len(other) < 2: + raise ValueError('`other`: requires two parameters after sources.') + super().__init__(bundle, pisa_instr, dsts, srcs, throughput + int(other[0]), latency, other, comment) + + @property + def data_type(self): + """ + Gets the data type from the 'other' parameters. + + Returns: + The data type. + """ + return self.other[1] + + @property + def wait_cycles(self): + """ + Gets the wait cycles from the 'other' parameters. + + Returns: + The wait cycles. + """ + return self.other[0] + + @property + def special_latency_max(self): + """ + Gets the special latency maximum. + + Returns: + int: The special latency maximum. + """ + return self._OP_RMOVE_LATENCY + + @property + def special_latency_increment(self): + """ + Gets the special latency increment. + + Returns: + int: The special latency increment. + """ + return self._OP_RMOVE_LATENCY_INC diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py new file mode 100644 index 00000000..1d495849 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/sub.py @@ -0,0 +1,79 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `sub` Instruction. + + This instruction performs element-wise polynomial subtraction. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_sub.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "sub". + """ + return "sub" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py new file mode 100644 index 00000000..3334a584 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twintt.py @@ -0,0 +1,79 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `twintt` Instruction. + + This instruction performs on-die generation of twiddle factors for the next stage of iNTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twintt.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "twntt". + """ + return "twntt" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py new file mode 100644 index 00000000..175db3b4 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/twntt.py @@ -0,0 +1,79 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `twntt` Instruction. + + This instruction performs on-die generation of twiddle factors for the next stage of NTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twntt.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # Bundle + int(tokens[1]), # Pisa + dst_src_map['dst'], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "twintt". + """ + return "twintt" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py new file mode 100644 index 00000000..d5a47482 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py @@ -0,0 +1,178 @@ +import re +from assembler.common.decorators import * +from assembler.instructions import tokenizeFromLine + +class XInstruction: + + # To be initialized from ASM ISA spec + _OP_NUM_DESTS : int + _OP_NUM_SOURCES : int + _OP_DEFAULT_THROUGHPUT : int + _OP_DEFAULT_LATENCY : int + + @classmethod + def SetNumTokens(cls, val): + pass + + @classmethod + def SetNumDests(cls, val): + cls._OP_NUM_DESTS = val + + @classmethod + def SetNumSources(cls, val): + cls._OP_NUM_SOURCES = val + + @classmethod + def SetDefaultThroughput(cls, val): + cls._OP_DEFAULT_THROUGHPUT = val + + @classmethod + def SetDefaultLatency(cls, val): + cls._OP_DEFAULT_LATENCY = val + + # Static methods + # -------------- + @staticmethod + def tokenizeFromASMISALine(op_name: str, line: str) -> list: + """ + Checks if the specified instruction can be parsed from the specified + line and, if so, return the tokenized line. + + Parameters: + op_name (str): Name of operation that should be contained in the line. + line (str): Line to tokenize. + + Returns: + tuple: A tuple containing tokens (tuple of strings) and comment (str). + None if instruction cannot be parsed from the line. + """ + retval = None + tokens, comment = tokenizeFromLine(line) + if len(tokens) > 2 and tokens[2] == op_name: + retval = (tokens, comment) + return retval + + @staticmethod + def parseASMISASourceDestsFromTokens(tokens: list, num_dests: int, num_sources: int, offset: int = 0) -> dict: + """ + Parses the sources and destinations for an instruction, given sources and + destinations in tokens in P-ISA format. + + Parameters: + tokens (list): List of string tokens where each token corresponds to a destination or + a source for the instruction being parsed, in order. + num_dests (int): Number of destinations for the instruction. + num_sources (int): Number of sources for the instruction. + offset (int): Offset in the list of tokens where to start parsing. + + Returns: + dict: A dictionary with, at most, two keys: "src" and "dst", representing the parsed sources + and destinations for the instruction. The value for each key is a list of parsed + registers, where a register is of the form tuple(register: int, bank: int). + + Raises: + ValueError: If an invalid register name is encountered. + """ + retval = {} + dst_start = offset + dst_end = dst_start + num_dests + dst = [] + for dst_tokens in tokens[dst_start:dst_end]: + if not re.search("r[0-9]+b[0-3]", dst_tokens): + raise ValueError(f'Invalid register name: `{dst_tokens}`.') + # Parse rXXbXX into a tuple of the form (reg, bank) + tmp = dst_tokens[1:] + reg = tuple(map(int, tmp.split('b'))) + dst.append(reg) + src_start = dst_end + src_end = src_start + num_sources + src = [] + for src_tokens in tokens[src_start:src_end]: + if not re.search("r[0-9]+b[0-3]", src_tokens): + raise ValueError(f'Invalid register name: `{src_tokens}`.') + # Parse rXXbXX into a tuple of the form (reg, bank) + tmp = src_tokens[1:] + reg = tuple(map(int, tmp.split('b'))) + src.append(reg) + if dst: + retval["dst"] = dst + if src: + retval["src"] = src + return retval + + @classproperty + def name(cls) -> str: + """ + Gets the name for the instruction. + + Returns: + str: The name of the instruction. + """ + return cls._get_name() + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name for the instruction. + + Raises: + NotImplementedError: If the method is not implemented in a derived class. + """ + raise NotImplementedError('Abstract base') + + # Constructor + # ----------- + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an XInstruction object. + + Parameters: + bundle (int): The bundle number. + pisa_instr (int): The PISA instruction number. + dsts (list): List of destination registers. + srcs (list): List of source registers. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list): Additional parameters for the instruction. + comment (str): Optional comment for the instruction. + """ + self.bundle = bundle + self.pisa_instr = pisa_instr + self.srcs = srcs + self.dsts = dsts + self.throughput = throughput + self.latency = latency + self.other = other + self.comment = comment + + def __str__(self): + """ + Gets the string representation of the XInstruction. + + Returns: + str: The string representation of the instruction. + """ + retval = "f{}, {}, {}".format(self.bundle, + self.pisa_instr, + self.name) + if self.dsts: + dsts = ['r{}b{}'.format(r, b) for r, b in self.dsts] + retval += ', {}'.format(', '.join(dsts)) + if self.srcs: + srcs = ['r{}b{}'.format(r, b) for r, b in self.srcs] + retval += ', {}'.format(', '.join(srcs)) + if self.other: + retval += ', {}'.format(', '.join(self.other)) + if self.comment: + retval += f' # {self.comment}' + + return retval \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py new file mode 100644 index 00000000..79a8d697 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py @@ -0,0 +1,83 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Represents a `xstore` Instruction. + + This instruction transfers data from a register into the intermediate data buffer for subsequent transfer into SPAD. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_xstore.md + + Methods: + fromASMISALine: Parses an ASM ISA line to create an Instruction instance. + """ + + @classmethod + def SetNumDests(cls, val): + cls._OP_NUM_DESTS = 0 + + @classmethod + def fromASMISALine(cls, line: str) -> list: + """ + Parses an ASM ISA line to create an Instruction instance. + + Args: + line (str): The ASM ISA line to parse. + + Returns: + list: A list containing the parsed Instruction instance. + + Raises: + ValueError: If the line cannot be parsed into the expected format. + """ + retval = None + tokens = XInstruction.tokenizeFromASMISALine(cls.name, line) + if tokens: + tokens, comment = tokens + if len(tokens) < 3 or tokens[2] != cls.name: + raise ValueError('`line`: could not parse f{cls.name} from specified line.') + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) + retval = cls(int(tokens[0][1:]), # bundle + int(tokens[1]), # pisa + [], + dst_src_map['src'], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], + comment) + return retval + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "xstore". + """ + return "xstore" + + def __init__(self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = ""): + """ + Initializes an Instruction instance. + + Args: + bundle (int): The bundle identifier. + pisa_instr (int): The PISA instruction identifier. + dsts (list): The list of destination operands. + srcs (list): The list of source operands. + throughput (int): The throughput of the instruction. + latency (int): The latency of the instruction. + other (list, optional): Additional parameters. Defaults to an empty list. + comment (str, optional): A comment associated with the instruction. Defaults to an empty string. + """ + super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py new file mode 100644 index 00000000..73ce674c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py @@ -0,0 +1,371 @@ +import argparse +import os + +import xinst +from spec_config import XTC_SpecConfig + +# Checks timing for register access. +# - Checks if a register is being read from before its write completes. +# - Checks if rshuffles are within correct timing of each other. +# - Checks for bank write conflicts between rshuffles and other instructions. + +NUM_BUNDLE_INSTRUCTIONS = 64 + +def makeUniquePath(path: str): + """ + Normalizes and expand a given file path. + + Parameters: + path (str): The file path to normalize and expand. + + Returns: + str: The normalized and expanded file path. + """ + return os.path.normcase(os.path.realpath(os.path.expanduser(path))) + +def computeXBundleLatency(xinstr_bundle: list) -> int: + """ + Computes the latency of a bundle of XInstructions. + + Parameters: + xinstr_bundle (list): A list of XInstructions in a bundle. + + Returns: + int: The computed latency of the bundle. + + Raises: + RuntimeError: If the bundle size is invalid. + """ + if len(xinstr_bundle) != NUM_BUNDLE_INSTRUCTIONS: + raise RuntimeError('Invalid bundle size for bundle. Expected {} instructions, but {} found.'.format(bundle_id, + NUM_BUNDLE_INSTRUCTIONS, + len(xinstrs[idx:]))) + current_bundle_cycle_count = 0 # Tracks number of cycles since last sync point (bundle start is the first sync point) + current_bundle_latency = 0 + for xinstr in xinstr_bundle: + if isinstance(xinstr, xinst.XStore): + # Reset the latency because `xstore`s are synchronization points + current_bundle_cycle_count = 0 + current_bundle_latency = 0 + else: + assert xinstr.throughput <= xinstr.latency + # Check if latency of new instruction at current cycle is greater than previous bundle latency + if current_bundle_latency < current_bundle_cycle_count + xinstr.latency: + current_bundle_latency = current_bundle_cycle_count + xinstr.latency + # Advance cycle by instruction throughput + current_bundle_cycle_count += xinstr.throughput + + if isinstance(xinstr, xinst.Exit): + break # Stop on exit + + return current_bundle_latency + +def computeXBundleLatencies(xinstrs: list) -> list: + """ + Computes latencies for all bundles of XInstructions. + + Parameters: + xinstrs (list): A list of XInstructions. + + Returns: + list: A list of latencies for each bundle. + """ + print('WARNING: Check latency for `exit` XInstruction.') + print('Computing x bundle latencies') + retval = [] + total_xinstr = len(xinstrs) + bundle_id = 0 + while xinstrs: + if bundle_id % 1000 == 0: + print("{}% - {}/{}".format((total_xinstr - len(xinstrs)) * 100 // total_xinstr, (total_xinstr - len(xinstrs)), total_xinstr)) + bundle = xinstrs[:NUM_BUNDLE_INSTRUCTIONS] + xinstrs = xinstrs[NUM_BUNDLE_INSTRUCTIONS:] + assert bundle[0].bundle == bundle_id and bundle[-1].bundle == bundle_id + retval.append(computeXBundleLatency(bundle)) + bundle_id += 1 + + print("100% - {0}/{0}".format(total_xinstr)) + + return retval + +def computeCBundleLatencies(cinstr_lines) -> list: + """ + Computes latencies for all bundles of CInstructions. + + Parameters: + cinstr_lines: An iterable of CInstruction lines. + + Returns: + list: A list of latencies for each bundle. + """ + print('Computing c bundle latencies') + retval = [] + bundle_id = 0 + bundle_latency = 0 + for idx, c_line in enumerate(cinstr_lines): + if idx % 500 == 0: + print(idx) + + if c_line.strip(): + # remove comment and tokenize + s_split = [s.strip() for s in c_line.split("#")[0].split(',')] + if bundle_id < 0 and ('cnop' not in s_split[1]): + raise RuntimeError('Invalid CInstruction detected after end of CInstQ') + if 'ifetch' == s_split[1]: + # New bundle + assert int(s_split[2]) == bundle_id, f'ifetch, {s_split[2]} | expected {bundle_id}' + retval.append(bundle_latency) + bundle_id += 1 + bundle_latency = 0 + elif 'exit' in s_split[1]: + # CInstQ terminate + retval.append(bundle_latency) + bundle_id = -1 # Will assert if more instructions after exit + elif 'cstore' == s_split[1]: + # Reset latency + bundle_latency = 0 + else: + instruction_throughput = 1 + if 'nop' in s_split[1]: + instruction_throughput = int(s_split[2]) + elif 'cload' in s_split[1]: + instruction_throughput = 4 + elif 'nload' in s_split[1]: + instruction_throughput = 4 + bundle_latency += instruction_throughput + return retval[1:] + +def main(input_dir: str, input_prefix: str = None): + """ + Main function to check timing for register access and synchronization. + + Parameters: + input_dir (str): Directory containing input files. + input_prefix (str): Prefix for input files. + """ + print("Starting") + + input_dir = makeUniquePath(input_dir) + if not input_prefix: + input_prefix = os.path.basename(input_dir) + + print('Input dir:', input_dir) + print('Input prefix:', input_prefix) + + xinst_file = os.path.join(input_dir, input_prefix + ".xinst") + cinst_file = os.path.join(input_dir, input_prefix + ".cinst") + + xinstrs = [] + with open(xinst_file, 'r') as f_in: + for idx, line in enumerate(f_in): + if idx % 50000 == 0: + print(idx) + if line.strip(): + # Remove comment + s_split = line.split("#")[0].split(',') + # Parse the line into an instruction + instr_name = s_split[2].strip() + b_parsed = False + for xinstr_type in xinst.ASMISA_INSTRUCTIONS: + if xinstr_type.name == instr_name: + xinstr = xinstr_type.fromASMISALine(line) + xinstrs.append(xinstr) + b_parsed = True + break + if not b_parsed: + raise ValueError(f'Could not parse line f{idx + 1}: {line}') + + # Check synchronization between C and X queues + print("--------------") + print("Checking synchronization between C and X queues...") + xbundle_cycles = computeXBundleLatencies(xinstrs) + with open(cinst_file, 'r') as f_in: + cbundle_cycles = computeCBundleLatencies(f_in) + + if len(xbundle_cycles) != len(cbundle_cycles): + raise RuntimeError('Mismatched bundles: {} xbundles vs. {} cbundles'.format(len(xbundle_cycles), + len(cbundle_cycles))) + print("Comparing latencies...") + bundle_cycles_violation_list = [] + for idx in range(len(xbundle_cycles)): + if xbundle_cycles[idx] > cbundle_cycles[idx]: + bundle_cycles_violation_list.append('Bundle {} | X {} cycles; C {} cycles'.format(idx, + xbundle_cycles[idx], + cbundle_cycles[idx])) + + # Check timings for register access + print("--------------") + print("Checking timings for register access...") + violation_lst = [] # list(tuple(xinstr_idx, violating_idx, register: str, cycle_counter)) + for idx, xinstr in enumerate(xinstrs): + if idx % 50000 == 0: + print("{}% - {}/{}".format(idx * 100 // len(xinstrs), idx, len(xinstrs))) + + # Check bank conflict + + banks = set() + for r, b in xinstr.srcs: + if b in banks: + violation_lst.append((idx + 1, f"Bank conflict source {b}", xinstr.name)) + break + banks.add(b) + + banks = set() + for r, b in xinstr.dsts: + if b in banks: + violation_lst.append((idx + 1, f"Bank conflict dests {b}", xinstr.name)) + break + banks.add(b) + + if xinstr.name == 'move': + # Make sure move is only moving from bank zero + src_bank = xinstr.srcs[0][1] + dst_bank = xinstr.dsts[0][1] + if src_bank != 0: + violation_lst.append((idx + 1, f"Move bank error sources {src_bank}", xinstr.name)) + if dst_bank == src_bank: + violation_lst.append((idx + 1, f"Move bank error dests {dst_bank}", xinstr.name)) + + # Check timing + + cycle_counter = xinstr.throughput + for jdx in range(idx + 1, len(xinstrs)): + if cycle_counter >= xinstr.latency: + break # Instruction outputs are ready + next_xinstr = xinstrs[jdx] + if next_xinstr.bundle != xinstr.bundle: + assert(next_xinstr.bundle == xinstr.bundle + 1) + break # Different bundle + + # Check + all_next_regs = set(next_xinstr.srcs + next_xinstr.dsts) + for reg in xinstr.dsts: + if reg in all_next_regs: + # Register is not ready and still used by an instruction + violation_lst.append((idx + 1, jdx + 1, f"r{reg[0]}b{reg[1]}", cycle_counter)) + + cycle_counter += next_xinstr.throughput + + print("100% - {}/{}".format(idx, len(xinstrs))) + + # Check rshuffle separation + print("--------------") + print("Checking rshuffle separation...") + rshuffle_violation_lst = [] # list(tuple(xinstr_idx, violating_idx, data_types: str, cycle_counter)) + print("WARNING: No distinction between `rshuffle` and `irshuffle`.") + for idx, xinstr in enumerate(xinstrs): + if idx % 50000 == 0: + print("{}% - {}/{}".format(idx * 100 // len(xinstrs), idx, len(xinstrs))) + + if isinstance(xinstr, xinst.rShuffle): + cycle_counter = xinstr.throughput + for jdx in range(idx + 1, len(xinstrs)): + if cycle_counter >= xinstr.latency: + break # Instruction outputs are ready + next_xinstr = xinstrs[jdx] + if next_xinstr.bundle != xinstr.bundle: + assert(next_xinstr.bundle == xinstr.bundle + 1) + break # Different bundle + + # Check + if isinstance(next_xinstr, xinst.rShuffle): + if next_xinstr.data_type != xinstr.data_type: + # Mixing ntt and intt rshuffle inside the latency of first rshuffle + rshuffle_violation_lst.append((idx + 1, jdx + 1, f"{xinstr.data_type} != {next_xinstr.data_type}", cycle_counter)) + elif cycle_counter < xinstr.special_latency_max \ + and cycle_counter % xinstr.special_latency_increment != 0: + # Same data type + rshuffle_violation_lst.append((idx + 1, jdx + 1, f"{xinstr.data_type} == {next_xinstr.data_type}", cycle_counter)) + + cycle_counter += next_xinstr.throughput + + print("100% - {}/{}".format(idx, len(xinstrs))) + + # Check bank conflicts with rshuffle + print("--------------") + print("Checking bank conflicts with rshuffle...") + rshuffle_bank_violation_lst = [] # list(tuple(xinstr_idx, violating_idx, banks: str, cycle_counter)) + for idx, xinstr in enumerate(xinstrs): + if idx % 50000 == 0: + print("{}% - {}/{}".format(idx * 100 // len(xinstrs), idx, len(xinstrs))) + + if isinstance(xinstr, xinst.rShuffle): + # No instruction should write to same bank at the write phase of rshuffle + rshuffle_write_cycle = xinstr.latency - 1 + rshuffle_banks = set(bank for _, bank in xinstr.dsts) + cycle_counter = xinstr.throughput + for jdx in range(idx + 1, len(xinstrs)): + if cycle_counter >= xinstr.latency: + break # Instruction outputs are ready + next_xinstr = xinstrs[jdx] + if next_xinstr.bundle != xinstr.bundle: + assert(next_xinstr.bundle == xinstr.bundle + 1) + break # Different bundle + # Check + if cycle_counter + next_xinstr.latency - 1 == rshuffle_write_cycle: + # Instruction writes in same cycle as rshuffle + # Check for bank conflicts + next_xinstr_banks = set(bank for _, bank in next_xinstr.dsts) + if rshuffle_banks & next_xinstr_banks: + rshuffle_bank_violation_lst.append((idx + 1, jdx + 1, "{} | banks: {}".format(next_xinstr.name, rshuffle_banks & next_xinstr_banks), cycle_counter)) + + cycle_counter += next_xinstr.throughput + + print("100% - {}/{}".format(idx, len(xinstrs))) + + s_error_msgs = [] + + if bundle_cycles_violation_list: + # Log violation list + print() + for x in bundle_cycles_violation_list: + print(x) + s_error_msgs.append('Bundle cycle violations detected.') + + if violation_lst: + # Log violation list + print() + for x in violation_lst: + print(x) + s_error_msgs.append('Register access violations detected.') + + if rshuffle_violation_lst: + # Log violation list + print() + for x in rshuffle_violation_lst: + print(x) + s_error_msgs.append('rShuffle special latency violations detected.') + + if rshuffle_bank_violation_lst: + # Log violation list + print() + for x in rshuffle_bank_violation_lst: + print(x) + s_error_msgs.append('rShuffle bank access violations detected.') + + if s_error_msgs: + raise RuntimeError('\n'.join(s_error_msgs)) + + print() + print('No timing errors found.') + +if __name__ == "__main__": + module_dir = os.path.dirname(__file__) + module_name = os.path.basename(__file__) + print(module_name) + + parser = argparse.ArgumentParser() + parser.add_argument("input_dir") + parser.add_argument("input_prefix", nargs="?") + parser.add_argument("--isa_spec", default="", dest="isa_spec_file", + help=("Input ISA specification (.json) file.")) + args = parser.parse_args() + + args.isa_spec_file = XTC_SpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + print(f"ISA Spec: {args.isa_spec_file}") + + print() + main(args.input_dir, args.input_prefix) + + print() + print(module_name, "- Complete") \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/docsrc/changelog.md b/assembler_tools/hec-assembler-tools/docsrc/changelog.md new file mode 100644 index 00000000..a9674c15 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/changelog.md @@ -0,0 +1,28 @@ +# Changelog + +### 2024-07-11 +- Updated SPAD capacity to reflect change from 64MB to 48MB. +- Updated the range of values for `cnop` parameters. +- Updated `rshuffle` to reflect slotting rules for 4KCE. + +### 2024-01-23 +- Updated `rshuffle` to reflect latency changes and rules for 4KCE. +- Updated `nload` to reflect the lack of support for multiple routing tables. +- Updated HBM capacity to reflect change from 64GB to 48GB. +- No keygen updates yet because the feature is work in progress. + +### 2023-07-25 +- Updated throughput for CInsts `cload`, and `nload`. +- Updated throughput and latency for `cstore`, and `bload`. +- Updated latency of `xstore`. +- Updated `rshuffle` to reflect the discontinuation of `wait_cyc` parameter. + +### 2023-07-24 +- Added XInstruction `sub`, required by CKKS scheme P-ISA kernels. + +### 2023-06-30 +- Updated latencies of XInstruction `rshuffle` based on Sim0.9 version. + +### 2023-06-12 +- XInstruction `exit` op name is now `bexit` to match the ISA spec, as required by Sim0.9. +- CInstructions `bload` and `bones` format changed to match philosophy of dests before sources. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bload.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bload.md new file mode 100644 index 00000000..c4012762 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bload.md @@ -0,0 +1,46 @@ +# BLOAD - CInst {#bload_CInst} + +Load metadata from scratchpad to register file. + +## Definition + +Loads metadata from scratchpad to special registers in register file. + +| instr | operand 0 | operand 1 | operand 2 | +|-|-|-|-| +| bload | meta_target_idx | spad_src | src_col_num | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 1 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `meta_target_idx` | int32 | Metadata register index in the range `[0, 32)`. | +| `spad_src` | spad_addr | SPAD address of metadata variable to load. | +| `src_col_num` | int32 | Block number inside metadata source variable in the range `[0, 4)` (see notes). | + +### Notes + +**Uses SPAD-CE data path**: Yes + +Some operations require metadata to indicate parameters and modes of operation that don't change throughout the execution of a program. Metadata is usually loaded once at the start of the program into special registers in the CE. + +The main use of this type of metadata is to generate twiddle factors. + +The destination registers for this instruction are special metadata registers of size 1/4 word. Each is indexed by parameter `meta_target_idx`. + +The source metadata to load is a word in SPAD addressed by `spad_src`. It needs to be partitioned into 4 blocks of size 1/4 word each to fit into the target registers. Since the smallest addressable unit is the word, `bload` features the parameter `src_col_num` to address the block inside the word as shown in the diagram below. + +``` +word [--------------------------] +block [--0--][--1--][--2--][--3--] +``` + +Metadata sources must be loaded into the destination registers in the order they appear. If there are `N` metadata variables to load, metadata target index, `meta_target_idx`, monotonically increases with every `bload` from `0` to `4 * N - 1 < 32`. + +There are only 32 destination registers in the CE, supporting up to 8 words of metadata. If there are more words of metadata, then metadata needs to be swapped as needed, requiring tracking of which metadata is loaded for twiddle factor generation. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bones.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bones.md new file mode 100644 index 00000000..4e6b77ce --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_bones.md @@ -0,0 +1,32 @@ +# BONES - CInst {#bones_CInst} + +Load metadata of identity (one) from scratchpad to register file. + +## Definition + +Loads metadata of identity representation for cryptographic operations from scratchpad to special registers in register file. + +| instr | operand 0 | operand 1 | +|-|-|-| +| bones | spad_src | col_num | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 5 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `spad_src` | spad_addr | SPAD address of identity metadata to load. | +| `col_num` | int32 | Block to load from source word. Must be `0` | + +### Notes + +**Uses SPAD-CE data path**: Yes + +The metadata block for identity is always `0`. + +Some operations require metadata to indicate parameters and modes of operation that don't change throughout the execution of the program. Metadata is loaded once at the start of the program into special registers in the CE. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cexit.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cexit.md new file mode 100644 index 00000000..7b4210a5 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cexit.md @@ -0,0 +1,25 @@ +# CEXIT - CInst {#exit_CInst} + +Exits a program. + +## Definition + +Terminates execution of a HERACLES program. + +| instr | +|-| +| cexit | + +## Timing + +**Throughput**: `1` clock cycle + +**Latency**: `1` clock cycle + +## Details + +This instruction has no operands. + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cload.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cload.md new file mode 100644 index 00000000..53b3142b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cload.md @@ -0,0 +1,39 @@ +# CLOAD - CInst {#cload_CInst} + +Load a single polynomial residue from scratchpad into a register. + +## Definition + +Load a word, corresponding to a single polynomial residue, from scratchpad memory into the register file memory. + +| instr | operand 0 | operand 1 | +|-|-|-| +| cload | dst | src | + +## Timing + +**Throughput**: 4 clock cycles + +**Latency**: 4 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register where to load the word. | +| `src` | spad_addr | SPAD address from where to load. | + +### Notes + +**Uses SPAD-CE data path**: Yes + +The register is ready to be used on the clock cycle after `cload` completes. Using the register during `cload` is undefined. + +Instruction `cload` writes to CE registers and thus, it can conflict with another XInst if both are writing to the same bank and get scheduled such that their write phases happen on the same cycle. Because the rule that there cannot be more than one write to the same bank in the same cycle this conflict must be avoided. + +Two ways to mitigate the above conflict are: + +- Carefully track and sync `cload` instructions with all XInsts. +- Use a convention such that `cload` instructions always write to one bank while all other XInsts cannot write their outputs to that bank. + +We have adopted the second option to implement the assembler. First option is cumbersome and prone to errors: By convention, `cload` should always load into **bank 0**. No XInst can write outputs to **bank 0**. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cstore.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cstore.md new file mode 100644 index 00000000..bd1f1953 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_cstore.md @@ -0,0 +1,37 @@ +# CSTORE - CInst {#cstore_CInst} + +Fetch a single polynomial residue from the intermediate data buffer and store back to SPAD. + +## Definition + +Pops the top word from the intermediate data buffer queue and stores it in SPAD. + +| instr | operand 0 | +|-|-| +| cstore | dst | + +## Timing + +**Throughput**: 1 clock cycles* + +**Latency**: 1 clock cycles* + +Variable timing because `cstore` is a blocking instruction. See notes. + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | spad_addr | Destination SPAD address where to store the word. | + +### Notes + +**Uses SPAD-CE data path**: No (see XInst [`xstore`](../xinst/xinst_xstore.md) ). + +This instruction will pop the word at the top of the intermediate buffer queue where Xinst `xstore` pushes data to store from CE registers. + +WARNING: If the intermediate buffer is empty, `cstore` blocks the CINST queue execution until there is data ready to pop. Produced code must ensure that there is data in the intermediate buffer, or that there will be a matching `xstore` in the bundle being executed to avoid a deadlock. + +## See Also + +- CInst [`xstore`](../xinst/xinst_xstore.md) diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_csyncm.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_csyncm.md new file mode 100644 index 00000000..e7b67fe0 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_csyncm.md @@ -0,0 +1,31 @@ +# CSYNCM - CInst {#csyncm_CInst} + +CINST execution waits for a particular Mfetch instruction to complete. + +## Definition + +Wait instruction similar to a barrier that stalls the execution of CINST queue until the specified instruction from MINST queue has completed. + +| instr | operand 0 | +|-|-| +| csyncm | inst_num | + +## Timing + +**Throughput**: varies + +**Latency**: Same as throughput. + +## Details + +| Operand | Type | Description | +|-|-|-| +| `inst_num` | int32 | Instruction number from the MINST queue for which to wait. | + +### Notes + +**Uses SPAD-CE data path**: No + +CINST execution resumes with the following instruction in the CINST queue, on the clock cycle after the specified MInst completed. + +Typically used to wait for a value to be loaded from HBM into SPAD. Since load times from HBM vary, assembler cannot assume that all `mload`s from MINST queue will complete in order, thus, every `mload` should have a matching `csyncm`. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_ifetch.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_ifetch.md new file mode 100644 index 00000000..a61dc28f --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_ifetch.md @@ -0,0 +1,43 @@ +# IFETCH - CInst {#ifetch_CInst} + +Fetch a bundle of instructions for execution. + +## Definition + +Fetch a bundle of instructions from the XINST queue and send it to the CE for execution. + +| instr | operand 0 | +|-|-| +| ifetch | bundle_idx | + +## Timing + +**Throughput**: 1 clock cycles `*` + +`*` This instruction has the ability to block execution of CINST queue. See details. + +**Latency**: 5 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `bundle_idx` | int32 | Index for the bundle of instructions to fetch. The bundle must exist in the current XINST queue. | + +### Notes + +**Uses SPAD-CE data path**: Yes + +A bundle is a collection of a pre-defined number contiguous instructions. The instructions in the bundle must have a minimum throughput of 64 clock cycles (except when there is an `exit` instruction in the bundle). The bundle index to which an instruction belongs is clearly indicated in the encoding of the instruction output by the assembler. + +XINST queue contains the instructions for the CE. This queue contains more than one bundle at a time. Instruction `ifetch` schedules the next bundle (from those bundles in XINST) to execute into the CE. + +It takes `ifetch` 2 cycles to start, and 4 more cycles before the CE is completely loaded with the entry point to the new bundle. + +The loaded bundle starts execution in the clock cycle after `ifetch` completed. + +Calling another `ifetch` while a bundle is executing causes undefined behavior. Thus, a new bundle should be fetched after the last bundle's latency elapses. + +Note that this instruction uses the SPAD-CE data path, so, code produced must ensure that an `ifetch` is not executed when a XInst `xstore` is in flight from the current bundle in-flight. This can be mitigated by a matching `cstore` before `ifetch`. + +`*` XINST queue content is filled out by instruction [`xinstfetch`](cinst_xinstfetch.md) . Instruction `xinstfetch` has a variable latency, and there may be occasions when an `ifetch` is encountered while the referenced bundle is part of code still being loaded by `xinstfetch`. If this is the case, `ifetch` will block until the bundle is ready. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nload.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nload.md new file mode 100644 index 00000000..6a486369 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nload.md @@ -0,0 +1,32 @@ +# NLOAD - CInst {#nload_CInst} + +Loads NTT/iNTT routing mapping data. + +## Definition + +Loads metadata (for NTT/iNTT routing mapping) from scratchpad into a special routing table register. + +| instr | operand 0 | operand 1 | +|-|-|-| +| nload | table_idx_dst | spad_src | + +## Timing + +**Throughput**: 4 clock cycles + +**Latency**: 4 clock cycless + +## Details + +| Operand | Type | Description | +|-|-|-| +| `table_idx_dst` | int32 | Destination routing table. Must be in range `[0, 6)` as there are 6 possible routing tables. | +| `spad_src` | spad_addr | SPAD address of metadata variable to load. | + +### Notes + +**Uses SPAD-CE data path**: Yes + +This instruction loads metadata indicating how will [`rshuffle`](../xinst/xinst_rshuffle.md) instruction shuffle tile pair registers for NTT outputs and iNTT inputs. Shuffling for one of these instructions requires two tables: routing table and auxiliary table. Therefore, the 6 special table registers only support 3 different shuffling operations (NTT, iNTT, and MISC). + +In the current HERACLES implementation, only routing tables `0` and `1` are functional, thus, assembler is able to perform only shuffling instructions for one of NTT or iNTT per bundle, requiring routing table changes whenever the other is needed. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md new file mode 100644 index 00000000..197720ed --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_nop.md @@ -0,0 +1,35 @@ +# NOP - CInst {#nop_CInst} + +Adds desired amount of idle cycles in the Cfetch flow. + +## Definition + +Introduces idle cycles in the CINST execution flow. + +| instr | operand 0 | +|-|-| +| cnop | cycles | + +## Timing + +**Throughput**: `1 + cycles` clock cycles + +**Latency**: `1 + cycles` clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `cycles` | int, 10bits | Number of idle cycles to introduce (see notes). | + +### Notes + +**Uses SPAD-CE data path**: No + +Note that this instruction will cause the compute flow to stall for `1 + cycles` since it takes 1 clock cycle to dispatch the instruction. Therefore, to introduce a single idle cycle, the correct instruction is: + +``` +cnop, 0 +``` + +Parameter `cycles` is encoded into a 10 bits field, and thus, its value must be less than 1024. If more thatn 1024 idle cycles is required, multiple `cnop` instructions must be scheduled back to back. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_xinstfetch.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_xinstfetch.md new file mode 100644 index 00000000..e2ebc13f --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/cinst/cinst_xinstfetch.md @@ -0,0 +1,38 @@ +# XINTFETCH - CInst {#xinstfetch_CInst} + +Fetches instructions from the HBM and sends it to the XINST queue. + +## Definition + +Fetches 1 word (32KB) worth of instructions from the HBM XInst region and sends it to the XINST queue. + +| instr | operand 0 | operand 1 | +|-|-|-| +| xinstfetch | xq_dst | hbm_src | + +## Timing + +**Throughput**: 1 clock cycles + +**Latency**: Varies + +## Details + +| Operand | Type | Description | +|-|-|-| +| `xq_dst` | int32 | Dest in XINST queue. | +| `hbm_src` | hbm_addr | Address where to read instructions from HBM XInst region. | + +### Notes + +**Uses SPAD-CE data path**: No (See CInst [`ifetch`](cinst_ifetch.md)) + +This instruction is special because it moves XInst data from HBM into XINST queue bypassing SPAD because SPAD is cache only for actual data (not instructions). + +Parameter `hbm_src` refers to an HBM address from the XInst region, not the HBM data region. + +The destination `xq_dst` indexes the XINST queue word-wise; this is, the 1MB XINST queue capacity is equivalent to 32 words capacity. Thus this parameter is between `0` and `31` and indicates where to load the XInst inside XINST queue. + +Instruction `xinstfetch` only loads the XInst to execute, but does not initiate their execution. This is initiated by instruction [`ifetch`](cinst_ifetch.md). + +The latency when accessing HBM varies, therefore, there is no specific latency for this instruction. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mload.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mload.md new file mode 100644 index 00000000..12d1ac9d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mload.md @@ -0,0 +1,28 @@ +# MLOAD - MInst {#mload_MInst} + +Load a single polynomial residue from local memory to scratchpad. + +## Definition + +Load a word, corresponding to a single polynomial residue, from HBM data region into the SPAD memory. + +| instr | operand 0 | operand 1 | +|-|-|-| +| mload | dst | src | + +## Timing + +**Throughput**: 1 clock cycles + +**Latency**: varies + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | spad_addr | Destination SPAD address where to load the word. | +| `src` | hbm_addr | HBM data region address from where to load. | + +### Notes + +Latency for read/write times involving HBM vary. CINST queue can use `csyncm` instruction to synchronize with the MINST queue. On the other hand, MINST can use `msyncc` to synchronize with CINST queue. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mstore.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mstore.md new file mode 100644 index 00000000..d152d3bd --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_mstore.md @@ -0,0 +1,28 @@ +# MSTORE - MInst {#mstore_MInst} + +Store a single polynomial residue from scratchpad to local memory. + +## Definition + +Store a word, corresponding to a single polynomial residue, from SPAD memory into HBM data region. + +| instr | operand 0 | operand 1 | +|-|-|-| +| mstore | dst | src | + +## Timing + +**Throughput**: 1 clock cycles + +**Latency**: varies + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | hbm_addr | Destination HBM data region address where to store the word. | +| `src` | spad_addr | SPAD address of the word to store. | + +### Notes + +Latency for read/write times involving HBM vary. CINST queue can use `csyncm` instruction to synchronize with the MINST queue. On the other hand, MINST can use `msyncc` to synchronize with CINST queue. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_msyncc.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_msyncc.md new file mode 100644 index 00000000..50393dff --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/minst/minst_msyncc.md @@ -0,0 +1,29 @@ +# MSYNCC - MInst {#msyncc_MInst} + +MINST execution waits for a particular instruction in CINST queue to complete. + +## Definition + +Wait instruction similar to a barrier that stalls the execution of MINST queue until the specified instruction from CINST queue has completed. + +| instr | operand 0 | +|-|-| +| msyncc | inst_num | + +## Timing + +**Throughput**: varies + +**Latency**: Same as throughput. + +## Details + +| Operand | Type | Description | +|-|-|-| +| `inst_num` | int32 | Instruction number from the QINST queue for which to wait. | + +### Notes + +MINST execution resumes with the following instruction in the MINST queue, on the clock cycle after the specified CInst completed. + +Typically used to wait for a value cached in SPAD to be updated before storing it into HMB. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_add.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_add.md new file mode 100644 index 00000000..e3ad1b73 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_add.md @@ -0,0 +1,30 @@ +# ADD - XInst {#add_XInst} + +Element-wise polynomial addition. + +## Definition + +Add two polynomials stored in the register file and store the result in a register. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | +|-|-|-|-|-| +| add | dst | src0 | src1 | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register for result of `src0 + src1`. | +| `src0` | register | Source register for first operand. | +| `src1` | register | Source register for second operand. Must be different than `src0`. | +| `res` | int32 | Residue to use for modular reduction | + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_exit.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_exit.md new file mode 100644 index 00000000..4a070c49 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_exit.md @@ -0,0 +1,27 @@ +# EXIT - XInst {#exit_XInst} + +Terminates execution of an instruction bundle. + +## Definition + +Terminates execution of an instruction bundle at the point it is encountered. CE will stop execution until a new bundle is enqueued. + +| instr | +|-| +| bexit | + +## Timing + +**Throughput**: `1` clock cycle + +**Latency**: `1` clock cycle + +## Details + +This instruction has no operands. + +### Notes + +**Uses SPAD-CE data path**: No + +As bundles require a minimum of 64 cycles to execute, `exit` is used as a bundle exit instruction to maximize the efficiency of short (incomplete) bundles. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md new file mode 100644 index 00000000..b9786e42 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_intt.md @@ -0,0 +1,37 @@ +# iNTT - XInst {#intt_XInst} + +Inverse Number Theoretic Transform. Convert NTT form to positional form. + +## Definition + +Performs one-stage of inverse NTT. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | operand 4 | operand 5 | operand 6 | +|-|-|-|-|-|-|-|-| +| intt | dst_top | dest_bot | src_top | src_bot | src_tw | stage | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst_top` | register | Destination register for top part of the iNTT result. | +| `dst_bot` | register | Destination register for bottom part of the iNTT result. Must be different than `dst_top`. | +| `src_top` | register | Source register for top part of the input NTT. | +| `src_bot` | register | Source register for bottom part of the input NTT. Must be different than `src_top`. | +| `src_tw` | register | Source register for twiddle factors. Must be different than `src_top` and `src_bot`. | +| `stage` | int32 | Stage number of the current iNTT instruction. | +| `res` | int32 | Residue to use for modular reduction. | + +### Notes + +**Uses SPAD-CE data path**: No + +Both NTT and inverse NTT instructions are defined as one-stage of the transformation. A complete NTT/iNTT transformation is composed of LOG_N such one-stage instructions. + +This instruction matches to HERACLES ISA `intt`. It requires a preceeding, matching [`rmove`](xinst_rmove.md) to shuffle the input bits into correct tile-pairs. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mac.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mac.md new file mode 100644 index 00000000..aecdaa33 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mac.md @@ -0,0 +1,31 @@ +# MAC - XInst {#mac_XInst} + +Element-wise polynomial multiplication and accumulation. + +## Definition + +Element-wise multiplication of two polynomials added to a third polynomial, all stored in the register file, and store the result in a register. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | operand 4 | +|-|-|-|-|-|-| +| mac | dst | src0 | src1 | src2 | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register for result of `src1[i] * src2[i] + src0[i]`. | +| `src0` | register | Source register for first operand. Must be the same as `dst`. | +| `src1` | register | Source register for second operand. Must be different than `src0`. | +| `src2` | register | Source register for third operand. Must be different than `src0` and `src1`. | +| `res` | int32 | Residue to use for modular reduction | + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_maci.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_maci.md new file mode 100644 index 00000000..fbdad0a7 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_maci.md @@ -0,0 +1,31 @@ +# MACI - XInst {#maci_XInst} + +Element-wise polynomial scaling by an immediate value and accumulation. + +## Definition + +Scale a polynomial in the register file by an immediate added to a third polynomial stored in a register, and store the result in a register. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | operand 4 | +|-|-|-|-|-|-| +| maci | dst | src0 | src1 | imm | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register for result of `src1[i] * imm[i] + src0[i]`. | +| `src0` | register | Source register for first operand. Must be same as `dst`. | +| `src1` | register | Source register for second operand. Must be different than `src0`. | +| `imm` | string | Named immediate value. | +| `res` | int32 | Residue to use for modular reduction | + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_move.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_move.md new file mode 100644 index 00000000..a508e5e0 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_move.md @@ -0,0 +1,28 @@ +# MOVE - XInst {#move_XInst} + +Copies data from one register to a different one. + +## Definition + +Copies data from a source register into a different destination register. + +| instr | operand 0 | operand 1 | +|-|-|-| +| move | dst | src | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register. | +| `src` | register | Source register to copy. Must be different than `dst`, but can be in the same bank. | + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mul.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mul.md new file mode 100644 index 00000000..859dc7b2 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_mul.md @@ -0,0 +1,30 @@ +# MUL - XInst {#mul_XInst} + +Element-wise polynomial multiplication. + +## Definition + +Element-wise multiplication of two polynomials stored in the register file and store the result in a register. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | +|-|-|-|-|-| +| mul | dst | src0 | src1 | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register for result of `src0[i] * src1[i]`. | +| `src0` | register | Source register for first operand. | +| `src1` | register | Source register for second operand. Must be different than `src0`. | +| `res` | int32 | Residue to use for modular reduction | + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_muli.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_muli.md new file mode 100644 index 00000000..f348a568 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_muli.md @@ -0,0 +1,30 @@ +# MULI - XInst {#muli_XInst} + +Element-wise polynomial scaling by an immediate value. + +## Definition + +Scale a polynomial in the register file by an immediate and store the result in a register. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | +|-|-|-|-|-| +| muli | dst | src0 | imm | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register for result of `src0[i] * imm[i]`. | +| `src0` | register | Source register for first operand. | +| `imm` | string | Named immediate value. | +| `res` | int32 | Residue to use for modular reduction | + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_nop.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_nop.md new file mode 100644 index 00000000..336b51da --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_nop.md @@ -0,0 +1,33 @@ +# NOP - XInst {#nop_XInst} + +Adds desired amount of idle cycles to the compute flow. + +## Definition + +Introduces idle cycles in the compute engine. + +| instr | operand 0 | +|-|-| +| nop | cycles | + +## Timing + +**Throughput**: `1 + cycles` clock cycles + +**Latency**: `1 + cycles` clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `cycles` | int32 | Number of idle cycles to introduce (see notes). | + +### Notes + +**Uses SPAD-CE data path**: No + +Note that this instruction will cause the compute flow to stall for `1 + cycles` since it takes 1 clock cycle to dispatch the instruction. Therefore, to introduce a single idle cycle, the correct instruction is: + +``` +nop, 0 +``` diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md new file mode 100644 index 00000000..3fccca29 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_ntt.md @@ -0,0 +1,37 @@ +# NTT - XInst {#ntt_XInst} + +Number Theoretic Transform. Convert positional form to NTT form. + +## Definition + +Performs one-stage of NTT on an input positional polynomial. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | operand 4 | operand 5 | operand 6 | +|-|-|-|-|-|-|-|-| +| ntt | dst_top | dest_bot | src_top | src_bot | src_tw | stage | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst_top` | register | Destination register for top part of the NTT result. | +| `dst_bot` | register | Destination register for bottom part of the NTT result. Must be different than `dst_top`. | +| `src_top` | register | Source register for top part of the input polynomial. | +| `src_bot` | register | Source register for bottom part of the input polynomial. Must be different than `src_top`. | +| `src_tw` | register | Source register for original twiddle factors. Must be different than `src_top` and `src_bot`. | +| `stage` | int32 | Stage number of the current NTT instruction. | +| `res` | int32 | Residue to use for modular reduction. | + +### Notes + +**Uses SPAD-CE data path**: No + +Both NTT and inverse NTT instructions are defined as one-stage of the transformation. A complete NTT/iNTT transformation is composed of LOG_N such one-stage instructions. + +This instruction matches to HERACLES ISA `ntt` with `store_local` bit set. i.e. it requires a subsequent, matching [`rmove`](xinst_rmove.md) to shuffle the output bits into correct tile-pairs. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md new file mode 100644 index 00000000..a0bc2cbc --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_rshuffle.md @@ -0,0 +1,76 @@ +# RSHUFFLE - XInst {#rshuffle_XInst} + +Shuffles/routes NTT/iNTT outputs/inputs across tile pairs. + +## Definition + +Shuffles two register locations among the tile pairs, based on destinations defined by currently loaded routing table metadata. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | operand 4 | operand 5 | +|-|-|-|-|-|-|-| +| rshuffle | dst0 | dst1 | src0 | src1 | wait_cyc | data_type | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 23 clock cycles + +**Special Latency**: + +| data_type | latency | +|-|-| +| `ntt` | 5*`n` or 17 clock cycles | +| `intt` | 5*`n` or 17 clock cycles | + +### Restrictions + +Hardware resources are tied up for each `rshuffle`. Behavior when attempting to execute another `rshuffle` in the pipeline before the corresponding latency elapses is undefined. + +- Two `rshuffle` instructions with the same `data_type` cannot overlap on *special latency*. This is: `rshuffle` instructions must be scheduled at multiples of 5 clock cycle intervals (up to 15 cycles) or at 17 or greater clock cycles from each other to avoid resource contention. +- Two `rshuffle` instructions with different `data_type` cannot execute in the same bundle as there is only one routing table metadata available. + +Any other XInstruction can overlap with `rshuffle`. However, no XInstruction that writes to the same banks as an `rshuffle` can end in the same cycle of said `rshuffle`. + +Instruction `rshuffle` will overwrite the contents of its output registers on its last cycle of latency. Values written to these same registers by any other XInstruction before `rshuffle` completes, will be overwritten at this time. + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst0` | register | Destination register where to shuffle `src0`. | +| `dst1` | register | Destination register where to shuffle `src1`. Must be different than `dst0`. | +| `src0` | register | Source register to shuffle. | +| `src1` | register | Source register to shuffle. Must be different than `src0`. | +| `wait_cyc` | int32 | Not used. Set to `0`. | +| `data_type` | string | One of [`ntt`, `intt`] to indicate the type of shuffling depending on the corresponding matching operation. See notes. | + +### Notes + +**Uses SPAD-CE data path**: No + +This instruction is mostly intended to shuffle values in registers among tile-pairs (akin to bit shuffling). Due to the nature of NTT and iNTT butterflies, their respective outputs and inputs are shuffled among tile-pairs by the mathematical operations. This instruction is intended to re-shuffle the outputs from an NTT or inputs for an iNTT into the correct bit order. + +The typical usage to shuffle the tile-pairs of NTT outputs is as such: + +```csv +ntt, dst_top, dst_bot, src_top, src_bot, src_tw, stage, res # SL=1 +nop 4 # dependency latency for ntt results +rshuffle, dst_top, dst_bot, dst_top, dst_bot, 0, ntt +``` + +Notice that while `rshuffle`'s source and destination registers are the same in the example above, the actual contents will be shuffled among the tile-pairs. + +The typical usage to shuffle the tile-pairs of iNTT inputs is as such: + +```csv +rshuffle, src_top, src_bot, src_top, src_bot, 0, intt +nop 21 # dependency latency for rshuffle results +intt, dst_top, dest_bot, src_top, src_bot, src_tw, stage, res +``` + +Notice that while `rshuffle`'s source and destination registers are the same in the example above, the actual contents will be shuffled among the tile-pairs. + +Routing table metadata defining shuffling patterns is loaded by [`nload`](../cinst/cinst_nload.md). + +Parameter `data_type` is intended to select the correct routing table, however, in the current HERACLES implementation, only one routing table is availabe, and this parameter is used only for book keeping and error detection during scheduling. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_sub.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_sub.md new file mode 100644 index 00000000..265826e6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_sub.md @@ -0,0 +1,30 @@ +# SUB - XInst {#sub_XInst} + +Element-wise polynomial subtraction. + +## Definition + +Add two polynomials stored in the register file and store the result in a register. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | +|-|-|-|-|-| +| sub | dst | src0 | src1 | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst` | register | Destination register for result of `src0 - src1`. | +| `src0` | register | Source register for first operand. | +| `src1` | register | Source register for second operand. Must be different than `src0`. | +| `res` | int32 | Residue to use for modular reduction | + +### Notes + +**Uses SPAD-CE data path**: No diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twintt.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twintt.md new file mode 100644 index 00000000..12c1c95c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twintt.md @@ -0,0 +1,35 @@ +# TWiNTT - XInst {#twintt_XInst} + +Compute twiddle factors for iNTT. + +## Definition + +Performs on-die generation of twiddle factors for the next stage of iNTT. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | operand 4 | operand 5 | operand 6 | +|-|-|-|-|-|-|-|-| +| twintt | dst_tw | src_tw | tw_meta | stage | block | ring_dim | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst_tw` | register | Destination register for resulting twiddles. | +| `src_tw` | register | Source register for original twiddle values. | +| `tw_meta` | int32 | Indexing information of the twiddle metadata. | +| `stage` | int32 | Stage number of the corresponding iNTT instruction | +| `block` | int32 | Index of current 16k polynomial chunk. | +| `ring_dim` | int32 | Ring dimension. This is `PMD = 2^ring_dim`, where `PMD` is the poly-modulus degree. | +| `res` | int32 | Residue to use for modular reduction. | + +### Notes + +**Uses SPAD-CE data path**: No + +Both NTT and inverse NTT instructions are defined as one-stage of the transformation. A complete NTT/iNTT transformation is composed of LOG_N such one-stage instructions. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twntt.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twntt.md new file mode 100644 index 00000000..58740943 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_twntt.md @@ -0,0 +1,35 @@ +# TWNTT - XInst {#twntt_XInst} + +Compute twiddle factors for NTT. + +## Definition + +Performs on-die generation of twiddle factors for the next stage of NTT. + +| instr | operand 0 | operand 1 | operand 2 | operand 3 | operand 4 | operand 5 | operand 6 | +|-|-|-|-|-|-|-|-| +| twntt | dst_tw | src_tw | tw_meta | stage | block | ring_dim | res | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 6 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `dst_tw` | register | Destination register for resulting twiddles. | +| `src_tw` | register | Source register for original twiddle values. | +| `tw_meta` | int32 | Indexing information of the twiddle metadata. | +| `stage` | int32 | Stage number of the corresponding NTT instruction | +| `block` | int32 | Index of current 16k polynomial chunk. | +| `ring_dim` | int32 | Ring dimension. This is `PMD = 2^ring_dim`, where `PMD` is the poly-modulus degree. | +| `res` | int32 | Residue to use for modular reduction. | + +### Notes + +**Uses SPAD-CE data path**: No + +Both NTT and inverse NTT instructions are defined as one-stage of the transformation. A complete NTT/iNTT transformation is composed of LOG_N such one-stage instructions. diff --git a/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_xstore.md b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_xstore.md new file mode 100644 index 00000000..2872f56f --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/inst_spec/xinst/xinst_xstore.md @@ -0,0 +1,36 @@ +# XSTORE - XInst {#xstore_XInst} + +Transfers data from a register into the intermediate data buffer for subsequent transfer into SPAD. + +## Definition + +Transfers a word from a CE register into the intermediate data buffer. The intermediate data buffer features a FIFO structure, which means that the transferred data is pushed at the end of the queue. + +| instr | operand 0 | +|-|-| +| xstore | src | + +## Timing + +**Throughput**: 1 clock cycle + +**Latency**: 4 clock cycles + +## Details + +| Operand | Type | Description | +|-|-|-| +| `src` | register | Source register to store into SPAD. | + +### Notes + +**Uses SPAD-CE data path**: Yes + +This instruction pushes data blindly into the intermediate data buffer in a LIFO structure. It is the responsibility of the CINST execution queue to pop this data buffer timely, via `cstore` instruction, to avoid overflows. + +The data will be ready in the intermediate data buffer queue one clock cycle after `xstore` completes. Writing to the `src` register during `xstore` is undefined. + +## See Also + +- CInst [`cstore`](../cinst/cinst_cstore.md) +- CInst [`cload`](../cinst/cinst_cload.md) diff --git a/assembler_tools/hec-assembler-tools/docsrc/specs.md b/assembler_tools/hec-assembler-tools/docsrc/specs.md new file mode 100644 index 00000000..cd042449 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/docsrc/specs.md @@ -0,0 +1,278 @@ +# HCGF Instruction Specification {#HCGF_specs} + +Terms used in this document are defined in the HERACLES Instruction Set Architecture (ISA). + +[Changelog](changelog.md) + +## Table of Contents +1. [Introduction](#introduction) +2. [Memory Specification](#mem_spec) + 1. [Word Size](#word_spec) + 2. [High-Bandwidth Memory (HBM)](#hbm_spec) + 3. [Scratch Pad (SPAD)](#spad_spec) + 4. [Register Banks](#registers_spec) +3. [Output File Formats](#output_format) + 1. [XINST File](#output_xinst) + 2. [CINST File](#output_cinst) + 3. [MINST File](#output_minst) +4. [Instruction Set](#instr_spec) + +## Introduction + +HERACLES architecture allows fine-grained control of memory movement between DRAM, SRAM, and register files. The architecture features three execution queues that control memory movement among the different levels, as well as control of the compute engine. + +The three execution queues are: + +- XINST: contains the Compute Engine (CE) instructions (XInst) to be loaded into the instruction queue of tiles to carry out computations. + + **Queue Capacity**: 1MB. + + **Instruction size**: 64 bits. + +- CINST: contains the instructions (CInst) to coordinate memory movement between the register banks in the CE and the SRAM cache (also known as scratch pad, or SPAD). + + **Queue Capacity**: 128KB. + + **Instruction size**: 64 bits. + +- MINST: contains the instructions (MInst) to coordinate memory movement between the SPAD and DRAM (also known as high-bandwidth memory, or HBM). + + **Queue Capacity**: 128KB. + + **Instruction size**: 64 bits. + +The three queues must work in concert to ensure memory consistency and optimized, functional correctness of the execution. + +The output of the assembler tool is composed of three files containing the instructions for each of the three execution queues for the HERACLES respectively. + +## Memory Specification + +#### Summary of Sizes and Memory Capacities in HERACLES Instructions and Memory Model + +| Parameter | Size | Unit | Word size | Description | +|-|-|-|-|-| +| Word size | 32 | KB | 1 | Data unit size. | +| XINST Queue Capacity | 1 | MB | 512 | Capacity of XInst queue in Cfetch engine. | +| CINST Queue Capacity | 128 | KB | 4 | Capacity of CInst queue in Cfetch engine. | +| MINST Queue Capacity | 128 | KB | 4 | Capacity of MInst queue in Mfetch engine. | +| XINST Intruction width | 64 | bits | | Size of instructions in the Xinst queue. | +| CINST Intruction width | 64 | bits | | Size of instructions in the Cinst queue. | +| MINST Intruction width | 64 | bits | | Size of instructions in the Minst queue. | +| XINST bundle size | 64 | instructions | | Number of instructions in an XInst bundle. | +| HBM Capacity`*` | 48 | GB | 1,572,864 | Capacity of DRAM. | +| SPAD Capacity`*` | 48 | MB | 1,536 | Capacity of SRAM/cache. | +| Store Buffer capacity | 128 | KB | 4 | Capacity of the intermediate data buffer queue for `xstore`. | +| Register banks | 4 | banks | | Number of register banks in a compute tile pair. | +| Registers per bank | 72 | registers | | Number of registers in a register bank. | +| Register capacity | 32 | KB | 1 | Capacity of a combined register for all compute tile pairs. | + +`*`Capacities configurable during assembling. + +See HERACLES ISA for more information. + +### Word Size + +A "word" is the *smallest addressable data unit* in the HERACLES memory model. + +**Word size**: 32KB. + +HERACLES is a polynomial computation engine, and thus, a word contains the coefficients for a polynomial. + +**Polynomial coefficient size**: 4B or 32bits. + +A word has capacity to hold polynomials of, up to, 8192 coefficients. + +Operations on larger polynomials is possible by splitting them into equivalent mathematical operations on smaller polynomials. + +Data sizes in bytes are offered for *reference purposes only*. Since the smallest addressable unit is the word, **all HERACLES instructions are word-based**. + +Note that the information included here is at the abstraction level that concerns the assembler. The HERACLES architecture further partitions the memory into blocks, and the compute engine into 64 compute tile pairs. For more information on the low level architecture refer to the HERACLES architecture documentation. + +### High-Bandwidth Memory (HBM) + +**HBM capacity**: 48GB == 1,572,864 words. + +HBM is partitioned into four regions: + +- Data region +- XInst region +- CInst region +- MInst region + +Each memory region will contain the namesake data for the whole HERACLES program. Their size is custom defined (in words) by the host during initialization through the driver. Their total size must add up to the total capacity of the HBM. + +While all the regions live in HBM space, their logical base address is always `0`. This means that, for example, the address `p` from the data region is in a different location of HBM than the address `p` from the XInst region. The definition of instructions that access HBM will specify which region they access. + +During HERACLES initialization, the host will copy the program parts into the corresponding memory regions. Once the transfer is complete, the host signals HERACLES through the driver to start the program. + +When the program starts, a state machine in the hardware will start streaming the contents of MInst and CInst memory regions into their corresponding queues in 64KB chunks. Note that these queues have a capacity of 128KB each, thus, the streaming will occur into the upper or lower 64KB portion not currently being executed, overwriting code already executed. + +Instruction pointers for MINST and CINST queues will automatically start once there is code ready to execute. The hardware state machine will ensure that there is always code ready for execution. + +Execution of XINST queue, however is controlled by CInst. + +### Scratch Pad (SPAD) + +**SPAD capacity**: 48MB == 1,536 words. + +Data transfers from CE into SPAD are initiated by XINST via `xstore` instruction. The data is temporarily pushed into an intermediate data buffer queue. It is the CINST's responsibility to pop this intermediate queue and complete the transfer into SPAD before the buffer overflows. + +Data transfers from SPAD into the CE's register file are available in the corresponding registers one clock cycle after the transfer instruction completes. + +#### SPAD Restrictions + +- There cannot be multiple operations that use the data path between SPAD and CE in flight. + +**Temporary Store Queue Buffer capacity**: 4 words. + +### Register Banks + +The CE features 64 compute tile pairs (these are the compute units of the CE) arranged in 8 rows of 8 tile pairs. Each tile pair has 4 register banks with 72 registers. Each register has a capacity of 512 bytes. However, the architecture will ensure that all tile pairs will execute the same instruction on the same clock cycle on the same registers; therefore, we can treat the CE as a unit with the characteristics listed below. + +As a unit, the CE features a register file with 4 **register banks**. + +**Registers per bank**: 72. + +**Register capacity**: 1 word. + +#### Bank Restrictions + +- A register bank cannot be accessed for reading more than once in the same cycle. Reads normally occur on the first cycle of instructions. + +- A register bank cannot be accessed for writing more than once in the same cycle. Writes normally occur on the last cycle of instructions. + +- A register bank can be accessed for a single read and a single write simulataneously in the same cycle. + +## Output File Formats + +The assembler provides its output in three csv-style files. + +#### Comments + +All output files support inline comments using the hash symbol `#`. All text to the right of a `#` in a line is ignored as a comment. + +Note that full line comments are not supported, and every line must contain an instruction. + +### XINST File + +Contains the instructions for the XINST execution queue. + +File extension: `.xinst` + +File format: + +```csv +F, , , , , , +``` + +| Field | Type | Description | +|-|-|-| +| `bundle_num` | int32 | ID of bundle to which this instruction belongs. Instructions are grouped by bundles, so, this value is never smaller than previous instructions. | +| `trace_instr_num` | int32 | Matching input kernel instruction that caused the generation of this instruction. For book keeping purposes. | +| `op` | string | Name of the instruction. | +| `dests` | csv_string | Comma-separated list of all destinations for the instruction. 0 or more values. | +| `sources` | csv_string | Comma-separated list of all sources for the instruction. 0 or more values. | +| `other` | csv_string | Comma-separated list of any extra parameters required for the operation that are not specifically listed here. 0 or more values. | +| `residual` | int32 | Residual for the operation. | + +Note that some of the elements after the instruction name may be missing, depending on the instruction. + +Example: + +```csv +F99, 1056, ntt, r24b2, r25b3, r60b2, r61b3, r35b1, 13, 12 # dst: r24b2, r25b3, src: r60b2, r61b3, r35b1, stage: 13, res: 12 +``` + +Check instruction specification for exceptions. + +### CINST File + +Contains the instructions for the CINST execution queue. + +File extension: `.cinst` + +File format: + +```csv +, , , , +``` + +| Field | Type | Description | +|-|-|-| +| `instr_num` | int32 | Monotonically increasing instruction number. | +| `op` | string | Name of the instruction. | +| `dests` | csv_string | Comma-separated list of all destinations for the instruction. 0 or more values. | +| `sources` | csv_string | Comma-separated list of all sources for the instruction. 0 or more values. | +| `other` | csv_string | Comma-separated list of any extra parameters required for the operation that are not specifically listed here. 0 or more values. | + +Note that some of these elements after the instruction name may be missing, depending on the instruction. + +Example: + +```csv +55, cload, r60b0, 9 # dst: r60b0, src: 9 +``` + +Check instruction specification for exceptions. + +### MINST File + +Contains the instructions for the MINST execution queue. + +File extension: `.minst` + +File format: + +```csv +, , , , +``` + +| Field | Type | Description | +|-|-|-| +| `instr_num` | int32 | Monotonically increasing instruction number. | +| `op` | string | Name of the instruction. | +| `dests` | csv_string | Comma-separated list of all destinations for the instruction. 0 or more values. | +| `sources` | csv_string | Comma-separated list of all sources for the instruction. 0 or more values. | +| `other` | csv_string | Comma-separated list of any extra parameters required for the operation that are not specifically listed here. 0 or more values. | + +Note that some of these elements after the instruction name may be missing, depending on the instruction. + +Example: + +```csv +54, mload, 40, 29 # dst: 40, src: 29 +``` + +Check instruction specification for exceptions. + +## Instruction Set + +Instructions are pipelined. Thus, they will have a throughput time and a total latency. + +Throughput time is the number of clock cycles that it takes for the instruction to be dispatched. The execution engine will not move to the next instruction in the queue until the throughput time for the current instruction has elapsed. Most instructions have a throughput time of 1 clock cycle. + +Latency is the number of clock cycles it takes for the instruction to complete and its outputs to be ready. It includes the throughput time. Most Xinst have a latency of 6 clock cycles. See the instruction specification for details and exceptions. + +Because of pipelining, there can be several instructions in flight at the same time. The code produced by the assembler ensures that dependent instructions don't read or write data before their dependency has resolved, this is, until the previous instruction's latency has elapsed. + +The following instruction set functionally matches those of HERACLES ISA. It is provided here as a reference for syntax and semantics of the output generated by the HCGF assembler. + +#### Instructions - Assembler Output + +| MINST | CINST | XINST | +|-|-|-| +| [msyncc](inst_spec/minst/minst_msyncc.md) | [bload](inst_spec/cinst/cinst_bload.md) | [move](inst_spec/xinst/xinst_move.md) | +| [mload](inst_spec/minst/minst_mload.md) | [bones](inst_spec/cinst/cinst_bones.md) | [xstore](inst_spec/xinst/xinst_xstore.md) | +| [mstore](inst_spec/minst/minst_mstore.md) | [nload](inst_spec/cinst/cinst_nload.md) | [rshuffle](inst_spec/xinst/xinst_rshuffle.md) | +| | [xinstfetch](inst_spec/cinst/cinst_xinstfetch.md) | [ntt](inst_spec/xinst/xinst_ntt.md) | +| | [ifetch](inst_spec/cinst/cinst_ifetch.md) | [twntt](inst_spec/xinst/xinst_twntt.md) | +| | [cload](inst_spec/cinst/cinst_cload.md) | [twintt](inst_spec/xinst/xinst_twintt.md) | +| | [cstore](inst_spec/cinst/cinst_cstore.md) | [intt](inst_spec/xinst/xinst_intt.md) | +| | [csyncm](inst_spec/cinst/cinst_csyncm.md) | [add](inst_spec/xinst/xinst_add.md) | +| | [cexit](inst_spec/cinst/cinst_cexit.md) | [sub](inst_spec/xinst/xinst_sub.md) | +| | [nop](inst_spec/cinst/cinst_nop.md) | [mul](inst_spec/xinst/xinst_mul.md) | +| | | [muli](inst_spec/xinst/xinst_muli.md) | +| | | [mac](inst_spec/xinst/xinst_mac.md) | +| | | [maci](inst_spec/xinst/xinst_maci.md) | +| | | [exit](inst_spec/xinst/xinst_exit.md) | +| | | [nop](inst_spec/xinst/xinst_nop.md) | diff --git a/assembler_tools/hec-assembler-tools/gen_he_ops.py b/assembler_tools/hec-assembler-tools/gen_he_ops.py new file mode 100644 index 00000000..364b3290 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/gen_he_ops.py @@ -0,0 +1,264 @@ +import argparse +import io +import os +import pathlib +import subprocess +import sys +import yaml + +from assembler.common.constants import Constants +from assembler.common.run_config import RunConfig +import he_prep as preproc +import he_as as asm +import he_link as linker + +# module constants +DEFAULT_OPERATIONS = Constants.OPERATIONS[:6] + +class GenRunConfig(RunConfig): + """ + Maintains the configuration data for the run. + """ + + __initialized = False # specifies whether static members have been initialized + # contains the dictionary of all configuration items supported and their + # default value (or None if no default) + __default_config = {} + + def __init__(self, **kwargs): + """ + Constructs a new GenRunConfig Object from input parameters. + + See base class constructor for more arguments. + + Parameters + ---------- + scheme: str + FHE Scheme to use + + N: int + Ring dimension: PMD = 2^N. + + min_nrns: int + Minimum number of residuals. + + max_nrns: int + Maximum number of residuals. + + key_nrns: int + Optional number of residuals for relinearization keys. Must be greater than `max_nrns`. + If missing, the `key_nrns` for each P-ISA kernel generated will be set to the kernel + `nrns` (number of residuals) + 1. + + op_list: list[str] + Optional list of name of operations to generate. If provided, it must be a non-empty + subset of `Constants.OPERATIONS`. + Defaults to `DEFAULT_OPERATIONS`. + + output_dir: str + Optional directory where to store all intermediate files and final output. + This will be created if it doesn't exists. + Defaults to /lib. + + Raises + ------ + TypeError + A mandatory configuration value was missing. + + ValueError + At least, one of the arguments passed is invalid. + """ + + self.__init_statics() + + super().__init__(**kwargs) + + for config_name, default_value in self.__default_config.items(): + assert(not hasattr(self, config_name)) + setattr(self, config_name, kwargs.get(config_name, default_value)) + if getattr(self, config_name) is None: + raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') + + for op in self.op_list: + if op not in Constants.OPERATIONS: + raise ValueError('Invalid operation in input list of ops "{}". Expected one of {}'.format(op, Constants.OPERATIONS)) + + if self.key_nrns > 0: + if self.key_nrns < self.max_nrns: + raise ValueError(('`key_nrns` must be greater than `max_nrns` when present. ' + 'Received {}, but expected greater than {}.').format(self.key_nrns, + self.max_nrns)) + + @classmethod + def __init_statics(cls): + if not cls.__initialized: + cls.__default_config["scheme"] = "bgv" + cls.__default_config["N"] = None + cls.__default_config["min_nrns"] = None + cls.__default_config["max_nrns"] = None + cls.__default_config["key_nrns"] = 0 + cls.__default_config["output_dir"] = os.path.join(pathlib.Path.cwd(), "lib") + cls.__default_config["op_list"] = DEFAULT_OPERATIONS + + cls.__initialized = True + + def __str__(self): + """ + Returns a string representation of the configuration. + """ + self_dict = self.as_dict() + with io.StringIO() as retval_f: + for key, value in self_dict.items(): + print("{}: {}".format(key, value), file=retval_f) + retval = retval_f.getvalue() + return retval + + def as_dict(self) -> dict: + retval = super().as_dict() + tmp_self_dict = vars(self) + retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) + return retval + +def main(config: GenRunConfig, + b_verbose: bool = False): + + lib_dir = config.output_dir + + # create output directory to store outputs (if it doesn't already exist) + pathlib.Path(lib_dir).mkdir(exist_ok = True, parents=True) + + # point to the HERACLES-SEAL-isa-mapping repo + home_dir = pathlib.Path.home() + mapping_dir = os.getenv("HERACLES_MAPPING_PATH", os.path.join(home_dir, "HERACLES/HERACLES-SEAL-isa-mapping")) + # command to run the mapping script to generate operations kernels for our input + #generate_cmd = 'python3 "{}"'.format(os.path.join(mapping_dir, "kernels/run_he_op.py")) + generate_cmd = ['python3', '{}'.format(os.path.join(mapping_dir, "kernels/run_he_op.py"))] + + assert config.N < 1024 + assert config.min_nrns > 1 + assert (config.key_nrns == 0 or config.key_nrns > config.max_nrns) + assert(all(op in Constants.OPERATIONS for op in config.op_list)) + + pdegree = 2 ** config.N + for op in config.op_list: + for rn_el in range(config.min_nrns, config.max_nrns + 1): + key_nrns = config.key_nrns if config.key_nrns > 0 else rn_el + 1 + print(f"{config.scheme} {op} {config.N} {rn_el} {key_nrns}") + + output_prefix = "t.{}.{}.{}.{}".format(rn_el,op,config.N,key_nrns) + basef = os.path.join(lib_dir, output_prefix) + memfile = basef + ".tw.mem" + generate_cmdln = generate_cmd + [ "--map-file" , memfile ] + [ str(x) for x in (config.scheme, op, pdegree, rn_el, key_nrns) ] + + csvfile = basef + ".csv" + + # call the external script to generate the kernel for this op + print(' '.join(generate_cmdln)) + with open(csvfile, 'w') as fout_csv: + run_result = subprocess.run(generate_cmdln, stdout=fout_csv) + if run_result.returncode != 0: + raise RuntimeError('Exit code: {}. Failure to complete kernel generation successfully.'.format(run_result.returncode)) + + + # pre-process kernel step + #------------------------- + + # generate twiddle factors for this kernel + basef = basef + ".tw" #use the newly generated twiddle file + print() + print("Preprocessing") + preproc.main(basef + ".csv", + csvfile, + b_verbose=b_verbose) + + # assembling step + #----------------- + + # prepare config for assembler + asm_config = asm.AssemblerRunConfig(input_file=basef + ".csv", + input_mem_file=memfile, + output_prefix=output_prefix + '.o', + **config.as_dict()) # convert config to a dictionary and expand it as arguments + # temp path to store assembled output before linking set + asm_config.output_dir = os.path.join(asm_config.output_dir, 'obj') + print() + print("Assembling") + # run the assembler for this file + asm.main(asm_config, verbose=b_verbose) + + # linking step + #-------------- + + # prepare config for linker + linker_config = linker.LinkerRunConfig(input_prefixes = [os.path.join(asm_config.output_dir, asm_config.output_prefix)], + input_mem_file=memfile, + output_prefix=output_prefix, + **config.as_dict()) # convert config to a dictionary and expand it as arguments + print() + print("Linking") + # run the linker on the assembler output + linker.main(linker_config, sys.stdout if b_verbose else None) + + print(f'Completed "{output_prefix}"') + print() + +def parse_args(): + parser = argparse.ArgumentParser(description=("Generates a collection of HE operations based on input configuration."), + epilog=("To use, users should dump a default configuration file. Edit the file to " + "match the needs for the run, then execute the program with the modified " + "configuration. Note that dumping on top of an existing file will overwrite " + "its contents.")) + parser.add_argument("config_file", help=("YAML configuration file.")) + parser.add_argument("--dump", action="store_true", + help=("A default configuration will be writen into the file specified by `config_file`. " + "If the file already exists, it will be overwriten.")) + parser.add_argument("-v", "--verbose", dest="verbose", action="store_true", + help="If enabled, extra information and progress reports are printed to stdout.") + args = parser.parse_args() + + return args + +def readYAMLConfig(input_filename: str): + """ + Reads in a YAML file and returns a GenRunConfig object parsed from it. + """ + retval_dict = {} + with open(input_filename, "r") as infile: + retval_dict = yaml.safe_load(infile) + + return GenRunConfig(**retval_dict) + +def writeYAMLConfig(output_filename: str, config: GenRunConfig): + """ + Outputs the specified configuration to a YAML file. + """ + with open(output_filename, "w") as outfile: + yaml.dump(vars(config), outfile, sort_keys=False) + +if __name__ == "__main__": + module_name = os.path.basename(__file__) + print(module_name) + print() + + args = parse_args() + + if args.dump: + print("Writing default configuration to") + print(" ", args.config_file) + default_config = GenRunConfig(N=15, min_nrns=2, max_nrns=18) + writeYAMLConfig(args.config_file, default_config) + else: + print("Loading configuration file:") + print(" ", args.config_file) + config = readYAMLConfig(args.config_file) + print() + print("Gen Run Configuration") + print("=====================") + print(config) + print("=====================") + print() + main(config, + b_verbose=args.verbose) + + print() + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/he_as.py b/assembler_tools/hec-assembler-tools/he_as.py new file mode 100644 index 00000000..2ee92c49 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/he_as.py @@ -0,0 +1,411 @@ +#! /usr/bin/env python3 +""" +This module provides functionality for assembling pre-processed P-ISA kernel programs into valid assembly code for execution queues: MINST, CINST, and XINST. + +Classes: + AssemblerRunConfig + Maintains the configuration data for the run. + +Functions: + asmisaAssemble(run_config, output_minst_filename: str, output_cinst_filename: str, output_xinst_filename: str, b_verbose=True) -> tuple + Assembles the P-ISA kernel into ASM-ISA instructions and saves them to specified output files. + + main(config: AssemblerRunConfig, verbose: bool = False) + Executes the assembly process using the provided configuration. + + parse_args() -> argparse.Namespace + Parses command-line arguments for the assembler script. + +Usage: + This script is intended to be run as a standalone program. It requires specific command-line arguments + to specify input and output files and configuration options for the assembly process. + +""" +import argparse +import io +import os +import pathlib +import sys +import time + +from assembler.common.run_config import RunConfig +from assembler.common.run_config import static_initializer + +from assembler.common import constants +from assembler.common import makeUniquePath +from assembler.common.config import GlobalConfig +from assembler.common.counter import Counter +from assembler.isa_spec import SpecConfig +from assembler.instructions import xinst +from assembler.stages import scheduler +from assembler.stages.asm_scheduler import scheduleASMISAInstructions +from assembler.memory_model import MemoryModel +from assembler.memory_model import mem_info + +script_dir = os.path.dirname(os.path.realpath(__file__)) + +# module constants +DEFAULT_XINST_FILE_EXT = "xinst" +DEFAULT_CINST_FILE_EXT = "cinst" +DEFAULT_MINST_FILE_EXT = "minst" +DEFAULT_MEM_FILE_EXT = "mem" + +@static_initializer +class AssemblerRunConfig(RunConfig): + """ + Maintains the configuration data for the run. + + Methods: + as_dict() -> dict + Returns the configuration as a dictionary. + """ + + __initialized = False # specifies whether static members have been initialized + # contains the dictionary of all configuration items supported and their + # default value (or None if no default) + __default_config = {} + + def __init__(self, **kwargs): + """ + Constructs a new AssemblerRunConfig Object from input parameters. + + See base class constructor for more parameters. + + Args: + input_file (str): + Input file containing the kernel code to assemble. + Kernel code should have twiddle factors added already as appropriate. + input_mem_file (str): + Optional input memory file associated with the kernel. + If missing, the memory file is expected to be same as `input_file`, but with extension ".mem". + output_dir (str): + Optional directory where to store all intermediate files and final output. + This will be created if it doesn't exists. + Defaults to the same directory as the input file. + output_prefix (str): + Optional prefix for the output file names. + Defaults to the name of the input file without extension. + + Raises: + TypeError: + A mandatory configuration value was missing. + ValueError: + At least, one of the arguments passed is invalid. + """ + + super().__init__(**kwargs) + + + # class members based on configuration + for config_name, default_value in self.__default_config.items(): + assert(not hasattr(self, config_name)) + setattr(self, config_name, kwargs.get(config_name, default_value)) + if getattr(self, config_name) is None: + raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') + + # class members + self.input_prefix = "" + + # fix file names + + self.input_file = makeUniquePath(self.input_file) + input_dir = os.path.dirname(os.path.realpath(self.input_file)) + if not self.output_dir: + self.output_dir = input_dir + self.output_dir = makeUniquePath(self.output_dir) + + self.input_prefix = os.path.splitext(os.path.basename(self.input_file))[0] + + if not self.input_mem_file: + self.input_mem_file = "{}.{}".format(os.path.join(input_dir, self.input_prefix), + DEFAULT_MEM_FILE_EXT) + self.input_mem_file = makeUniquePath(self.input_mem_file) + + @classmethod + def init_static(cls): + """ + Initializes static members of the class. + """ + if not cls.__initialized: + cls.__default_config["input_file"] = None + cls.__default_config["input_mem_file"] = "" + cls.__default_config["output_dir"] = "" + cls.__default_config["output_prefix"] = "" + cls.__default_config["has_hbm"] = True + cls.__initialized = True + + def __str__(self): + """ + Provides a string representation of the configuration. + + Returns: + str: The string for the configuration. + """ + self_dict = self.as_dict() + with io.StringIO() as retval_f: + for key, value in self_dict.items(): + print("{}: {}".format(key, value), file=retval_f) + retval = retval_f.getvalue() + return retval + + def as_dict(self) -> dict: + """ + Provides the configuration as a dictionary. + + Returns: + dict: The configuration. + """ + retval = super().as_dict() + tmp_self_dict = vars(self) + retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) + return retval + +def asmisaAssemble(run_config, + output_minst_filename: str, + output_cinst_filename: str, + output_xinst_filename: str, + b_verbose=True) -> tuple: + """ + Assembles the P-ISA kernel into ASM-ISA instructions and saves them to specified output files. + + This function reads the input kernel file, interprets variable meta information, generates a dependency graph, + schedules ASM-ISA instructions, and saves the results to output files. + + Args: + run_config: The configuration object containing run parameters. + output_minst_filename (str): The filename for saving MINST instructions. + output_cinst_filename (str): The filename for saving CINST instructions. + output_xinst_filename (str): The filename for saving XINST instructions. + b_verbose (bool): Flag indicating whether verbose output is enabled. + + Returns: + tuple: A tuple containing the number of XInstructions, number of NOPs, number of idle cycles, dependency timing, and scheduling timing. + """ + + max_bundle_size = 64 + + input_filename: str = run_config.input_file + mem_filename: str = run_config.input_mem_file + hbm_capcity_words: int = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) + spad_capacity_words: int = constants.convertBytes2Words(run_config.spad_size * constants.Constants.KILOBYTE) + num_register_banks: int = constants.MemoryModel.NUM_REGISTER_BANKS + register_range: range = None + + if b_verbose: + print("Assembling!") + print("Reloading kernel from intermediate...") + + hec_mem_model = MemoryModel(hbm_capcity_words, spad_capacity_words, num_register_banks, register_range) + + insts_listing = [] + with open(input_filename, 'r') as insts: + for line_no, s_line in enumerate(insts, 1): + parsed_insts = None + if GlobalConfig.debugVerbose: + if line_no % 100 == 0: + print(f"{line_no}") + # instruction is one that is represented by single XInst + inst = xinst.createFromPISALine(hec_mem_model, s_line, line_no) + if inst: + parsed_insts = [ inst ] + + if not parsed_insts: + raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) + + insts_listing += parsed_insts + + if b_verbose: + print("Interpreting variable meta information...") + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) + + if b_verbose: + print("Generating dependency graph...") + start_time = time.time() + dep_graph = scheduler.generateInstrDependencyGraph(insts_listing, + sys.stdout if b_verbose else None) + scheduler.enforceKeygenOrdering(dep_graph, hec_mem_model, sys.stdout if b_verbose else None) + deps_end = time.time() - start_time + + if b_verbose: + print("Preparing to schedule ASM-ISA instructions...") + start_time = time.time() + minsts, cinsts, xinsts, num_idle_cycles = scheduleASMISAInstructions(dep_graph, + max_bundle_size, # max number of instructions in a bundle + hec_mem_model, + run_config.repl_policy, + b_verbose) + sched_end = time.time() - start_time + num_nops = 0 + num_xinsts = 0 + for bundle_xinsts, *_ in xinsts: + for xinstr in bundle_xinsts: + num_xinsts += 1 + if isinstance(xinstr, xinst.Exit): + break # stop counting instructions after bundle exit + if isinstance(xinstr, xinst.Nop): + num_nops += 1 + + if b_verbose: + print("Saving minst...") + with open(output_minst_filename, 'w') as outnum: + for idx, inst in enumerate(minsts): + inst_line = inst.toMASMISAFormat() + if inst_line: + print(f"{idx}, {inst_line}", file=outnum) + + if b_verbose: + print("Saving cinst...") + with open(output_cinst_filename, 'w') as outnum: + for idx, inst in enumerate(cinsts): + inst_line = inst.toCASMISAFormat() + if inst_line: + print(f"{idx}, {inst_line}", file=outnum) + + if b_verbose: + print("Saving xinst...") + with open(output_xinst_filename, 'w') as outnum: + for bundle_i, bundle_data in enumerate(xinsts): + for inst in bundle_data[0]: + inst_line = inst.toXASMISAFormat() + if inst_line: + print(f"F{bundle_i}, {inst_line}", file=outnum) + + return num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end + +def main(config: AssemblerRunConfig, verbose: bool = False): + """ + Executes the assembly process using the provided configuration. + + This function sets up the output directory, initializes output filenames, tests output writability, + and performs the assembly process, printing results if verbose output is enabled. + + Args: + config (AssemblerRunConfig): The configuration object containing run parameters. + verbose (bool): Flag indicating whether verbose output is enabled. + + Returns: + None + """ + # check defaults + + # make a copy to avoid changing original config + config = AssemblerRunConfig(**config.as_dict()) + + # create output directory to store outputs (if it doesn't already exist) + pathlib.Path(config.output_dir).mkdir(exist_ok = True, parents=True) + + # initialize output filenames + + output_basef = os.path.join(config.output_dir, config.output_prefix) \ + if config.output_prefix \ + else os.path.join(config.output_dir, config.input_prefix) + + output_xinst_file = f'{output_basef}.{DEFAULT_XINST_FILE_EXT}' + output_cinst_file = f'{output_basef}.{DEFAULT_CINST_FILE_EXT}' + output_minst_file = f'{output_basef}.{DEFAULT_MINST_FILE_EXT}' + + # test output is writable + for filename in (output_minst_file, output_cinst_file, output_xinst_file): + try: + with open(filename, 'w') as outnum: + print("", file=outnum) + except Exception as ex: + raise Exception(f'Failed to write to output location "{filename}"') from ex + + GlobalConfig.useHBMPlaceHolders = True #config.use_hbm_placeholders + GlobalConfig.useXInstFetch = config.use_xinstfetch + GlobalConfig.supressComments = config.suppress_comments + GlobalConfig.hasHBM = config.has_hbm + GlobalConfig.debugVerbose = config.debug_verbose + + Counter.reset() + + num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end = \ + asmisaAssemble(config, + output_minst_file, + output_cinst_file, + output_xinst_file, + b_verbose=verbose) + + if verbose: + print(f"Output:") + for filename in (output_minst_file, output_cinst_file, output_xinst_file): + print(f" {filename}") + print(f"--- Total XInstructions: {num_xinsts} ---") + print(f"--- Deps time: {deps_end} seconds ---") + print(f"--- Scheduling time: {sched_end} seconds ---") + print(f"--- Minimum idle cycles: {num_idle_cycles} ---") + print(f"--- Minimum nops required: {num_nops} ---") + +def parse_args(): + """ + Parses command-line arguments for the assembler script. + + This function sets up the argument parser and defines the expected arguments for the script. + It returns a Namespace object containing the parsed arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser( + description=("HERACLES Assembler.\n" + "The assembler takes a pre-processed P-ISA kernel program and generates " + "valid assembly code for each of the three execution queues: MINST, CINST, and XINST.")) + parser.add_argument("input_file", + help=("Input pre-processed P-ISA kernel file. " + "File must be the result of pre-processing a P-ISA kernel with he_prep.py")) + parser.add_argument("--isa_spec", default="", dest="isa_spec_file", + help=("Input ISA specification (.json) file.")) + parser.add_argument("--input_mem_file", default="", help=("Input memory mapping file associated with the kernel. " + "Defaults to the same name as the input file, but with `.mem` extension.")) + parser.add_argument("--output_dir", default="", help=("Directory where to store all intermediate files and final output. " + "This will be created if it doesn't exists. " + "Defaults to the same directory as the input file.")) + parser.add_argument("--output_prefix", default="", help=("Prefix for the output files. " + "Defaults to the same the input file without extension.")) + parser.add_argument("--spad_size", type=int, default=AssemblerRunConfig.DEFAULT_SPAD_SIZE_KB, + help="Scratchpad size in KB. Defaults to {} KB.".format(AssemblerRunConfig.DEFAULT_SPAD_SIZE_KB)) + parser.add_argument("--hbm_size", type=int, default=AssemblerRunConfig.DEFAULT_HBM_SIZE_KB, + help="HBM size in KB. Defaults to {} KB.".format(AssemblerRunConfig.DEFAULT_HBM_SIZE_KB)) + parser.add_argument("--no_hbm", dest="has_hbm", action="store_false", + help="If set, this flag tells he_prep there is no HBM in the target chip.") + parser.add_argument("--repl_policy", default=AssemblerRunConfig.DEFAULT_REPL_POLICY, + choices=constants.Constants.REPLACEMENT_POLICIES, + help="Replacement policy for cache evictions. Defaults to {}.".format(AssemblerRunConfig.DEFAULT_REPL_POLICY)) + parser.add_argument("--use_xinstfetch", dest="use_xinstfetch", action="store_true", + help=("When enabled, `xinstfetch` instructions are generated in the CInstQ.")) + parser.add_argument("--suppress_comments", "--no_comments", dest="suppress_comments", action="store_true", + help=("When enabled, no comments will be emited on the output generated by the assembler.")) + parser.add_argument("--debug_verbose", type=int, default=0) + parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, + help=("If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + args = parser.parse_args() + + return args + +if __name__ == "__main__": + module_dir = os.path.dirname(__file__) + module_name = os.path.basename(__file__) + + args = parse_args() + args.isa_spec_file = SpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + + config = AssemblerRunConfig(**vars(args)) # convert argsparser into a dictionary + + if args.verbose > 0: + print(module_name) + print() + print("Run Configuration") + print("=================") + print(config) + print("=================") + print() + + main(config, verbose = args.verbose > 1) + + if args.verbose > 0: + print() + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/he_link.py b/assembler_tools/hec-assembler-tools/he_link.py new file mode 100644 index 00000000..b68439fe --- /dev/null +++ b/assembler_tools/hec-assembler-tools/he_link.py @@ -0,0 +1,378 @@ +#! /usr/bin/env python3 +# encoding: utf-8 +""" +This module provides functionality for linking assembled kernels into a full HERACLES program for execution queues: MINST, CINST, and XINST. + +Classes: + LinkerRunConfig + Maintains the configuration data for the run. + + KernelFiles + Structure for kernel files. + +Functions: + main(run_config: LinkerRunConfig, verbose_stream=None) + Executes the linking process using the provided configuration. + + parse_args() -> argparse.Namespace + Parses command-line arguments for the linker script. + +Usage: + This script is intended to be run as a standalone program. It requires specific command-line arguments + to specify input and output files and configuration options for the linking process. + +""" +import argparse +import io +import os +import pathlib +import sys +import time +import warnings + +import linker + +from typing import NamedTuple + +from assembler.common import constants +from assembler.common import makeUniquePath +from assembler.common.counter import Counter +from assembler.common.run_config import RunConfig +from assembler.common.run_config import static_initializer +from assembler.common.config import GlobalConfig +from assembler.memory_model import mem_info +from linker import loader +from linker.steps import variable_discovery +from linker.steps import program_linker + +@static_initializer +class LinkerRunConfig(RunConfig): + """ + Maintains the configuration data for the run. + + Methods: + as_dict() -> dict + Returns the configuration as a dictionary. + """ + + __initialized = False # specifies whether static members have been initialized + # contains the dictionary of all configuration items supported and their + # default value (or None if no default) + __default_config = {} + + def __init__(self, **kwargs): + """ + Constructs a new LinkerRunConfig Object from input parameters. + + See base class constructor for more parameters. + + Args: + input_prefixes (list[str]): + List of input prefixes, including full path. For an input prefix, linker will + assume there are three files named `input_prefixes[i] + '.minst'`, + `input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`. + This list must not be empty. + output_prefix (str): + Prefix for the output file names. + Three files will be generated: + `output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and + `output_dir/output_prefix.xinst`. + Output filenames cannot match input file names. + input_mem_file (str): + Input memory file associated with the result kernel. + output_dir (str): current working directory + OPTIONAL directory where to store all intermediate files and final output. + This will be created if it doesn't exists. + Defaults to current working directory. + + Raises: + TypeError: + A mandatory configuration value was missing. + ValueError: + At least, one of the arguments passed is invalid. + """ + + super().__init__(**kwargs) + + + # class members based on configuration + for config_name, default_value in self.__default_config.items(): + assert(not hasattr(self, config_name)) + setattr(self, config_name, kwargs.get(config_name, default_value)) + if getattr(self, config_name) is None: + raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') + + # fix file names + self.output_dir = makeUniquePath(self.output_dir) + self.input_mem_file = makeUniquePath(self.input_mem_file) + + @classmethod + def init_static(cls): + """ + Initializes static members of the class. + """ + if not cls.__initialized: + cls.__default_config["input_prefixes"] = None + cls.__default_config["input_mem_file"] = None + cls.__default_config["output_dir"] = os.getcwd() + cls.__default_config["output_prefix"] = None + cls.__default_config["has_hbm"] = True + + cls.__initialized = True + + def __str__(self): + """ + Provides a string representation of the configuration. + + Returns: + str: The string for the configuration. + """ + self_dict = self.as_dict() + with io.StringIO() as retval_f: + for key, value in self_dict.items(): + print("{}: {}".format(key, value), file=retval_f) + retval = retval_f.getvalue() + return retval + + def as_dict(self) -> dict: + """ + Provides the configuration as a dictionary. + + Returns: + dict: The configuration. + """ + retval = super().as_dict() + tmp_self_dict = vars(self) + retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) + return retval + +class KernelFiles(NamedTuple): + """ + Structure for kernel files. + + Attributes: + minst (str): + Index = 0. Name for file containing MInstructions for represented kernel. + cinst (str): + Index = 1. Name for file containing CInstructions for represented kernel. + xinst (str): + Index = 2. Name for file containing XInstructions for represented kernel. + prefix (str): + Index = 3 + """ + minst: str + cinst: str + xinst: str + prefix: str + +def main(run_config: LinkerRunConfig, verbose_stream = None): + """ + Executes the linking process using the provided configuration. + + This function prepares input and output file names, initializes the memory model, discovers variables, + and links each kernel, writing the output to specified files. + + Args: + run_config (LinkerRunConfig): The configuration object containing run parameters. + verbose_stream: The stream to which verbose output is printed. Defaults to None. + + Returns: + None + """ + if verbose_stream: + print("Linking...", file=verbose_stream) + + if run_config.use_xinstfetch: + warnings.warn(f'Ignoring configuration flag "use_xinstfetch".') + + # Update global config + GlobalConfig.hasHBM = run_config.has_hbm + + mem_filename: str = run_config.input_mem_file + hbm_capcity_words: int = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) + input_files = [] # list(KernelFiles) + output_files: KernelFiles = None + + # prepare output file names + output_prefix = os.path.join(run_config.output_dir, run_config.output_prefix) + output_dir = os.path.dirname(output_prefix) + pathlib.Path(output_dir).mkdir(exist_ok = True, parents=True) + output_files = KernelFiles(minst=makeUniquePath(output_prefix + '.minst'), + cinst=makeUniquePath(output_prefix + '.cinst'), + xinst=makeUniquePath(output_prefix + '.xinst'), + prefix=makeUniquePath(output_prefix)) + + # prepare input file names + for file_prefix in run_config.input_prefixes: + input_files.append(KernelFiles(minst=makeUniquePath(file_prefix + '.minst'), + cinst=makeUniquePath(file_prefix + '.cinst'), + xinst=makeUniquePath(file_prefix + '.xinst'), + prefix=makeUniquePath(file_prefix))) + for input_filename in input_files[-1][:-1]: + if not os.path.isfile(input_filename): + raise FileNotFoundError(input_filename) + if input_filename in output_files: + raise RuntimeError(f'Input files cannot match output files: "{input_filename}"') + + # reset counters + Counter.reset() + + # parse mem file + + if verbose_stream: + print("", file=verbose_stream) + print("Interpreting variable meta information...", file=verbose_stream) + + with open(mem_filename, 'r') as mem_ifnum: + mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) + + # initialize memory model + if verbose_stream: + print("Initializing linker memory model", file=verbose_stream) + + mem_model = linker.MemoryModel(hbm_capcity_words, mem_meta_info) + if verbose_stream: + print(f" HBM capacity: {mem_model.hbm.capacity} words", file=verbose_stream) + + # find all variables and usage across all the input kernels + + if verbose_stream: + print(" Finding all program variables...", file=verbose_stream) + print(" Scanning", file=verbose_stream) + + for idx, kernel in enumerate(input_files): + if not GlobalConfig.hasHBM: + if verbose_stream: + print(" {}/{}".format(idx + 1, len(input_files)), kernel.cinst, + file=verbose_stream) + # load next CInst kernel and scan for variables used in SPAD + kernel_cinstrs = loader.loadCInstKernelFromFile(kernel.cinst) + for var_name in variable_discovery.discoverVariablesSPAD(kernel_cinstrs): + mem_model.addVariable(var_name) + else: + if verbose_stream: + print(" {}/{}".format(idx + 1, len(input_files)), kernel.minst, + file=verbose_stream) + # load next MInst kernel and scan for variables used + kernel_minstrs = loader.loadMInstKernelFromFile(kernel.minst) + for var_name in variable_discovery.discoverVariables(kernel_minstrs): + mem_model.addVariable(var_name) + + # check that all non-keygen variables from MemInfo are used + for var_name in mem_model.mem_info_vars: + if var_name not in mem_model.variables: + if GlobalConfig.hasHBM or var_name not in mem_model.mem_info_meta: # skip checking meta vars when no HBM + raise RuntimeError(f'Unused variable from input mem file: "{var_name}" not in memory model.') + + if verbose_stream: + print(f" Variables found: {len(mem_model.variables)}", file=verbose_stream) + + if verbose_stream: + print("Linking started", file=verbose_stream) + + # open the output files + with open(output_files.minst, 'w') as fnum_output_minst, \ + open(output_files.cinst, 'w') as fnum_output_cinst, \ + open(output_files.xinst, 'w') as fnum_output_xinst: + + # prepare the linker class + result_program = program_linker.LinkedProgram(fnum_output_minst, + fnum_output_cinst, + fnum_output_xinst, + mem_model, + supress_comments=run_config.suppress_comments) + # start linking each kernel + for idx, kernel in enumerate(input_files): + if verbose_stream: + print("[ {: >3}% ]".format(idx * 100 // len(input_files)), kernel.prefix, + file=verbose_stream) + kernel_minstrs = loader.loadMInstKernelFromFile(kernel.minst) + kernel_cinstrs = loader.loadCInstKernelFromFile(kernel.cinst) + kernel_xinstrs = loader.loadXInstKernelFromFile(kernel.xinst) + + result_program.linkKernel(kernel_minstrs, kernel_cinstrs, kernel_xinstrs) + + if verbose_stream: + print("[ 100% ] Finalizing output", output_files.prefix, file=verbose_stream) + + # signal that we have linked all kernels + result_program.close() + + if verbose_stream: + print("Output written to files:", file=verbose_stream) + print(" ", output_files.minst, file=verbose_stream) + print(" ", output_files.cinst, file=verbose_stream) + print(" ", output_files.xinst, file=verbose_stream) + +def parse_args(): + """ + Parses command-line arguments for the linker script. + + This function sets up the argument parser and defines the expected arguments for the script. + It returns a Namespace object containing the parsed arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser( + description=("HERACLES Linker.\n" + "Links assembled kernels into a full HERACLES program " + "for each of the three execution queues: MINST, CINST, and XINST.\n\n" + "To link several kernels, specify each kernel's input prefix in order. " + "Variables that should carry on across kernels should be have the same name. " + "Linker will recognize matching variables and keep their values between kernels. " + "Variables that are inputs and outputs (and metadata) for the whole program must " + "be indicated in the input memory mapping file.")) + parser.add_argument("input_prefixes", nargs="+", + help=("List of input prefixes, including full path. For an input prefix, linker will " + "assume three files exist named `input_prefixes[i] + '.minst'`, " + "`input_prefixes[i] + '.cinst'`, and `input_prefixes[i] + '.xinst'`.")) + parser.add_argument("-im", "--input_mem_file", dest="input_mem_file", required=True, + help=("Input memory mapping file associated with the resulting program. " + "Specifies the names for input, output, and metadata variables for the full program. " + "This file is usually the same as the kernel's when converting a single kernel into " + "a program, but, when linking multiple kernels together, it should be tailored to the " + "whole program.")) + parser.add_argument("-o", "--output_prefix", dest="output_prefix", required=True, + help=("Prefix for the output file names. " + "Three files will be generated: \n" + "`output_dir/output_prefix.minst`, `output_dir/output_prefix.cinst`, and " + "`output_dir/output_prefix.xinst`. \n" + "Output filenames cannot match input file names.")) + parser.add_argument("-od", "--output_dir", dest="output_dir", default="", + help=("Directory where to store all intermediate files and final output. " + "This will be created if it doesn't exists. " + "Defaults to current working directory.")) + parser.add_argument("--hbm_size", type=int, default=LinkerRunConfig.DEFAULT_HBM_SIZE_KB, + help="HBM size in KB. Defaults to {} KB.".format(LinkerRunConfig.DEFAULT_HBM_SIZE_KB)) + parser.add_argument("--no_hbm", dest="has_hbm", action="store_false", + help="If set, this flag tells he_prep there is no HBM in the target chip.") + parser.add_argument("--suppress_comments", "--no_comments", dest="suppress_comments", action="store_true", + help=("When enabled, no comments will be emited on the output generated.")) + parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, + help=("If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + args = parser.parse_args() + + return args + +if __name__ == "__main__": + module_name = os.path.basename(__file__) + + args = parse_args() + config = LinkerRunConfig(**vars(args)) # convert argsparser into a dictionary + + if args.verbose > 0: + print(module_name) + print() + print("Run Configuration") + print("=================") + print(config) + print("=================") + print() + + main(config, sys.stdout if args.verbose > 1 else None) + + if args.verbose > 0: + print() + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/he_prep.py b/assembler_tools/hec-assembler-tools/he_prep.py new file mode 100644 index 00000000..7c207e99 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/he_prep.py @@ -0,0 +1,155 @@ +#! /usr/bin/env python3 + +""" +This module provides functionality for preprocessing P-ISA abstract kernels before further assembling for HERACLES. + +Functions: + __savePISAListing(out_stream, instr_listing: list) + Stores instructions to a stream in P-ISA format. + + main(output_file_name: str, input_file_name: str, b_verbose: bool) + Preprocesses the P-ISA kernel and saves the output to a specified file. + + parse_args() -> argparse.Namespace + Parses command-line arguments for the preprocessing script. + +Usage: + This script is intended to be run as a standalone program. It requires specific command-line arguments + to specify input and output files and verbosity options for the preprocessing process. + +""" +import argparse +import os +import time + +from assembler.common import constants +from assembler.isa_spec import SpecConfig +from assembler.stages import preprocessor +from assembler.memory_model import MemoryModel + +def __savePISAListing(out_stream, + instr_listing: list): + """ + Stores the instructions to a stream in P-ISA format. + + This function iterates over a list of instructions and prints each instruction in P-ISA format + to the specified output stream. + + Args: + out_stream: The output stream to which the instructions are printed. + instr_listing (list): A list of instructions to be printed in P-ISA format. + + Returns: + None + """ + for inst in instr_listing: + inst_line = inst.toPISAFormat() + if inst_line: + print(inst_line, file=out_stream) + +def main(output_file_name: str, + input_file_name: str, + b_verbose: bool): + """ + Preprocesses the P-ISA kernel and saves the output to a specified file. + + This function reads an input kernel file, preprocesses it to transform instructions into ASM-ISA format, + assigns register banks to variables, and saves the processed instructions to an output file. + + Args: + output_file_name (str): The name of the output file where processed instructions are saved. + input_file_name (str): The name of the input file containing the P-ISA kernel. + b_verbose (bool): Flag indicating whether verbose output is enabled. + + Returns: + None + """ + # used for timings + insts_end: int = 0 + + # check for default `output_file_name` + # e.g. of default + # input_file_name = /path/to/some/file.csv + # output_file_name = /path/to/some/file.tw.csv + if not output_file_name: + output_file_name = os.path.splitext(input_file_name) + output_file_name = ''.join(output_file_name[:-1] + (".tw",) + output_file_name[-1:]) + + hec_mem_model = MemoryModel(constants.MemoryModel.HBM.MAX_CAPACITY_WORDS, + constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS) + + insts_listing = [] + start_time = time.time() + # read input kernel and pre-process P-ISA: + # resulting instructions will be correctly transformed and ready to be converted into ASM-ISA instructions; + # variables used in the kernel will be automatically assigned to banks. + with open(input_file_name, 'r') as insts: + insts_listing = preprocessor.preprocessPISAKernelListing(hec_mem_model, + insts, + progress_verbose=b_verbose) + num_input_instr: int = len(insts_listing) # track number of instructions in input kernel + if b_verbose: + print("Assigning register banks to variables...") + preprocessor.assignRegisterBanksToVars(hec_mem_model, + insts_listing, + use_bank0=False, + verbose=b_verbose) + insts_end = time.time() - start_time + + if b_verbose: + print("Saving...") + with open(output_file_name, 'w') as outnum: + __savePISAListing(outnum, insts_listing) + + if b_verbose: + print(f"Input: {input_file_name}") + print(f"Output: {output_file_name}") + print(f"Instructions in input: {num_input_instr}") + print(f"Instructions in output: {len(insts_listing)}") + print(f"--- Generation time: {insts_end} seconds ---") + +def parse_args(): + """ + Parses command-line arguments for the preprocessing script. + + This function sets up the argument parser and defines the expected arguments for the script. + It returns a Namespace object containing the parsed arguments. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + parser = argparse.ArgumentParser( + description="HERACLES Assembling Pre-processor.\nThis program performs the preprocessing of P-ISA abstract kernels before further assembling.") + parser.add_argument("input_file_name", help="Input abstract kernel file to which to add twiddle factors.") + parser.add_argument("output_file_name", nargs="?", help="Output file name. Defaults to .tw.") + parser.add_argument("--isa_spec", default="", dest="isa_spec_file", + help=("Input ISA specification (.json) file.")) + parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, + help=("If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + args = parser.parse_args() + + return args + +if __name__ == "__main__": + module_dir = os.path.dirname(__file__) + module_name = os.path.basename(__file__) + + args = parse_args() + + args.isa_spec_file = SpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + + if args.verbose > 0: + print(module_name) + print() + print("Input: {0}".format(args.input_file_name)) + print("Output: {0}".format(args.output_file_name)) + print("ISA Spec: {0}".format(args.isa_spec_file)) + + main(output_file_name=args.output_file_name, + input_file_name=args.input_file_name, + b_verbose=(args.verbose > 1)) + + if args.verbose > 0: + print() + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/linker/__init__.py b/assembler_tools/hec-assembler-tools/linker/__init__.py new file mode 100644 index 00000000..a32676d9 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/__init__.py @@ -0,0 +1,246 @@ +import collections.abc as collections +from assembler.common.config import GlobalConfig +from assembler.memory_model import mem_info + +# linker/__init__.py contains classes to encapsulate the memory model used +# by the linker. + +class VariableInfo(mem_info.MemInfoVariable): + """ + Represents information about a variable in the memory model. + """ + + def __init__(self, var_name, hbm_address=-1): + """ + Initializes a VariableInfo object. + + Parameters: + var_name (str): The name of the variable. + hbm_address (int): The HBM address of the variable. Defaults to -1. + """ + super().__init__(var_name, hbm_address) + self.uses = 0 + self.last_kernel_used = -1 + +class HBM: + """ + Represents the HBM model. + """ + + def __init__(self, hbm_size_words: int): + """ + Initializes an HBM object. + + Parameters: + hbm_size_words (int): The size of the HBM in words. + + Raises: + ValueError: If hbm_size_words is less than 1. + """ + if hbm_size_words < 1: + raise ValueError('`hbm_size_words` must be a positive integer.') + # Represents the memory buffer where variables live + self.__buffer = [None] * hbm_size_words + + @property + def capacity(self) -> int: + """ + Gets the capacity in words for the HBM buffer. + + Returns: + int: The capacity of the HBM buffer. + """ + return len(self.buffer) + + @property + def buffer(self) -> list: + """ + Gets the HBM buffer. + + Returns: + list: The HBM buffer. + """ + return self.__buffer + + def forceAllocate(self, var_info: VariableInfo, hbm_address: int): + """ + Forcefully allocates a variable at a specific HBM address. + + Parameters: + var_info (VariableInfo): The variable information. + hbm_address (int): The HBM address to allocate the variable. + + Raises: + IndexError: If hbm_address is out of bounds. + ValueError: If the variable is already allocated at a different address. + RuntimeError: If the HBM address is already occupied by another variable. + """ + if hbm_address < 0 or hbm_address >= len(self.buffer): + raise IndexError('`hbm_address` out of bounds. Expected a word address in range [0, {}), but {} received'.format(len(self.buffer), + hbm_address)) + if var_info.hbm_address != hbm_address: + if var_info.hbm_address >= 0: + raise ValueError(f'`var_info`: variable {var_info.var_name} already allocated in address {var_info.hbm_address}.') + + in_var_info = self.buffer[hbm_address] + # Validate hbm address + if not GlobalConfig.hasHBM: + # Attempt to recycle SPAD locations inside kernel when no HBM + # Note: there is no HBM, so, SPAD is used as the sole memory space + if in_var_info and in_var_info.uses > 0: + raise RuntimeError(('HBM address {} already occupied by variable {} ' + 'when attempting to allocate variable {}').format(hbm_address, + in_var_info.var_name, + var_info.var_name)) + else: + if in_var_info \ + and (in_var_info.uses > 0 or in_var_info.last_kernel_used >= var_info.last_kernel_used): + raise RuntimeError(('HBM address {} already occupied by variable {} ' + 'when attempting to allocate variable {}').format(hbm_address, + in_var_info.var_name, + var_info.var_name)) + var_info.hbm_address = hbm_address + self.buffer[hbm_address] = var_info + + def allocate(self, var_info: VariableInfo): + """ + Allocates a variable in the HBM. + + Parameters: + var_info (VariableInfo): The variable information. + + Raises: + RuntimeError: If there is no available HBM memory. + """ + # Find next available HBM address + retval = -1 + for idx, in_var_info in enumerate(self.buffer): + if not GlobalConfig.hasHBM: + # Attempt to recycle SPAD locations inside kernel when no HBM + # Note: there is no HBM, so, SPAD is used as the sole memory space + if not in_var_info or in_var_info.uses <= 0: + retval = idx + break + else: + if not in_var_info \ + or (in_var_info.uses <= 0 and in_var_info.last_kernel_used < var_info.last_kernel_used): + retval = idx + break + if retval < 0: + raise RuntimeError('Out of HBM memory.') + self.forceAllocate(var_info, retval) + +class MemoryModel: + """ + Encapsulates the memory model for a linker run, tracking HBM usage and program variables. + """ + + def __init__(self, hbm_size_words: int, mem_meta_info: mem_info.MemInfo): + """ + Initializes a MemoryModel object. + + Parameters: + hbm_size_words (int): The size of the HBM in words. + mem_meta_info (mem_info.MemInfo): The memory metadata information. + """ + self.hbm = HBM(hbm_size_words) + self.__mem_info = mem_meta_info + self.__variables = {} # dict(var_name: str, VariableInfo) + self.__keygen_vars = {var_info.var_name: var_info for var_info in self.__mem_info.keygens} + self.__mem_info_inputs = {var_info.var_name: var_info for var_info in self.__mem_info.inputs} + self.__mem_info_outputs = {var_info.var_name: var_info for var_info in self.__mem_info.outputs} + self.__mem_info_meta = {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_auxiliary_table} \ + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.intt_routing_table} \ + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_auxiliary_table} \ + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ntt_routing_table} \ + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.ones} \ + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.twiddle} \ + | {var_info.var_name: var_info for var_info in self.__mem_info.metadata.keygen_seeds} + self.__mem_info_fixed_addr_vars = self.__mem_info_outputs | self.__mem_info_meta + # Keygen variables should not be part of mem_info_vars set since they + # do not start in HBM + self.__mem_info_vars = self.__mem_info_inputs | self.__mem_info_outputs | self.__mem_info_meta + + @property + def mem_info_meta(self) -> collections.Collection: + """ + Set of metadata variable names in MemInfo used to construct this object. + Clients must not modify this set. + """ + return self.__mem_info_meta + + @property + def mem_info_vars(self) -> collections.Collection: + """ + Gets the set of variable names in MemInfo used to construct this object. + + Returns: + collections.Collection: The set of variable names. + """ + return self.__mem_info_vars + + @property + def variables(self) -> dict: + """ + Gets direct access to internal variables dictionary. + + Clients should use as read-only. Must not add, replace, remove or change + contents in any way. Use provided helper functions to manipulate. + + Returns: + dict: A dictionary of variables. + """ + return self.__variables + + def addVariable(self, var_name: str): + """ + Adds a variable to the HBM model. If variable already exists, its `uses` + field is incremented. + + Parameters: + var_name (str): The name of the variable to add. + """ + var_info: VariableInfo + if var_name in self.variables: + var_info = self.variables[var_name] + else: + var_info = VariableInfo(var_name) + if var_name in self.__mem_info_vars: + # Variables explicitly marked in mem file must persist throughout the program + # with predefined HBM address + if var_name in self.__mem_info_fixed_addr_vars: + var_info.uses = float('inf') + self.hbm.forceAllocate(var_info, + self.__mem_info_vars[var_name].hbm_address) + self.variables[var_name] = var_info + var_info.uses += 1 + + def useVariable(self, var_name: str, kernel: int) -> int: + """ + Uses a variable, decrementing its usage count. + + If a variable usage count reaches zero, it will be deallocated from HBM, if needed, + when a future kernel requires HBM space. + + Parameters: + var_name (str): The name of the variable to use. + kernel (int): The kernel that is using the variable. + + Returns: + int: The HBM address for the variable. + """ + var_info: VariableInfo = self.variables[var_name] + assert var_info.uses > 0 + + var_info.uses -= 1 # Mark the usage + var_info.last_kernel_used = kernel + + if var_info.hbm_address < 0: + # Find HBM address for variable + self.hbm.allocate(var_info) + + assert var_info.hbm_address >= 0 + assert self.hbm.buffer[var_info.hbm_address].var_name == var_info.var_name, \ + f'Expected variable {var_info.var_name} in HBM {var_info.hbm_address}, but variable {self.hbm[var_info.hbm_address].var_name} found instead.' + + return var_info.hbm_address \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py new file mode 100644 index 00000000..66afd290 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -0,0 +1,25 @@ +from assembler.instructions import tokenizeFromLine +from linker.instructions.instruction import BaseInstruction + +def fromStrLine(line: str, factory) -> BaseInstruction: + """ + Parses an instruction from a line of text. + + Parameters: + line (str): Line of text from which to parse an instruction. + + Returns: + BaseInstruction or None: The parsed BaseInstruction object, or None if no object could be + parsed from the specified input line. + """ + retval = None + tokens, comment = tokenizeFromLine(line) + for instr_type in factory: + try: + retval = instr_type(tokens, comment) + except: + retval = None + if retval: + break + + return retval diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/__init__.py new file mode 100644 index 00000000..850a5c14 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/__init__.py @@ -0,0 +1,40 @@ + +from . import bload, bones, cexit, cload, cnop, cstore, csyncm, ifetch, kgload, kgseed, kgstart, nload, xinstfetch + +# MInst aliases + +BLoad = bload.Instruction +BOnes = bones.Instruction +CExit = cexit.Instruction +CLoad = cload.Instruction +CNop = cnop.Instruction +CStore = cstore.Instruction +CSyncm = csyncm.Instruction +IFetch = ifetch.Instruction +KGLoad = kgload.Instruction +KGSeed = kgseed.Instruction +KGStart = kgstart.Instruction +NLoad = nload.Instruction +XInstFetch = xinstfetch.Instruction + +def factory() -> set: + """ + Creates a set of all instruction classes. + + Returns: + set: A set containing all instruction classes. + """ + + return { BLoad, + BOnes, + CExit, + CLoad, + CNop, + CStore, + CSyncm, + IFetch, + KGLoad, + KGSeed, + KGStart, + NLoad, + XInstFetch } diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/bload.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/bload.py new file mode 100644 index 00000000..2fc070bd --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/bload.py @@ -0,0 +1,64 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates the `bload` CInstruction. + + The `bload` instruction loads metadata from the scratchpad to special registers in the register file. + + For more information, check the `bload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bload.md + """ + + @classmethod + def _get_num_tokens(cls)->int: + """ + Gets the number of tokens required for the instruction. + + The `bload` instruction requires 5 tokens: + , bload, , , + + Returns: + int: The number of tokens, which is 5. + """ + # 5 tokens: + # , bload, , , + # No HBM variant: + # , bload, , , + return 5 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "bload". + """ + return "bload" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `bload` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + + @property + def source(self) -> str: + """ + Name of the source. + This is a Variable name when loaded. Should be set to HBM address to write back. + """ + return self.tokens[3] + + @source.setter + def source(self, value: str): + self.tokens[3] = value diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/bones.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/bones.py new file mode 100644 index 00000000..e2459f21 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/bones.py @@ -0,0 +1,63 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `bones` CInstruction. + + The `bones` instruction loads metadata of identity (one) from the scratchpad to the register file. + + For more information, check the `bones` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bones.md + """ + + @classmethod + def _get_num_tokens(cls)->int: + """ + Gets the number of tokens required for the instruction. + + The `bones` instruction requires 4 tokens: + , bones, , + + Returns: + int: The number of tokens, which is 4. + """ + # 4 tokens: + # , bones, , + # No HBM variant: + # , bones, , + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "bones". + """ + return "bones" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `bones` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def source(self) -> str: + """ + Name of the source. + This is a Variable name when loaded. Should be set to HBM address to write back. + """ + return self.tokens[2] + + @source.setter + def source(self, value: str): + self.tokens[2] = value diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cexit.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cexit.py new file mode 100644 index 00000000..23137724 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cexit.py @@ -0,0 +1,47 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `cexit` CInstruction. + + This instruction terminates execution of a HERACLES program. + + For more information, check the `cexit` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cexit.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `cexit` instruction requires 2 tokens: + , cexit + + Returns: + int: The number of tokens, which is 2. + """ + return 2 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "cexit". + """ + return "cexit" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `cexit` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py new file mode 100644 index 00000000..a82fb36d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cinstruction.py @@ -0,0 +1,41 @@ +from linker.instructions.instruction import BaseInstruction + +class CInstruction(BaseInstruction): + """ + Represents a CInstruction, inheriting from BaseInstruction. + """ + + @classmethod + def _get_name_token_index(cls) -> int: + """ + Gets the index of the token containing the name of the instruction. + + Returns: + int: The index of the name token, which is 1. + """ + return 1 + + # Constructor + # ----------- + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new CInstruction. + + Parameters: + tokens (list): List of tokens for the instruction. + comment (str): Optional comment for the instruction. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + def to_line(self) -> str: + """ + Retrieves the string form of the instruction to write to the instruction file. + + Returns: + str: The string representation of the instruction, excluding the first token. + """ + return ", ".join(self.tokens[1:]) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cload.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cload.py new file mode 100644 index 00000000..41218a24 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cload.py @@ -0,0 +1,63 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `cload` CInstruction. + + This instruction loads a single polynomial residue from scratchpad into a register. + + For more information, check the `cload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cload.md + """ + + @classmethod + def _get_num_tokens(cls)->int: + """ + Gets the number of tokens required for the instruction. + + The `cload` instruction requires 4 tokens: + , cload, , + + Returns: + int: The number of tokens, which is 4. + """ + # 4 tokens: + # , cload, , + # No HBM variant + # , cload, , + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "cload". + """ + return "cload" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `cload` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def source(self) -> str: + """ + Name of the source. + This is a Variable name when loaded. Should be set to HBM address to write back. + """ + return self.tokens[3] + + @source.setter + def source(self, value: str): + self.tokens[3] = value diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cnop.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cnop.py new file mode 100644 index 00000000..266b719b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cnop.py @@ -0,0 +1,75 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `cnop` CInstruction. + + This instruction adds a desired amount of idle cycles in the Cfetch flow. + + For more information, check the `cnop` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_nop.md + + Properties: + cycles: Gets or sets the number of idle cycles. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `cnop` instruction requires 3 tokens: + , cnop, + + Returns: + int: The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "cnop". + """ + return "cnop" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `cnop` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def cycles(self) -> int: + """ + Gets the number of idle cycles. + + Returns: + int: The number of idle cycles. + """ + return int(self.tokens[2]) + + @cycles.setter + def cycles(self, value: int): + """ + Sets the number of idle cycles. + + Args: + value (int): The number of idle cycles to set. + + Raises: + ValueError: If the value is negative. + """ + if value < 0: + raise ValueError(f'`value` must be non-negative, but {value} received.') + self.tokens[2] = str(value) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cstore.py new file mode 100644 index 00000000..a95abad6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/cstore.py @@ -0,0 +1,63 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `cstore` CInstruction. + + This instruction fetches a single polynomial residue from the intermediate data buffer and stores it back to SPAD. + + For more information, check the `cstore` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cstore.md + """ + + @classmethod + def _get_num_tokens(cls)->int: + """ + Gets the number of tokens required for the instruction. + + The `cstore` instruction requires 3 tokens: + , cstore, + + Returns: + int: The number of tokens, which is 3. + """ + # 3 tokens: + # , cstore, + # No HBM variant + # , cstore, + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "cstore". + """ + return "cstore" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `cstore` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def dest(self) -> str: + """ + Name of the destination. + This is a Variable name when loaded. Should be set to HBM address to write back. + """ + return self.tokens[2] + + @dest.setter + def dest(self, value: str): + self.tokens[2] = value diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py new file mode 100644 index 00000000..204bb4e7 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/csyncm.py @@ -0,0 +1,76 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `csyncm` CInstruction. + + Wait instruction similar to a barrier that stalls the execution of the CINST + queue until the specified instruction from the MINST queue has completed. + + For more information, check the `csyncm` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_csyncm.md + + Properties: + target: Gets or sets the target MInst. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `csyncm` instruction requires 3 tokens: + , csyncm, + + Returns: + int: The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "csyncm". + """ + return "csyncm" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `csyncm` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def target(self) -> int: + """ + Gets the target MInst. + + Returns: + int: The target MInst. + """ + return int(self.tokens[2]) + + @target.setter + def target(self, value: int): + """ + Sets the target MInst. + + Args: + value (int): The target MInst to set. + + Raises: + ValueError: If the value is negative. + """ + if value < 0: + raise ValueError(f'`value`: expected non-negative target, but {value} received.') + self.tokens[2] = str(value) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py new file mode 100644 index 00000000..96924223 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/ifetch.py @@ -0,0 +1,75 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates an `ifetch` CInstruction. + + This instruction fetches a bundle of instructions from the XINST queue and sends it to the CE for execution. + + For more information, check the `ifetch` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_ifetch.md + + Properties: + bundle: Gets or sets the target bundle index. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `ifetch` instruction requires 3 tokens: + , ifetch, + + Returns: + int: The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "ifetch". + """ + return "ifetch" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `ifetch` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def bundle(self) -> int: + """ + Gets the target bundle index. + + Returns: + int: The target bundle index. + """ + return int(self.tokens[2]) + + @bundle.setter + def bundle(self, value: int): + """ + Sets the target bundle index. + + Args: + value (int): The target bundle index to set. + + Raises: + ValueError: If the value is negative. + """ + if value < 0: + raise ValueError(f'`value`: expected non-negative bundle index, but {value} received.') + self.tokens[2] = str(value) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgload.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgload.py new file mode 100644 index 00000000..2c7126e9 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgload.py @@ -0,0 +1,42 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `kg_load` CInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `kg_load` instruction requires 3 tokens: + , kg_load, + + Returns: + int: The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "kg_load". + """ + return "kg_load" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `kg_load` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgseed.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgseed.py new file mode 100644 index 00000000..f067f0e2 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgseed.py @@ -0,0 +1,42 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `kg_seed` CInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `kg_seed` instruction requires 4 tokens: + , kg_seed, , + + Returns: + int: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "kg_seed". + """ + return "kg_seed" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `kg_seed` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgstart.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgstart.py new file mode 100644 index 00000000..89e8669d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/kgstart.py @@ -0,0 +1,42 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `kg_start` CInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `kg_start` instruction requires 2 tokens: + , kg_start + + Returns: + int: The number of tokens, which is 2. + """ + return 2 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "kg_start". + """ + return "kg_start" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `kg_start` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/nload.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/nload.py new file mode 100644 index 00000000..c2d13df7 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/nload.py @@ -0,0 +1,64 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `nload` CInstruction. + + This instruction loads metadata (for NTT/iNTT routing mapping) from + scratchpad into a special routing table register. + + For more information, check the `nload` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_nload.md + """ + + @classmethod + def _get_num_tokens(cls)->int: + """ + Gets the number of tokens required for the instruction. + + The `nload` instruction requires 4 tokens: + , nload, , + + Returns: + int: The number of tokens, which is 4. + """ + # 4 tokens: + # , nload, , + # No HBM variant: + # , nload, , + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "nload". + """ + return "nload" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `nload` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def source(self) -> str: + """ + Name of the source. + This is a Variable name when loaded. Should be set to HBM address to write back. + """ + return self.tokens[3] + + @source.setter + def source(self, value: str): + self.tokens[3] = value diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py new file mode 100644 index 00000000..f82fbcf7 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/cinst/xinstfetch.py @@ -0,0 +1,103 @@ +from .cinstruction import CInstruction + +class Instruction(CInstruction): + """ + Encapsulates a `xinstfetch` CInstruction. + + This instruction fetches instructions from the HBM and sends them to the XINST queue. + + For more information, check the `xinstfetch` Specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_xinstfetch.md + + Properties: + dstXQueue: Gets or sets the destination in the XINST queue. + srcHBM: Gets or sets the source in the HBM. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `xinstfetch` instruction requires 4 tokens: + , xinstfetch, , + + Returns: + int: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "xinstfetch". + """ + return "xinstfetch" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `xinstfetch` CInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + NotImplementedError: If the `xinstfetch` CInstruction is not supported in the linker. + """ + super().__init__(tokens, comment=comment) + raise NotImplementedError('`xinstfetch` CInstruction is not currently supported in linker.') + + @property + def dstXQueue(self) -> int: + """ + Gets the destination in the XINST queue. + + Returns: + int: The destination in the XINST queue. + """ + return int(self.tokens[2]) + + @dstXQueue.setter + def dstXQueue(self, value: int): + """ + Sets the destination in the XINST queue. + + Args: + value (int): The destination value to set. + + Raises: + ValueError: If the value is negative. + """ + if value < 0: + raise ValueError(f'`value`: expected non-negative value, but {value} received.') + self.tokens[2] = str(value) + + @property + def srcHBM(self) -> int: + """ + Gets the source in the HBM. + + Returns: + int: The source in the HBM. + """ + return int(self.tokens[3]) + + @srcHBM.setter + def srcHBM(self, value: int): + """ + Sets the source in the HBM. + + Args: + value (int): The source value to set. + + Raises: + ValueError: If the value is negative. + """ + if value < 0: + raise ValueError(f'`value`: expected non-negative value, but {value} received.') + self.tokens[3] = str(value) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py new file mode 100644 index 00000000..ad029607 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/instruction.py @@ -0,0 +1,175 @@ +from assembler.common.decorators import * +from assembler.common.counter import Counter + +class BaseInstruction: + """ + Base class for all instructions. + + This class provides common functionality for all instructions in the linker. + + Class Properties: + name (str): Returns the name of the represented operation. + + Attributes: + comment (str): Comment for the instruction. + + Properties: + tokens (list[str]): List of tokens for the instruction. + id (int): Unique instruction ID. This is a unique nonce representing the instruction. + + Methods: + to_line(self) -> str: + Retrieves the string form of the instruction to write to the instruction file. + """ + + __id_count = Counter.count(0) # Internal unique sequence counter to generate unique IDs + + # Class methods and properties + # ---------------------------- + + @classproperty + def name(cls) -> str: + """ + Name for the instruction. + + Returns: + str: The name of the instruction. + """ + return cls._get_name() + + @classmethod + def _get_name(cls) -> str: + """ + Derived classes should implement this method and return correct + name for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + + @classproperty + def NAME_TOKEN_INDEX(cls) -> int: + """ + Index for the token containing the name of the instruction + in the list of tokens. + + Returns: + int: The index of the name token. + """ + return cls._get_name_token_index() + + @classmethod + def _get_name_token_index(cls) -> int: + """ + Derived classes should implement this method and return correct + index for the token containing the name of the instruction + in the list of tokens. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + + @classproperty + def NUM_TOKENS(cls) -> int: + """ + Number of tokens required for this instruction. + + Returns: + int: The number of tokens required. + """ + return cls._get_num_tokens() + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Derived classes should implement this method and return correct + required number of tokens for the instruction. + + Raises: + NotImplementedError: Abstract method. This base method should not be called. + """ + raise NotImplementedError() + + # Constructor + # ----------- + + def __init__(self, tokens: list, comment: str = ""): + """ + Creates a new BaseInstruction object. + + Parameters: + tokens (list): List of tokens for the instruction. + comment (str): Optional comment for the instruction. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + assert self.NAME_TOKEN_INDEX < self.NUM_TOKENS + + if len(tokens) != self.NUM_TOKENS: + raise ValueError(('`tokens`: invalid amount of tokens. ' + 'Instruction {} requires {}, but {} received').format(self.name, + self.NUM_TOKENS, + len(tokens))) + if tokens[self.NAME_TOKEN_INDEX] != self.name: + raise ValueError('`tokens`: invalid name. Expected {}, but {} received'.format(self.name, + tokens[self.NAME_TOKEN_INDEX])) + + self.__id = next(BaseInstruction.__id_count) + + self.__tokens = list(tokens) + self.comment = comment + + def __repr__(self): + retval = ('<{}({}, id={}) object at {}>(tokens={})').format(type(self).__name__, + self.name, + self.id, + hex(id(self)), + self.token) + return retval + + def __eq__(self, other): + # Equality operator== overload + return self is other + + def __hash__(self): + return hash(self.id) + + def __str__(self): + return f'{self.name}({self.id})' + + # Methods and properties + # ---------------------------- + + @property + def id(self) -> tuple: + """ + Unique ID for the instruction. + + This is a combination of the client ID specified during construction and a unique nonce per instruction. + + Returns: + tuple: (client_id: int, nonce: int) where client_id is the id specified at construction. + """ + return self.__id + + @property + def tokens(self) -> list: + """ + Gets the list of tokens for the instruction. + + Returns: + list: The list of tokens. + """ + return self.__tokens + + def to_line(self) -> str: + """ + Retrieves the string form of the instruction to write to the instruction file. + + Returns: + str: The string representation of the instruction. + """ + return ", ".join(self.tokens) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/__init__.py new file mode 100644 index 00000000..fd037e60 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/__init__.py @@ -0,0 +1,19 @@ + +from . import mload, mstore, msyncc + +# MInst aliases + +MLoad = mload.Instruction +MStore = mstore.Instruction +MSyncc = msyncc.Instruction + +def factory() -> set: + """ + Creates a set of all instruction classes. + + Returns: + set: A set containing all instruction classes. + """ + return { MLoad, + MStore, + MSyncc } diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py new file mode 100644 index 00000000..00d40452 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/minstruction.py @@ -0,0 +1,42 @@ + +from linker.instructions.instruction import BaseInstruction + +class MInstruction(BaseInstruction): + """ + Represents an MInstruction, inheriting from BaseInstruction. + """ + + @classmethod + def _get_name_token_index(cls) -> int: + """ + Gets the index of the token containing the name of the instruction. + + Returns: + int: The index of the name token, which is 1. + """ + return 1 + + # Constructor + # ----------- + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new MInstruction. + + Parameters: + tokens (list): List of tokens for the instruction. + comment (str): Optional comment for the instruction. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + def to_line(self) -> str: + """ + Retrieves the string form of the instruction to write to the instruction file. + + Returns: + str: The string representation of the instruction, excluding the first token. + """ + return ", ".join(self.tokens[1:]) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/mload.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/mload.py new file mode 100644 index 00000000..b662101b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/mload.py @@ -0,0 +1,72 @@ +from .minstruction import MInstruction + +class Instruction(MInstruction): + """ + Encapsulates an `mload` MInstruction. + + This instruction loads a single polynomial residue from local memory to scratchpad. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_mload.md + + Properties: + source: Gets or sets the name of the source. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `mload` instruction requires 4 tokens: + , mload, , + + Returns: + int: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "mload". + """ + return "mload" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `mload` MInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def source(self) -> str: + """ + Gets the name of the source. + + This is a Variable name when loaded. Should be set to HBM address to write back. + + Returns: + str: The name of the source. + """ + return self.tokens[3] + + @source.setter + def source(self, value: str): + """ + Sets the name of the source. + + Args: + value (str): The name of the source to set. + """ + self.tokens[3] = value \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/mstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/mstore.py new file mode 100644 index 00000000..6046ea0f --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/mstore.py @@ -0,0 +1,72 @@ +from .minstruction import MInstruction + +class Instruction(MInstruction): + """ + Encapsulates an `mstore` MInstruction. + + This instruction stores a single polynomial residue from scratchpad to local memory. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_mstore.md + + Properties: + dest: Gets or sets the name of the destination. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `mstore` instruction requires 4 tokens: + , mstore, , + + Returns: + int: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "mstore". + """ + return "mstore" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `mstore` MInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def dest(self) -> str: + """ + Gets the name of the destination. + + This is a Variable name when loaded. Should be set to HBM address to write back. + + Returns: + str: The name of the destination. + """ + return self.tokens[2] + + @dest.setter + def dest(self, value: str): + """ + Sets the name of the destination. + + Args: + value (str): The name of the destination to set. + """ + self.tokens[2] = value \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py b/assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py new file mode 100644 index 00000000..4291d239 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/minst/msyncc.py @@ -0,0 +1,76 @@ +from .minstruction import MInstruction + +class Instruction(MInstruction): + """ + Encapsulates an `msyncc` MInstruction. + + Wait instruction similar to a barrier that stalls the execution of the MINST + queue until the specified instruction from the CINST queue has completed. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/minst/minst_msyncc.md + + Properties: + target: Gets or sets the target CInst. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `msyncc` instruction requires 3 tokens: + , msyncc, + + Returns: + int: The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "msyncc". + """ + return "msyncc" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `msyncc` MInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def target(self) -> int: + """ + Gets the target CInst. + + Returns: + int: The target CInst. + """ + return int(self.tokens[2]) + + @target.setter + def target(self, value: int): + """ + Sets the target CInst. + + Args: + value (int): The target CInst to set. + + Raises: + ValueError: If the value is negative. + """ + if value < 0: + raise ValueError(f'`value`: expected non-negative target, but {value} received.') + self.tokens[2] = str(value) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py new file mode 100644 index 00000000..3a63845c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/__init__.py @@ -0,0 +1,46 @@ +from . import add, sub, mul, muli, mac, maci, ntt, intt, twntt, twintt, rshuffle, move, xstore, nop +from . import exit as exit_mod +#from . import copy as copy_mod + +# XInst aliases + +# XInsts with P-ISA equivalent +Add = add.Instruction +Sub = sub.Instruction +Mul = mul.Instruction +Muli = muli.Instruction +Mac = mac.Instruction +Maci = maci.Instruction +NTT = ntt.Instruction +iNTT = intt.Instruction +twNTT = twntt.Instruction +twiNTT = twintt.Instruction +rShuffle = rshuffle.Instruction +# All other XInsts +Move = move.Instruction +XStore = xstore.Instruction +Exit = exit_mod.Instruction +Nop = nop.Instruction + +def factory() -> set: + """ + Creates a set of all instruction classes. + + Returns: + set: A set containing all instruction classes. + """ + return { Add, + Sub, + Mul, + Muli, + Mac, + Maci, + NTT, + iNTT, + twNTT, + twiNTT, + rShuffle, + Move, + XStore, + Exit, + Nop } diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/add.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/add.py new file mode 100644 index 00000000..ac7194e8 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/add.py @@ -0,0 +1,48 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates an `add` XInstruction. + + This instruction adds two polynomials stored in the register file and + stores the result in a register. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_add.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `add` instruction requires 7 tokens: + F, , add, , , , + + Returns: + int: The number of tokens, which is 7. + """ + return 7 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "add". + """ + return "add" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `add` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/exit.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/exit.py new file mode 100644 index 00000000..98d36ada --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/exit.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates an `bexit` XInstruction. + + This instruction terminates execution of an instruction bundle. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_exit.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `bexit` instruction requires 3 tokens: + F, , bexit + + Returns: + int: The number of tokens, which is 3. + """ + return 3 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "bexit". + """ + return "bexit" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `bexit` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py new file mode 100644 index 00000000..c9240440 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/intt.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates an `intt` XInstruction. + + The Inverse Number Theoretic Transform (iNTT) converts NTT form to positional form. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_intt.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `intt` instruction requires 10 tokens: + F, , intt, , , , , , , + + Returns: + int: The number of tokens, which is 10. + """ + return 10 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "intt". + """ + return "intt" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `intt` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/mac.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/mac.py new file mode 100644 index 00000000..d7e9800d --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/mac.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `mac` XInstruction. + + Element-wise polynomial multiplication and accumulation. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mac.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `mac` instruction requires 8 tokens: + F, , mac, , , , , + + Returns: + int: The number of tokens, which is 8. + """ + return 8 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "mac". + """ + return "mac" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `mac` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py new file mode 100644 index 00000000..a1703be8 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/maci.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `maci` XInstruction. + + Element-wise polynomial scaling by an immediate value and accumulation. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_maci.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `maci` instruction requires 8 tokens: + F, , maci, , , , , + + Returns: + int: The number of tokens, which is 8. + """ + return 8 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "maci". + """ + return "maci" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `maci` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/move.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/move.py new file mode 100644 index 00000000..8ae2233a --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/move.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `move` XInstruction. + + This instruction copies data from one register to a different one. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_move.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `move` instruction requires 5 tokens: + F, , move, , + + Returns: + int: The number of tokens, which is 5. + """ + return 5 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "move". + """ + return "move" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `move` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/mul.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/mul.py new file mode 100644 index 00000000..40bbb0c6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/mul.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `mul` XInstruction. + + This instruction performs element-wise polynomial multiplication. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_mul.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `mul` instruction requires 7 tokens: + F, , mul, , , , + + Returns: + int: The number of tokens, which is 7. + """ + return 7 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "mul". + """ + return "mul" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `mul` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/muli.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/muli.py new file mode 100644 index 00000000..3590a6c6 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/muli.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `muli` XInstruction. + + This instruction performs element-wise polynomial scaling by an immediate value. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_muli.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `muli` instruction requires 7 tokens: + F, , muli, , , , + + Returns: + int: The number of tokens, which is 7. + """ + return 7 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "muli". + """ + return "muli" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `muli` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/nop.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/nop.py new file mode 100644 index 00000000..f6b3f513 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/nop.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `nop` XInstruction. + + This instruction adds a desired amount of idle cycles to the compute flow. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_nop.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `nop` instruction requires 4 tokens: + F, , nop, + + Returns: + int: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "nop". + """ + return "nop" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `nop` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py new file mode 100644 index 00000000..ead438ea --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/ntt.py @@ -0,0 +1,46 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates an `ntt` XInstruction (Number Theoretic Transform). + Converts positional form to NTT form. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_ntt.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `ntt` instruction requires 10 tokens: + F, , ntt, , , , , , , + + Returns: + int: The number of tokens, which is 10. + """ + return 10 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "ntt". + """ + return "ntt" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `ntt` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/rshuffle.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/rshuffle.py new file mode 100644 index 00000000..daf0d754 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/rshuffle.py @@ -0,0 +1,42 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates an `rshuffle` XInstruction. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `rshuffle` instruction requires 9 tokens: + F, , rshuffle, , , , , , + + Returns: + int: The number of tokens, which is 9. + """ + return 9 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "rshuffle". + """ + return "rshuffle" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `rshuffle` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/sub.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/sub.py new file mode 100644 index 00000000..f4cecb19 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/sub.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `sub` XInstruction. + + This instruction performs element-wise polynomial subtraction. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_sub.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `sub` instruction requires 7 tokens: + F, , sub, , , , + + Returns: + int: The number of tokens, which is 7. + """ + return 7 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "sub". + """ + return "sub" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `sub` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py new file mode 100644 index 00000000..5ea3b62c --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twintt.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `twintt` XInstruction. + + This instruction performs on-die generation of twiddle factors for the next stage of iNTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twintt.md. + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `twintt` instruction requires 10 tokens: + F, , twintt, , , , , , , + + Returns: + int: The number of tokens, which is 10. + """ + return 10 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "twintt". + """ + return "twintt" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `twintt` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py new file mode 100644 index 00000000..1e1ae65b --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/twntt.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates a `twntt` XInstruction. + + This instruction performs on-die generation of twiddle factors for the next stage of NTT. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_twntt.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `twntt` instruction requires 10 tokens: + F, , twntt, , , , , , , + + Returns: + int: The number of tokens, which is 10. + """ + return 10 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "twntt". + """ + return "twntt" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `twntt` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py new file mode 100644 index 00000000..3a622539 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xinstruction.py @@ -0,0 +1,64 @@ + +from linker.instructions.instruction import BaseInstruction + +class XInstruction(BaseInstruction): + """ + Represents an XInstruction, inheriting from BaseInstruction. + """ + + @classmethod + def _get_name_token_index(cls) -> int: + """ + Gets the index of the token containing the name of the instruction. + + Returns: + int: The index of the name token, which is 2. + """ + # Name at index 2. + return 2 + + # Constructor + # ----------- + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new XInstruction. + + Parameters: + tokens (list): List of tokens for the instruction. + comment (str): Optional comment for the instruction. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) + + @property + def bundle(self) -> int: + """ + Gets the bundle index. + + Returns: + int: The bundle index. + + Raises: + RuntimeError: If the bundle format is invalid. + """ + if len(self.tokens[0]) < 2 or self.tokens[0][0] != 'F': + raise RuntimeError(f'Invalid bundle format detected: "{self.tokens[0]}".') + return int(self.tokens[0][1:]) + + @bundle.setter + def bundle(self, value: int): + """ + Sets the bundle index. + + Parameters: + value (int): The new bundle index. + + Raises: + ValueError: If the value is negative. + """ + if value < 0: + raise ValueError(f'`value`: expected non-negative bundle index, but {value} received.') + self.tokens[0] = f'F{value}' \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xstore.py b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xstore.py new file mode 100644 index 00000000..8ed88dbb --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/instructions/xinst/xstore.py @@ -0,0 +1,47 @@ +from .xinstruction import XInstruction + +class Instruction(XInstruction): + """ + Encapsulates an `xstore` XInstruction. + + This instruction transfers data from a register into the intermediate data buffer for subsequent transfer into SPAD. + + For more information, check the specification: + https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_xstore.md + """ + + @classmethod + def _get_num_tokens(cls) -> int: + """ + Gets the number of tokens required for the instruction. + + The `xstore` instruction requires 4 tokens: + F, , xstore, + + Returns: + int: The number of tokens, which is 4. + """ + return 4 + + @classmethod + def _get_name(cls) -> str: + """ + Gets the name of the instruction. + + Returns: + str: The name of the instruction, which is "xstore". + """ + return "xstore" + + def __init__(self, tokens: list, comment: str = ""): + """ + Constructs a new `xstore` XInstruction. + + Args: + tokens (list): A list of tokens representing the instruction. + comment (str, optional): An optional comment for the instruction. Defaults to an empty string. + + Raises: + ValueError: If the number of tokens is invalid or the instruction name is incorrect. + """ + super().__init__(tokens, comment=comment) \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/loader.py b/assembler_tools/hec-assembler-tools/linker/loader.py new file mode 100644 index 00000000..26894912 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/loader.py @@ -0,0 +1,124 @@ +from linker.instructions import minst +from linker.instructions import cinst +from linker.instructions import xinst +from linker import instructions + +def loadMInstKernel(line_iter) -> list: + """ + Loads MInstruction kernel from an iterator of lines. + + Parameters: + line_iter: An iterator over lines of MInstruction strings. + + Returns: + list: A list of MInstruction objects. + + Raises: + RuntimeError: If a line cannot be parsed into an MInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + minstr = instructions.fromStrLine(s_line, minst.factory()) + if not minstr: + raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + retval.append(minstr) + return retval + +def loadMInstKernelFromFile(filename: str) -> list: + """ + Loads MInstruction kernel from a file. + + Parameters: + filename (str): The file containing MInstruction strings. + + Returns: + list: A list of MInstruction objects. + + Raises: + RuntimeError: If an error occurs while loading the file. + """ + with open(filename, 'r') as kernel_minsts: + try: + return loadMInstKernel(kernel_minsts) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e + +def loadCInstKernel(line_iter) -> list: + """ + Loads CInstruction kernel from an iterator of lines. + + Parameters: + line_iter: An iterator over lines of CInstruction strings. + + Returns: + list: A list of CInstruction objects. + + Raises: + RuntimeError: If a line cannot be parsed into a CInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + cinstr = instructions.fromStrLine(s_line, cinst.factory()) + if not cinstr: + raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + retval.append(cinstr) + return retval + +def loadCInstKernelFromFile(filename: str) -> list: + """ + Loads CInstruction kernel from a file. + + Parameters: + filename (str): The file containing CInstruction strings. + + Returns: + list: A list of CInstruction objects. + + Raises: + RuntimeError: If an error occurs while loading the file. + """ + with open(filename, 'r') as kernel_cinsts: + try: + return loadCInstKernel(kernel_cinsts) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e + +def loadXInstKernel(line_iter) -> list: + """ + Loads XInstruction kernel from an iterator of lines. + + Parameters: + line_iter: An iterator over lines of XInstruction strings. + + Returns: + list: A list of XInstruction objects. + + Raises: + RuntimeError: If a line cannot be parsed into an XInstruction. + """ + retval = [] + for idx, s_line in enumerate(line_iter): + xinstr = instructions.fromStrLine(s_line, xinst.factory()) + if not xinstr: + raise RuntimeError(f'Error parsing line {idx + 1}: {s_line}') + retval.append(xinstr) + return retval + +def loadXInstKernelFromFile(filename: str) -> list: + """ + Loads XInstruction kernel from a file. + + Parameters: + filename (str): The file containing XInstruction strings. + + Returns: + list: A list of XInstruction objects. + + Raises: + RuntimeError: If an error occurs while loading the file. + """ + with open(filename, 'r') as kernel_xinsts: + try: + return loadXInstKernel(kernel_xinsts) + except Exception as e: + raise RuntimeError(f'Error occurred loading file "{filename}"') from e \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/steps/__init__.py b/assembler_tools/hec-assembler-tools/linker/steps/__init__.py new file mode 100644 index 00000000..e02abfc9 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/steps/__init__.py @@ -0,0 +1 @@ + diff --git a/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py new file mode 100644 index 00000000..d1377e39 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/steps/program_linker.py @@ -0,0 +1,352 @@ + +from linker import MemoryModel +from linker.instructions import minst, cinst, xinst +from assembler.common.config import GlobalConfig +from assembler.isa_spec import cinst as ISACInst + +class LinkedProgram: + """ + Encapsulates a linked program. + + This class offers facilities to track and link kernels, and + outputs the linked program to specified output streams as kernels + are linked. + + The program itself is not contained in this object. + """ + + def __init__(self, + program_minst_ostream, + program_cinst_ostream, + program_xinst_ostream, + mem_model: MemoryModel, + supress_comments: bool): + """ + Initializes a LinkedProgram object. + + Parameters: + program_minst_ostream: Output stream for MInst instructions. + program_cinst_ostream: Output stream for CInst instructions. + program_xinst_ostream: Output stream for XInst instructions. + mem_model (MemoryModel): Correctly initialized linker memory model. It must already contain the + variables used throughout the program and their usage. + This memory model will be modified by this object when linking kernels. + supress_comments (bool): Whether to suppress comments in the output. + """ + self.__minst_ostream = program_minst_ostream + self.__cinst_ostream = program_cinst_ostream + self.__xinst_ostream = program_xinst_ostream + self.__mem_model = mem_model + self.__supress_comments = supress_comments + self.__bundle_offset = 0 + self.__minst_line_offset = 0 + self.__cinst_line_offset = 0 + self.__xinst_line_offset = 0 + self.__kernel_count = 0 # Number of kernels linked into this program + self.__is_open = True # Tracks whether this program is still accepting kernels to link + + @property + def isOpen(self) -> bool: + """ + Checks if the program is open for linking new kernels. + + Returns: + bool: True if the program is open, False otherwise. + """ + return self.__is_open + + @property + def supressComments(self) -> bool: + """ + Checks if comments are suppressed in the output. + + Returns: + bool: True if comments are suppressed, False otherwise. + """ + return self.__supress_comments + + def close(self): + """ + Completes the program by terminating the queues with the correct exit code. + + Program will not accept new kernels to link after this call. + + Raises: + RuntimeError: If the program is already closed. + """ + if not self.isOpen: + raise RuntimeError('Program is already closed.') + + # Add closing `cexit` + tokens = [str(self.__cinst_line_offset), cinst.CExit.name] + cexit_cinstr = cinst.CExit(tokens) + print(f'{cexit_cinstr.tokens[0]}, {cexit_cinstr.to_line()}', file=self.__cinst_ostream) + + # Add closing msyncc + tokens = [str(self.__minst_line_offset), minst.MSyncc.name, str(self.__cinst_line_offset + 1)] + cmsyncc_minstr = minst.MSyncc(tokens) + print(f'{cmsyncc_minstr.tokens[0]}, {cmsyncc_minstr.to_line()}', end="", file=self.__minst_ostream) + if not self.supressComments: + print(' # terminating MInstQ', end="", file=self.__minst_ostream) + print(file=self.__minst_ostream) + + # Program has been closed + self.__is_open = False + + def __validateHBMAddress(self, var_name: str, hbm_address: int): + """ + Validates the HBM address for a variable. + + Parameters: + var_name (str): The name of the variable. + hbm_address (int): The HBM address to validate. + + Raises: + RuntimeError: If the HBM address is invalid or does not match the declared address. + """ + if hbm_address < 0: + raise RuntimeError(f'Invalid negative HBM address for variable "{var_name}".') + if var_name in self.__mem_model.mem_info_vars: + if self.__mem_model.mem_info_vars[var_name].hbm_address != hbm_address: + raise RuntimeError(('Declared HBM address ({}) of mem Variable "{}"' + ' differs from allocated HBM address ({}).').format(self.__mem_model.mem_info_vars[var_name].hbm_address, + var_name, + hbm_address)) + + def __validateSPADAddress(self, var_name: str, spad_address: int): + # only available when no HBM + assert not GlobalConfig.hasHBM + + # this method will validate the variable SPAD address against the + # original HBM address, since ther is no HBM + if spad_address < 0: + raise RuntimeError(f'Invalid negative SPAD address for variable "{var_name}".') + if var_name in self.__mem_model.mem_info_vars: + if self.__mem_model.mem_info_vars[var_name].hbm_address != spad_address: + raise RuntimeError(('Declared HBM address ({}) of mem Variable "{}"' + ' differs from allocated HBM address ({}).').format(self.__mem_model.mem_info_vars[var_name].hbm_address, + var_name, + spad_address)) + + def __updateMInsts(self, kernel_minstrs: list): + """ + Updates the MInsts in the kernel to offset to the current expected + synchronization points, and convert variable placeholders/names into + the corresponding HBM address. + + All MInsts in the kernel are expected to synchronize with CInsts starting at line 0. + Does not change the `LinkedProgram` object. + + Parameters: + kernel_minstrs (list): List of MInstructions to update. + """ + for minstr in kernel_minstrs: + # Update msyncc + if isinstance(minstr, minst.MSyncc): + minstr.target = minstr.target + self.__cinst_line_offset + # Change mload variable names into HBM addresses + if isinstance(minstr, minst.MLoad): + var_name = minstr.source + hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) + self.__validateHBMAddress(var_name, hbm_address) + minstr.source = str(hbm_address) + minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" + # Change mstore variable names into HBM addresses + if isinstance(minstr, minst.MStore): + var_name = minstr.dest + hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) + self.__validateHBMAddress(var_name, hbm_address) + minstr.dest = str(hbm_address) + minstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{minstr.comment}" if minstr.comment else "" + + def __updateCInsts(self, kernel_cinstrs: list): + """ + Updates the CInsts in the kernel to offset to the current expected bundle + and synchronization points. + + All CInsts in the kernel are expected to start at bundle 0, and to + synchronize with MInsts starting at line 0. + Does not change the `LinkedProgram` object. + + Parameters: + kernel_cinstrs (list): List of CInstructions to update. + """ + + if not GlobalConfig.hasHBM: + # Remove csyncm instructions + i = 0 + current_bundle = 0 + csyncm_count = 0 # Used by 1st code block: plz remove if second code block ends up being the one used + while i < len(kernel_cinstrs): + cinstr = kernel_cinstrs[i] + cinstr.tokens[0] = i # Update the line number + + #------------------------------ + # This code block will remove csyncm instructions and keep track, + # later adding their throughput into a cnop instruction before + # a new bundle is fetched. + + if isinstance(cinstr, cinst.CNop): + # Add the missing cycles to any cnop we encounter up to this point + cinstr.cycles += (csyncm_count * ISACInst.CSyncM.Throughput) + csyncm_count = 0 # Idle cycles to account for the csyncm have been added + + if isinstance(cinstr, (cinst.IFetch, cinst.NLoad, cinst.BLoad)): + if csyncm_count > 0: + # Extra cycles needed before scheduling next bundle + cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(csyncm_count * ISACInst.CSyncM.Throughput - 1)]) # Subtract 1 because cnop n, waits for n+1 cycles + kernel_cinstrs.insert(i, cinstr_nop) + csyncm_count = 0 # Idle cycles to account for the csyncm have been added + i += 1 + if isinstance(cinstr, cinst.IFetch): + current_bundle = cinstr.bundle + 1 + cinstr.tokens[0] = i # Update the line number + + if isinstance(cinstr, cinst.CSyncm): + # Remove instruction + kernel_cinstrs.pop(i) + if current_bundle > 0: + csyncm_count += 1 + else: + i += 1 # Next instruction + + #------------------------------ + # This code block differs from previous in that csyncm instructions + # are replaced in place by cnops with the corresponding throughput. + # This may result in several continuous cnop instructions, so, + # the cnop merging code afterwards is needed to remove this side effect + # if contiguous cnops are not desired. + + # if isinstance(cinstr, cinst.IFetch): + # current_bundle = cinstr.bundle + 1 + # + # if isinstance(cinstr, cinst.CSyncm): + # # replace instruction by cnop + # kernel_cinstrs.pop(i) + # if current_bundle > 0: + # cinstr_nop = cinst.CNop([i, cinst.CNop.name, str(ISACInst.CSyncM.Throughput)]) # Subtract 1 because cnop n, waits for n+1 cycles + # kernel_cinstrs.insert(i, cinstr_nop) + # + # i += 1 # next instruction + + # Merge continuous cnop + i = 0 + while i < len(kernel_cinstrs): + cinstr = kernel_cinstrs[i] + cinstr.tokens[0] = i # Update the line number + + if isinstance(cinstr, cinst.CNop): + # Do look ahead + if i + 1 < len(kernel_cinstrs): + if isinstance(kernel_cinstrs[i + 1], cinst.CNop): + kernel_cinstrs[i + 1].cycles += (cinstr.cycles + 1) # Add 1 because cnop n, waits for n+1 cycles + kernel_cinstrs.pop(i) + i -= 1 + i += 1 + + for cinstr in kernel_cinstrs: + # Update ifetch + if isinstance(cinstr, cinst.IFetch): + cinstr.bundle = cinstr.bundle + self.__bundle_offset + # Update xinstfetch + if isinstance(cinstr, cinst.XInstFetch): + raise NotImplementedError('`xinstfetch` not currently supported by linker.') + # Update csyncm + if isinstance(cinstr, cinst.CSyncm): + cinstr.target = cinstr.target + self.__minst_line_offset + + if not GlobalConfig.hasHBM: + # update all SPAD instruction variable names to be SPAD addresses + # change xload variable names into SPAD addresses + if isinstance(cinstr, (cinst.BLoad, cinst.BOnes, cinst.CLoad, cinst.NLoad)): + var_name = cinstr.source + hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) + self.__validateSPADAddress(var_name, hbm_address) + cinstr.source = str(hbm_address) + cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" + if isinstance(cinstr, cinst.CStore): + var_name = cinstr.dest + hbm_address = self.__mem_model.useVariable(var_name, self.__kernel_count) + self.__validateSPADAddress(var_name, hbm_address) + cinstr.dest = str(hbm_address) + cinstr.comment = f" var: {var_name} - HBM({hbm_address})" + f";{cinstr.comment}" if cinstr.comment else "" + + def __updateXInsts(self, kernel_xinstrs: list) -> int: + """ + Updates the XInsts in the kernel to offset to the current expected bundle. + + All XInsts in the kernel are expected to start at bundle 0. + Does not change the `LinkedProgram` object. + + Parameters: + kernel_xinstrs (list): List of XInstructions to update. + + Returns: + int: The last bundle number after updating. + """ + last_bundle = self.__bundle_offset + for xinstr in kernel_xinstrs: + xinstr.bundle = xinstr.bundle + self.__bundle_offset + if last_bundle > xinstr.bundle: + raise RuntimeError(f'Detected invalid bundle. Instruction bundle is less than previous: "{xinstr.to_line()}"') + last_bundle = xinstr.bundle + return last_bundle + + def linkKernel(self, + kernel_minstrs: list, + kernel_cinstrs: list, + kernel_xinstrs: list): + """ + Links a specified kernel (given by its three instruction queues) into this + program. + + The adjusted kernels will be appended into the output streams specified during + construction of this object. + + Parameters: + kernel_minstrs (list): List of MInstructions for the MInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. + kernel_cinstrs (list): List of CInstructions for the CInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. + kernel_xinstrs (list): List of XInstructions for the XInst Queue corresponding to the kernel to link. + These instructions will be modified by this method. + + Raises: + RuntimeError: If the program is closed and does not accept new kernels. + """ + if not self.isOpen: + raise RuntimeError('Program is closed and does not accept new kernels.') + + # No minsts without HBM + if not GlobalConfig.hasHBM: + kernel_minstrs = [] + + self.__updateMInsts(kernel_minstrs) + self.__updateCInsts(kernel_cinstrs) + self.__bundle_offset = self.__updateXInsts(kernel_xinstrs) + 1 + + # Append the kernel to the output + + for xinstr in kernel_xinstrs: + print(xinstr.to_line(), end="", file=self.__xinst_ostream) + if not self.supressComments and xinstr.comment: + print(f' #{xinstr.comment}', end="", file=self.__xinst_ostream) + print(file=self.__xinst_ostream) + + for idx, cinstr in enumerate(kernel_cinstrs[:-1]): # Skip the `cexit` + line_no = idx + self.__cinst_line_offset + print(f'{line_no}, {cinstr.to_line()}', end="", file=self.__cinst_ostream) + if not self.supressComments and cinstr.comment: + print(f' #{cinstr.comment}', end="", file=self.__cinst_ostream) + print(file=self.__cinst_ostream) + + for idx, minstr in enumerate(kernel_minstrs[:-1]): # Skip the exit `msyncc` + line_no = idx + self.__minst_line_offset + print(f'{line_no}, {minstr.to_line()}', end="", file=self.__minst_ostream) + if not self.supressComments and minstr.comment: + print(f' #{minstr.comment}', end="", file=self.__minst_ostream) + print(file=self.__minst_ostream) + + self.__minst_line_offset += (len(kernel_minstrs) - 1) # Subtract last line that is getting removed + self.__cinst_line_offset += (len(kernel_cinstrs) - 1) # Subtract last line that is getting removed + self.__kernel_count += 1 # Count the appended kernel \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py new file mode 100644 index 00000000..69457786 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/linker/steps/variable_discovery.py @@ -0,0 +1,65 @@ +from assembler.memory_model.variable import Variable +from linker.instructions import minst, cinst +from linker.instructions.minst.minstruction import MInstruction +from linker.instructions.cinst.cinstruction import CInstruction + +def discoverVariablesSPAD(cinstrs: list): + """ + Finds Variable names used in a list of CInstructions. + + Attributes: + cinstrs (list[CInstruction]): + List of CInstructions where to find variable names. + Raises: + RuntimeError: + Invalid Variable name detected in an CInstruction. + Returns: + Iterable: + Yields an iterable over variable names identified in the listing + of CInstructions specified. + """ + for idx, cinstr in enumerate(cinstrs): + if not isinstance(cinstr, CInstruction): + raise TypeError(f'Item {idx} in list of MInstructions is not a valid MInstruction.') + retval = None + # TODO: Implement variable counting for CInst + ############### + # Raise NotImplementedError("Implement variable counting for CInst") + if isinstance(cinstr, (cinst.BLoad, cinst.CLoad, cinst.BOnes, cinst.NLoad)): + retval = cinstr.source + elif isinstance(cinstr, cinst.CStore): + retval = cinstr.dest + + if retval is not None: + if not Variable.validateName(retval): + raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {cinstr.to_line()}"') + yield retval + +def discoverVariables(minstrs: list): + """ + Finds variable names used in a list of MInstructions. + + Parameters: + minstrs (list[MInstruction]): List of MInstructions where to find variable names. + + Raises: + TypeError: If an item in the list is not a valid MInstruction. + RuntimeError: If an invalid variable name is detected in an MInstruction. + + Returns: + Iterable: Yields an iterable over variable names identified in the listing + of MInstructions specified. + """ + for idx, minstr in enumerate(minstrs): + if not isinstance(minstr, MInstruction): + raise TypeError(f'Item {idx} in list of MInstructions is not a valid MInstruction.') + retval = None + if isinstance(minstr, minst.MLoad): + retval = minstr.source + elif isinstance(minstr, minst.MStore): + retval = minstr.dest + + if retval is not None: + if not Variable.validateName(retval): + raise RuntimeError(f'Invalid Variable name "{retval}" detected in instruction "{idx}, {minstr.to_line()}"') + yield retval \ No newline at end of file diff --git a/assembler_tools/hec-assembler-tools/requirements.txt b/assembler_tools/hec-assembler-tools/requirements.txt new file mode 100644 index 00000000..8105d54f --- /dev/null +++ b/assembler_tools/hec-assembler-tools/requirements.txt @@ -0,0 +1,14 @@ +contourpy==1.0.7 +cycler==0.11.0 +fonttools==4.39.0 +importlib-resources==5.12.0 +kiwisolver==1.4.4 +matplotlib==3.7.1 +networkx==3.0 +numpy==1.24.2 +packaging==23.0 +Pillow==10.0.1 +python-dateutil==2.8.2 +PyYAML==6.0 +six==1.16.0 +zipp==3.15.0 From 9887c6955edad49120bd4d346d314e398d5af449 Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Thu, 22 May 2025 22:54:10 +0000 Subject: [PATCH 02/12] Update README.md --- assembler_tools/hec-assembler-tools/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/README.md b/assembler_tools/hec-assembler-tools/README.md index 90d45b38..9f6214e5 100644 --- a/assembler_tools/hec-assembler-tools/README.md +++ b/assembler_tools/hec-assembler-tools/README.md @@ -1,6 +1,6 @@ -# HERACLES Code Generation Framework User Guide +# HERACLES Code Generation Framework (Reference Implementation) -This tool, also known as the "assembler", takes a pre-generated Polynomial Instruction Set Architecture (P-ISA) kernel containing instructions that use an abstract, flat memory model for polynomial operations, such as those applied in homomorphic encryption (HE), and maps them to a corresponding set of instructions compatible with the HERACLES architecture, accounting for hardware restrictions, including memory management for the HERACLES memory model. +The tools in this directory are the reference implementation of and Assembler codegenerator that takes a pre-generated Polynomial Instruction Set Architecture (P-ISA) program containing instructions that use an abstract, flat memory model for polynomial operations, such as those applied in homomorphic encryption (HE), and maps them to a corresponding set of instructions compatible with the HERACLES architecture, accounting for hardware restrictions, including memory management for the HERACLES memory model. ## Table of Contents 1. [Dependencies](#dependencies) From cc1f7f464d127793739a3784aaaf38ac9b703270 Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Thu, 22 May 2025 19:15:56 -0400 Subject: [PATCH 03/12] update pre-commit configuration Signed-off-by: Flavio Bergamaschi --- .pre-commit-config.yaml | 4 + .../hec-assembler-tools/.gitignore | 205 ------------------ .../.pre-commit-config.yaml | 17 -- 3 files changed, 4 insertions(+), 222 deletions(-) delete mode 100644 assembler_tools/hec-assembler-tools/.gitignore delete mode 100644 assembler_tools/hec-assembler-tools/.pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 943d7bb9..72b98915 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,6 +38,10 @@ repos: - HEADER - --comment-style - // # defaults to: # + - id: remove-tabs + name: remove-tabs + files: \.(py)$ + args: [--whitespaces-count, '4'] - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.4.0 # Updated 2024/04 hooks: diff --git a/assembler_tools/hec-assembler-tools/.gitignore b/assembler_tools/hec-assembler-tools/.gitignore deleted file mode 100644 index 5daf1d05..00000000 --- a/assembler_tools/hec-assembler-tools/.gitignore +++ /dev/null @@ -1,205 +0,0 @@ -#======================== -# Intermediate and Output files -#======================== - -*.tmp -*.temp -*.mem -*.minst -*.cinst -*.xinst -*.csv -*.out -#*# -*~ -tmp/ - - -# Local files -#======================== - -*.yml -*.pyc -*.bak -*.pkl -*.lock -*.swp -tfedlrn.egg-info/ -bin/out - -#======================== -# Generated docs -#======================== - -*.htm -*.html -*.pdf -html -latex -[Dd]ocs/ - - -#======================== -# Eclipse & PyDev intermediate files -#======================== - -.metadata/ -RemoteSystemsTempFiles/ -.settings -.project -.pydevproject - -#======================== -# Visual Studio -#======================== - -# User-specific files -*.rsuser -*.suo -*.user -*.userosscache -*.sln.docstates - -# User-specific files (MonoDevelop/Xamarin Studio) -*.userprefs - -# Mono auto generated files -mono_crash.* - -# Visual Studio cache/options directory -.vs/ -.vscode/ - -#======================== -# Python general stuff -#======================== - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ diff --git a/assembler_tools/hec-assembler-tools/.pre-commit-config.yaml b/assembler_tools/hec-assembler-tools/.pre-commit-config.yaml deleted file mode 100644 index 6e479482..00000000 --- a/assembler_tools/hec-assembler-tools/.pre-commit-config.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (C) 2023 Intel Corporation - -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 # Updated 2023/02 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-merge-conflict - - id: mixed-line-ending - - id: check-yaml - - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.1.15 - hooks: - - id: remove-tabs - files: \.(py)$ - args: [--whitespaces-count, '4'] From 3ec5b76d9392e51871e8c6d512e41fbb4e6a823b Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Thu, 22 May 2025 19:17:49 -0400 Subject: [PATCH 04/12] update .gitignore configuration Signed-off-by: Flavio Bergamaschi --- .gitignore | 205 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..5daf1d05 --- /dev/null +++ b/.gitignore @@ -0,0 +1,205 @@ +#======================== +# Intermediate and Output files +#======================== + +*.tmp +*.temp +*.mem +*.minst +*.cinst +*.xinst +*.csv +*.out +#*# +*~ +tmp/ + + +# Local files +#======================== + +*.yml +*.pyc +*.bak +*.pkl +*.lock +*.swp +tfedlrn.egg-info/ +bin/out + +#======================== +# Generated docs +#======================== + +*.htm +*.html +*.pdf +html +latex +[Dd]ocs/ + + +#======================== +# Eclipse & PyDev intermediate files +#======================== + +.metadata/ +RemoteSystemsTempFiles/ +.settings +.project +.pydevproject + +#======================== +# Visual Studio +#======================== + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Visual Studio cache/options directory +.vs/ +.vscode/ + +#======================== +# Python general stuff +#======================== + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ From 9ff5e8d5e32ee13bdad10af8aff149867a7cbb91 Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 13:50:54 +0000 Subject: [PATCH 05/12] Delete assembler_tools/hec-assembler-tools/CODEOWNERS Delete this CODEOWNERS file as it is not required in a sub-directory --- assembler_tools/hec-assembler-tools/CODEOWNERS | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 assembler_tools/hec-assembler-tools/CODEOWNERS diff --git a/assembler_tools/hec-assembler-tools/CODEOWNERS b/assembler_tools/hec-assembler-tools/CODEOWNERS deleted file mode 100644 index ffa15bdf..00000000 --- a/assembler_tools/hec-assembler-tools/CODEOWNERS +++ /dev/null @@ -1,2 +0,0 @@ -# Default codeowners for all files -* @faberga @ChrisWilkerson @sidezrw @jlhcrawford @hamishun @kylanerace @jobottle From ebb902823dacaf3343b3344b84435e9f4c84ebcc Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 19:44:06 +0000 Subject: [PATCH 06/12] Update README.md Update README.md to be inline with latest codebase --- assembler_tools/hec-assembler-tools/README.md | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/README.md b/assembler_tools/hec-assembler-tools/README.md index 9f6214e5..862dc260 100644 --- a/assembler_tools/hec-assembler-tools/README.md +++ b/assembler_tools/hec-assembler-tools/README.md @@ -9,7 +9,6 @@ The tools in this directory are the reference implementation of and Assembler co 1. [Assembler Instruction Specs](#asm_specs) 4. [Executing the Assembler](#executing_asm) 1. [Running for a Pre-Generated Kernel](#executing_single) - 2. [Running for a Batch of Operations](#executing_batch) 5. [Debug Tools](./debug_tools/README.md) ## Dependencies @@ -112,28 +111,3 @@ python3 he_prep.py -h python3 he_as.py -h python3 he_link.py -h ``` - -### Running for a Batch of Operations - -This project provides script `gen_he_ops.py` that allows for assembling a batch of P-ISA kernels generated for HE operations. It calls the generator script internally to generate a batch of kernels, and then runs them through the assembler. - -Since the script to generate P-ISA kernels resides in another repo (HERACLES-SEAL-isa-mapping), we must specify the location of the cloned external repo using the environment variable `HERACLES_MAPPING_PATH`. Correctly setting this variable should result in the following path being valid: `$HERACLES_MAPPING_PATH/kernels/run_he_ops.py` . - -Provided script, `gen_he_ops.py`, takes in a YAML configuration file that specifies parameters for operations to assemble. To obtain a template for the configuration file, use the script itself with the `--dump` command line flag. Use `-h` flag for more information. - -```bash -# save template for configuration file to ./config.yaml -python3 gen_he_ops.py config.yaml --dump -``` - -Set the parameters in the configuration file to match your needs and then execute the script as shown below (code for Linux terminal). - -```bash -# env variable pointing to HERACLES-SEAL-isa-mapping -export HERACLES_MAPPING_PATH=/path/to/HERACLES-SEAL-isa-mapping -python3 gen_he_ops.py config.yaml -``` - -Based on your chosen configuration, this will generate kernels, run them through the assembler and place all the outputs (and intermediate files) in the output directory specified in the configuration file. - -This way of executing is mostly intended for testing purposes. From 1d9652d1ba766047bfd4c949652efd98104f6a40 Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 19:56:50 +0000 Subject: [PATCH 07/12] Update README.md Update README.md with correct reference to the PGM --- assembler_tools/hec-assembler-tools/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assembler_tools/hec-assembler-tools/README.md b/assembler_tools/hec-assembler-tools/README.md index 862dc260..26b3d3a4 100644 --- a/assembler_tools/hec-assembler-tools/README.md +++ b/assembler_tools/hec-assembler-tools/README.md @@ -37,7 +37,7 @@ The assembler framework requires two inputs: Kernels and metadata are structured in comma-separated value (csv) files. -P-ISA kernels, along with corresponding memory metadata required as input to the assembler, are generated by Python script `HERACLES-SEAL-isa-mapping/kernels/run_he_op.py` in the repo [HERACLES-SEAL-isa-mapping](https://github.com/IntelLabs/HERACLES-SEAL-isa-mapping) +P-ISA kernels, along with corresponding memory metadata required as input to the assembler, are generated by the upper layers in the Encrypted Computing SDK stack, e.g. the [Program Mapper](../../README.md#encrypted-computing-sdk-phase-1-components-and-tasks) component of the [p-ISA tools](../../p-isa_tools). ## Outputs From 4c77016c31ced89f80014e48a8427a9bd23074be Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 20:18:10 +0000 Subject: [PATCH 08/12] Delete assembler_tools/hec-assembler-tools/gen_he_ops.py Not required after for the new assembler --- .../hec-assembler-tools/gen_he_ops.py | 264 ------------------ 1 file changed, 264 deletions(-) delete mode 100644 assembler_tools/hec-assembler-tools/gen_he_ops.py diff --git a/assembler_tools/hec-assembler-tools/gen_he_ops.py b/assembler_tools/hec-assembler-tools/gen_he_ops.py deleted file mode 100644 index 364b3290..00000000 --- a/assembler_tools/hec-assembler-tools/gen_he_ops.py +++ /dev/null @@ -1,264 +0,0 @@ -import argparse -import io -import os -import pathlib -import subprocess -import sys -import yaml - -from assembler.common.constants import Constants -from assembler.common.run_config import RunConfig -import he_prep as preproc -import he_as as asm -import he_link as linker - -# module constants -DEFAULT_OPERATIONS = Constants.OPERATIONS[:6] - -class GenRunConfig(RunConfig): - """ - Maintains the configuration data for the run. - """ - - __initialized = False # specifies whether static members have been initialized - # contains the dictionary of all configuration items supported and their - # default value (or None if no default) - __default_config = {} - - def __init__(self, **kwargs): - """ - Constructs a new GenRunConfig Object from input parameters. - - See base class constructor for more arguments. - - Parameters - ---------- - scheme: str - FHE Scheme to use - - N: int - Ring dimension: PMD = 2^N. - - min_nrns: int - Minimum number of residuals. - - max_nrns: int - Maximum number of residuals. - - key_nrns: int - Optional number of residuals for relinearization keys. Must be greater than `max_nrns`. - If missing, the `key_nrns` for each P-ISA kernel generated will be set to the kernel - `nrns` (number of residuals) + 1. - - op_list: list[str] - Optional list of name of operations to generate. If provided, it must be a non-empty - subset of `Constants.OPERATIONS`. - Defaults to `DEFAULT_OPERATIONS`. - - output_dir: str - Optional directory where to store all intermediate files and final output. - This will be created if it doesn't exists. - Defaults to /lib. - - Raises - ------ - TypeError - A mandatory configuration value was missing. - - ValueError - At least, one of the arguments passed is invalid. - """ - - self.__init_statics() - - super().__init__(**kwargs) - - for config_name, default_value in self.__default_config.items(): - assert(not hasattr(self, config_name)) - setattr(self, config_name, kwargs.get(config_name, default_value)) - if getattr(self, config_name) is None: - raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') - - for op in self.op_list: - if op not in Constants.OPERATIONS: - raise ValueError('Invalid operation in input list of ops "{}". Expected one of {}'.format(op, Constants.OPERATIONS)) - - if self.key_nrns > 0: - if self.key_nrns < self.max_nrns: - raise ValueError(('`key_nrns` must be greater than `max_nrns` when present. ' - 'Received {}, but expected greater than {}.').format(self.key_nrns, - self.max_nrns)) - - @classmethod - def __init_statics(cls): - if not cls.__initialized: - cls.__default_config["scheme"] = "bgv" - cls.__default_config["N"] = None - cls.__default_config["min_nrns"] = None - cls.__default_config["max_nrns"] = None - cls.__default_config["key_nrns"] = 0 - cls.__default_config["output_dir"] = os.path.join(pathlib.Path.cwd(), "lib") - cls.__default_config["op_list"] = DEFAULT_OPERATIONS - - cls.__initialized = True - - def __str__(self): - """ - Returns a string representation of the configuration. - """ - self_dict = self.as_dict() - with io.StringIO() as retval_f: - for key, value in self_dict.items(): - print("{}: {}".format(key, value), file=retval_f) - retval = retval_f.getvalue() - return retval - - def as_dict(self) -> dict: - retval = super().as_dict() - tmp_self_dict = vars(self) - retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) - return retval - -def main(config: GenRunConfig, - b_verbose: bool = False): - - lib_dir = config.output_dir - - # create output directory to store outputs (if it doesn't already exist) - pathlib.Path(lib_dir).mkdir(exist_ok = True, parents=True) - - # point to the HERACLES-SEAL-isa-mapping repo - home_dir = pathlib.Path.home() - mapping_dir = os.getenv("HERACLES_MAPPING_PATH", os.path.join(home_dir, "HERACLES/HERACLES-SEAL-isa-mapping")) - # command to run the mapping script to generate operations kernels for our input - #generate_cmd = 'python3 "{}"'.format(os.path.join(mapping_dir, "kernels/run_he_op.py")) - generate_cmd = ['python3', '{}'.format(os.path.join(mapping_dir, "kernels/run_he_op.py"))] - - assert config.N < 1024 - assert config.min_nrns > 1 - assert (config.key_nrns == 0 or config.key_nrns > config.max_nrns) - assert(all(op in Constants.OPERATIONS for op in config.op_list)) - - pdegree = 2 ** config.N - for op in config.op_list: - for rn_el in range(config.min_nrns, config.max_nrns + 1): - key_nrns = config.key_nrns if config.key_nrns > 0 else rn_el + 1 - print(f"{config.scheme} {op} {config.N} {rn_el} {key_nrns}") - - output_prefix = "t.{}.{}.{}.{}".format(rn_el,op,config.N,key_nrns) - basef = os.path.join(lib_dir, output_prefix) - memfile = basef + ".tw.mem" - generate_cmdln = generate_cmd + [ "--map-file" , memfile ] + [ str(x) for x in (config.scheme, op, pdegree, rn_el, key_nrns) ] - - csvfile = basef + ".csv" - - # call the external script to generate the kernel for this op - print(' '.join(generate_cmdln)) - with open(csvfile, 'w') as fout_csv: - run_result = subprocess.run(generate_cmdln, stdout=fout_csv) - if run_result.returncode != 0: - raise RuntimeError('Exit code: {}. Failure to complete kernel generation successfully.'.format(run_result.returncode)) - - - # pre-process kernel step - #------------------------- - - # generate twiddle factors for this kernel - basef = basef + ".tw" #use the newly generated twiddle file - print() - print("Preprocessing") - preproc.main(basef + ".csv", - csvfile, - b_verbose=b_verbose) - - # assembling step - #----------------- - - # prepare config for assembler - asm_config = asm.AssemblerRunConfig(input_file=basef + ".csv", - input_mem_file=memfile, - output_prefix=output_prefix + '.o', - **config.as_dict()) # convert config to a dictionary and expand it as arguments - # temp path to store assembled output before linking set - asm_config.output_dir = os.path.join(asm_config.output_dir, 'obj') - print() - print("Assembling") - # run the assembler for this file - asm.main(asm_config, verbose=b_verbose) - - # linking step - #-------------- - - # prepare config for linker - linker_config = linker.LinkerRunConfig(input_prefixes = [os.path.join(asm_config.output_dir, asm_config.output_prefix)], - input_mem_file=memfile, - output_prefix=output_prefix, - **config.as_dict()) # convert config to a dictionary and expand it as arguments - print() - print("Linking") - # run the linker on the assembler output - linker.main(linker_config, sys.stdout if b_verbose else None) - - print(f'Completed "{output_prefix}"') - print() - -def parse_args(): - parser = argparse.ArgumentParser(description=("Generates a collection of HE operations based on input configuration."), - epilog=("To use, users should dump a default configuration file. Edit the file to " - "match the needs for the run, then execute the program with the modified " - "configuration. Note that dumping on top of an existing file will overwrite " - "its contents.")) - parser.add_argument("config_file", help=("YAML configuration file.")) - parser.add_argument("--dump", action="store_true", - help=("A default configuration will be writen into the file specified by `config_file`. " - "If the file already exists, it will be overwriten.")) - parser.add_argument("-v", "--verbose", dest="verbose", action="store_true", - help="If enabled, extra information and progress reports are printed to stdout.") - args = parser.parse_args() - - return args - -def readYAMLConfig(input_filename: str): - """ - Reads in a YAML file and returns a GenRunConfig object parsed from it. - """ - retval_dict = {} - with open(input_filename, "r") as infile: - retval_dict = yaml.safe_load(infile) - - return GenRunConfig(**retval_dict) - -def writeYAMLConfig(output_filename: str, config: GenRunConfig): - """ - Outputs the specified configuration to a YAML file. - """ - with open(output_filename, "w") as outfile: - yaml.dump(vars(config), outfile, sort_keys=False) - -if __name__ == "__main__": - module_name = os.path.basename(__file__) - print(module_name) - print() - - args = parse_args() - - if args.dump: - print("Writing default configuration to") - print(" ", args.config_file) - default_config = GenRunConfig(N=15, min_nrns=2, max_nrns=18) - writeYAMLConfig(args.config_file, default_config) - else: - print("Loading configuration file:") - print(" ", args.config_file) - config = readYAMLConfig(args.config_file) - print() - print("Gen Run Configuration") - print("=====================") - print(config) - print("=====================") - print() - main(config, - b_verbose=args.verbose) - - print() - print(module_name, "- Complete") From 104f5416853a7e66b8f3d9b55ca62b54684a153e Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 20:20:19 +0000 Subject: [PATCH 09/12] Delete assembler_tools/hec-assembler-tools/atomic_tester.py Not required in the new assembler --- .../hec-assembler-tools/atomic_tester.py | 261 ------------------ 1 file changed, 261 deletions(-) delete mode 100644 assembler_tools/hec-assembler-tools/atomic_tester.py diff --git a/assembler_tools/hec-assembler-tools/atomic_tester.py b/assembler_tools/hec-assembler-tools/atomic_tester.py deleted file mode 100644 index c3ae6929..00000000 --- a/assembler_tools/hec-assembler-tools/atomic_tester.py +++ /dev/null @@ -1,261 +0,0 @@ -import argparse -import io -import os -import pathlib -import subprocess -import yaml - -from assembler.common.constants import Constants -from assembler.common.run_config import RunConfig -import he_prep as preproc -import he_as as asm - -# module constants -DEFAULT_OPERATIONS = Constants.OPERATIONS[:6] - -class GenRunConfig(RunConfig): - """ - Maintains the configuration data for the run. - """ - - SCHEMES = [ 'bgv', 'ckks' ] - DEFAULT_SCHEME = SCHEMES[0] - - __initialized = False # specifies whether static members have been initialized - # contains the dictionary of all configuration items supported and their - # default value (or None if no default) - __default_config = {} - - def __init__(self, **kwargs): - """ - Constructs a new GenRunConfig Object from input parameters. - - See base class constructor for more arguments. - - Parameters - ---------- - N: int - Ring dimension: PMD = 2^N. - - min_nrns: int - Minimum number of residuals. - - max_nrns: int - Maximum number of residuals. - - key_nrns: int - Optional number of residuals for relinearization keys. Must be greater than `max_nrns`. - If missing, the `key_nrns` for each P-ISA kernel generated will be set to the kernel - `nrns` (number of residuals) + 1. - - scheme: str - FHE Scheme to use. Must be one of the schemes in `GenRunConfig.SCHEMES`. - Defaults to `GenRunConfig.DEFAULT_SCHEME`. - - op_list: list[str] - Optional list of name of operations to generate. If provided, it must be a non-empty - subset of `Constants.OPERATIONS`. - Defaults to `DEFAULT_OPERATIONS`. - - output_dir: str - Optional directory where to store all intermediate files and final output. - This will be created if it doesn't exists. - Defaults to /lib. - - Raises - ------ - TypeError - A mandatory configuration value was missing. - - ValueError - At least, one of the arguments passed is invalid. - """ - - self.__init_statics() - - super().__init__(**kwargs) - - for config_name, default_value in self.__default_config.items(): - assert(not hasattr(self, config_name)) - setattr(self, config_name, kwargs.get(config_name, default_value)) - if getattr(self, config_name) is None: - raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') - - if self.scheme not in self.SCHEMES: - raise ValueError('Invalid acheme "{}". Expected one of {}'.format(self.scheme, self.SCHEMES)) - - for op in self.op_list: - if op not in Constants.OPERATIONS: - raise ValueError('Invalid operation in input list of ops "{}". Expected one of {}'.format(op, Constants.OPERATIONS)) - - if self.key_nrns > 0: - if self.key_nrns < self.max_nrns: - raise ValueError(('`key_nrns` must be greater than `max_nrns` when present. ' - 'Received {}, but expected greater than {}.').format(self.key_nrns, - self.max_nrns)) - - @classmethod - def __init_statics(cls): - if not cls.__initialized: - cls.__default_config["N"] = None - cls.__default_config["min_nrns"] = None - cls.__default_config["max_nrns"] = None - cls.__default_config["key_nrns"] = 0 - cls.__default_config["scheme"] = cls.DEFAULT_SCHEME - cls.__default_config["output_dir"] = os.path.join(pathlib.Path.cwd(), "lib") - cls.__default_config["op_list"] = DEFAULT_OPERATIONS - - cls.__initialized = True - - def __str__(self): - """ - Returns a string representation of the configuration. - """ - self_dict = self.as_dict() - with io.StringIO() as retval_f: - for key, value in self_dict.items(): - print("{}: {}".format(key, value), file=retval_f) - retval = retval_f.getvalue() - return retval - - def as_dict(self) -> dict: - retval = super().as_dict() - tmp_self_dict = vars(self) - retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) - return retval - -def main(config: GenRunConfig, - b_verbose: bool = False): - - lib_dir = config.output_dir - - # create output directory to store outputs (if it doesn't already exist) - pathlib.Path(lib_dir).mkdir(exist_ok = True, parents=True) - - # point to the HERACLES-SEAL-isa-mapping repo - home_dir = pathlib.Path.home() - mapping_dir = os.getenv("HERACLES_MAPPING_PATH", os.path.join(home_dir, "HERACLES/HERACLES-SEAL-isa-mapping")) - # command to run the mapping script to generate operations kernels for our input - #generate_cmd = 'python3 "{}"'.format(os.path.join(mapping_dir, "kernels/run_he_op.py")) - generate_cmd = ['python3', '{}'.format(os.path.join(mapping_dir, "kernels/run_he_op.py"))] - - assert config.N < 1024 - assert config.min_nrns > 1 - assert (config.key_nrns == 0 or config.key_nrns > config.max_nrns) - assert(all(op in Constants.OPERATIONS for op in config.op_list)) - - pdegree = 2 ** config.N - regenerate_string = "" - for op in config.op_list: - for rn_el in range(config.min_nrns, config.max_nrns + 1): - key_nrns = config.key_nrns if config.key_nrns > 0 else rn_el + 1 - regenerate_string = "" - print("{} {} {} {} {}".format(config.scheme, op, config.N, rn_el, key_nrns)) - output_prefix = "t.{}.{}.{}".format(rn_el, op, config.N) - basef = os.path.join(lib_dir, output_prefix) - generate_cmdln = generate_cmd + [ str(x) for x in (config.scheme, op, pdegree, rn_el, key_nrns) ] - - csvfile = basef + ".csv" - memfile = basef + ".tw.mem" - - # call the external script to generate the kernel for this op - print(' '.join(generate_cmdln)) - run_result = subprocess.run(generate_cmdln, stdout=subprocess.PIPE) - if run_result.returncode != 0: - raise RuntimeError('Exit code: {}. Failure to complete kernel generation successfully.'.format(run_result.returncode)) - - # interpret output into correct kernel and mem files - merged_output = run_result.stdout.decode().splitlines() - with open(csvfile, 'w') as fout_csv: - with open(memfile, 'w') as fout_mem: - for s_line in merged_output: - if s_line: - if s_line.startswith('dload') \ - or s_line.startswith('dstore'): - print(s_line, file=fout_mem) - else: - print(s_line, file=fout_csv) - - # pre-process kernel - - # generate twiddle factors for this kernel - basef = basef + ".tw" #use the newly generated twiddle file - print() - print("Preprocessing") - preproc.main(basef + ".csv", - csvfile, - b_verbose=b_verbose) - - # prepare config for assembler - asm_config = asm.AssemblerRunConfig(input_file=basef + ".csv", - input_mem_file=memfile, - output_prefix=output_prefix, - **config.as_dict()) # convert config to a dictionary and expand it as arguments - print() - print("Assembling") - # run the assembler for this file - asm.main(asm_config, verbose=b_verbose) - - print(f'Completed "{output_prefix}"') - print() - -def parse_args(): - parser = argparse.ArgumentParser(description=("Generates a collection of HE operations based on input configuration."), - epilog=("To use, users should dump a default configuration file. Edit the file to " - "match the needs for the run, then execute the program with the modified " - "configuration. Note that dumping on top of an existing file will overwrite " - "its contents.")) - parser.add_argument("config_file", help=("YAML configuration file.")) - parser.add_argument("--dump", action="store_true", - help=("A default configuration will be writen into the file specified by `config_file`. " - "If the file already exists, it will be overwriten.")) - parser.add_argument("-v", "--verbose", dest="verbose", action="store_true", - help="If enabled, extra information and progress reports are printed to stdout.") - args = parser.parse_args() - - return args - -def readYAMLConfig(input_filename: str): - """ - Reads in a YAML file and returns a GenRunConfig object parsed from it. - """ - retval_dict = {} - with open(input_filename, "r") as infile: - retval_dict = yaml.safe_load(infile) - - return GenRunConfig(**retval_dict) - -def writeYAMLConfig(output_filename: str, config: GenRunConfig): - """ - Outputs the specified configuration to a YAML file. - """ - with open(output_filename, "w") as outfile: - yaml.dump(vars(config), outfile, sort_keys=False) - -if __name__ == "__main__": - module_name = os.path.basename(__file__) - print(module_name) - print() - - args = parse_args() - - if args.dump: - print("Writing default configuration to") - print(" ", args.config_file) - default_config = GenRunConfig(N=15, min_nrns=2, max_nrns=18) - writeYAMLConfig(args.config_file, default_config) - else: - print("Loading configuration file:") - print(" ", args.config_file) - config = readYAMLConfig(args.config_file) - print() - print("Gen Run Configuration") - print("=====================") - print(config) - print("=====================") - print() - main(config, - b_verbose=args.verbose) - - print() - print(module_name, "- Complete") From a06589637d2b96634c21d03b8f4aa7ac3eaa6360 Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 20:25:56 +0000 Subject: [PATCH 10/12] Update README.md Remove old style execution instructions --- assembler_tools/hec-assembler-tools/README.md | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/README.md b/assembler_tools/hec-assembler-tools/README.md index 26b3d3a4..a410b92f 100644 --- a/assembler_tools/hec-assembler-tools/README.md +++ b/assembler_tools/hec-assembler-tools/README.md @@ -52,19 +52,6 @@ On a successful run, given a P-ISA kernel in file `filename.csv` (and correspond The format for the output files and instruction set can be found at [HCGF Instruction Specification](docsrc/specs.md). ## Executing the Assembler - -There are two ways to execute the assembler: - -- [Running on a pre-generated kernel](#executing_single): uses the main interface of the assembler to assemble a single pre-existing kernel. - - This method is intended for a production chain. - -or - -- [Running for a batch of kernels](#executing_batch): uses a provided script wrapper to generate a collection of kernels and runs them through the assembler. - - This method is intended for testing purposes as it generates test kernels using external tools before assembling. - ### Running for a Pre-Generated Kernel Given a P-ISA kernel (`filename.csv`) and corresponding memory mapping file (`filename.mem`), there are three steps to assemble them into HERACLES code. From 6f5aa2c4e66de7e7daa461c3f3c846acd337080c Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 20:49:10 +0000 Subject: [PATCH 11/12] Update changelog.md --- .../hec-assembler-tools/docsrc/changelog.md | 28 ++----------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/docsrc/changelog.md b/assembler_tools/hec-assembler-tools/docsrc/changelog.md index a9674c15..52c4ec04 100644 --- a/assembler_tools/hec-assembler-tools/docsrc/changelog.md +++ b/assembler_tools/hec-assembler-tools/docsrc/changelog.md @@ -1,28 +1,4 @@ # Changelog -### 2024-07-11 -- Updated SPAD capacity to reflect change from 64MB to 48MB. -- Updated the range of values for `cnop` parameters. -- Updated `rshuffle` to reflect slotting rules for 4KCE. - -### 2024-01-23 -- Updated `rshuffle` to reflect latency changes and rules for 4KCE. -- Updated `nload` to reflect the lack of support for multiple routing tables. -- Updated HBM capacity to reflect change from 64GB to 48GB. -- No keygen updates yet because the feature is work in progress. - -### 2023-07-25 -- Updated throughput for CInsts `cload`, and `nload`. -- Updated throughput and latency for `cstore`, and `bload`. -- Updated latency of `xstore`. -- Updated `rshuffle` to reflect the discontinuation of `wait_cyc` parameter. - -### 2023-07-24 -- Added XInstruction `sub`, required by CKKS scheme P-ISA kernels. - -### 2023-06-30 -- Updated latencies of XInstruction `rshuffle` based on Sim0.9 version. - -### 2023-06-12 -- XInstruction `exit` op name is now `bexit` to match the ISA spec, as required by Sim0.9. -- CInstructions `bload` and `bones` format changed to match philosophy of dests before sources. +### 20250523 +- First release of the HERACLES Code Generation Framework (Reference Implementation) From d86a8a04534326cd72e5a5f05d08f3f00216d019 Mon Sep 17 00:00:00 2001 From: Flavio Bergamaschi Date: Fri, 23 May 2025 20:50:07 +0000 Subject: [PATCH 12/12] Update changelog.md --- assembler_tools/hec-assembler-tools/docsrc/changelog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assembler_tools/hec-assembler-tools/docsrc/changelog.md b/assembler_tools/hec-assembler-tools/docsrc/changelog.md index 52c4ec04..f15ca368 100644 --- a/assembler_tools/hec-assembler-tools/docsrc/changelog.md +++ b/assembler_tools/hec-assembler-tools/docsrc/changelog.md @@ -1,4 +1,4 @@ # Changelog -### 20250523 +### 2025-05-23 - First release of the HERACLES Code Generation Framework (Reference Implementation)