From c9021da760ee124a744af91ba60690e7fd4983e0 Mon Sep 17 00:00:00 2001 From: cpunkzzz Date: Mon, 22 Aug 2022 11:29:18 -0400 Subject: [PATCH 01/12] Add raw bytes interface --- .vscode/settings.json | 27 +++++++++++++++++++- risc0/zkp/rust/src/core/sha.rs | 9 +++++++ risc0/zkvm/platform/io.h | 2 ++ risc0/zkvm/sdk/cpp/host/c_api.cpp | 7 ++++++ risc0/zkvm/sdk/cpp/host/c_api.h | 5 ++++ risc0/zkvm/sdk/cpp/host/receipt.cpp | 35 ++++++++++++++++++++++++++ risc0/zkvm/sdk/cpp/host/receipt.h | 2 ++ risc0/zkvm/sdk/rust/guest/src/env.rs | 14 ++++++++++- risc0/zkvm/sdk/rust/platform/src/io.rs | 1 + risc0/zkvm/sdk/rust/src/host/ffi.rs | 30 ++++++++++++++++++++++ 10 files changed, 130 insertions(+), 2 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 988f817d02..d9a52c68d8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -27,6 +27,31 @@ "ios": "cpp", "iosfwd": "cpp", "vector": "cpp", - "charconv": "cpp" + "charconv": "cpp", + "__hash_table": "cpp", + "__split_buffer": "cpp", + "__tree": "cpp", + "array": "cpp", + "bitset": "cpp", + "deque": "cpp", + "initializer_list": "cpp", + "list": "cpp", + "map": "cpp", + "queue": "cpp", + "random": "cpp", + "regex": "cpp", + "set": "cpp", + "span": "cpp", + "stack": "cpp", + "string": "cpp", + "string_view": "cpp", + "unordered_map": "cpp", + "unordered_set": "cpp", + "valarray": "cpp", + "iterator": "cpp", + "utility": "cpp", + "rope": "cpp", + "slist": "cpp", + "ranges": "cpp" } } diff --git a/risc0/zkp/rust/src/core/sha.rs b/risc0/zkp/rust/src/core/sha.rs index c26d1f2318..c3e98baa43 100644 --- a/risc0/zkp/rust/src/core/sha.rs +++ b/risc0/zkp/rust/src/core/sha.rs @@ -77,6 +77,15 @@ impl Digest { &self.0 } + /// Returns as a slice of be u8 + pub fn get_u8(&self) -> [u8; DIGEST_WORDS * 4] { + let mut res: [u8; DIGEST_WORDS * 4] = [0; DIGEST_WORDS * 4]; + for i in 0..DIGEST_WORDS { + res[4 * i..][..4].copy_from_slice(&self.0[i].to_be_bytes()); + } + res + } + /// Returns a mutable slice of words. pub fn get_mut(&mut self) -> &mut [u32; DIGEST_WORDS] { &mut self.0 diff --git a/risc0/zkvm/platform/io.h b/risc0/zkvm/platform/io.h index 90a1e7aa4f..83e20751c2 100644 --- a/risc0/zkvm/platform/io.h +++ b/risc0/zkvm/platform/io.h @@ -35,6 +35,8 @@ constexpr uint32_t kSendRecvChannel_InitialInput = 0; constexpr uint32_t kSendRecvChannel_Stdout = 1; // Write bytes to standard error constexpr uint32_t kSendRecvChannel_Stderr = 2; +// Request aux tape to the guest +constexpr uint32_t kSendRecvChannel_InitialInputAux = 3; // To invoke accelerated SHA, the guest writes ShaDescriptor structs // in sequence to the "SHA" memory region. Once the ShaDescriptor has diff --git a/risc0/zkvm/sdk/cpp/host/c_api.cpp b/risc0/zkvm/sdk/cpp/host/c_api.cpp index 638dc54125..8f5de9fa43 100644 --- a/risc0/zkvm/sdk/cpp/host/c_api.cpp +++ b/risc0/zkvm/sdk/cpp/host/c_api.cpp @@ -129,6 +129,13 @@ void risc0_prover_add_input(risc0_error* err, risc0_prover* ptr, const uint8_t* ffi_wrap_void(err, [&] { ptr->prover->writeInput(buf, len); }); } +void risc0_prover_add_aux_input(risc0_error* err, + risc0_prover* ptr, + const uint8_t* buf, + size_t len) { + ffi_wrap_void(err, [&] { ptr->prover->writeInputAux(buf, len); }); +} + const void* risc0_prover_get_output_buf(risc0_error* err, const risc0_prover* ptr) { return ffi_wrap(err, nullptr, [&] { return ptr->prover->getOutput().data(); }); } diff --git a/risc0/zkvm/sdk/cpp/host/c_api.h b/risc0/zkvm/sdk/cpp/host/c_api.h index 20172ebdee..b9f851078d 100644 --- a/risc0/zkvm/sdk/cpp/host/c_api.h +++ b/risc0/zkvm/sdk/cpp/host/c_api.h @@ -89,6 +89,11 @@ void risc0_prover_free(risc0_error* err, risc0_prover* ptr); void risc0_prover_add_input(risc0_error* err, risc0_prover* ptr, const uint8_t* buf, size_t len); +void risc0_prover_add_aux_input(risc0_error* err, + risc0_prover* ptr, + const uint8_t* buf, + size_t len); + size_t risc0_prover_get_num_outputs(risc0_error* err, risc0_prover* ptr); const void* risc0_prover_get_output_buf(risc0_error* err, const risc0_prover* ptr); diff --git a/risc0/zkvm/sdk/cpp/host/receipt.cpp b/risc0/zkvm/sdk/cpp/host/receipt.cpp index fb7d3eaf9d..732ab3e203 100644 --- a/risc0/zkvm/sdk/cpp/host/receipt.cpp +++ b/risc0/zkvm/sdk/cpp/host/receipt.cpp @@ -63,6 +63,7 @@ struct Prover::Impl : public IoHandler { , outputStream(outputBuffer) , commitStream(commitBuffer) , inputWriter(inputStream) + , inputWriterAux(inputStreamAux) , outputReader(outputStream) , commitReader(commitStream) { // Set default handlers: @@ -83,6 +84,13 @@ struct Prover::Impl : public IoHandler { LOG(1, "IoHandler::InitialInput, " << input.size() << " bytes"); return input; }); + setSendRecvHandler( + kSendRecvChannel_InitialInputAux, [this](uint32_t, const BufferU8& buf) -> BufferU8 { + const uint8_t* byte_ptr = reinterpret_cast(inputStreamAux.vec.data()); + BufferU8 input(byte_ptr, byte_ptr + inputStreamAux.vec.size() * sizeof(uint32_t)); + LOG(1, "IoHandler::InitialInputAux, " << input.size() << " bytes"); + return input; + }); } virtual ~Impl() {} @@ -116,9 +124,11 @@ struct Prover::Impl : public IoHandler { BufferU8 outputBuffer; BufferU8 commitBuffer; VectorStreamWriter inputStream; + VectorStreamWriter inputStreamAux; CheckedStreamReader outputStream; CheckedStreamReader commitStream; ArchiveWriter inputWriter; + ArchiveWriter inputWriterAux; ArchiveReader outputReader; ArchiveReader commitReader; @@ -219,6 +229,31 @@ void Prover::writeInput(const void* ptr, size_t size) { } } +void Prover::writeInputAux(const void* ptr, size_t size) { + LOG(1, "Prover::writeInputAux> size: " << size); + const uint8_t* ptr_u8 = static_cast(ptr); + while (size >= sizeof(uint32_t)) { + uint32_t word = 0; + word |= *ptr_u8++; + word |= *ptr_u8++ << 8; + word |= *ptr_u8++ << 16; + word |= *ptr_u8++ << 24; + LOG(1, " write_word: " << hex(word)); + impl->inputStreamAux.write_word(word); + size -= sizeof(uint32_t); + } + + if (size) { + LOG(1, " tail: " << size); + uint32_t word = 0; + for (size_t i = 0; i < size; i++) { + word |= *ptr_u8++ << (8 * i); + } + LOG(1, " write_word: " << hex(word)); + impl->inputStreamAux.write_word(word); + } +} + void Prover::setSendRecvHandler( uint32_t channelId, const std::function& handler) { diff --git a/risc0/zkvm/sdk/cpp/host/receipt.h b/risc0/zkvm/sdk/cpp/host/receipt.h index 9e419935fb..ebb7dadb2f 100644 --- a/risc0/zkvm/sdk/cpp/host/receipt.h +++ b/risc0/zkvm/sdk/cpp/host/receipt.h @@ -87,6 +87,8 @@ class Prover { void writeInput(const void* ptr, size_t size); + void writeInputAux(const void* ptr, size_t size); + template void writeInput(const T& obj) { getInputWriter().transfer(obj); } const BufferU8& getOutput(); diff --git a/risc0/zkvm/sdk/rust/guest/src/env.rs b/risc0/zkvm/sdk/rust/guest/src/env.rs index 2e54466da5..d2de503b31 100644 --- a/risc0/zkvm/sdk/rust/guest/src/env.rs +++ b/risc0/zkvm/sdk/rust/guest/src/env.rs @@ -17,7 +17,10 @@ use core::{cell::UnsafeCell, mem::MaybeUninit, slice}; use risc0_zkp::core::sha::Digest; use risc0_zkvm::{ platform::{ - io::{IoDescriptor, GPIO_COMMIT, SENDRECV_CHANNEL_INITIAL_INPUT, SENDRECV_CHANNEL_STDOUT}, + io::{ + IoDescriptor, GPIO_COMMIT, SENDRECV_CHANNEL_INITIAL_AUX_INPUT, + SENDRECV_CHANNEL_INITIAL_INPUT, SENDRECV_CHANNEL_STDOUT, + }, memory, WORD_SIZE, }, serde::{Deserializer, Serializer, Slice}, @@ -98,6 +101,11 @@ pub fn read>() -> T { ENV.get().read() } +/// Read private raw data from the host. +pub fn read_aux_input() -> &'static [u8] { + ENV.get().read_aux_input() +} + /// Write private data to the host. pub fn write(data: &T) { ENV.get().write(data); @@ -140,6 +148,10 @@ impl Env { self.initial_input_reader.as_mut().unwrap() } + pub fn read_aux_input(&mut self) -> &[u8] { + self.send_recv(SENDRECV_CHANNEL_INITIAL_AUX_INPUT, &[]) + } + pub fn read>(&mut self) -> T { self.initial_input().read() } diff --git a/risc0/zkvm/sdk/rust/platform/src/io.rs b/risc0/zkvm/sdk/rust/platform/src/io.rs index b072cb43a1..b75b641db3 100644 --- a/risc0/zkvm/sdk/rust/platform/src/io.rs +++ b/risc0/zkvm/sdk/rust/platform/src/io.rs @@ -86,3 +86,4 @@ pub struct GetKeyDescriptor { pub const SENDRECV_CHANNEL_INITIAL_INPUT: u32 = 0; pub const SENDRECV_CHANNEL_STDOUT: u32 = 1; pub const SENDRECV_CHANNEL_STDERR: u32 = 2; +pub const SENDRECV_CHANNEL_INITIAL_AUX_INPUT: u32 = 3; diff --git a/risc0/zkvm/sdk/rust/src/host/ffi.rs b/risc0/zkvm/sdk/rust/src/host/ffi.rs index f09f8131c7..7d209ef6d3 100644 --- a/risc0/zkvm/sdk/rust/src/host/ffi.rs +++ b/risc0/zkvm/sdk/rust/src/host/ffi.rs @@ -102,6 +102,13 @@ extern "C" { len: usize, ); + pub(crate) fn risc0_prover_add_aux_input( + err: *mut RawError, + prover: *mut RawProver, + buf: *const u8, + len: usize, + ); + pub(crate) fn risc0_prover_get_output_buf( err: *mut RawError, prover: *mut RawProver, @@ -388,6 +395,29 @@ impl<'a> Prover<'a> { check(err, || ()) } + /// Provide private input data that is availble to guest-side method code + /// to 'read_aux_input'. + pub fn add_aux_input(&mut self, slice: &[u32]) -> super::Result<()> { + let mut err = RawError::default(); + unsafe { + risc0_prover_add_aux_input( + &mut err, + self.ptr, + slice.as_ptr().cast(), + slice.len() * mem::size_of::(), + ) + }; + check(err, || ()) + } + + /// Allow auxiliary input to be passed in as u8 with zero-copy framework + pub fn add_input_u8_slice_aux(&mut self, slice: &[u8]) { + let mut v: Vec = Vec::new(); + v.resize((slice.len() + 3) / 4, 0); + bytemuck::cast_slice_mut(v.as_mut_slice())[..slice.len()].clone_from_slice(slice); + self.add_aux_input(v.as_slice()).unwrap() + } + /// Compatibility with pure-rust prover pub fn add_input_u8_slice(&mut self, slice: &[u8]) { let mut v: Vec = Vec::new(); From 3a6ac0ac728f91a47539faa0720388793bae235f Mon Sep 17 00:00:00 2001 From: cpunkzzz Date: Fri, 9 Sep 2022 12:56:30 -0400 Subject: [PATCH 02/12] Add env::log --- risc0/zkvm/sdk/rust/guest/src/env.rs | 11 ++++++++++- risc0/zkvm/sdk/rust/guest/src/lib.rs | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/risc0/zkvm/sdk/rust/guest/src/env.rs b/risc0/zkvm/sdk/rust/guest/src/env.rs index d2de503b31..896eb7d0b7 100644 --- a/risc0/zkvm/sdk/rust/guest/src/env.rs +++ b/risc0/zkvm/sdk/rust/guest/src/env.rs @@ -18,7 +18,7 @@ use risc0_zkp::core::sha::Digest; use risc0_zkvm::{ platform::{ io::{ - IoDescriptor, GPIO_COMMIT, SENDRECV_CHANNEL_INITIAL_AUX_INPUT, + IoDescriptor, GPIO_COMMIT, GPIO_LOG, SENDRECV_CHANNEL_INITIAL_AUX_INPUT, SENDRECV_CHANNEL_INITIAL_INPUT, SENDRECV_CHANNEL_STDOUT, }, memory, WORD_SIZE, @@ -116,6 +116,15 @@ pub fn commit(data: &T) { ENV.get().commit(data); } +/// Print a message to the debug console. +pub fn log(msg: &str) { + // TODO: format! is expensive, replace with a better solution. + let msg = alloc_crate::format!("{}\0", msg); + let ptr = msg.as_ptr(); + memory_barrier(ptr); + unsafe { GPIO_LOG.as_ptr().write_volatile(ptr) }; +} + impl Env { fn new() -> Self { Env { diff --git a/risc0/zkvm/sdk/rust/guest/src/lib.rs b/risc0/zkvm/sdk/rust/guest/src/lib.rs index d12df0a88f..a7fd6745da 100644 --- a/risc0/zkvm/sdk/rust/guest/src/lib.rs +++ b/risc0/zkvm/sdk/rust/guest/src/lib.rs @@ -21,6 +21,7 @@ #![cfg_attr(target_arch = "riscv32", feature(new_uninit))] extern crate alloc as _alloc; +pub extern crate alloc as alloc_crate; #[cfg(not(feature = "std"))] mod alloc; From 08ff5deb7c7d26acd26a046bf70b45482cb5e168 Mon Sep 17 00:00:00 2001 From: cpunkzzz Date: Sun, 2 Oct 2022 20:52:29 -0400 Subject: [PATCH 03/12] Add mul goldilocks --- risc0/zkvm/platform/io.h | 10 ++++ risc0/zkvm/platform/memory.h | 1 + risc0/zkvm/platform/risc0.ld | 1 + risc0/zkvm/prove/io_handler.cpp | 40 ++++++++++++++ risc0/zkvm/sdk/rust/guest/src/lib.rs | 3 ++ risc0/zkvm/sdk/rust/guest/src/mul.rs | 62 ++++++++++++++++++++++ risc0/zkvm/sdk/rust/platform/src/io.rs | 10 ++++ risc0/zkvm/sdk/rust/platform/src/memory.rs | 1 + 8 files changed, 128 insertions(+) create mode 100644 risc0/zkvm/sdk/rust/guest/src/mul.rs diff --git a/risc0/zkvm/platform/io.h b/risc0/zkvm/platform/io.h index 83e20751c2..55d3436b2a 100644 --- a/risc0/zkvm/platform/io.h +++ b/risc0/zkvm/platform/io.h @@ -26,6 +26,7 @@ constexpr size_t kGPIO_GetKey = 0x01F0010; constexpr size_t kGPIO_SendRecvChannel = 0x01F00014; constexpr size_t kGPIO_SendRecvSize = 0x01F00018; constexpr size_t kGPIO_SendRecvAddr = 0x01F0001C; +constexpr size_t kGPIO_Mul = 0x01F00020; // Standard ZKVM channels; must match zkvm/sdk/rust/platform/src/io.rs. @@ -67,6 +68,15 @@ struct ShaDescriptor { uint32_t digest; }; +struct MulDescriptor { + // Address of first byte of MUL data to process + // 64 bits for first operand and 64 bits for second + uint32_t source; + + // 64 bit result + uint32_t result; +}; + inline volatile ShaDescriptor* volatile* GPIO_SHA() { return reinterpret_cast(kGPIO_SHA); } diff --git a/risc0/zkvm/platform/memory.h b/risc0/zkvm/platform/memory.h index 7b427a43f9..e88098cd7e 100644 --- a/risc0/zkvm/platform/memory.h +++ b/risc0/zkvm/platform/memory.h @@ -46,6 +46,7 @@ MEM_REGION(SHA, 0x02A00000, k1MB) MEM_REGION(WOM, 0x02B00000, 21 * k1MB) MEM_REGION(Output, 0x02B00000, 20 * k1MB) MEM_REGION(Commit, 0x03F00000, k1MB) +MEM_REGION(MUL, 0x04000000, k1MB) // clang-format on #define PTR_TO(type, name) reinterpret_cast(kMem##name##Start); diff --git a/risc0/zkvm/platform/risc0.ld b/risc0/zkvm/platform/risc0.ld index ed4e89e238..7e5678db1b 100644 --- a/risc0/zkvm/platform/risc0.ld +++ b/risc0/zkvm/platform/risc0.ld @@ -30,6 +30,7 @@ MEMORY { prog (X) : ORIGIN = 0x02000000, LENGTH = 10M sha : ORIGIN = 0x02A00000, LENGTH = 1M wom : ORIGIN = 0x02B00000, LENGTH = 21M + mul : ORIGIN = 0x04000000, LENGTH = 1M } SECTIONS { diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 68784408ef..73a681d7d4 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -46,6 +46,39 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { } } +static void processMul(MemoryState& mem, const MulDescriptor& desc) { + uint32_t first_operand[2]; + uint32_t second_operand[2]; + + first_operand[0] = mem.loadBE(desc.source); + LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(first_operand[0])); + first_operand[1] = mem.loadBE(desc.source + 4); + LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(first_operand[1])); + second_operand[0] = mem.loadBE(desc.source + 8); + LOG(1, + "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(second_operand[0])); + second_operand[1] = mem.loadBE(desc.source + 12); + LOG(1, + "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(second_operand[1])); + + // MSB is at 0 + uint64_t first = first_operand[1] | (uint64_t(first_operand[0]) << 32); + uint64_t second = second_operand[1] | (uint64_t(second_operand[0]) << 32); + + __uint128_t result = __uint128_t(first) * __uint128_t(second); + + // goldilocks + uint64_t moded_result = result % 0xFFFFFFFF00000001; + + uint32_t high = (uint32_t)((moded_result & 0xFFFFFFFF00000000LL) >> 32); + uint32_t low = (uint32_t)(moded_result & 0xFFFFFFFFLL); + + LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(high)); + mem.store(desc.result, high); + LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(low)); + mem.store(desc.result + 4, low); +} + void IoHandler::onFault(const std::string& msg) { throw std::runtime_error(msg); } @@ -63,6 +96,13 @@ void MemoryHandler::onInit(MemoryState& mem) { void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) { LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value)); switch (addr) { + case kGPIO_Mul: { + LOG(1, "MemoryHandler::onWrite> GPIO_MUL"); + MulDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + processMul(mem, desc); + break; + } case kGPIO_SHA: { LOG(1, "MemoryHandler::onWrite> GPIO_SHA"); ShaDescriptor desc; diff --git a/risc0/zkvm/sdk/rust/guest/src/lib.rs b/risc0/zkvm/sdk/rust/guest/src/lib.rs index a7fd6745da..094fc776d4 100644 --- a/risc0/zkvm/sdk/rust/guest/src/lib.rs +++ b/risc0/zkvm/sdk/rust/guest/src/lib.rs @@ -32,6 +32,9 @@ pub mod env; /// Functions for computing SHA-256 hashes. pub mod sha; +/// mul +pub mod mul; + /// Functions for handling input and output pub mod io; diff --git a/risc0/zkvm/sdk/rust/guest/src/mul.rs b/risc0/zkvm/sdk/rust/guest/src/mul.rs new file mode 100644 index 0000000000..f95cff644d --- /dev/null +++ b/risc0/zkvm/sdk/rust/guest/src/mul.rs @@ -0,0 +1,62 @@ +use core::{cell::UnsafeCell, mem}; + +use _alloc::{boxed::Box, vec::Vec}; +use risc0_zkvm::platform::{io::{MulDescriptor, GPIO_MUL}, memory}; + +// Current sha descriptor index. +struct CurOutput(UnsafeCell); + +// SAFETY: single threaded environment +unsafe impl Sync for CurOutput {} + +static CUR_OUTPUT: CurOutput = CurOutput(UnsafeCell::new(0)); + +pub struct MulGoldilocks([u32; 2]); + +fn alloc_output() -> *mut MulDescriptor { + // SAFETY: Single threaded and this is the only place we use CUR_DESC. + unsafe { + let cur_desc = CUR_OUTPUT.0.get(); + let ptr = (memory::MUL.start() as *mut MulDescriptor).add(*cur_desc); + *cur_desc += 1; + ptr + } +} + +pub fn mul_goldilocks(a: &u64, b: &u64) -> &'static MulGoldilocks { + // Allocate fresh memory that's guaranteed to be uninitialized so + // the host can write to it. + let mut buf = Vec::::with_capacity(4); + let a_hi = (u32)((a & 0xFFFFFFFF00000000LL) >> 32); + let a_lo = (u32)(a & 0xFFFFFFFFLL); + + let b_hi = (u32)((b & 0xFFFFFFFF00000000LL) >> 32); + let b_lo = (u32)(b & 0xFFFFFFFFLL); + + buf.push(a_hi); + buf.push(a_lo); + buf.push(b_hi); + buf.push(b_lo); + + unsafe { + let alloced = Box::>::new( + mem::MaybeUninit::::uninit(), + ); + let output = (*Box::into_raw(alloced)).as_mut_ptr(); + mul_raw(&buf[..], output); + &*output + } +} + +pub(crate) unsafe fn mul_raw(data: &[u32], result: *mut MulGoldilocks) { + let output_ptr = alloc_output(); + + let ptr = data.as_ptr(); + super::memory_barrier(ptr); + output_ptr.write_volatile(MulDescriptor { + source: ptr as usize, + result: result as usize, + }); + + GPIO_MUL.as_ptr().write_volatile(output_ptr); +} diff --git a/risc0/zkvm/sdk/rust/platform/src/io.rs b/risc0/zkvm/sdk/rust/platform/src/io.rs index b75b641db3..34a950351f 100644 --- a/risc0/zkvm/sdk/rust/platform/src/io.rs +++ b/risc0/zkvm/sdk/rust/platform/src/io.rs @@ -49,6 +49,8 @@ pub const GPIO_SENDRECV_CHANNEL: Gpio = Gpio::new(0x01F0_0014); pub const GPIO_SENDRECV_SIZE: Gpio = Gpio::new(0x01F0_0018); pub const GPIO_SENDRECV_ADDR: Gpio<*const u8> = Gpio::new(0x01F0_001C); +pub const GPIO_MUL: Gpio<*const MulDescriptor> = Gpio::new(0x01F0_0020); + pub mod addr { pub const GPIO_SHA: u32 = super::GPIO_SHA.addr(); pub const GPIO_COMMIT: u32 = super::GPIO_COMMIT.addr(); @@ -59,6 +61,8 @@ pub mod addr { pub const GPIO_SENDRECV_CHANNEL: u32 = super::GPIO_SENDRECV_CHANNEL.addr(); pub const GPIO_SENDRECV_SIZE: u32 = super::GPIO_SENDRECV_SIZE.addr(); pub const GPIO_SENDRECV_ADDR: u32 = super::GPIO_SENDRECV_ADDR.addr(); + + pub const GPIO_MUL: u32 = super::GPIO_MUL.addr(); } #[repr(C)] @@ -75,6 +79,12 @@ pub struct SHADescriptor { pub digest: usize, } +#[repr(C)] +pub struct MulDescriptor { + pub source: usize, + pub result: usize, +} + #[repr(C)] pub struct GetKeyDescriptor { pub name: u32, diff --git a/risc0/zkvm/sdk/rust/platform/src/memory.rs b/risc0/zkvm/sdk/rust/platform/src/memory.rs index e9034d9d4a..df658de1fe 100644 --- a/risc0/zkvm/sdk/rust/platform/src/memory.rs +++ b/risc0/zkvm/sdk/rust/platform/src/memory.rs @@ -65,3 +65,4 @@ pub const SHA: Region = Region::new(0x02A0_0000, mb(1)); pub const WOM: Region = Region::new(0x02B0_0000, mb(21)); pub const OUTPUT: Region = Region::new(0x02B0_0000, mb(20)); pub const COMMIT: Region = Region::new(0x03F0_0000, mb(1)); +pub const MUL: Region = Region::new(0x0400_0000, mb(1)); From 9eeb01b6c5a9d11cdfe500832c53b658c5d495bc Mon Sep 17 00:00:00 2001 From: cpunkzzz Date: Mon, 3 Oct 2022 16:35:08 -0400 Subject: [PATCH 04/12] fix regions --- risc0/zkvm/platform/memory.h | 8 +++---- risc0/zkvm/platform/risc0.ld | 4 ++-- risc0/zkvm/sdk/rust/guest/src/mul.rs | 26 +++++++++++++++++----- risc0/zkvm/sdk/rust/platform/src/memory.rs | 8 +++---- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/risc0/zkvm/platform/memory.h b/risc0/zkvm/platform/memory.h index e88098cd7e..c3fe992ba3 100644 --- a/risc0/zkvm/platform/memory.h +++ b/risc0/zkvm/platform/memory.h @@ -43,10 +43,10 @@ MEM_REGION(Input, 0x01E00000, k1MB) MEM_REGION(GPIO, 0x01F00000, k1MB) MEM_REGION(Prog, 0x02000000, 10 * k1MB) MEM_REGION(SHA, 0x02A00000, k1MB) -MEM_REGION(WOM, 0x02B00000, 21 * k1MB) -MEM_REGION(Output, 0x02B00000, 20 * k1MB) -MEM_REGION(Commit, 0x03F00000, k1MB) -MEM_REGION(MUL, 0x04000000, k1MB) +MEM_REGION(MUL, 0x02B00000, k1MB) +MEM_REGION(WOM, 0x02C00000, 21 * k1MB) +MEM_REGION(Output, 0x02C00000, 20 * k1MB) +MEM_REGION(Commit, 0x04000000, k1MB) // clang-format on #define PTR_TO(type, name) reinterpret_cast(kMem##name##Start); diff --git a/risc0/zkvm/platform/risc0.ld b/risc0/zkvm/platform/risc0.ld index 7e5678db1b..7a60a5a2cb 100644 --- a/risc0/zkvm/platform/risc0.ld +++ b/risc0/zkvm/platform/risc0.ld @@ -29,8 +29,8 @@ MEMORY { gpio : ORIGIN = 0x01F00000, LENGTH = 1M prog (X) : ORIGIN = 0x02000000, LENGTH = 10M sha : ORIGIN = 0x02A00000, LENGTH = 1M - wom : ORIGIN = 0x02B00000, LENGTH = 21M - mul : ORIGIN = 0x04000000, LENGTH = 1M + mul : ORIGIN = 0x02B00000, LENGTH = 1M + wom : ORIGIN = 0x02C00000, LENGTH = 21M } SECTIONS { diff --git a/risc0/zkvm/sdk/rust/guest/src/mul.rs b/risc0/zkvm/sdk/rust/guest/src/mul.rs index f95cff644d..5fb7401aea 100644 --- a/risc0/zkvm/sdk/rust/guest/src/mul.rs +++ b/risc0/zkvm/sdk/rust/guest/src/mul.rs @@ -1,7 +1,10 @@ use core::{cell::UnsafeCell, mem}; use _alloc::{boxed::Box, vec::Vec}; -use risc0_zkvm::platform::{io::{MulDescriptor, GPIO_MUL}, memory}; +use risc0_zkvm::platform::{ + io::{MulDescriptor, GPIO_MUL}, + memory, +}; // Current sha descriptor index. struct CurOutput(UnsafeCell); @@ -11,8 +14,20 @@ unsafe impl Sync for CurOutput {} static CUR_OUTPUT: CurOutput = CurOutput(UnsafeCell::new(0)); +/// Result of multiply goldilocks pub struct MulGoldilocks([u32; 2]); +impl MulGoldilocks { + /// Get the result as u64 + pub fn get_u64(&self) -> u64 { + let mut res = 0u64; + for i in 0..2 { + res |= (self.0[i] as u64) << (32 * i); + } + res + } +} + fn alloc_output() -> *mut MulDescriptor { // SAFETY: Single threaded and this is the only place we use CUR_DESC. unsafe { @@ -23,15 +38,16 @@ fn alloc_output() -> *mut MulDescriptor { } } +/// Multiply goldilocks oracle, verification is done separately pub fn mul_goldilocks(a: &u64, b: &u64) -> &'static MulGoldilocks { // Allocate fresh memory that's guaranteed to be uninitialized so // the host can write to it. let mut buf = Vec::::with_capacity(4); - let a_hi = (u32)((a & 0xFFFFFFFF00000000LL) >> 32); - let a_lo = (u32)(a & 0xFFFFFFFFLL); + let a_hi = ((a & 0xFFFFFFFF00000000) >> 32) as u32; + let a_lo = (a & 0xFFFFFFFF) as u32; - let b_hi = (u32)((b & 0xFFFFFFFF00000000LL) >> 32); - let b_lo = (u32)(b & 0xFFFFFFFFLL); + let b_hi = ((b & 0xFFFFFFFF00000000) >> 32) as u32; + let b_lo = (b & 0xFFFFFFFF) as u32; buf.push(a_hi); buf.push(a_lo); diff --git a/risc0/zkvm/sdk/rust/platform/src/memory.rs b/risc0/zkvm/sdk/rust/platform/src/memory.rs index df658de1fe..c491a07a6c 100644 --- a/risc0/zkvm/sdk/rust/platform/src/memory.rs +++ b/risc0/zkvm/sdk/rust/platform/src/memory.rs @@ -62,7 +62,7 @@ pub const INPUT: Region = Region::new(0x01E0_0000, mb(1)); pub const GPIO: Region = Region::new(0x01F0_0000, mb(1)); pub const PROG: Region = Region::new(0x0200_0000, mb(10)); pub const SHA: Region = Region::new(0x02A0_0000, mb(1)); -pub const WOM: Region = Region::new(0x02B0_0000, mb(21)); -pub const OUTPUT: Region = Region::new(0x02B0_0000, mb(20)); -pub const COMMIT: Region = Region::new(0x03F0_0000, mb(1)); -pub const MUL: Region = Region::new(0x0400_0000, mb(1)); +pub const MUL: Region = Region::new(0x02B0_0000, mb(1)); +pub const WOM: Region = Region::new(0x02C0_0000, mb(21)); +pub const OUTPUT: Region = Region::new(0x02C0_0000, mb(20)); +pub const COMMIT: Region = Region::new(0x0400_0000, mb(1)); From aa258322625777d7da92473438b6b58d54d0495c Mon Sep 17 00:00:00 2001 From: cpunkzzz Date: Mon, 3 Oct 2022 20:32:43 -0400 Subject: [PATCH 05/12] Solved memory overflow --- risc0/zkvm/platform/memory.h | 6 +++--- risc0/zkvm/platform/risc0.ld | 2 +- risc0/zkvm/sdk/rust/platform/src/memory.rs | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/risc0/zkvm/platform/memory.h b/risc0/zkvm/platform/memory.h index c3fe992ba3..66723d37f3 100644 --- a/risc0/zkvm/platform/memory.h +++ b/risc0/zkvm/platform/memory.h @@ -44,9 +44,9 @@ MEM_REGION(GPIO, 0x01F00000, k1MB) MEM_REGION(Prog, 0x02000000, 10 * k1MB) MEM_REGION(SHA, 0x02A00000, k1MB) MEM_REGION(MUL, 0x02B00000, k1MB) -MEM_REGION(WOM, 0x02C00000, 21 * k1MB) -MEM_REGION(Output, 0x02C00000, 20 * k1MB) -MEM_REGION(Commit, 0x04000000, k1MB) +MEM_REGION(WOM, 0x02C00000, 20 * k1MB) +MEM_REGION(Output, 0x02C00000, 19 * k1MB) +MEM_REGION(Commit, 0x03F00000, k1MB) // clang-format on #define PTR_TO(type, name) reinterpret_cast(kMem##name##Start); diff --git a/risc0/zkvm/platform/risc0.ld b/risc0/zkvm/platform/risc0.ld index 7a60a5a2cb..7b51400427 100644 --- a/risc0/zkvm/platform/risc0.ld +++ b/risc0/zkvm/platform/risc0.ld @@ -30,7 +30,7 @@ MEMORY { prog (X) : ORIGIN = 0x02000000, LENGTH = 10M sha : ORIGIN = 0x02A00000, LENGTH = 1M mul : ORIGIN = 0x02B00000, LENGTH = 1M - wom : ORIGIN = 0x02C00000, LENGTH = 21M + wom : ORIGIN = 0x02C00000, LENGTH = 20M } SECTIONS { diff --git a/risc0/zkvm/sdk/rust/platform/src/memory.rs b/risc0/zkvm/sdk/rust/platform/src/memory.rs index c491a07a6c..5b2fea9209 100644 --- a/risc0/zkvm/sdk/rust/platform/src/memory.rs +++ b/risc0/zkvm/sdk/rust/platform/src/memory.rs @@ -63,6 +63,6 @@ pub const GPIO: Region = Region::new(0x01F0_0000, mb(1)); pub const PROG: Region = Region::new(0x0200_0000, mb(10)); pub const SHA: Region = Region::new(0x02A0_0000, mb(1)); pub const MUL: Region = Region::new(0x02B0_0000, mb(1)); -pub const WOM: Region = Region::new(0x02C0_0000, mb(21)); -pub const OUTPUT: Region = Region::new(0x02C0_0000, mb(20)); -pub const COMMIT: Region = Region::new(0x0400_0000, mb(1)); +pub const WOM: Region = Region::new(0x02C0_0000, mb(20)); +pub const OUTPUT: Region = Region::new(0x02C0_0000, mb(19)); +pub const COMMIT: Region = Region::new(0x03F0_0000, mb(1)); From e23deab759c00849329ef9b469271b36ce3a99c0 Mon Sep 17 00:00:00 2001 From: cpunkzzz Date: Tue, 4 Oct 2022 20:28:18 -0400 Subject: [PATCH 06/12] fix alignment --- risc0/zkvm/prove/io_handler.cpp | 26 ++++++++++---------------- risc0/zkvm/sdk/rust/guest/src/mul.rs | 8 +++----- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 73a681d7d4..d34070f20d 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -47,23 +47,17 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { } static void processMul(MemoryState& mem, const MulDescriptor& desc) { - uint32_t first_operand[2]; - uint32_t second_operand[2]; + uint32_t a_hi = mem.load(desc.source); + LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a_hi)); + uint32_t a_lo = mem.load(desc.source + 4); + LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a_lo)); + uint32_t b_hi = mem.load(desc.source + 8); + LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(b_hi)); + uint32_t b_lo = mem.load(desc.source + 12); + LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(b_lo)); - first_operand[0] = mem.loadBE(desc.source); - LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(first_operand[0])); - first_operand[1] = mem.loadBE(desc.source + 4); - LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(first_operand[1])); - second_operand[0] = mem.loadBE(desc.source + 8); - LOG(1, - "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(second_operand[0])); - second_operand[1] = mem.loadBE(desc.source + 12); - LOG(1, - "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(second_operand[1])); - - // MSB is at 0 - uint64_t first = first_operand[1] | (uint64_t(first_operand[0]) << 32); - uint64_t second = second_operand[1] | (uint64_t(second_operand[0]) << 32); + uint64_t first = a_lo | (uint64_t(a_hi) << 32); + uint64_t second = b_lo | (uint64_t(b_hi) << 32); __uint128_t result = __uint128_t(first) * __uint128_t(second); diff --git a/risc0/zkvm/sdk/rust/guest/src/mul.rs b/risc0/zkvm/sdk/rust/guest/src/mul.rs index 5fb7401aea..73299ed4fe 100644 --- a/risc0/zkvm/sdk/rust/guest/src/mul.rs +++ b/risc0/zkvm/sdk/rust/guest/src/mul.rs @@ -1,5 +1,7 @@ use core::{cell::UnsafeCell, mem}; +use crate::env::log; +use _alloc::format; use _alloc::{boxed::Box, vec::Vec}; use risc0_zkvm::platform::{ io::{MulDescriptor, GPIO_MUL}, @@ -20,11 +22,7 @@ pub struct MulGoldilocks([u32; 2]); impl MulGoldilocks { /// Get the result as u64 pub fn get_u64(&self) -> u64 { - let mut res = 0u64; - for i in 0..2 { - res |= (self.0[i] as u64) << (32 * i); - } - res + (self.0[1] as u64) | ((self.0[0] as u64) << 32) } } From a9945a43f86fb0da1ce05cfb3695ff59d1643d2d Mon Sep 17 00:00:00 2001 From: cpunkzzz Date: Wed, 5 Oct 2022 20:46:32 -0400 Subject: [PATCH 07/12] slight improvement --- risc0/zkvm/sdk/rust/guest/src/mul.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/risc0/zkvm/sdk/rust/guest/src/mul.rs b/risc0/zkvm/sdk/rust/guest/src/mul.rs index 73299ed4fe..22593744f7 100644 --- a/risc0/zkvm/sdk/rust/guest/src/mul.rs +++ b/risc0/zkvm/sdk/rust/guest/src/mul.rs @@ -38,19 +38,13 @@ fn alloc_output() -> *mut MulDescriptor { /// Multiply goldilocks oracle, verification is done separately pub fn mul_goldilocks(a: &u64, b: &u64) -> &'static MulGoldilocks { - // Allocate fresh memory that's guaranteed to be uninitialized so - // the host can write to it. - let mut buf = Vec::::with_capacity(4); let a_hi = ((a & 0xFFFFFFFF00000000) >> 32) as u32; let a_lo = (a & 0xFFFFFFFF) as u32; let b_hi = ((b & 0xFFFFFFFF00000000) >> 32) as u32; let b_lo = (b & 0xFFFFFFFF) as u32; - buf.push(a_hi); - buf.push(a_lo); - buf.push(b_hi); - buf.push(b_lo); + let buf = [a_hi, a_lo, b_hi, b_lo]; unsafe { let alloced = Box::>::new( From 4e5a4631687577efccef71b76449b01d9cca2251 Mon Sep 17 00:00:00 2001 From: starkoracles Date: Wed, 23 Nov 2022 16:47:52 -0500 Subject: [PATCH 08/12] Progress towards accel --- risc0/zkvm/prove/io_handler.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index d34070f20d..159e2eb299 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -46,6 +46,18 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { } } +static uint64_t montRedCst(__uint128_t n) { + uint64_t xl = n & 0xFFFFFFFF; + uint64_t xh = (n >> 64) & 0xFFFFFFFF; + bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; + uint64_t a = xl + (xl << 32); + uint64_t b = a - (a >> 32) - e; + bool c = (int64_t(xh) - int64_t(b)) < 0; + uint64_t r = xh - b; + uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); + return mont_result; +} + static void processMul(MemoryState& mem, const MulDescriptor& desc) { uint32_t a_hi = mem.load(desc.source); LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a_hi)); @@ -60,12 +72,10 @@ static void processMul(MemoryState& mem, const MulDescriptor& desc) { uint64_t second = b_lo | (uint64_t(b_hi) << 32); __uint128_t result = __uint128_t(first) * __uint128_t(second); + uint64_t mont_result = montRedCst(result); - // goldilocks - uint64_t moded_result = result % 0xFFFFFFFF00000001; - - uint32_t high = (uint32_t)((moded_result & 0xFFFFFFFF00000000LL) >> 32); - uint32_t low = (uint32_t)(moded_result & 0xFFFFFFFFLL); + uint32_t high = (uint32_t)((mont_result & 0xFFFFFFFF00000000LL) >> 32); + uint32_t low = (uint32_t)(mont_result & 0xFFFFFFFFLL); LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(high)); mem.store(desc.result, high); From e8c520fdfebd656ed0f78cd1b4bca0ccb7c57aba Mon Sep 17 00:00:00 2001 From: starkoracles Date: Sat, 26 Nov 2022 10:23:16 -0500 Subject: [PATCH 09/12] Fix bug in mul impl --- risc0/zkvm/prove/io_handler.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 159e2eb299..9027777fe5 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -47,14 +47,17 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { } static uint64_t montRedCst(__uint128_t n) { - uint64_t xl = n & 0xFFFFFFFF; - uint64_t xh = (n >> 64) & 0xFFFFFFFF; - bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; + uint64_t xl = uint64_t(n); + uint64_t xh = uint64_t(n >> 64); + bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; // overflow uint64_t a = xl + (xl << 32); uint64_t b = a - (a >> 32) - e; - bool c = (int64_t(xh) - int64_t(b)) < 0; + bool c = xh < b; uint64_t r = xh - b; uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); + // std::cout << "xl = " << xl << ", xh = " << xh << ", a = " << a << ", e = " << e << ", b = " << + // b + // << ", c = " << c << ", r = " << r << ", mont_result = " << mont_result << std::endl; return mont_result; } From 6b77b5a7a65330a013cdca89689dd0040afd3d2d Mon Sep 17 00:00:00 2001 From: starkoracles Date: Sat, 26 Nov 2022 20:32:25 -0500 Subject: [PATCH 10/12] move to mul extension --- risc0/zkvm/platform/io.h | 4 +- risc0/zkvm/prove/io_handler.cpp | 591 ++++++++++++++++----------- risc0/zkvm/sdk/rust/guest/src/mul.rs | 25 +- 3 files changed, 381 insertions(+), 239 deletions(-) diff --git a/risc0/zkvm/platform/io.h b/risc0/zkvm/platform/io.h index 55d3436b2a..2f9925c524 100644 --- a/risc0/zkvm/platform/io.h +++ b/risc0/zkvm/platform/io.h @@ -70,10 +70,10 @@ struct ShaDescriptor { struct MulDescriptor { // Address of first byte of MUL data to process - // 64 bits for first operand and 64 bits for second + // 128 bits for first operand and 128 bits for second uint32_t source; - // 64 bit result + // 128 bit result uint32_t result; }; diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 9027777fe5..0e939a00f3 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -21,263 +21,398 @@ #include "risc0/zkvm/platform/memory.h" #include "risc0/zkvm/prove/step.h" -namespace risc0 { - -static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { - uint16_t type = (desc.typeAndCount & 0xFFFF) >> 4; - uint16_t count = desc.typeAndCount & 0xFFFF; - LOG(1, - "SHA256 type: " << type << ", count: " << count << ", idx: " << desc.idx - << ", source: " << hex(desc.source) << ", digest: " << hex(desc.digest)); - ShaDigest sha = impl::initState(); - uint32_t words[16]; - for (int i = 0; i < count; i++) { - for (int j = 0; j < 16; j++) { - uint32_t from = desc.source + i * 16 * 4 + j * 4; - words[j] = mem.loadBE(from); - LOG(1, "Input[" << hex(j, 2) << "]: " << hex(from) << " -> " << hex(words[j])); +namespace risc0 +{ + + class FpG + { + // implement just enough operations to support extension field multiplication + // all values are in mont form + public: + static CONSTSCALAR uint64_t M = 0xFFFFFFFF00000001; + uint64_t val; + + private: + static DEVSPEC constexpr uint64_t add(uint64_t a, uint64_t b) + { + bool c1 = b > M; + uint64_t x1 = a - (M - b); + uint32_t adj = 0 - uint32_t(c1); + return x1 - uint64_t(adj); } - LOG(1, "Compress"); - impl::compress(sha, words); + + static DEVSPEC constexpr uint64_t sub(uint64_t a, uint64_t b) + { + bool c1 = b > a; + uint64_t x1 = a - b; + uint32_t adj = 0 - uint32_t(c1); + return x1 - uint64_t(adj); + } + + static DEVSPEC constexpr uint64_t doubleVal(uint64_t a) + { + __uint128_t ret = __uint128_t(a) << 1; + uint64_t result = uint64_t(ret); + uint64_t over = uint64_t(ret >> 64); + return result - (M * over); + } + + static uint64_t montRedCst(__uint128_t n) + { + uint64_t xl = uint64_t(n); + uint64_t xh = uint64_t(n >> 64); + bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; // overflow + uint64_t a = xl + (xl << 32); + uint64_t b = a - (a >> 32) - e; + bool c = xh < b; + uint64_t r = xh - b; + uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); + // std::cout << "xl = " << xl << ", xh = " << xh << ", a = " << a << ", e = " << e << ", b = " + // << + // b + // << ", c = " << c << ", r = " << r << ", mont_result = " << mont_result << + // std::endl; + return mont_result; + } + + static DEVSPEC constexpr uint64_t mul(uint64_t a, uint64_t b) + { + __uint128_t n = __uint128_t(a) * __uint128_t(b); + return montRedCst(n); + } + + public: + DEVSPEC constexpr FpG(uint64_t val) : val(val) {} + DEVSPEC constexpr FpG operator+(FpG rhs) const { return FpG(add(val, rhs.val)); } + DEVSPEC constexpr FpG operator-(FpG rhs) const { return FpG(sub(val, rhs.val)); } + DEVSPEC constexpr FpG operator*(FpG rhs) const { return FpG(mul(val, rhs.val)); } + DEVSPEC constexpr FpG doubleVal() const { return FpG(doubleVal(val)); } + }; + + static std::pair extensionMul(std::pair a, std::pair b) + { + FpG a0b0 = a.first * b.first; + + FpG first = a0b0 - (a.second * b.second).doubleVal(); + FpG second = (a.first + a.second) * (b.first + b.second) - a0b0; + + return std::pair(first, second); } - for (int i = 0; i < 8; i++) { - LOG(1, "Output[" << hex(i, 1) << "]: " << hex(sha.words[i])); - mem.store(desc.digest + i * 4, sha.words[i]); + + static void processSHA(MemoryState &mem, const ShaDescriptor &desc) + { + uint16_t type = (desc.typeAndCount & 0xFFFF) >> 4; + uint16_t count = desc.typeAndCount & 0xFFFF; + LOG(1, + "SHA256 type: " << type << ", count: " << count << ", idx: " << desc.idx + << ", source: " << hex(desc.source) << ", digest: " << hex(desc.digest)); + ShaDigest sha = impl::initState(); + uint32_t words[16]; + for (int i = 0; i < count; i++) + { + for (int j = 0; j < 16; j++) + { + uint32_t from = desc.source + i * 16 * 4 + j * 4; + words[j] = mem.loadBE(from); + LOG(1, "Input[" << hex(j, 2) << "]: " << hex(from) << " -> " << hex(words[j])); + } + LOG(1, "Compress"); + impl::compress(sha, words); + } + for (int i = 0; i < 8; i++) + { + LOG(1, "Output[" << hex(i, 1) << "]: " << hex(sha.words[i])); + mem.store(desc.digest + i * 4, sha.words[i]); + } + } + + static void processMul(MemoryState &mem, const MulDescriptor &desc) + { + uint32_t a0_hi = mem.load(desc.source); + LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a0_hi)); + uint32_t a0_lo = mem.load(desc.source + 4); + LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a0_lo)); + uint32_t a1_hi = mem.load(desc.source + 8); + LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(a1_hi)); + uint32_t a1_lo = mem.load(desc.source + 12); + LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(a1_lo)); + + uint32_t b0_hi = mem.load(desc.source); + LOG(1, "Input[" << hex(4, 2) << "]: " << hex(desc.source + 16) << " -> " << hex(b0_hi)); + uint32_t b0_lo = mem.load(desc.source + 4); + LOG(1, "Input[" << hex(5, 2) << "]: " << hex(desc.source + 20) << " -> " << hex(b0_lo)); + uint32_t b1_hi = mem.load(desc.source + 8); + LOG(1, "Input[" << hex(6, 2) << "]: " << hex(desc.source + 24) << " -> " << hex(b1_hi)); + uint32_t b1_lo = mem.load(desc.source + 12); + LOG(1, "Input[" << hex(7, 2) << "]: " << hex(desc.source + 28) << " -> " << hex(b1_lo)); + + uint64_t a0 = a0_lo | (uint64_t(a0_hi) << 32); + uint64_t a1 = a1_lo | (uint64_t(a1_hi) << 32); + uint64_t b0 = b0_lo | (uint64_t(b0_hi) << 32); + uint64_t b1 = b1_lo | (uint64_t(b1_hi) << 32); + + std::pair a = std::pair(FpG(a0), FpG(a1)); + std::pair b = std::pair(FpG(b0), FpG(b1)); + std::pair result = extensionMul(a, b); + + uint64_t r0 = result.first.val; + uint32_t r0_high = (uint32_t)((r0 & 0xFFFFFFFF00000000LL) >> 32); + uint32_t r0_low = (uint32_t)(r0 & 0xFFFFFFFFLL); + + uint64_t r1 = result.second.val; + uint32_t r1_high = (uint32_t)((r1 & 0xFFFFFFFF00000000LL) >> 32); + uint32_t r1_low = (uint32_t)(r1 & 0xFFFFFFFFLL); + + LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(r0_high)); + mem.store(desc.result, r0_high); + LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(r0_low)); + mem.store(desc.result + 4, r0_low); + LOG(1, "Output[" << hex(2, 2) << "]: " << hex(desc.result + 8) << " <- " << hex(r1_high)); + mem.store(desc.result + 8, r1_high); + LOG(1, "Output[" << hex(3, 2) << "]: " << hex(desc.result + 12) << " <- " << hex(r1_low)); + mem.store(desc.result + 12, r1_low); } -} - -static uint64_t montRedCst(__uint128_t n) { - uint64_t xl = uint64_t(n); - uint64_t xh = uint64_t(n >> 64); - bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; // overflow - uint64_t a = xl + (xl << 32); - uint64_t b = a - (a >> 32) - e; - bool c = xh < b; - uint64_t r = xh - b; - uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); - // std::cout << "xl = " << xl << ", xh = " << xh << ", a = " << a << ", e = " << e << ", b = " << - // b - // << ", c = " << c << ", r = " << r << ", mont_result = " << mont_result << std::endl; - return mont_result; -} - -static void processMul(MemoryState& mem, const MulDescriptor& desc) { - uint32_t a_hi = mem.load(desc.source); - LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a_hi)); - uint32_t a_lo = mem.load(desc.source + 4); - LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a_lo)); - uint32_t b_hi = mem.load(desc.source + 8); - LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(b_hi)); - uint32_t b_lo = mem.load(desc.source + 12); - LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(b_lo)); - - uint64_t first = a_lo | (uint64_t(a_hi) << 32); - uint64_t second = b_lo | (uint64_t(b_hi) << 32); - - __uint128_t result = __uint128_t(first) * __uint128_t(second); - uint64_t mont_result = montRedCst(result); - - uint32_t high = (uint32_t)((mont_result & 0xFFFFFFFF00000000LL) >> 32); - uint32_t low = (uint32_t)(mont_result & 0xFFFFFFFFLL); - - LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(high)); - mem.store(desc.result, high); - LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(low)); - mem.store(desc.result + 4, low); -} - -void IoHandler::onFault(const std::string& msg) { - throw std::runtime_error(msg); -} - -MemoryHandler::MemoryHandler() : MemoryHandler(nullptr) {} - -MemoryHandler::MemoryHandler(IoHandler* io) : io(io), cur_host_to_guest_offset(kMemInputStart) {} - -void MemoryHandler::onInit(MemoryState& mem) { - if (io) { - io->onInit(mem); + + void IoHandler::onFault(const std::string &msg) + { + throw std::runtime_error(msg); } -} - -void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) { - LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value)); - switch (addr) { - case kGPIO_Mul: { - LOG(1, "MemoryHandler::onWrite> GPIO_MUL"); - MulDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - processMul(mem, desc); - break; + + MemoryHandler::MemoryHandler() : MemoryHandler(nullptr) {} + + MemoryHandler::MemoryHandler(IoHandler *io) : io(io), cur_host_to_guest_offset(kMemInputStart) {} + + void MemoryHandler::onInit(MemoryState &mem) + { + if (io) + { + io->onInit(mem); + } } - case kGPIO_SHA: { - LOG(1, "MemoryHandler::onWrite> GPIO_SHA"); - ShaDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - processSHA(mem, desc); - } break; - case kGPIO_Commit: { - LOG(1, "MemoryHandler::onWrite> GPIO_Commit"); - IoDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - if (io) { - std::vector buf(desc.size); - mem.loadRegion(desc.addr, buf.data(), desc.size); - io->onCommit(buf); + + void MemoryHandler::onWrite(MemoryState &mem, uint32_t cycle, uint32_t addr, uint32_t value) + { + LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value)); + switch (addr) + { + case kGPIO_Mul: + { + LOG(1, "MemoryHandler::onWrite> GPIO_MUL"); + MulDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + processMul(mem, desc); + break; } - } break; - case kGPIO_Fault: { - LOG(1, "MemoryHandler::onWrite> GPIO_Fault"); - if (io) { + case kGPIO_SHA: + { + LOG(1, "MemoryHandler::onWrite> GPIO_SHA"); + ShaDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + processSHA(mem, desc); + } + break; + case kGPIO_Commit: + { + LOG(1, "MemoryHandler::onWrite> GPIO_Commit"); + IoDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + if (io) + { + std::vector buf(desc.size); + mem.loadRegion(desc.addr, buf.data(), desc.size); + io->onCommit(buf); + } + } + break; + case kGPIO_Fault: + { + LOG(1, "MemoryHandler::onWrite> GPIO_Fault"); + if (io) + { + size_t len = mem.strlen(value); + std::vector buf(len); + mem.loadRegion(value, buf.data(), len); + std::string str(buf.data(), buf.size()); + io->onFault(str); + } + } + break; + case kGPIO_Log: + { + LOG(2, "MemoryHandler::onWrite> GPIO_Log"); size_t len = mem.strlen(value); std::vector buf(len); mem.loadRegion(value, buf.data(), len); std::string str(buf.data(), buf.size()); - io->onFault(str); + LOG(0, "R0VM[C" << cycle << "]> " << str); } - } break; - case kGPIO_Log: { - LOG(2, "MemoryHandler::onWrite> GPIO_Log"); - size_t len = mem.strlen(value); - std::vector buf(len); - mem.loadRegion(value, buf.data(), len); - std::string str(buf.data(), buf.size()); - LOG(0, "R0VM[C" << cycle << "]> " << str); - } break; - case kGPIO_GetKey: { - LOG(1, "MemoryHandler::onWrite> GPIO_GetKey"); - GetKeyDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - if (!io) { - throw std::runtime_error("Get key called with no IO handler set"); + break; + case kGPIO_GetKey: + { + LOG(1, "MemoryHandler::onWrite> GPIO_GetKey"); + GetKeyDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + if (!io) + { + throw std::runtime_error("Get key called with no IO handler set"); + } + size_t len = mem.strlen(desc.name); + std::vector buf(len); + mem.loadRegion(desc.name, buf.data(), len); + std::string str(buf.data(), buf.size()); + LOG(1, " addr = " << hex(desc.addr)); + LOG(1, " key = " << str); + LOG(1, " mode = " << desc.mode); + KeyStore &store = io->getKeyStore(); + if (desc.mode == 0 && store.count(str)) + { + throw std::runtime_error("GetKey Mode = NEW and key exists: " + str); + } + if (desc.mode == 1 && !store.count(str)) + { + throw std::runtime_error("GetKey Mode = EXISTING and key does not exist: " + str); + } + const Key &key = store[str]; + mem.store(desc.addr, reinterpret_cast(&key), sizeof(Key)); } - size_t len = mem.strlen(desc.name); - std::vector buf(len); - mem.loadRegion(desc.name, buf.data(), len); - std::string str(buf.data(), buf.size()); - LOG(1, " addr = " << hex(desc.addr)); - LOG(1, " key = " << str); - LOG(1, " mode = " << desc.mode); - KeyStore& store = io->getKeyStore(); - if (desc.mode == 0 && store.count(str)) { - throw std::runtime_error("GetKey Mode = NEW and key exists: " + str); + break; + case kGPIO_SendRecvAddr: + { + if (io) + { + uint32_t channel = mem.load(kGPIO_SendRecvChannel); + std::vector buf(mem.load(kGPIO_SendRecvSize)); + LOG(1, + "MemoryHandler::onWrite> GPIO_SendReceive, channel " << channel + << " size=" << buf.size()); + mem.loadRegion(value, buf.data(), buf.size()); + BufferU8 result = io->onSendRecv(channel, buf); + LOG(1, + "MemoryHandler::onWrite> GPIO_SendReceive, host replied with " << result.size() + << " bytes"); + size_t aligned_len = align(result.size()); + if ((cur_host_to_guest_offset + sizeof(uint32_t) + aligned_len) >= kMemInputEnd) + { + throw(std::runtime_error("Read buffer overrun")); + } + mem.store(cur_host_to_guest_offset, result.size()); + cur_host_to_guest_offset += sizeof(uint32_t); + for (size_t i = 0; i < result.size(); ++i) + { + mem.storeByte(cur_host_to_guest_offset + i, result[i]); + } + cur_host_to_guest_offset += aligned_len; + } + else + { + throw std::runtime_error("SendRecv called with no IO handler set"); + } } - if (desc.mode == 1 && !store.count(str)) { - throw std::runtime_error("GetKey Mode = EXISTING and key does not exist: " + str); + break; } - const Key& key = store[str]; - mem.store(desc.addr, reinterpret_cast(&key), sizeof(Key)); - } break; - case kGPIO_SendRecvAddr: { - if (io) { - uint32_t channel = mem.load(kGPIO_SendRecvChannel); - std::vector buf(mem.load(kGPIO_SendRecvSize)); - LOG(1, - "MemoryHandler::onWrite> GPIO_SendReceive, channel " << channel - << " size=" << buf.size()); - mem.loadRegion(value, buf.data(), buf.size()); - BufferU8 result = io->onSendRecv(channel, buf); - LOG(1, - "MemoryHandler::onWrite> GPIO_SendReceive, host replied with " << result.size() - << " bytes"); - size_t aligned_len = align(result.size()); - if ((cur_host_to_guest_offset + sizeof(uint32_t) + aligned_len) >= kMemInputEnd) { - throw(std::runtime_error("Read buffer overrun")); - } - mem.store(cur_host_to_guest_offset, result.size()); - cur_host_to_guest_offset += sizeof(uint32_t); - for (size_t i = 0; i < result.size(); ++i) { - mem.storeByte(cur_host_to_guest_offset + i, result[i]); + } + + void MemoryState::dump(size_t logLevel) + { + LOG(logLevel, "MemoryState::dump> size: " << data.size()); + if (getLogLevel() >= logLevel) + { + for (auto pair : data) + { + LOG(logLevel, " " << hex(pair.first * 4) << ": " << hex(pair.second)); } - cur_host_to_guest_offset += aligned_len; - } else { - throw std::runtime_error("SendRecv called with no IO handler set"); } - } break; } -} -void MemoryState::dump(size_t logLevel) { - LOG(logLevel, "MemoryState::dump> size: " << data.size()); - if (getLogLevel() >= logLevel) { - for (auto pair : data) { - LOG(logLevel, " " << hex(pair.first * 4) << ": " << hex(pair.second)); + size_t MemoryState::strlen(uint32_t addr) + { + size_t len = 0; + while (loadByte(addr++)) + { + len++; } + return len; + } + + uint8_t MemoryState::loadByte(uint32_t addr) + { + // align to the nearest word + uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); + size_t byte_offset = addr % sizeof(uint32_t); + uint32_t word = load(aligned); + return (word >> (byte_offset * 8)) & 0xff; } -} -size_t MemoryState::strlen(uint32_t addr) { - size_t len = 0; - while (loadByte(addr++)) { - len++; + uint32_t MemoryState::load(uint32_t addr) + { + auto it = data.find(addr / 4); + if (it == data.end()) + { + std::stringstream ss; + ss << "addr out of range: " << hex(addr); + throw std::out_of_range(ss.str()); + } + return it->second; } - return len; -} - -uint8_t MemoryState::loadByte(uint32_t addr) { - // align to the nearest word - uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); - size_t byte_offset = addr % sizeof(uint32_t); - uint32_t word = load(aligned); - return (word >> (byte_offset * 8)) & 0xff; -} - -uint32_t MemoryState::load(uint32_t addr) { - auto it = data.find(addr / 4); - if (it == data.end()) { - std::stringstream ss; - ss << "addr out of range: " << hex(addr); - throw std::out_of_range(ss.str()); + + void MemoryState::loadRegion(uint32_t addr, void *ptr, uint32_t len) + { + uint8_t *bytes = static_cast(ptr); + for (size_t i = 0; i < len; i++) + { + bytes[i] = loadByte(addr++); + } } - return it->second; -} -void MemoryState::loadRegion(uint32_t addr, void* ptr, uint32_t len) { - uint8_t* bytes = static_cast(ptr); - for (size_t i = 0; i < len; i++) { - bytes[i] = loadByte(addr++); + uint32_t MemoryState::loadBE(uint32_t addr) + { + return loadByte(addr + 0) << 24 | // + loadByte(addr + 1) << 16 | // + loadByte(addr + 2) << 8 | // + loadByte(addr + 3); } -} - -uint32_t MemoryState::loadBE(uint32_t addr) { - return loadByte(addr + 0) << 24 | // - loadByte(addr + 1) << 16 | // - loadByte(addr + 2) << 8 | // - loadByte(addr + 3); -} - -void MemoryState::storeByte(uint32_t addr, uint8_t byte) { - // align to the nearest word - uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); - size_t byte_offset = addr % sizeof(uint32_t); - uint32_t word = data[aligned / 4] & ~(0xff << (byte_offset * 8)); - word |= byte << (byte_offset * 8); - store(aligned, word); -} - -void MemoryState::store(uint32_t addr, const void* ptr, uint32_t len) { - const uint8_t* bytes = static_cast(ptr); - for (size_t i = 0; i < len; i++) { - storeByte(addr++, bytes[i]); + + void MemoryState::storeByte(uint32_t addr, uint8_t byte) + { + // align to the nearest word + uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); + size_t byte_offset = addr % sizeof(uint32_t); + uint32_t word = data[aligned / 4] & ~(0xff << (byte_offset * 8)); + word |= byte << (byte_offset * 8); + store(aligned, word); } -} -void MemoryState::store(uint32_t addr, uint32_t value) { - if (addr % 4 != 0) { - throw std::runtime_error("Unaligned store"); + void MemoryState::store(uint32_t addr, const void *ptr, uint32_t len) + { + const uint8_t *bytes = static_cast(ptr); + for (size_t i = 0; i < len; i++) + { + storeByte(addr++, bytes[i]); + } } - uint32_t key = addr / 4; - auto it = data.find(key); - if (it != data.end()) { - auto txn = history.lower_bound({key, 0, 0, 0}); - if (txn != history.end() && txn->addr == key && it->second != value) { - // The guest has actually touched this memory, and we are not writing the same value - throw std::runtime_error("Host cannot mutate existing memory."); + + void MemoryState::store(uint32_t addr, uint32_t value) + { + if (addr % 4 != 0) + { + throw std::runtime_error("Unaligned store"); + } + uint32_t key = addr / 4; + auto it = data.find(key); + if (it != data.end()) + { + auto txn = history.lower_bound({key, 0, 0, 0}); + if (txn != history.end() && txn->addr == key && it->second != value) + { + // The guest has actually touched this memory, and we are not writing the same value + throw std::runtime_error("Host cannot mutate existing memory."); + } + it->second = value; + } + else + { + data[key] = value; } - it->second = value; - } else { - data[key] = value; } -} } // namespace risc0 diff --git a/risc0/zkvm/sdk/rust/guest/src/mul.rs b/risc0/zkvm/sdk/rust/guest/src/mul.rs index 22593744f7..a10ecdc557 100644 --- a/risc0/zkvm/sdk/rust/guest/src/mul.rs +++ b/risc0/zkvm/sdk/rust/guest/src/mul.rs @@ -17,12 +17,15 @@ unsafe impl Sync for CurOutput {} static CUR_OUTPUT: CurOutput = CurOutput(UnsafeCell::new(0)); /// Result of multiply goldilocks -pub struct MulGoldilocks([u32; 2]); +pub struct MulGoldilocks([u32; 4]); impl MulGoldilocks { /// Get the result as u64 - pub fn get_u64(&self) -> u64 { - (self.0[1] as u64) | ((self.0[0] as u64) << 32) + pub fn get_u64(&self) -> [u64; 2] { + [ + (self.0[1] as u64) | ((self.0[0] as u64) << 32), + (self.0[3] as u64) | ((self.0[2] as u64) << 32), + ] } } @@ -37,14 +40,18 @@ fn alloc_output() -> *mut MulDescriptor { } /// Multiply goldilocks oracle, verification is done separately -pub fn mul_goldilocks(a: &u64, b: &u64) -> &'static MulGoldilocks { - let a_hi = ((a & 0xFFFFFFFF00000000) >> 32) as u32; - let a_lo = (a & 0xFFFFFFFF) as u32; +pub fn mul_goldilocks(a: &[u64; 2], b: &[u64; 2]) -> &'static MulGoldilocks { + let a0_hi = ((a[0] & 0xFFFFFFFF00000000) >> 32) as u32; + let a0_lo = (a[0] & 0xFFFFFFFF) as u32; + let a1_hi = ((a[1] & 0xFFFFFFFF00000000) >> 32) as u32; + let a1_lo = (a[1] & 0xFFFFFFFF) as u32; - let b_hi = ((b & 0xFFFFFFFF00000000) >> 32) as u32; - let b_lo = (b & 0xFFFFFFFF) as u32; + let b0_hi = ((b[0] & 0xFFFFFFFF00000000) >> 32) as u32; + let b0_lo = (b[0] & 0xFFFFFFFF) as u32; + let b1_hi = ((b[1] & 0xFFFFFFFF00000000) >> 32) as u32; + let b1_lo = (b[1] & 0xFFFFFFFF) as u32; - let buf = [a_hi, a_lo, b_hi, b_lo]; + let buf = [a0_hi, a0_lo, a1_hi, a1_lo, b0_hi, b0_lo, b1_hi, b1_lo]; unsafe { let alloced = Box::>::new( From b96f42768ec96d1b912b9478f11e4c7f336a02cd Mon Sep 17 00:00:00 2001 From: starkoracles Date: Sun, 27 Nov 2022 12:01:03 -0500 Subject: [PATCH 11/12] Fix bug --- risc0/zkvm/prove/io_handler.cpp | 35 ++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 0e939a00f3..89fe8a5d62 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -35,10 +35,12 @@ namespace risc0 private: static DEVSPEC constexpr uint64_t add(uint64_t a, uint64_t b) { - bool c1 = b > M; + bool c1 = (M - b) > a; uint64_t x1 = a - (M - b); - uint32_t adj = 0 - uint32_t(c1); - return x1 - uint64_t(adj); + uint32_t adj = uint32_t(0) - uint32_t(c1); + uint64_t res = x1 - uint64_t(adj); + // std::cout << "c1: " << c1 << ", x1: " << x1 << ", adj: " << adj << ", res: " << res << std::endl; + return res; } static DEVSPEC constexpr uint64_t sub(uint64_t a, uint64_t b) @@ -67,11 +69,6 @@ namespace risc0 bool c = xh < b; uint64_t r = xh - b; uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); - // std::cout << "xl = " << xl << ", xh = " << xh << ", a = " << a << ", e = " << e << ", b = " - // << - // b - // << ", c = " << c << ", r = " << r << ", mont_result = " << mont_result << - // std::endl; return mont_result; } @@ -92,9 +89,19 @@ namespace risc0 static std::pair extensionMul(std::pair a, std::pair b) { FpG a0b0 = a.first * b.first; + FpG a1b1 = a.second * b.second; + FpG first = a0b0 - a1b1.doubleVal(); + + FpG a0a1 = a.first + a.second; + FpG b0b1 = b.first + b.second; + FpG second = a0a1 * b0b1 - a0b0; + + // std::cout << "CPP a: [" << a.first.val << ", " << a.second.val << "]" << std::endl; + // std::cout << "b: [" << b.first.val << ", " << b.second.val << "]" << std::endl; - FpG first = a0b0 - (a.second * b.second).doubleVal(); - FpG second = (a.first + a.second) * (b.first + b.second) - a0b0; + // std::cout << "a0b0: " << a0b0.val << ", a1b1: " << a1b1.val << ", first: " << first.val + // << ", a0a1: " << a0a1.val << ", b0b1: " << b0b1.val << ", second: " + // << second.val << std::endl; return std::pair(first, second); } @@ -137,13 +144,13 @@ namespace risc0 uint32_t a1_lo = mem.load(desc.source + 12); LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(a1_lo)); - uint32_t b0_hi = mem.load(desc.source); + uint32_t b0_hi = mem.load(desc.source + 16); LOG(1, "Input[" << hex(4, 2) << "]: " << hex(desc.source + 16) << " -> " << hex(b0_hi)); - uint32_t b0_lo = mem.load(desc.source + 4); + uint32_t b0_lo = mem.load(desc.source + 20); LOG(1, "Input[" << hex(5, 2) << "]: " << hex(desc.source + 20) << " -> " << hex(b0_lo)); - uint32_t b1_hi = mem.load(desc.source + 8); + uint32_t b1_hi = mem.load(desc.source + 24); LOG(1, "Input[" << hex(6, 2) << "]: " << hex(desc.source + 24) << " -> " << hex(b1_hi)); - uint32_t b1_lo = mem.load(desc.source + 12); + uint32_t b1_lo = mem.load(desc.source + 28); LOG(1, "Input[" << hex(7, 2) << "]: " << hex(desc.source + 28) << " -> " << hex(b1_lo)); uint64_t a0 = a0_lo | (uint64_t(a0_hi) << 32); From fe226072448e09f17aaec572700b1b1f013bc1f9 Mon Sep 17 00:00:00 2001 From: starkoracles Date: Mon, 28 Nov 2022 14:03:43 -0500 Subject: [PATCH 12/12] fix warning --- risc0/zkvm/prove/io_handler.cpp | 666 +++++++++++++++----------------- 1 file changed, 303 insertions(+), 363 deletions(-) diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 89fe8a5d62..7bdc9dd0ed 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -21,405 +21,345 @@ #include "risc0/zkvm/platform/memory.h" #include "risc0/zkvm/prove/step.h" -namespace risc0 -{ - - class FpG - { - // implement just enough operations to support extension field multiplication - // all values are in mont form - public: - static CONSTSCALAR uint64_t M = 0xFFFFFFFF00000001; - uint64_t val; - - private: - static DEVSPEC constexpr uint64_t add(uint64_t a, uint64_t b) - { - bool c1 = (M - b) > a; - uint64_t x1 = a - (M - b); - uint32_t adj = uint32_t(0) - uint32_t(c1); - uint64_t res = x1 - uint64_t(adj); - // std::cout << "c1: " << c1 << ", x1: " << x1 << ", adj: " << adj << ", res: " << res << std::endl; - return res; - } - - static DEVSPEC constexpr uint64_t sub(uint64_t a, uint64_t b) - { - bool c1 = b > a; - uint64_t x1 = a - b; - uint32_t adj = 0 - uint32_t(c1); - return x1 - uint64_t(adj); - } - - static DEVSPEC constexpr uint64_t doubleVal(uint64_t a) - { - __uint128_t ret = __uint128_t(a) << 1; - uint64_t result = uint64_t(ret); - uint64_t over = uint64_t(ret >> 64); - return result - (M * over); - } - - static uint64_t montRedCst(__uint128_t n) - { - uint64_t xl = uint64_t(n); - uint64_t xh = uint64_t(n >> 64); - bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; // overflow - uint64_t a = xl + (xl << 32); - uint64_t b = a - (a >> 32) - e; - bool c = xh < b; - uint64_t r = xh - b; - uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); - return mont_result; - } - - static DEVSPEC constexpr uint64_t mul(uint64_t a, uint64_t b) - { - __uint128_t n = __uint128_t(a) * __uint128_t(b); - return montRedCst(n); - } - - public: - DEVSPEC constexpr FpG(uint64_t val) : val(val) {} - DEVSPEC constexpr FpG operator+(FpG rhs) const { return FpG(add(val, rhs.val)); } - DEVSPEC constexpr FpG operator-(FpG rhs) const { return FpG(sub(val, rhs.val)); } - DEVSPEC constexpr FpG operator*(FpG rhs) const { return FpG(mul(val, rhs.val)); } - DEVSPEC constexpr FpG doubleVal() const { return FpG(doubleVal(val)); } - }; - - static std::pair extensionMul(std::pair a, std::pair b) - { - FpG a0b0 = a.first * b.first; - FpG a1b1 = a.second * b.second; - FpG first = a0b0 - a1b1.doubleVal(); - - FpG a0a1 = a.first + a.second; - FpG b0b1 = b.first + b.second; - FpG second = a0a1 * b0b1 - a0b0; - - // std::cout << "CPP a: [" << a.first.val << ", " << a.second.val << "]" << std::endl; - // std::cout << "b: [" << b.first.val << ", " << b.second.val << "]" << std::endl; - - // std::cout << "a0b0: " << a0b0.val << ", a1b1: " << a1b1.val << ", first: " << first.val - // << ", a0a1: " << a0a1.val << ", b0b1: " << b0b1.val << ", second: " - // << second.val << std::endl; - - return std::pair(first, second); +namespace risc0 { + +class FpG { + // implement just enough operations to support extension field multiplication + // all values are in mont form +public: + static CONSTSCALAR uint64_t M = 0xFFFFFFFF00000001; + uint64_t val; + +private: + static DEVSPEC constexpr uint64_t add(uint64_t a, uint64_t b) { + bool c1 = (M - b) > a; + uint64_t x1 = a - (M - b); + uint32_t adj = uint32_t(0) - uint32_t(c1); + uint64_t res = x1 - uint64_t(adj); + // std::cout << "c1: " << c1 << ", x1: " << x1 << ", adj: " << adj << ", res: " << res << + // std::endl; + return res; } - static void processSHA(MemoryState &mem, const ShaDescriptor &desc) - { - uint16_t type = (desc.typeAndCount & 0xFFFF) >> 4; - uint16_t count = desc.typeAndCount & 0xFFFF; - LOG(1, - "SHA256 type: " << type << ", count: " << count << ", idx: " << desc.idx - << ", source: " << hex(desc.source) << ", digest: " << hex(desc.digest)); - ShaDigest sha = impl::initState(); - uint32_t words[16]; - for (int i = 0; i < count; i++) - { - for (int j = 0; j < 16; j++) - { - uint32_t from = desc.source + i * 16 * 4 + j * 4; - words[j] = mem.loadBE(from); - LOG(1, "Input[" << hex(j, 2) << "]: " << hex(from) << " -> " << hex(words[j])); - } - LOG(1, "Compress"); - impl::compress(sha, words); - } - for (int i = 0; i < 8; i++) - { - LOG(1, "Output[" << hex(i, 1) << "]: " << hex(sha.words[i])); - mem.store(desc.digest + i * 4, sha.words[i]); - } + static DEVSPEC constexpr uint64_t sub(uint64_t a, uint64_t b) { + bool c1 = b > a; + uint64_t x1 = a - b; + uint32_t adj = 0 - uint32_t(c1); + return x1 - uint64_t(adj); } - static void processMul(MemoryState &mem, const MulDescriptor &desc) - { - uint32_t a0_hi = mem.load(desc.source); - LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a0_hi)); - uint32_t a0_lo = mem.load(desc.source + 4); - LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a0_lo)); - uint32_t a1_hi = mem.load(desc.source + 8); - LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(a1_hi)); - uint32_t a1_lo = mem.load(desc.source + 12); - LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(a1_lo)); - - uint32_t b0_hi = mem.load(desc.source + 16); - LOG(1, "Input[" << hex(4, 2) << "]: " << hex(desc.source + 16) << " -> " << hex(b0_hi)); - uint32_t b0_lo = mem.load(desc.source + 20); - LOG(1, "Input[" << hex(5, 2) << "]: " << hex(desc.source + 20) << " -> " << hex(b0_lo)); - uint32_t b1_hi = mem.load(desc.source + 24); - LOG(1, "Input[" << hex(6, 2) << "]: " << hex(desc.source + 24) << " -> " << hex(b1_hi)); - uint32_t b1_lo = mem.load(desc.source + 28); - LOG(1, "Input[" << hex(7, 2) << "]: " << hex(desc.source + 28) << " -> " << hex(b1_lo)); - - uint64_t a0 = a0_lo | (uint64_t(a0_hi) << 32); - uint64_t a1 = a1_lo | (uint64_t(a1_hi) << 32); - uint64_t b0 = b0_lo | (uint64_t(b0_hi) << 32); - uint64_t b1 = b1_lo | (uint64_t(b1_hi) << 32); - - std::pair a = std::pair(FpG(a0), FpG(a1)); - std::pair b = std::pair(FpG(b0), FpG(b1)); - std::pair result = extensionMul(a, b); - - uint64_t r0 = result.first.val; - uint32_t r0_high = (uint32_t)((r0 & 0xFFFFFFFF00000000LL) >> 32); - uint32_t r0_low = (uint32_t)(r0 & 0xFFFFFFFFLL); - - uint64_t r1 = result.second.val; - uint32_t r1_high = (uint32_t)((r1 & 0xFFFFFFFF00000000LL) >> 32); - uint32_t r1_low = (uint32_t)(r1 & 0xFFFFFFFFLL); - - LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(r0_high)); - mem.store(desc.result, r0_high); - LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(r0_low)); - mem.store(desc.result + 4, r0_low); - LOG(1, "Output[" << hex(2, 2) << "]: " << hex(desc.result + 8) << " <- " << hex(r1_high)); - mem.store(desc.result + 8, r1_high); - LOG(1, "Output[" << hex(3, 2) << "]: " << hex(desc.result + 12) << " <- " << hex(r1_low)); - mem.store(desc.result + 12, r1_low); + static DEVSPEC constexpr uint64_t doubleVal(uint64_t a) { + __uint128_t ret = __uint128_t(a) << 1; + uint64_t result = uint64_t(ret); + uint64_t over = uint64_t(ret >> 64); + return result - (M * over); } - void IoHandler::onFault(const std::string &msg) - { - throw std::runtime_error(msg); + static DEVSPEC constexpr uint64_t montRedCst(__uint128_t n) { + uint64_t xl = uint64_t(n); + uint64_t xh = uint64_t(n >> 64); + bool e = (__uint128_t(xl) + __uint128_t(xl << 32)) > UINT64_MAX; // overflow + uint64_t a = xl + (xl << 32); + uint64_t b = a - (a >> 32) - e; + bool c = xh < b; + uint64_t r = xh - b; + uint64_t mont_result = r - (uint32_t(0) - uint32_t(c)); + return mont_result; } - MemoryHandler::MemoryHandler() : MemoryHandler(nullptr) {} - - MemoryHandler::MemoryHandler(IoHandler *io) : io(io), cur_host_to_guest_offset(kMemInputStart) {} - - void MemoryHandler::onInit(MemoryState &mem) - { - if (io) - { - io->onInit(mem); - } + static DEVSPEC constexpr uint64_t mul(uint64_t a, uint64_t b) { + __uint128_t n = __uint128_t(a) * __uint128_t(b); + return montRedCst(n); } - void MemoryHandler::onWrite(MemoryState &mem, uint32_t cycle, uint32_t addr, uint32_t value) - { - LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value)); - switch (addr) - { - case kGPIO_Mul: - { - LOG(1, "MemoryHandler::onWrite> GPIO_MUL"); - MulDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - processMul(mem, desc); - break; - } - case kGPIO_SHA: - { - LOG(1, "MemoryHandler::onWrite> GPIO_SHA"); - ShaDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - processSHA(mem, desc); - } - break; - case kGPIO_Commit: - { - LOG(1, "MemoryHandler::onWrite> GPIO_Commit"); - IoDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - if (io) - { - std::vector buf(desc.size); - mem.loadRegion(desc.addr, buf.data(), desc.size); - io->onCommit(buf); - } +public: + DEVSPEC constexpr FpG(uint64_t val) : val(val) {} + DEVSPEC constexpr FpG operator+(FpG rhs) const { return FpG(add(val, rhs.val)); } + DEVSPEC constexpr FpG operator-(FpG rhs) const { return FpG(sub(val, rhs.val)); } + DEVSPEC constexpr FpG operator*(FpG rhs) const { return FpG(mul(val, rhs.val)); } + DEVSPEC constexpr FpG doubleVal() const { return FpG(doubleVal(val)); } +}; + +static std::pair extensionMul(std::pair a, std::pair b) { + FpG a0b0 = a.first * b.first; + FpG a1b1 = a.second * b.second; + FpG first = a0b0 - a1b1.doubleVal(); + + FpG a0a1 = a.first + a.second; + FpG b0b1 = b.first + b.second; + FpG second = a0a1 * b0b1 - a0b0; + + // std::cout << "CPP a: [" << a.first.val << ", " << a.second.val << "]" << std::endl; + // std::cout << "b: [" << b.first.val << ", " << b.second.val << "]" << std::endl; + + // std::cout << "a0b0: " << a0b0.val << ", a1b1: " << a1b1.val << ", first: " << first.val + // << ", a0a1: " << a0a1.val << ", b0b1: " << b0b1.val << ", second: " + // << second.val << std::endl; + + return std::pair(first, second); +} + +static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { + uint16_t type = (desc.typeAndCount & 0xFFFF) >> 4; + uint16_t count = desc.typeAndCount & 0xFFFF; + LOG(1, + "SHA256 type: " << type << ", count: " << count << ", idx: " << desc.idx + << ", source: " << hex(desc.source) << ", digest: " << hex(desc.digest)); + ShaDigest sha = impl::initState(); + uint32_t words[16]; + for (int i = 0; i < count; i++) { + for (int j = 0; j < 16; j++) { + uint32_t from = desc.source + i * 16 * 4 + j * 4; + words[j] = mem.loadBE(from); + LOG(1, "Input[" << hex(j, 2) << "]: " << hex(from) << " -> " << hex(words[j])); } + LOG(1, "Compress"); + impl::compress(sha, words); + } + for (int i = 0; i < 8; i++) { + LOG(1, "Output[" << hex(i, 1) << "]: " << hex(sha.words[i])); + mem.store(desc.digest + i * 4, sha.words[i]); + } +} + +static void processMul(MemoryState& mem, const MulDescriptor& desc) { + uint32_t a0_hi = mem.load(desc.source); + LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a0_hi)); + uint32_t a0_lo = mem.load(desc.source + 4); + LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a0_lo)); + uint32_t a1_hi = mem.load(desc.source + 8); + LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(a1_hi)); + uint32_t a1_lo = mem.load(desc.source + 12); + LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(a1_lo)); + + uint32_t b0_hi = mem.load(desc.source + 16); + LOG(1, "Input[" << hex(4, 2) << "]: " << hex(desc.source + 16) << " -> " << hex(b0_hi)); + uint32_t b0_lo = mem.load(desc.source + 20); + LOG(1, "Input[" << hex(5, 2) << "]: " << hex(desc.source + 20) << " -> " << hex(b0_lo)); + uint32_t b1_hi = mem.load(desc.source + 24); + LOG(1, "Input[" << hex(6, 2) << "]: " << hex(desc.source + 24) << " -> " << hex(b1_hi)); + uint32_t b1_lo = mem.load(desc.source + 28); + LOG(1, "Input[" << hex(7, 2) << "]: " << hex(desc.source + 28) << " -> " << hex(b1_lo)); + + uint64_t a0 = a0_lo | (uint64_t(a0_hi) << 32); + uint64_t a1 = a1_lo | (uint64_t(a1_hi) << 32); + uint64_t b0 = b0_lo | (uint64_t(b0_hi) << 32); + uint64_t b1 = b1_lo | (uint64_t(b1_hi) << 32); + + std::pair a = std::pair(FpG(a0), FpG(a1)); + std::pair b = std::pair(FpG(b0), FpG(b1)); + std::pair result = extensionMul(a, b); + + uint64_t r0 = result.first.val; + uint32_t r0_high = (uint32_t)((r0 & 0xFFFFFFFF00000000LL) >> 32); + uint32_t r0_low = (uint32_t)(r0 & 0xFFFFFFFFLL); + + uint64_t r1 = result.second.val; + uint32_t r1_high = (uint32_t)((r1 & 0xFFFFFFFF00000000LL) >> 32); + uint32_t r1_low = (uint32_t)(r1 & 0xFFFFFFFFLL); + + LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(r0_high)); + mem.store(desc.result, r0_high); + LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(r0_low)); + mem.store(desc.result + 4, r0_low); + LOG(1, "Output[" << hex(2, 2) << "]: " << hex(desc.result + 8) << " <- " << hex(r1_high)); + mem.store(desc.result + 8, r1_high); + LOG(1, "Output[" << hex(3, 2) << "]: " << hex(desc.result + 12) << " <- " << hex(r1_low)); + mem.store(desc.result + 12, r1_low); +} + +void IoHandler::onFault(const std::string& msg) { + throw std::runtime_error(msg); +} + +MemoryHandler::MemoryHandler() : MemoryHandler(nullptr) {} + +MemoryHandler::MemoryHandler(IoHandler* io) : io(io), cur_host_to_guest_offset(kMemInputStart) {} + +void MemoryHandler::onInit(MemoryState& mem) { + if (io) { + io->onInit(mem); + } +} + +void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) { + LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value)); + switch (addr) { + case kGPIO_Mul: { + LOG(1, "MemoryHandler::onWrite> GPIO_MUL"); + MulDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + processMul(mem, desc); break; - case kGPIO_Fault: - { - LOG(1, "MemoryHandler::onWrite> GPIO_Fault"); - if (io) - { - size_t len = mem.strlen(value); - std::vector buf(len); - mem.loadRegion(value, buf.data(), len); - std::string str(buf.data(), buf.size()); - io->onFault(str); - } + } + case kGPIO_SHA: { + LOG(1, "MemoryHandler::onWrite> GPIO_SHA"); + ShaDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + processSHA(mem, desc); + } break; + case kGPIO_Commit: { + LOG(1, "MemoryHandler::onWrite> GPIO_Commit"); + IoDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + if (io) { + std::vector buf(desc.size); + mem.loadRegion(desc.addr, buf.data(), desc.size); + io->onCommit(buf); } - break; - case kGPIO_Log: - { - LOG(2, "MemoryHandler::onWrite> GPIO_Log"); + } break; + case kGPIO_Fault: { + LOG(1, "MemoryHandler::onWrite> GPIO_Fault"); + if (io) { size_t len = mem.strlen(value); std::vector buf(len); mem.loadRegion(value, buf.data(), len); std::string str(buf.data(), buf.size()); - LOG(0, "R0VM[C" << cycle << "]> " << str); + io->onFault(str); } - break; - case kGPIO_GetKey: - { - LOG(1, "MemoryHandler::onWrite> GPIO_GetKey"); - GetKeyDescriptor desc; - mem.loadRegion(value, &desc, sizeof(desc)); - if (!io) - { - throw std::runtime_error("Get key called with no IO handler set"); - } - size_t len = mem.strlen(desc.name); - std::vector buf(len); - mem.loadRegion(desc.name, buf.data(), len); - std::string str(buf.data(), buf.size()); - LOG(1, " addr = " << hex(desc.addr)); - LOG(1, " key = " << str); - LOG(1, " mode = " << desc.mode); - KeyStore &store = io->getKeyStore(); - if (desc.mode == 0 && store.count(str)) - { - throw std::runtime_error("GetKey Mode = NEW and key exists: " + str); - } - if (desc.mode == 1 && !store.count(str)) - { - throw std::runtime_error("GetKey Mode = EXISTING and key does not exist: " + str); - } - const Key &key = store[str]; - mem.store(desc.addr, reinterpret_cast(&key), sizeof(Key)); + } break; + case kGPIO_Log: { + LOG(2, "MemoryHandler::onWrite> GPIO_Log"); + size_t len = mem.strlen(value); + std::vector buf(len); + mem.loadRegion(value, buf.data(), len); + std::string str(buf.data(), buf.size()); + LOG(0, "R0VM[C" << cycle << "]> " << str); + } break; + case kGPIO_GetKey: { + LOG(1, "MemoryHandler::onWrite> GPIO_GetKey"); + GetKeyDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + if (!io) { + throw std::runtime_error("Get key called with no IO handler set"); } - break; - case kGPIO_SendRecvAddr: - { - if (io) - { - uint32_t channel = mem.load(kGPIO_SendRecvChannel); - std::vector buf(mem.load(kGPIO_SendRecvSize)); - LOG(1, - "MemoryHandler::onWrite> GPIO_SendReceive, channel " << channel - << " size=" << buf.size()); - mem.loadRegion(value, buf.data(), buf.size()); - BufferU8 result = io->onSendRecv(channel, buf); - LOG(1, - "MemoryHandler::onWrite> GPIO_SendReceive, host replied with " << result.size() - << " bytes"); - size_t aligned_len = align(result.size()); - if ((cur_host_to_guest_offset + sizeof(uint32_t) + aligned_len) >= kMemInputEnd) - { - throw(std::runtime_error("Read buffer overrun")); - } - mem.store(cur_host_to_guest_offset, result.size()); - cur_host_to_guest_offset += sizeof(uint32_t); - for (size_t i = 0; i < result.size(); ++i) - { - mem.storeByte(cur_host_to_guest_offset + i, result[i]); - } - cur_host_to_guest_offset += aligned_len; - } - else - { - throw std::runtime_error("SendRecv called with no IO handler set"); - } + size_t len = mem.strlen(desc.name); + std::vector buf(len); + mem.loadRegion(desc.name, buf.data(), len); + std::string str(buf.data(), buf.size()); + LOG(1, " addr = " << hex(desc.addr)); + LOG(1, " key = " << str); + LOG(1, " mode = " << desc.mode); + KeyStore& store = io->getKeyStore(); + if (desc.mode == 0 && store.count(str)) { + throw std::runtime_error("GetKey Mode = NEW and key exists: " + str); } - break; + if (desc.mode == 1 && !store.count(str)) { + throw std::runtime_error("GetKey Mode = EXISTING and key does not exist: " + str); } - } - - void MemoryState::dump(size_t logLevel) - { - LOG(logLevel, "MemoryState::dump> size: " << data.size()); - if (getLogLevel() >= logLevel) - { - for (auto pair : data) - { - LOG(logLevel, " " << hex(pair.first * 4) << ": " << hex(pair.second)); + const Key& key = store[str]; + mem.store(desc.addr, reinterpret_cast(&key), sizeof(Key)); + } break; + case kGPIO_SendRecvAddr: { + if (io) { + uint32_t channel = mem.load(kGPIO_SendRecvChannel); + std::vector buf(mem.load(kGPIO_SendRecvSize)); + LOG(1, + "MemoryHandler::onWrite> GPIO_SendReceive, channel " << channel + << " size=" << buf.size()); + mem.loadRegion(value, buf.data(), buf.size()); + BufferU8 result = io->onSendRecv(channel, buf); + LOG(1, + "MemoryHandler::onWrite> GPIO_SendReceive, host replied with " << result.size() + << " bytes"); + size_t aligned_len = align(result.size()); + if ((cur_host_to_guest_offset + sizeof(uint32_t) + aligned_len) >= kMemInputEnd) { + throw(std::runtime_error("Read buffer overrun")); + } + mem.store(cur_host_to_guest_offset, result.size()); + cur_host_to_guest_offset += sizeof(uint32_t); + for (size_t i = 0; i < result.size(); ++i) { + mem.storeByte(cur_host_to_guest_offset + i, result[i]); } + cur_host_to_guest_offset += aligned_len; + } else { + throw std::runtime_error("SendRecv called with no IO handler set"); } + } break; } +} - size_t MemoryState::strlen(uint32_t addr) - { - size_t len = 0; - while (loadByte(addr++)) - { - len++; +void MemoryState::dump(size_t logLevel) { + LOG(logLevel, "MemoryState::dump> size: " << data.size()); + if (getLogLevel() >= logLevel) { + for (auto pair : data) { + LOG(logLevel, " " << hex(pair.first * 4) << ": " << hex(pair.second)); } - return len; } +} - uint8_t MemoryState::loadByte(uint32_t addr) - { - // align to the nearest word - uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); - size_t byte_offset = addr % sizeof(uint32_t); - uint32_t word = load(aligned); - return (word >> (byte_offset * 8)) & 0xff; +size_t MemoryState::strlen(uint32_t addr) { + size_t len = 0; + while (loadByte(addr++)) { + len++; } - - uint32_t MemoryState::load(uint32_t addr) - { - auto it = data.find(addr / 4); - if (it == data.end()) - { - std::stringstream ss; - ss << "addr out of range: " << hex(addr); - throw std::out_of_range(ss.str()); - } - return it->second; + return len; +} + +uint8_t MemoryState::loadByte(uint32_t addr) { + // align to the nearest word + uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); + size_t byte_offset = addr % sizeof(uint32_t); + uint32_t word = load(aligned); + return (word >> (byte_offset * 8)) & 0xff; +} + +uint32_t MemoryState::load(uint32_t addr) { + auto it = data.find(addr / 4); + if (it == data.end()) { + std::stringstream ss; + ss << "addr out of range: " << hex(addr); + throw std::out_of_range(ss.str()); } + return it->second; +} - void MemoryState::loadRegion(uint32_t addr, void *ptr, uint32_t len) - { - uint8_t *bytes = static_cast(ptr); - for (size_t i = 0; i < len; i++) - { - bytes[i] = loadByte(addr++); - } +void MemoryState::loadRegion(uint32_t addr, void* ptr, uint32_t len) { + uint8_t* bytes = static_cast(ptr); + for (size_t i = 0; i < len; i++) { + bytes[i] = loadByte(addr++); } - - uint32_t MemoryState::loadBE(uint32_t addr) - { - return loadByte(addr + 0) << 24 | // - loadByte(addr + 1) << 16 | // - loadByte(addr + 2) << 8 | // - loadByte(addr + 3); +} + +uint32_t MemoryState::loadBE(uint32_t addr) { + return loadByte(addr + 0) << 24 | // + loadByte(addr + 1) << 16 | // + loadByte(addr + 2) << 8 | // + loadByte(addr + 3); +} + +void MemoryState::storeByte(uint32_t addr, uint8_t byte) { + // align to the nearest word + uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); + size_t byte_offset = addr % sizeof(uint32_t); + uint32_t word = data[aligned / 4] & ~(0xff << (byte_offset * 8)); + word |= byte << (byte_offset * 8); + store(aligned, word); +} + +void MemoryState::store(uint32_t addr, const void* ptr, uint32_t len) { + const uint8_t* bytes = static_cast(ptr); + for (size_t i = 0; i < len; i++) { + storeByte(addr++, bytes[i]); } +} - void MemoryState::storeByte(uint32_t addr, uint8_t byte) - { - // align to the nearest word - uint32_t aligned = addr & ~(sizeof(uint32_t) - 1); - size_t byte_offset = addr % sizeof(uint32_t); - uint32_t word = data[aligned / 4] & ~(0xff << (byte_offset * 8)); - word |= byte << (byte_offset * 8); - store(aligned, word); +void MemoryState::store(uint32_t addr, uint32_t value) { + if (addr % 4 != 0) { + throw std::runtime_error("Unaligned store"); } - - void MemoryState::store(uint32_t addr, const void *ptr, uint32_t len) - { - const uint8_t *bytes = static_cast(ptr); - for (size_t i = 0; i < len; i++) - { - storeByte(addr++, bytes[i]); - } - } - - void MemoryState::store(uint32_t addr, uint32_t value) - { - if (addr % 4 != 0) - { - throw std::runtime_error("Unaligned store"); - } - uint32_t key = addr / 4; - auto it = data.find(key); - if (it != data.end()) - { - auto txn = history.lower_bound({key, 0, 0, 0}); - if (txn != history.end() && txn->addr == key && it->second != value) - { - // The guest has actually touched this memory, and we are not writing the same value - throw std::runtime_error("Host cannot mutate existing memory."); - } - it->second = value; - } - else - { - data[key] = value; + uint32_t key = addr / 4; + auto it = data.find(key); + if (it != data.end()) { + auto txn = history.lower_bound({key, 0, 0, 0}); + if (txn != history.end() && txn->addr == key && it->second != value) { + // The guest has actually touched this memory, and we are not writing the same value + throw std::runtime_error("Host cannot mutate existing memory."); } + it->second = value; + } else { + data[key] = value; } +} } // namespace risc0