diff --git a/MerkleSet.py b/MerkleSet.py index 9b9067b..417720f 100644 --- a/MerkleSet.py +++ b/MerkleSet.py @@ -64,6 +64,9 @@ DONE = 6 FULL = 7 +HASHSIZE = 32 +SHORTSIZE = 2 + def from_bytes(f): return int.from_bytes(f, 'big') @@ -110,6 +113,84 @@ def __setitem__(self, index, thing): assert index < len(self) bytearray.__setitem__(self, index, thing) +class Node: + def __init__(self,data,pos): + self.data = data + self.pos = pos + + def get_offset(self): + return self.pos + + def make_unused(self,nextent): + self.data[self.pos:self.pos + 2] = to_bytes(nextent,2) + self.data[self.pos + 2:self.pos + 68] = bytes(66) + + def make_unused_unsafe(self,nextent): + self.data[self.pos:self.pos + 2] = to_bytes(nextent,2) + + def get_unused_ptr(self): + assert self.data[self.pos+2:self.pos+68] == bytes(66) and from_bytes(self.data[self.pos:self.pos+2]) != 0 + return from_bytes(self.data[self.pos:self.pos+2]) + + def get_unused_ptr_unsafe(self): + return from_bytes(self.data[self.pos:self.pos+2]) + + def hash_loc(self,n): + return self.pos + (n * HASHSIZE) + + def pos_loc(self,n): + return self.pos + 64 + (n * SHORTSIZE) + + def get_hash(self,n): + offset = self.hash_loc(n) + return self.data[offset:offset+HASHSIZE] + + def set_hash(self,n,v): + offset = self.hash_loc(n) + self.data[offset:offset+HASHSIZE] = v + + def get_pos(self,n): + offset = self.pos_loc(n) + res = from_bytes(self.data[offset:offset+SHORTSIZE]) + if res != 0xffff: + return res - 1 + else: + return res + + def set_pos(self,n,v): + offset = self.pos_loc(n) + if v != 0xffff: + v += 1 + self.data[offset:offset+SHORTSIZE] = to_bytes(v, 2) + + def get_type(self,n): + offset = self.hash_loc(n) + return get_type(self.data, offset) + + def make_invalid(self,n): + offset = self.hash_loc(n) + return make_invalid(self.data, offset) + +class Leaf(safearray): + def get_next_ptr(self): + return from_bytes(self[:2]) + + def set_next_ptr(self,ptr): + self[:2] = to_bytes(ptr, 2) + + def get_node(self,pos): + assert pos >= 0 and (4 + 68 * (pos + 1)) <= len(self) + return Node(self, 4 + 68 * pos) + + def get_inputs(self): + return from_bytes(self[2:4]) + + def set_inputs(self,i): + self[2:4] = to_bytes(i, 2) + + def get_hash(self,n): + return self[n * HASHSIZE:(n+1) * HASHSIZE] + class MerkleSet: # depth sets the size of branches, it's power of two scale with a smallest value of 0 # leaf_units is the size of leaves, its smallest possible value is 1 @@ -140,16 +221,16 @@ def audit(self, hashes): else: allblocks = set() e = (self.root if t == MIDDLE else None) - self._audit_branch(self._deref(self.rootblock), 0, allblocks, e, newhashes, True) + self._audit_branch(self._addrof(self.rootblock), 0, allblocks, e, newhashes, True) assert allblocks == set(self.pointers_to_arrays.keys()) - s = sorted([flip_terminal(x) for x in hashes]) + s = sorted([set_terminal(x) for x in hashes]) assert newhashes == s def _audit_branch(self, branch, depth, allblocks, expected, hashes, can_terminate): assert branch not in allblocks allblocks.add(branch) outputs = {} - branch = self._ref(branch) + branch = self._deref(branch) assert len(branch) == 8 + self.subblock_lengths[-1] self._audit_branch_inner(branch, 8, depth, len(self.subblock_lengths) - 1, outputs, allblocks, expected, hashes, can_terminate) active = branch[:8] @@ -169,7 +250,7 @@ def _audit_branch_inner(self, branch, pos, depth, moddepth, outputs, allblocks, self._audit_branch(output, depth, allblocks, expected, hashes, can_terminate) else: outputs.setdefault(output, []).append((newpos, expected)) - self._add_hashes_leaf(self._ref(output), newpos, hashes, can_terminate) + self._add_hashes_leaf(self._deref(output), newpos, hashes, can_terminate) return t0 = get_type(branch, pos) t1 = get_type(branch, pos + 32) @@ -198,19 +279,18 @@ def _audit_branch_inner(self, branch, pos, depth, moddepth, outputs, allblocks, self._audit_branch_inner(branch, pos + 64 + self.subblock_lengths[moddepth - 1], depth + 1, moddepth - 1, outputs, allblocks, e, hashes, t0 != EMPTY) def _add_hashes_leaf(self, leaf, pos, hashes, can_terminate): - assert pos >= 0 - rpos = 4 + pos * 68 - t0 = get_type(leaf, rpos) - t1 = get_type(leaf, rpos + 32) + node = leaf.get_node(pos) + t0 = node.get_type(0) + t1 = node.get_type(1) if t0 == TERMINAL: - hashes.append(leaf[rpos:rpos + 32]) + hashes.append(node.get_hash(0)) assert can_terminate or t1 != TERMINAL elif t0 != EMPTY: - self._add_hashes_leaf(leaf, from_bytes(leaf[rpos + 64:rpos + 66]) - 1, hashes, t1 != EMPTY) + self._add_hashes_leaf(leaf, node.get_pos(0), hashes, t1 != EMPTY) if t1 == TERMINAL: - hashes.append(leaf[rpos + 32:rpos + 64]) + hashes.append(node.get_hash(1)) elif t1 != EMPTY: - self._add_hashes_leaf(leaf, from_bytes(leaf[rpos + 66:rpos + 68]) - 1, hashes, t0 != EMPTY) + self._add_hashes_leaf(leaf, node.get_pos(1), hashes, t0 != EMPTY) def _audit_branch_inner_empty(self, branch, pos, moddepth): if moddepth == 0: @@ -221,14 +301,14 @@ def _audit_branch_inner_empty(self, branch, pos, moddepth): self._audit_branch_inner_empty(branch, pos + 64 + self.subblock_lengths[moddepth - 1], moddepth - 1) def _audit_whole_leaf(self, leaf, inputs): - leaf = self._ref(leaf) + leaf = self._deref(leaf) assert len(leaf) == 4 + self.leaf_units * 68 - assert len(inputs) == from_bytes(leaf[2:4]) + assert len(inputs) == leaf.get_inputs() # 88 is the ASCII value for 'X' mycopy = bytearray([88] * (4 + self.leaf_units * 68)) for pos, expected in inputs: self._audit_whole_leaf_inner(leaf, mycopy, pos, expected) - i = from_bytes(leaf[:2]) + i = leaf.get_next_ptr() while i != 0xFFFF: nexti = from_bytes(leaf[4 + i * 68:4 + i * 68 + 2]) assert mycopy[4 + i * 68:4 + i * 68 + 68] == b'X' * 68 @@ -242,59 +322,64 @@ def _audit_whole_leaf_inner(self, leaf, mycopy, pos, expected): rpos = 4 + pos * 68 assert mycopy[rpos:rpos + 68] == b'X' * 68 mycopy[rpos:rpos + 68] = leaf[rpos:rpos + 68] + + source_node = Node(leaf,rpos) + t0 = get_type(leaf, rpos) t1 = get_type(leaf, rpos + 32) + if expected is not None: - assert t0 != INVALID and t1 != INVALID and hashaudit(leaf[rpos:rpos + 64]) == expected + assert t0 != INVALID and t1 != INVALID and hashaudit(source_node.get_hash(0) + source_node.get_hash(1)) == expected if t0 == EMPTY: assert t1 != EMPTY assert t1 != TERMINAL - assert leaf[rpos:rpos + 32] == BLANK - assert leaf[rpos + 64:rpos + 66] == bytes(2) + assert source_node.get_hash(0) == BLANK + assert source_node.get_pos(0) == -1 elif t0 == TERMINAL: assert t1 != EMPTY - assert leaf[rpos + 64:rpos + 66] == bytes(2) + assert source_node.get_pos(0) == -1 else: - e = (leaf[rpos:rpos + 32] if t0 == MIDDLE else None) - self._audit_whole_leaf_inner(leaf, mycopy, from_bytes(leaf[rpos + 64:rpos + 66]) - 1, e) + e = (source_node.get_hash(0) if t0 == MIDDLE else None) + self._audit_whole_leaf_inner(leaf, mycopy, source_node.get_pos(0), e) if t1 == EMPTY: - assert leaf[rpos + 32:rpos + 64] == BLANK - assert leaf[rpos + 66:rpos + 68] == bytes(2) + assert source_node.get_hash(1) == BLANK + assert source_node.get_pos(1) == -1 elif t1 == TERMINAL: - assert leaf[rpos + 66:rpos + 68] == bytes(2) + assert source_node.get_pos(1) == -1 else: - e = (leaf[rpos + 32:rpos + 64] if t1 == MIDDLE else None) - self._audit_whole_leaf_inner(leaf, mycopy, from_bytes(leaf[rpos + 66:rpos + 68]) - 1, e) + e = (source_node.get_hash(1) if t1 == MIDDLE else None) + self._audit_whole_leaf_inner(leaf, mycopy, source_node.get_pos(1), e) # In C this should be malloc/new def _allocate_branch(self): b = safearray(8 + self.subblock_lengths[-1]) - self.pointers_to_arrays[self._deref(b)] = b + self.pointers_to_arrays[self._addrof(b, False)] = b return b # In C this should be malloc/new def _allocate_leaf(self): - leaf = safearray(4 + self.leaf_units * 68) + leaf = Leaf(4 + self.leaf_units * 68) for i in range(self.leaf_units): p = 4 + i * 68 leaf[p:p + 2] = to_bytes((i + 1) if i != self.leaf_units - 1 else 0xFFFF, 2) - self.pointers_to_arrays[self._deref(leaf)] = leaf + self.pointers_to_arrays[self._addrof(leaf, False)] = leaf return leaf # In C this should be calloc/free def _deallocate(self, thing): - del self.pointers_to_arrays[self._deref(thing)] + del self.pointers_to_arrays[self._addrof(thing)] # In C this should be * - def _ref(self, ref): + def _deref(self, ref): assert len(ref) == 8 if ref == bytes(8): return None return self.pointers_to_arrays[bytes(ref)] # In C this should be & - def _deref(self, thing): + def _addrof(self, thing, check=True): assert thing is not None + assert not check or any(x == thing for x in self.pointers_to_arrays.values()) return to_bytes(id(thing), 8) def get_root(self): @@ -304,7 +389,7 @@ def get_root(self): def _force_calculation_branch(self, block, pos, moddepth): if moddepth == 0: - block2 = self._ref(block[pos:pos + 8]) + block2 = self._deref(block[pos:pos + 8]) pos = from_bytes(block[pos + 8:pos + 10]) if pos == 0xFFFF: return self._force_calculation_branch(block2, 8, len(self.subblock_lengths) - 1) @@ -318,10 +403,11 @@ def _force_calculation_branch(self, block, pos, moddepth): def _force_calculation_leaf(self, block, pos): pos = 4 + pos * 68 + node = Node(block, pos) if get_type(block, pos) == INVALID: - block[pos:pos + 32] = self._force_calculation_leaf(block, from_bytes(block[pos + 64:pos + 66]) - 1) + block[pos:pos + 32] = self._force_calculation_leaf(block, node.get_pos(0)) if get_type(block, pos + 32) == INVALID: - block[pos + 32:pos + 64] = self._force_calculation_leaf(block, from_bytes(block[pos + 66:pos + 68]) - 1) + block[pos + 32:pos + 64] = self._force_calculation_leaf(block, node.get_pos(1)) return hashaudit(block[pos:pos + 64]) # Convenience function @@ -329,7 +415,7 @@ def add(self, toadd): return self.add_already_hashed(sha256(toadd).digest()) def add_already_hashed(self, toadd): - self._add(flip_terminal(toadd)) + self._add(set_terminal(toadd)) def _add(self, toadd): t = get_type(self.root, 0) @@ -352,7 +438,7 @@ def _add_to_branch(self, toadd, block, depth): # returns NOTSTARTED, INVALIDATING, DONE def _add_to_branch_inner(self, toadd, block, pos, depth, moddepth): if moddepth == 0: - nextblock = self._ref(block[pos:pos + 8]) + nextblock = self._deref(block[pos:pos + 8]) if nextblock is None: return NOTSTARTED nextpos = from_bytes(block[pos + 8:pos + 10]) @@ -438,7 +524,7 @@ def _add_to_branch_inner(self, toadd, block, pos, depth, moddepth): def _insert_branch(self, things, block, pos, depth, moddepth): assert 2 <= len(things) <= 3 if moddepth == 0: - child = self._ref(block[:8]) + child = self._deref(block[:8]) r = FULL if child is not None: r, leafpos = self._insert_leaf(things, child, depth) @@ -448,14 +534,14 @@ def _insert_branch(self, things, block, pos, depth, moddepth): if r == FULL: self._deallocate(child) newb = self._allocate_branch() - block[pos:pos + 8] = self._deref(newb) + block[pos:pos + 8] = self._addrof(newb) block[pos + 8:pos + 10] = to_bytes(0xFFFF, 2) self._insert_branch(things, newb, 8, depth, len(self.subblock_lengths) - 1) return - block[:8] = self._deref(child) + block[:8] = self._addrof(child) # increment the number of inputs in the active child - child[2:4] = to_bytes(from_bytes(child[2:4]) + 1, 2) - block[pos:pos + 8] = self._deref(child) + child.set_inputs(child.get_inputs() + 1) + block[pos:pos + 8] = self._addrof(child) block[pos + 8:pos + 10] = to_bytes(leafpos, 2) return things.sort() @@ -486,19 +572,19 @@ def _add_to_leaf(self, toadd, branch, branchpos, leaf, leafpos, depth): r = self._add_to_leaf_inner(toadd, leaf, leafpos, depth) if r != FULL: return r - if from_bytes(leaf[2:4]) == 1: + if leaf.get_inputs() == 1: # leaf is full and only has one input # it cannot be split so it must be replaced with a branch newb = self._allocate_branch() self._copy_leaf_to_branch(newb, 8, len(self.subblock_lengths) - 1, leaf, leafpos) self._add_to_branch(toadd, newb, depth) - branch[branchpos:branchpos + 8] = self._deref(newb) + branch[branchpos:branchpos + 8] = self._addrof(newb) branch[branchpos + 8:branchpos + 10] = to_bytes(0xFFFF, 2) - if branch[:8] == self._deref(leaf): + if branch[:8] == self._addrof(leaf): branch[:8] = bytes(8) self._deallocate(leaf) return INVALIDATING - active = self._ref(branch[:8]) + active = self._deref(branch[:8]) if active is None or active is leaf: active = self._allocate_leaf() r, newpos = self._copy_between_leafs(leaf, active, leafpos) @@ -506,9 +592,9 @@ def _add_to_leaf(self, toadd, branch, branchpos, leaf, leafpos, depth): active = self._allocate_leaf() r, newpos = self._copy_between_leafs(leaf, active, leafpos) assert r == DONE - branch[branchpos:branchpos + 8] = self._deref(active) - if branch[:8] != self._deref(active): - branch[:8] = self._deref(active) + branch[branchpos:branchpos + 8] = self._addrof(active) + if branch[:8] != self._addrof(active): + branch[:8] = self._addrof(active) branch[branchpos + 8:branchpos + 10] = to_bytes(newpos, 2) self._delete_from_leaf(leaf, leafpos) return self._add_to_leaf(toadd, branch, branchpos, active, newpos, depth) @@ -516,88 +602,86 @@ def _add_to_leaf(self, toadd, branch, branchpos, leaf, leafpos, depth): # returns INVALIDATING, DONE, FULL def _add_to_leaf_inner(self, toadd, leaf, pos, depth): assert pos >= 0 - rpos = pos * 68 + 4 + node = leaf.get_node(pos) if get_bit(toadd, depth) == 0: - t = get_type(leaf, rpos) + t = node.get_type(0) if t == EMPTY: - leaf[rpos:rpos + 32] = toadd + node.set_hash(0, toadd) return INVALIDATING elif t == TERMINAL: - oldval0 = leaf[rpos:rpos + 32] + oldval0 = node.get_hash(0) if oldval0 == toadd: return DONE - t1 = get_type(leaf, rpos + 32) + t1 = node.get_type(1) if t1 == TERMINAL: - oldval1 = leaf[rpos + 32:rpos + 64] + oldval1 = node.get_hash(1) if toadd == oldval1: return DONE - nextpos = from_bytes(leaf[:2]) - leaf[:2] = to_bytes(pos, 2) - leaf[rpos:rpos + 64] = bytes(64) - leaf[rpos:rpos + 2] = to_bytes(nextpos, 2) + nextpos = leaf.get_next_ptr() + leaf.set_next_ptr(pos) + node.make_unused(nextpos) r, nextnextpos = self._insert_leaf([toadd, oldval0, oldval1], leaf, depth) if r == FULL: - leaf[:2] = to_bytes(nextpos, 2) - leaf[rpos:rpos + 32] = oldval0 - leaf[rpos + 32:rpos + 64] = oldval1 + leaf.set_next_ptr(nextpos) + node.set_hash(0, oldval0) + node.set_hash(1, oldval1) return FULL assert nextnextpos == pos return INVALIDATING r, newpos = self._insert_leaf([toadd, oldval0], leaf, depth + 1) if r == FULL: return FULL - leaf[rpos + 64:rpos + 66] = to_bytes(newpos + 1, 2) - make_invalid(leaf, rpos) - if get_type(leaf, rpos + 32) == INVALID: + node.set_pos(0, newpos) + node.make_invalid(0) + if node.get_type(1) == INVALID: return DONE return INVALIDATING else: - r = self._add_to_leaf_inner(toadd, leaf, from_bytes(leaf[rpos + 64:rpos + 66]) - 1, depth + 1) + r = self._add_to_leaf_inner(toadd, leaf, node.get_pos(0), depth + 1) if r == INVALIDATING: if t == MIDDLE: - make_invalid(leaf, rpos) + node.make_invalid(0) return INVALIDATING return DONE return r else: - t = get_type(leaf, rpos + 32) + t = node.get_type(1) if t == EMPTY: - leaf[rpos + 32:rpos + 64] = toadd + node.set_hash(1, toadd) return INVALIDATING elif t == TERMINAL: - oldval1 = leaf[rpos + 32:rpos + 64] + oldval1 = node.get_hash(1) if oldval1 == toadd: return DONE - t0 = get_type(leaf, rpos) + t0 = node.get_type(0) if t0 == TERMINAL: - oldval0 = leaf[rpos:rpos + 32] + oldval0 = node.get_hash(0) if toadd == oldval0: return DONE - nextpos = from_bytes(leaf[:2]) - leaf[:2] = to_bytes(pos, 2) - leaf[rpos:rpos + 64] = bytes(64) - leaf[rpos:rpos + 2] = to_bytes(nextpos, 2) + nextpos = leaf.get_next_ptr() + leaf.set_next_ptr(pos) + node.make_unused(nextpos) r, nextnextpos = self._insert_leaf([toadd, oldval0, oldval1], leaf, depth) if r == FULL: - leaf[:2] = to_bytes(nextpos, 2) - leaf[rpos:rpos + 32] = oldval0 - leaf[rpos + 32:rpos + 64] = oldval1 + leaf.set_next_ptr(nextpos) + node.set_hash(0, oldval0) + node.set_hash(1, oldval1) return FULL assert nextnextpos == pos return INVALIDATING r, newpos = self._insert_leaf([toadd, oldval1], leaf, depth + 1) if r == FULL: return FULL - leaf[rpos + 66:rpos + 68] = to_bytes(newpos + 1, 2) - make_invalid(leaf, rpos + 32) - if get_type(leaf, rpos) == INVALID: + node.set_pos(1, newpos) + node.make_invalid(1) + if node.get_type(0) == INVALID: return DONE return INVALIDATING else: - r = self._add_to_leaf_inner(toadd, leaf, from_bytes(leaf[rpos + 66:rpos + 68]) - 1, depth + 1) + r = self._add_to_leaf_inner(toadd, leaf, node.get_pos(1), depth + 1) if r == INVALIDATING: if t == MIDDLE: - make_invalid(leaf, rpos + 32) + node.make_invalid(1) return INVALIDATING return DONE return r @@ -607,78 +691,79 @@ def _add_to_leaf_inner(self, toadd, leaf, pos, depth): def _copy_between_leafs(self, fromleaf, toleaf, frompos): r, pos = self._copy_between_leafs_inner(fromleaf, toleaf, frompos) if r == DONE: - toleaf[2:4] = to_bytes(from_bytes(toleaf[2:4]) + 1, 2) - fromleaf[2:4] = to_bytes(from_bytes(fromleaf[2:4]) - 1, 2) + toleaf.set_inputs(toleaf.get_inputs() + 1) + fromleaf.set_inputs(fromleaf.get_inputs() - 1) return r, pos # returns state, newpos # state can be FULL, DONE def _copy_between_leafs_inner(self, fromleaf, toleaf, frompos): - topos = from_bytes(toleaf[:2]) + topos = toleaf.get_next_ptr() if topos == 0xFFFF: return FULL, None - rfrompos = 4 + frompos * 68 - rtopos = 4 + topos * 68 - toleaf[0:2] = toleaf[rtopos:rtopos + 2] - t0 = get_type(fromleaf, rfrompos) + from_node = fromleaf.get_node(frompos) + to_node = toleaf.get_node(topos) + toleaf.set_next_ptr(to_node.get_unused_ptr()) + t0 = from_node.get_type(0) lowpos = None highpos = None if t0 == MIDDLE or t0 == INVALID: - r, lowpos = self._copy_between_leafs_inner(fromleaf, toleaf, from_bytes(fromleaf[rfrompos + 64:rfrompos + 66]) - 1) + r, lowpos = self._copy_between_leafs_inner(fromleaf, toleaf, from_node.get_pos(0)) if r == FULL: - assert toleaf[:2] == toleaf[rtopos:rtopos + 2] - toleaf[:2] = to_bytes(topos, 2) + assert toleaf.get_next_ptr() == to_node.get_unused_ptr() + toleaf.set_next_ptr(topos) return FULL, None - t1 = get_type(fromleaf, rfrompos + 32) + t1 = from_node.get_type(1) if t1 == MIDDLE or t1 == INVALID: - r, highpos = self._copy_between_leafs_inner(fromleaf, toleaf, from_bytes(fromleaf[rfrompos + 66:rfrompos + 68]) - 1) + r, highpos = self._copy_between_leafs_inner(fromleaf, toleaf, from_node.get_pos(1)) if r == FULL: if t0 == MIDDLE or t0 == INVALID: self._delete_from_leaf(toleaf, lowpos) - assert toleaf[:2] == toleaf[rtopos:rtopos + 2] - toleaf[:2] = to_bytes(topos, 2) + assert toleaf.get_next_ptr() == to_node.get_unused_ptr() + toleaf.set_next_ptr(topos) return FULL, None - toleaf[rtopos:rtopos + 64] = fromleaf[rfrompos:rfrompos + 64] + to_node.set_hash(0, from_node.get_hash(0)) + to_node.set_hash(1, from_node.get_hash(1)) if lowpos is not None: - toleaf[rtopos + 64:rtopos + 66] = to_bytes(lowpos + 1, 2) + to_node.set_pos(0, lowpos) if highpos is not None: - toleaf[rtopos + 66:rtopos + 68] = to_bytes(highpos + 1, 2) + to_node.set_pos(1, highpos) return DONE, topos def _delete_from_leaf(self, leaf, pos): assert pos >= 0 rpos = 4 + pos * 68 + node = Node(leaf, rpos) t = get_type(leaf, rpos) if t == MIDDLE or t == INVALID: - self._delete_from_leaf(leaf, from_bytes(leaf[rpos + 64:rpos + 66]) - 1) + self._delete_from_leaf(leaf, node.get_pos(0)) t = get_type(leaf, rpos + 32) if t == MIDDLE or t == INVALID: - self._delete_from_leaf(leaf, from_bytes(leaf[rpos + 66:rpos + 68]) - 1) - leaf[rpos + 2:rpos + 68] = bytes(66) - leaf[rpos:rpos + 2] = leaf[:2] - leaf[:2] = to_bytes(pos, 2) + self._delete_from_leaf(leaf, node.get_pos(1)) + node.make_unused(leaf.get_next_ptr()) + leaf.set_next_ptr(pos) def _copy_leaf_to_branch(self, branch, branchpos, moddepth, leaf, leafpos): assert leafpos >= 0 - rleafpos = 4 + leafpos * 68 + node = leaf.get_node(leafpos) if moddepth == 0: - active = self._ref(branch[:8]) + active = self._deref(branch[:8]) if active is None: active = self._allocate_leaf() - branch[0:8] = self._deref(active) + branch[0:8] = self._addrof(active) r, newpos = self._copy_between_leafs_inner(leaf, active, leafpos) assert r == DONE - active[2:4] = to_bytes(from_bytes(active[2:4]) + 1, 2) - branch[branchpos:branchpos + 8] = self._deref(active) + active.set_inputs(active.get_inputs() + 1) + branch[branchpos:branchpos + 8] = self._addrof(active) branch[branchpos + 8:branchpos + 10] = to_bytes(newpos, 2) return - branch[branchpos:branchpos + 64] = leaf[rleafpos:rleafpos + 64] - t = get_type(leaf, rleafpos) + branch[branchpos:branchpos + 64] = node.get_hash(0) + node.get_hash(1) + t = node.get_type(0) if t == MIDDLE or t == INVALID: - self._copy_leaf_to_branch(branch, branchpos + 64, moddepth - 1, leaf, from_bytes(leaf[rleafpos + 64:rleafpos + 66]) - 1) - t = get_type(leaf, rleafpos + 32) + self._copy_leaf_to_branch(branch, branchpos + 64, moddepth - 1, leaf, node.get_pos(0)) + t = node.get_type(1) if t == MIDDLE or t == INVALID: - self._copy_leaf_to_branch(branch, branchpos + 64 + self.subblock_lengths[moddepth - 1], moddepth - 1, leaf, from_bytes(leaf[rleafpos + 66:rleafpos + 68]) - 1) + self._copy_leaf_to_branch(branch, branchpos + 64 + self.subblock_lengths[moddepth - 1], moddepth - 1, leaf, node.get_pos(1)) # returns (status, pos) # status can be INVALIDATING, FULL @@ -686,46 +771,46 @@ def _insert_leaf(self, things, leaf, depth): assert 2 <= len(things) <= 3 for thing in things: assert len(thing) == 32 - pos = from_bytes(leaf[:2]) + pos = leaf.get_next_ptr() if pos == 0xFFFF: return FULL, None - lpos = pos * 68 + 4 - leaf[:2] = leaf[lpos:lpos + 2] + node = leaf.get_node(pos) + leaf.set_next_ptr(node.get_unused_ptr_unsafe()) things.sort() if len(things) == 2: - leaf[lpos:lpos + 32] = things[0] - leaf[lpos + 32:lpos + 64] = things[1] + node.set_hash(0, things[0]) + node.set_hash(1, things[1]) return INVALIDATING, pos bits = [get_bit(thing, depth) for thing in things] if bits[0] == bits[1] == bits[2]: r, laterpos = self._insert_leaf(things, leaf, depth + 1) if r == FULL: - leaf[:2] = to_bytes(pos, 2) + leaf.set_next_ptr(pos) return FULL, None if bits[0] == 0: - leaf[lpos + 64:lpos + 66] = to_bytes(laterpos + 1, 2) - make_invalid(leaf, lpos) + node.set_pos(0, laterpos) + node.make_invalid(0) else: - leaf[lpos + 66:lpos + 68] = to_bytes(laterpos + 1, 2) - make_invalid(leaf, lpos + 32) - leaf[lpos:lpos + 2] = bytes(2) + node.set_pos(1, laterpos) + node.make_invalid(1) + node.make_unused_unsafe(0) return INVALIDATING, pos elif bits[0] == bits[1]: r, laterpos = self._insert_leaf([things[0], things[1]], leaf, depth + 1) if r == FULL: - leaf[:2] = to_bytes(pos, 2) + leaf.set_next_ptr(pos) return FULL, None - leaf[lpos + 32:lpos + 64] = things[2] - leaf[lpos + 64:lpos + 66] = to_bytes(laterpos + 1, 2) - make_invalid(leaf, lpos) + node.set_hash(1, things[2]) + node.set_pos(0, laterpos) + node.make_invalid(0) else: r, laterpos = self._insert_leaf([things[1], things[2]], leaf, depth + 1) if r == FULL: - leaf[:2] = to_bytes(pos, 2) + leaf.set_next_ptr(pos) return FULL, None - leaf[lpos:lpos + 32] = things[0] - leaf[lpos + 66:lpos + 68] = to_bytes(laterpos + 1, 2) - make_invalid(leaf, lpos + 32) + node.set_hash(0, things[0]) + node.set_pos(1, laterpos) + node.make_invalid(1) return INVALIDATING, pos # Convenience function @@ -733,7 +818,7 @@ def remove(self, toremove): return self.remove_already_hashed(sha256(toremove).digest()) def remove_already_hashed(self, toremove): - return self._remove(flip_terminal(toremove)) + return self._remove(set_terminal(toremove)) def _remove(self, toremove): t = get_type(self.root, 0) @@ -772,9 +857,9 @@ def _remove_branch_inner(self, toremove, block, pos, depth, moddepth): return NOTSTARTED, None p = from_bytes(block[pos + 8:pos + 10]) if p == 0xFFFF: - r, val = self._remove_branch(toremove, self._ref(block[pos:pos + 8]), depth) + r, val = self._remove_branch(toremove, self._deref(block[pos:pos + 8]), depth) else: - r, val = self._remove_leaf(toremove, self._ref(block[pos:pos + 8]), p, depth, block) + r, val = self._remove_leaf(toremove, self._deref(block[pos:pos + 8]), p, depth, block) if r == ONELEFT: block[pos:pos + 10] = bytes(10) return r, val @@ -901,127 +986,128 @@ def _remove_branch_inner(self, toremove, block, pos, depth, moddepth): def _remove_leaf(self, toremove, block, pos, depth, branch): result, val = self._remove_leaf_inner(toremove, block, pos, depth) if result == ONELEFT: - numin = from_bytes(block[2:4]) + numin = block.get_inputs() if numin == 1: - self._deallocate(block) - if branch[:8] == self._deref(block): + if branch[:8] == self._addrof(block): branch[:8] = bytes(8) + self._deallocate(block) else: - block[2:4] = to_bytes(numin - 1, 2) + block.set_inputs(numin - 1) return result, val def _deallocate_leaf_node(self, leaf, pos): assert pos >= 0 rpos = 4 + pos * 68 - next = leaf[:2] - leaf[rpos:rpos + 2] = leaf[:2] - leaf[rpos + 2:rpos + 68] = bytes(66) - leaf[:2] = to_bytes(pos, 2) + node = Node(leaf, rpos) + next_entry = leaf.get_next_ptr() + target_node = Node(leaf, rpos) + target_node.make_unused(next_entry) + leaf.set_next_ptr(pos) # returns (status, oneval) # status can be ONELEFT, FRAGILE, INVALIDATING, DONE def _remove_leaf_inner(self, toremove, block, pos, depth): assert pos >= 0 - rpos = 4 + pos * 68 + node = block.get_node(pos) if get_bit(toremove, depth) == 0: - t = get_type(block, rpos) + t = node.get_type(0) if t == EMPTY: return DONE, None if t == TERMINAL: - t1 = get_type(block, rpos + 32) - if block[rpos:rpos + 32] == toremove: + t1 = node.get_type(1) + if node.get_hash(0) == toremove: if t1 == TERMINAL: - left = block[rpos + 32:rpos + 64] + left = node.get_hash(1) self._deallocate_leaf_node(block, pos) return ONELEFT, left - block[rpos:rpos + 32] = bytes(32) + node.set_hash(0, bytes(32)) return FRAGILE, None - if block[rpos + 32:rpos + 64] == toremove: - left = block[rpos:rpos + 32] + if node.get_hash(1) == toremove: + left = node.get_hash(0) self._deallocate_leaf_node(block, pos) return ONELEFT, left return DONE, None else: - r, val = self._remove_leaf_inner(toremove, block, from_bytes(block[rpos + 64:rpos + 66]) - 1, depth + 1) + r, val = self._remove_leaf_inner(toremove, block, node.get_pos(0), depth + 1) if r == DONE: return DONE, None if r == INVALIDATING: if t == MIDDLE: - make_invalid(block, rpos) - if get_type(block, rpos + 32) != INVALID: + node.make_invalid(0) + if node.get_type(1) != INVALID: return INVALIDATING, None return DONE, None if r == ONELEFT: - t1 = get_type(block, rpos + 32) + t1 = node.get_type(1) assert t1 != EMPTY - block[rpos:rpos + 32] = val - block[rpos + 64:rpos + 66] = bytes(2) + node.set_hash(0, val) + node.set_pos(0, -1) if t1 == TERMINAL: return FRAGILE, None if t != INVALID and t1 != INVALID: return INVALIDATING, None return DONE, None assert r == FRAGILE - t1 = get_type(block, rpos + 32) + t1 = node.get_type(1) if t1 == EMPTY: if t != INVALID: - make_invalid(block, rpos) + node.make_invalid(0) return FRAGILE, None - self._catch_leaf(block, from_bytes(block[rpos + 64:rpos + 66]) - 1) + self._catch_leaf(block, node.get_pos(0)) if t == INVALID: return DONE, None - make_invalid(block, rpos) + node.make_invalid(0) if t1 == INVALID: return DONE, None return INVALIDATING, None else: - t = get_type(block, rpos + 32) + t = node.get_type(1) if t == EMPTY: return DONE, None elif t == TERMINAL: - t0 = get_type(block, rpos) - if block[rpos + 32:rpos + 64] == toremove: + t0 = node.get_type(0) + if node.get_hash(1) == toremove: if t0 == TERMINAL: - left = block[rpos:rpos + 32] + left = node.get_hash(0) self._deallocate_leaf_node(block, pos) return ONELEFT, left - block[rpos + 32:rpos + 64] = bytes(32) + node.set_hash(1, bytes(32)) return FRAGILE, None - if block[rpos:rpos + 32] == toremove: - left = block[rpos + 32:rpos + 64] + if node.get_hash(0) == toremove: + left = node.get_hash(1) self._deallocate_leaf_node(block, pos) return ONELEFT, left return DONE, None else: - r, val = self._remove_leaf_inner(toremove, block, from_bytes(block[rpos + 66:rpos + 68]) - 1, depth + 1) + r, val = self._remove_leaf_inner(toremove, block, node.get_pos(1), depth + 1) if r == DONE: return DONE, None if r == INVALIDATING: if t == MIDDLE: - make_invalid(block, rpos + 32) - if get_type(block, rpos) != INVALID: + node.make_invalid(1) + if node.get_type(0) != INVALID: return INVALIDATING, None return DONE, None if r == ONELEFT: - t0 = get_type(block, rpos) + t0 = node.get_type(0) assert t0 != EMPTY - block[rpos + 32:rpos + 64] = val - block[rpos + 66:rpos + 68] = bytes(2) + node.set_hash(1, val) + node.set_pos(1, -1) if t0 == TERMINAL: return FRAGILE, None if t != INVALID and t0 != INVALID: return INVALIDATING, None return DONE, None assert r == FRAGILE - t0 = get_type(block, rpos) + t0 = node.get_type(0) if t0 == EMPTY: if t != INVALID: - make_invalid(block, rpos + 32) + node.make_invalid(1) return FRAGILE, None - self._catch_leaf(block, from_bytes(block[rpos + 66:rpos + 68]) - 1) - if get_type(block, rpos + 32) == INVALID: + self._catch_leaf(block, node.get_pos(1)) + if node.get_type(1) == INVALID: return DONE, None - make_invalid(block, rpos + 32) + node.make_invalid(1) if t0 == INVALID: return DONE, None return INVALIDATING, None @@ -1030,9 +1116,9 @@ def _catch_branch(self, block, pos, moddepth): if moddepth == 0: leafpos = from_bytes(block[pos + 8:pos + 10]) if leafpos == 0xFFFF: - self._catch_branch(self._ref(block[pos:pos + 8]), 8, len(self.subblock_lengths) - 1) + self._catch_branch(self._deref(block[pos:pos + 8]), 8, len(self.subblock_lengths) - 1) else: - self._catch_leaf(self._ref(block[pos:pos + 8]), leafpos) + self._catch_leaf(self._deref(block[pos:pos + 8]), leafpos) return if get_type(block, pos) == EMPTY: assert get_type(block, pos + 32) != TERMINAL @@ -1058,9 +1144,9 @@ def _collapse_branch_inner(self, block, pos, moddepth): if moddepth == 0: leafpos = from_bytes(block[pos + 8:pos + 10]) if leafpos == 0xFFFF: - r = self._collapse_branch(self._ref(block[pos:pos + 8])) + r = self._collapse_branch(self._deref(block[pos:pos + 8])) else: - r = self._collapse_leaf(self._ref(block[pos:pos + 8]), from_bytes(block[pos + 8:pos + 10]), block) + r = self._collapse_leaf(self._deref(block[pos:pos + 8]), from_bytes(block[pos + 8:pos + 10]), block) if r != None: block[pos:pos + 10] = bytes(10) return r @@ -1085,19 +1171,21 @@ def _collapse_branch_inner(self, block, pos, moddepth): def _catch_leaf(self, leaf, pos): assert pos >= 0 rpos = 4 + pos * 68 + node = Node(leaf, rpos) t0 = get_type(leaf, rpos) t1 = get_type(leaf, rpos + 32) if t0 == EMPTY: - r = self._collapse_leaf_inner(leaf, from_bytes(leaf[rpos + 66:rpos + 68]) - 1) + r = self._collapse_leaf_inner(leaf, node.get_pos(1)) if r != None: leaf[rpos + 66:rpos + 68] = bytes(2) leaf[rpos:rpos + 64] = r return if t1 == EMPTY: - r = self._collapse_leaf_inner(leaf, from_bytes(leaf[rpos + 64:rpos + 66]) - 1) + r = self._collapse_leaf_inner(leaf, node.get_pos(0)) if r != None: - leaf[rpos + 64:rpos + 66] = bytes(2) - leaf[rpos:rpos + 64] = r + node.set_pos(0, -1) + node.set_hash(0, r[:32]) + node.set_hash(1, r[32:]) return # returns two hashes string or None @@ -1105,33 +1193,34 @@ def _collapse_leaf(self, leaf, pos, branch): assert pos >= 0 r = self._collapse_leaf_inner(leaf, pos) if r != None: - inputs = from_bytes(leaf[2:4]) + inputs = leaf.get_inputs() if inputs == 1: - self._deallocate(leaf) - if branch[:8] == self._deref(leaf): + if branch[:8] == self._addrof(leaf): branch[:8] = bytes(8) + self._deallocate(leaf) return r - leaf[2:4] = to_bytes(inputs - 1, 2) + leaf.set_inputs(inputs - 1) return r # returns two hashes string or None def _collapse_leaf_inner(self, leaf, pos): assert pos >= 0 rpos = 4 + pos * 68 + node = Node(leaf, rpos) t0 = get_type(leaf, rpos) t1 = get_type(leaf, rpos + 32) r = None if t0 == TERMINAL and t1 == TERMINAL: - r = leaf[rpos:rpos + 64] + r = node.get_hash(0) + node.get_hash(1) elif t0 == EMPTY: - r = self._collapse_leaf_inner(leaf, from_bytes(leaf[rpos + 66:rpos + 68]) - 1) + r = self._collapse_leaf_inner(leaf, node.get_pos(1)) elif t1 == EMPTY: - r = self._collapse_leaf_inner(leaf, from_bytes(leaf[rpos + 64:rpos + 66]) - 1) + r = self._collapse_leaf_inner(leaf, node.get_pos(0)) if r is not None: # this leaf node is being collapsed, deallocate it - leaf[rpos + 2:rpos + 68] = bytes(66) - leaf[rpos:rpos + 2] = leaf[:2] - leaf[:2] = to_bytes(pos, 2) + next_entry = leaf.get_next_ptr() + node.make_unused(next_entry) + leaf.set_next_ptr(pos) return r # Convenience function @@ -1140,7 +1229,7 @@ def is_included(self, tocheck): # returns (boolean, proof string) def is_included_already_hashed(self, tocheck): - return self._is_included(flip_terminal(tocheck)) + return self._is_included(set_terminal(tocheck)) # returns (boolean, proof string) def _is_included(self, tocheck): @@ -1159,9 +1248,9 @@ def _is_included(self, tocheck): def _is_included_branch(self, tocheck, block, pos, depth, moddepth, buf): if moddepth == 0: if block[pos + 8:pos + 10] == bytes([0xFF, 0xFF]): - return self._is_included_branch(tocheck, self._ref(block[pos:pos + 8]), 8, depth, len(self.subblock_lengths) - 1, buf) + return self._is_included_branch(tocheck, self._deref(block[pos:pos + 8]), 8, depth, len(self.subblock_lengths) - 1, buf) else: - return self._is_included_leaf(tocheck, self._ref(block[pos:pos + 8]), from_bytes(block[pos + 8:pos + 10]), depth, buf) + return self._is_included_leaf(tocheck, self._deref(block[pos:pos + 8]), from_bytes(block[pos + 8:pos + 10]), depth, buf) buf.append(bytes([MIDDLE])) if block[pos:pos + 32] == tocheck or block[pos + 32:pos + 64] == tocheck: _finish_proof(block[pos:pos + 64], depth, buf) @@ -1188,27 +1277,28 @@ def _is_included_branch(self, tocheck, block, pos, depth, moddepth, buf): def _is_included_leaf(self, tocheck, block, pos, depth, buf): assert pos >= 0 pos = 4 + pos * 68 + node = Node(block, pos) buf.append(bytes([MIDDLE])) - if block[pos:pos + 32] == tocheck or block[pos + 32:pos + 64] == tocheck: - _finish_proof(block[pos:pos + 64], depth, buf) + if node.get_hash(0) == tocheck or node.get_hash(1) == tocheck: + _finish_proof(node.get_hash(0) + node.get_hash(1), depth, buf) return True if get_bit(tocheck, depth) == 0: t = get_type(block, pos) if t == EMPTY or t == TERMINAL: - _finish_proof(block[pos:pos + 64], depth, buf) + _finish_proof(node.get_hash(0) + node.get_hash(1), depth, buf) return False assert t == MIDDLE - r = self._is_included_leaf(tocheck, block, from_bytes(block[pos + 64:pos + 66]) - 1, depth + 1, buf) - buf.append(_quick_summary(block[pos + 32:pos + 64])) + r = self._is_included_leaf(tocheck, block, node.get_pos(0), depth + 1, buf) + buf.append(_quick_summary(node.get_hash(1))) return r else: t = get_type(block, pos + 32) if t == EMPTY or t == TERMINAL: - _finish_proof(block[pos:pos + 64], depth, buf) + _finish_proof(node.get_hash(0) + node.get_hash(1), depth, buf) return False assert t == MIDDLE - buf.append(_quick_summary(block[pos:pos + 32])) - return self._is_included_leaf(tocheck, block, from_bytes(block[pos + 66:pos + 68]) - 1, depth + 1, buf) + buf.append(_quick_summary(node.get_hash(0))) + return self._is_included_leaf(tocheck, block, node.get_pos(1), depth + 1, buf) def _finish_proof(val, depth, buf): v0 = val[:32] @@ -1236,4 +1326,4 @@ def _quick_summary(val): if t == TERMINAL: return val assert t == MIDDLE - return flip_invalid(val) + return set_invalid(val) diff --git a/ReferenceMerkleSet.py b/ReferenceMerkleSet.py index a2c5227..4be94eb 100644 --- a/ReferenceMerkleSet.py +++ b/ReferenceMerkleSet.py @@ -37,15 +37,15 @@ BLANK = bytes([0] * 32) -def flip_terminal(mystr): +def set_terminal(mystr): assert len(mystr) == 32 return bytes([TERMINAL | (mystr[0] & 0x3F)]) + mystr[1:] -def flip_middle(mystr): +def set_middle(mystr): assert len(mystr) == 32 return bytes([MIDDLE | (mystr[0] & 0x3F)]) + mystr[1:] -def flip_invalid(mystr): +def set_invalid(mystr): assert len(mystr) == 32 return bytes([INVALID | (mystr[0] & 0x3F)]) + mystr[1:] @@ -68,13 +68,13 @@ def get_root(self): return self.root.hash def add_already_hashed(self, toadd): - self.root = self.root.add(flip_terminal(toadd), 0) + self.root = self.root.add(set_terminal(toadd), 0) def remove_already_hashed(self, toremove): - self.root = self.root.remove(flip_terminal(toremove), 0) + self.root = self.root.remove(set_terminal(toremove), 0) def is_included_already_hashed(self, tocheck): - tocheck = flip_terminal(tocheck) + tocheck = set_terminal(tocheck) p = [] r = self.root.is_included(tocheck, 0, p) return r, b''.join(p) @@ -83,7 +83,7 @@ def audit(self, hashes): newhashes = [] self.root.audit(newhashes, []) assert newhashes == sorted(newhashes) - assert newhashes == sorted([flip_terminal(x) for x in hashes]) + assert newhashes == sorted([set_terminal(x) for x in hashes]) class EmptyNode: def __init__(self): @@ -228,7 +228,7 @@ def is_included(self, tocheck, depth, p): def other_included(self, tocheck, depth, p, collapse): if collapse or not self.is_double(): - p.append(flip_invalid(self.hash)) + p.append(set_invalid(self.hash)) else: self.is_included(tocheck, depth, p) @@ -253,7 +253,7 @@ def is_included(self, tocheck, depth, p): raise SetError() def other_included(self, tocheck, depth, p, collapse): - p.append(flip_invalid(self.hash)) + p.append(set_invalid(self.hash)) class SetError(BaseException): pass @@ -298,7 +298,7 @@ def _deserialize(proof, pos, bits): if t == TERMINAL: return TerminalNode(proof[pos:pos + 32], bits), pos + 32 if t == INVALID: - return UnknownNode(flip_middle(proof[pos:pos + 32])), pos + 32 + return UnknownNode(set_middle(proof[pos:pos + 32])), pos + 32 if proof[pos] != MIDDLE: raise SetError() v0, pos = _deserialize(proof, pos + 1, bits + [0])