From 5e15682549623e2f7aacadb2177bb22567c3a799 Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Sun, 23 Feb 2025 20:34:47 +0700 Subject: [PATCH 01/10] add fork-resistant gossip-DAG protocol# --- .gitignore | 3 + dagpool/consensus_client.py | 959 ++++++++++++++++++++++++++++++++++++ dagpool/schemas.py | 7 + 3 files changed, 969 insertions(+) create mode 100644 dagpool/consensus_client.py create mode 100644 dagpool/schemas.py diff --git a/.gitignore b/.gitignore index 2a39912..75f3dac 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +**/__pycache__/** +**/.DS_Store \ No newline at end of file diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py new file mode 100644 index 0000000..c6452ea --- /dev/null +++ b/dagpool/consensus_client.py @@ -0,0 +1,959 @@ +import asyncio +import random +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Set, Optional, List, Tuple +import time +from queue import Queue +import random +import networkx as nx +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +import hashlib +from schemas import TransactionId, PeerId, NodeId, NodeLabel +import json +from math import log2, ceil # Added log2 and ceil imports + +class Node: + def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs_list: list[TransactionId], self_parent_hash: NodeId, cross_parent_hash: NodeId): + self.peer_id = peer_id + self.is_witness = is_witness + self.round = round + self.self_parent_hash = self_parent_hash + self.cross_parent_hash = cross_parent_hash + # TODO: migrate to Sparse Merkle Tree + proof of SMT transition + self.newly_seen_txs_list = newly_seen_txs_list + self.node_hash = self.hash_node(peer_id, round, is_witness, self_parent_hash, cross_parent_hash, newly_seen_txs_list) + ## fork-related data + self.equivocated_peers: Set[PeerId] = None # set of peers that current node believes are equivocated, and this node won't SEE (i.e. UNSEE) all nodes created by them. Note that this doesn't affect STRONGLY SEEING property of this node. + self.seen_nodes: Set[NodeId] = None # set of nodes that current node sees + + def clone(self): + return Node(self.peer_id, self.round, self.is_witness, [txs for txs in self.newly_seen_txs_list], self.self_parent_hash, self.cross_parent_hash) + + def label(self) -> NodeLabel: + return f"{self.peer_id}:{self.node_hash}" + + @staticmethod + def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: + assert len(txs) > 0 + # do the merkle tree construction using a while loop + res = [tx for tx in txs] + while len(res) > 1: + new_res = [] + for i in range(0, len(res), 2): + if i+1 < len(res): + new_res.append(hashlib.sha256(f"{res[i]}{res[i+1]}".encode()).hexdigest()) + else: + new_res.append(res[i]) + res = new_res + return res[0] + + @staticmethod + def hash_node(creator: PeerId, round: int, is_witness: bool, self_parent_hash: NodeId, cross_parent_hash: NodeId, newly_seen_txs_list: list[TransactionId]) -> NodeId: + """Create deterministic hash for a node""" + components = [creator, str(round), str(is_witness)] + if cross_parent_hash: + components.append(cross_parent_hash) + if self_parent_hash: + components.append(self_parent_hash) + if newly_seen_txs_list: + components.append(Node.merkle_root_of_transaction_list(newly_seen_txs_list)) + return hashlib.sha256(''.join(components).encode()).hexdigest()[:8] # the hash value of a node basically depends deterministically on all of its content + + def verify_node_hash(self) -> bool: + return self.node_hash == Node.hash_node(self.peer_id, self.round, self.is_witness, self.self_parent_hash, self.cross_parent_hash, self.newly_seen_txs_list) + + def is_genesis(self): + is_genesis = self.round == 0 and self.is_witness == True and self.self_parent_hash == "" and self.cross_parent_hash == "" and self.verify_node_hash() + return is_genesis + + def is_non_genesis(self): + return not self.is_genesis() + + def __str__(self): + return f"{"GENESIS " if self.is_genesis() else ""}Node(node_hash={self.node_hash}, peer_id={self.peer_id}, round={self.round}, is_witness={self.is_witness}, self_parent_hash={self.self_parent_hash}, cross_parent_hash={self.cross_parent_hash}, newly_seen_txs_list={self.newly_seen_txs_list})" # , equivocated_peers={self.equivocated_peers}, seen_nodes={self.seen_nodes})" + +class ConnectionState(Enum): + CLOSED = 0 + SYN_SENT = 1 + SYN_RECEIVED = 2 + ESTABLISHED = 3 + FIN_WAIT = 4 + +@dataclass +class Checkpoint: + timestamp: float + nodes: Dict[PeerId, Node] # peer_id -> Node mapping + accumulated_txs: Set[TransactionId] # All transactions up to this checkpoint + + def verify_dag_structure(self) -> bool: + """Verify that the checkpoint forms a valid DAG structure""" + for node in self.nodes.values(): + if node.self_parent_hash and node.self_parent_hash.peer_id != node.peer_id: + return False + if node.self_parent_hash and node.self_parent_hash.round >= node.round: + return False + if node.cross_parent_hash and node.cross_parent_hash.round >= node.round: + return False + return True + +class NetworkSimulator: + def __init__(self, latency_ms_range=(50, 200), packet_loss_prob=0.1, random_instance: random.Random=random.Random(0)): + self.connections: Dict[tuple, ConnectionState] = {} + self.latency_range = latency_ms_range + self.packet_loss_prob = packet_loss_prob + self.random_instance = random_instance + # this NetworkSimulator also works like a "beacon chain" to manage the common knowledge of all peers + self.peers: List['ConsensusPeer'] = [] + # Checkpoint-related attributes + self.checkpoints: List[Checkpoint] = [] + self.genesis_checkpoint: Optional[Checkpoint] = None + self.last_checkpoint_round = 0 + + # global mempool: simulate a global mempool of all transactions from all clients + self.global_mempool = set() + + def register_peer(self, peer: 'ConsensusPeer'): + # peer_id must be unique + assert peer.peer_id not in [p.peer_id for p in self.peers] + self.peers.append(peer) + + def get_all_peer_ids(self) -> List[PeerId]: + return [p.peer_id for p in self.peers] + + def unregister_peer(self, peer: 'ConsensusPeer'): + assert peer in self.peers + self.peers.remove(peer) + + async def register_genesis_nodes(self): + """Register genesis nodes from all peers as the first checkpoint""" + # Create genesis nodes and initial gossip + genesis_nodes: Dict[PeerId, Node] = {} + for peer1 in self.peers: + genesis_node = peer1.create_genesis_node() + + genesis_nodes[peer1.peer_id] = genesis_node + for peer2 in self.peers: + if peer2.peer_id == peer1.peer_id: + continue + + cloned_genesis_node = genesis_node.clone() # simulate the process of serializing and deserializing the nodes in internet protocols + + # Gossip genesis node to neighbors + success = await self.gossip_send_node_and_ancestry(peer1.peer_id, peer2.peer_id, cloned_genesis_node) + if not success: + print(f"peer {peer1.peer_id} gossiped send to {peer2.peer_id} genesis node {genesis_node.node_hash} failed") + else: + print(f"peer {peer1.peer_id} gossiped send to {peer2.peer_id} genesis node {genesis_node.node_hash} successfully") + + self.genesis_checkpoint = Checkpoint( + timestamp=time.time(), + nodes=genesis_nodes.copy(), + accumulated_txs=set() # Empty at genesis + ) + self.checkpoints.append(self.genesis_checkpoint) + + # TODO: implement checkpoint mechanism + + def new_txs_from_user_client(self) -> (list[TransactionId], list['ConsensusPeer']): + """Pick a random transaction from the global mempool that will be sent to random peers""" + num_txs = self.random_instance.randint(1, 10) + + mempool_txs = sorted(self.global_mempool) + txs = [] + for _ in range(num_txs): + # 50% pick a random txs already in the global mempool + if len(mempool_txs) > 0 and self.random_instance.random() < 0.5: + txs.append(self.random_instance.choice(mempool_txs)) + else: + # 50% add a new txs + new_txs = f"tx_{len(self.global_mempool) + 1}" + txs.append(new_txs) + self.global_mempool.add(new_txs) + + # pick random peers from the network to send the txs to + # choose ⌈log2(log2(N))⌉ peers on average (because usually a client only sends to 1 rpc portal) + num_peers = len(self.peers) + num_peers_to_send = max(1, ceil(log2(log2(num_peers)))) + peer_ids_to_send = self.random_instance.sample([p.peer_id for p in self.peers], num_peers_to_send) + + return txs, [peer for peer in self.peers if peer.peer_id in peer_ids_to_send] + + def get_accumulated_txs_until_node(self, node: Node) -> Set[TransactionId]: + """Get all transactions accumulated up to a specific node""" + # First, find the latest checkpoint before this node + latest_applicable_checkpoint = None + for checkpoint in reversed(self.checkpoints): + if checkpoint.round < node.round: + latest_applicable_checkpoint = checkpoint + break + + accumulated_txs = set() + if latest_applicable_checkpoint: + accumulated_txs.update(latest_applicable_checkpoint.accumulated_txs) + + # Add transactions from the node and its ancestors back to the checkpoint + def collect_txs(current_node): + if not current_node or (latest_applicable_checkpoint and + current_node.round <= latest_applicable_checkpoint.round): + return + accumulated_txs.update(current_node.newly_seen_txs_list) + if current_node.self_parent_hash: + collect_txs(current_node.self_parent_hash) + if current_node.cross_parent_hash: + collect_txs(current_node.cross_parent_hash) + + collect_txs(node) + return accumulated_txs + + async def connect(self, peer_a: PeerId, peer_b: PeerId): + """Simulate TCP three-way handshake""" + conn_key = (peer_a, peer_b) + + # SYN + if self.random_instance.random() > self.packet_loss_prob: + self.connections[conn_key] = ConnectionState.SYN_SENT + await self._delay() + + # SYN-ACK + if self.random_instance.random() > self.packet_loss_prob: + self.connections[conn_key] = ConnectionState.SYN_RECEIVED + await self._delay() + # ACK + if self.random_instance.random() > self.packet_loss_prob: + self.connections[conn_key] = ConnectionState.ESTABLISHED + + ## update neighbors + ## TODO: move this to a better place + peer_a_peer: ConsensusPeer = [p for p in self.peers if p.peer_id == peer_a][0] + peer_b_peer: ConsensusPeer = [p for p in self.peers if p.peer_id == peer_b][0] + + if peer_a_peer.peer_id not in peer_b_peer.neighbors: + peer_b_peer.neighbors.append(peer_a_peer.peer_id) + if peer_b_peer.peer_id not in peer_a_peer.neighbors: + peer_a_peer.neighbors.append(peer_b_peer.peer_id) + + return True + + self.connections[conn_key] = ConnectionState.CLOSED + return False + + def is_connected(self, peer_a: PeerId, peer_b: PeerId) -> bool: + conn_key = (peer_a, peer_b) + conn_key_rev = (peer_b, peer_a) + return self.connections.get(conn_key, ConnectionState.CLOSED) == ConnectionState.ESTABLISHED or self.connections.get(conn_key_rev, ConnectionState.CLOSED) == ConnectionState.ESTABLISHED + + async def disconnect(self, peer_a: PeerId, peer_b: PeerId): + """Simulate TCP connection termination""" + conn_key = (peer_a, peer_b) + if conn_key in self.connections: + self.connections[conn_key] = ConnectionState.FIN_WAIT + await self._delay() + self.connections.pop(conn_key) + + async def gossip_send_node_and_ancestry(self, sender: PeerId, receiver: PeerId, node1: Node) -> bool: + """Gossip a node and its ancestry to a peer""" + assert sender != receiver + + for _ in range(3): + if not self.is_connected(sender, receiver): + if _ < 2: + await self.connect(sender, receiver) + else: + print(f"peer {sender} can't connect to {receiver} even after 2 attempts") + return False # can't help any more + + ### TODO: migrate this to efficient proof-based gossip where sender does not have to send the whole ancestry of node1 backwards to node2 + + sender_peer = [p for p in self.peers if p.peer_id == sender][0] + receiver_peer = [p for p in self.peers if p.peer_id == receiver][0] + + receiver_has_all_needed_ancestors = False + all_received_nodes = [] + current_gossiped_list = [node1] # we don't send the ancestry of the last node of the sender because node1 might be an equivocated node + while not receiver_has_all_needed_ancestors: + # receiver receives nodes from sender + receiver_has_all_needed_ancestors = True + + new_gossiped_list = [] + for node in current_gossiped_list: + if receiver_peer.has_seen_valid_node(node): + continue + else: + all_received_nodes.append(node) + receiver_has_all_needed_ancestors = False # not stop + + if not node.is_genesis(): + self_parent_node = sender_peer.get_node_by_hash(node.self_parent_hash) + cross_parent_node = sender_peer.get_node_by_hash(node.cross_parent_hash) + assert self_parent_node is not None and cross_parent_node is not None + + # use clone() to simulate the process of serializing and deserializing the nodes in internet protocols + new_gossiped_list.append(self_parent_node.clone()) + new_gossiped_list.append(cross_parent_node.clone()) + + current_gossiped_list = new_gossiped_list + + ### reverse all_received_nodes because the early nodes in the ancestry are in the right of the list, but we want them to be in the left + all_received_nodes = all_received_nodes[::-1] + + for i in range(len(all_received_nodes)): + current_node = all_received_nodes[i] + if not receiver_peer.verify_node_and_add_to_local_view(current_node): + continue + + return receiver_peer.has_seen_valid_node(node1) + + async def _delay(self): + """Simulate network latency""" + delay = 0 + # TODO: turn on delay (a value in the self.latency_range range) for full simulation + await asyncio.sleep(delay) + + def find_node_by_hash(self, node_hash: str) -> Optional[Node]: + """Find a node by its hash across all peers""" + if not node_hash: + return None + for peer in self.peers: + for node in peer.my_nodes(): + if node.node_hash == node_hash: + return node + return None + +class ConsensusPeer: + def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: NetworkSimulator): + self.peer_id = peer_id + self.is_adversary = is_adversary + self.random_instance = random.Random(seed) + self.current_round = 0 + self.seen_valid_nodes: Dict[PeerId, List[Node]] = {} + self.accumulated_txs = set() # Track all transactions seen by this peer + self.network = network + ## adversary-related data + self.equivocated_peers: Set[PeerId] = set() # set of peers that current peer believes they actively create equivocated nodes + self.equivocated_nodes: Set[NodeId] = set() # set of nodes that current peer believes are equivocated + self.equivocation_prob = 0.5 if is_adversary else 0.0 + self.neighbors: list[PeerId] = [] # Track neighboring peers + # if each peer connects to log(N) neighbors, a transaction would takes O(log(N)/log(log(N))) gossip hops to reach the whole network + # for N = 10^6, it would be 7 hops + self.tx_receive_times = {} # Track when transactions were received + self.pending_txs: List[Tuple[TransactionId, float]] = [] + self.local_graph: Dict[NodeId, Set[NodeId]] = {} # node_hash to set of node_hashes (parent, child) + + def my_nodes(self) -> List[Node]: + return self.seen_valid_nodes[self.peer_id] + + def get_my_last_node(self): + return self.my_nodes()[-1] + + def get_seen_valid_nodes(self) -> List[Node]: + res = [] + for peer_id in self.seen_valid_nodes: + res.extend(self.seen_valid_nodes[peer_id]) + return res + + def count_seen_valid_nodes(self): + return len(self.get_seen_valid_nodes()) + + def get_predecessors(self, node: Node) -> list[Node]: + if node.is_genesis(): + return [] + + predecessors = [] + self_parent_node = self.get_node_by_hash(node.self_parent_hash) + cross_parent_node = self.get_node_by_hash(node.cross_parent_hash) + assert self_parent_node is not None and cross_parent_node is not None + predecessors.append(self_parent_node) + predecessors.append(cross_parent_node) + return predecessors + + def get_successors(self, node: Node) -> list[Node]: + successors = [] + for successor in self.local_graph.get(node.node_hash, set()): + successors.append(self.get_node_by_hash(successor)) + return sorted(successors, key=lambda x: x.round) + + def get_ancestry(self, node: Node) -> Set[Node]: + ancestry: Dict[NodeId, Node] = {} + def dfs_backwards(node: Node): + if node.node_hash in ancestry: + return + ancestry[node.node_hash] = node + for predecessor in self.get_predecessors(node): + dfs_backwards(predecessor) + dfs_backwards(node) + return set(ancestry.values()) + + def get_lineage(self, node: Node) -> Set[Node]: + lineage: Dict[NodeId, Node] = {} + def dfs_forwards(node: Node): + if node.node_hash in lineage: + return + lineage[node.node_hash] = node + for successor in self.get_successors(node): + dfs_forwards(successor) + dfs_forwards(node) + return set(lineage.values()) + + def get_graph_info(self) -> Dict[NodeId, Node]: + res: Dict[NodeId, Node] = {} + for peer_id in self.seen_valid_nodes: + for node in self.seen_valid_nodes[peer_id]: + res[node.node_hash] = node + + return res + + def visualize_view(self): + print(f"peer {self.peer_id} sees the DAG:") + for peer_id in self.seen_valid_nodes: + for node in self.seen_valid_nodes[peer_id]: + print(f"{node}") + + # pos = {} + round_colors = { + 0: '#8b0000', # Dark red for round 0 + 1: '#ff6600', # Orange for round 1 + 2: '#b7950b', # Yellow for round 2 + 3: '#00cc00', # Green for round 3 + 4: '#0066cc', # Blue for round 4 + 5: '#6600cc', # Purple for round 5 + 6: '#008080', # Teal for round 6 + 7: '#CD5C5C', # IndianRed for round 7 + 8: '#DE3163', # #DE3163 for round 8 + 9: '#800080', # #800080 for round 9 + } + + # Assign positions to nodes + peers = self.network.peers + node_positions: Dict[NodeId, Tuple[int, int]] = {} + seen_valid_nodes = self.get_seen_valid_nodes() + + count_lines = 0 + lines_of_peers: Dict[PeerId, list[int]] = {} + count_self_direct_children = {} # including the equivocated self direct children + + # initial line id for each peer and its nodes + for node in seen_valid_nodes: + peer_id = node.peer_id + line_id = -1 + if peer_id not in lines_of_peers: + count_lines += 1 + lines_of_peers[peer_id] = [count_lines] + line_id = count_lines + else: + line_id = lines_of_peers[peer_id][0] + + node_positions[node.node_hash] = (0, line_id) + + # Construct the DiGraph for plotting + edges = [] + G = nx.DiGraph() + + visited = {} + # Adjust the positioning to ensure chronological order + def adjust_position(node, visited): + nonlocal count_lines + if node.node_hash in visited: + return + visited[node.node_hash] = True + + for predecessor in self.get_predecessors(node): + edges.append((predecessor.node_hash, node.node_hash)) + adjust_position(predecessor, visited) + line_id = node_positions[node.node_hash][1] # keep current line id + + ## calculate line id for the current node if the predecessor is self-parent and it has equivocated children + if predecessor.node_hash == node.self_parent_hash: + if predecessor.node_hash not in count_self_direct_children: + count_self_direct_children[predecessor.node_hash] = 0 + count_self_direct_children[predecessor.node_hash] += 1 + if count_self_direct_children[predecessor.node_hash] > 1: # self parent has equivocated children + count_lines += 1 + lines_of_peers[predecessor.peer_id].append(count_lines) + line_id = count_lines + + node_positions[node.node_hash] = (max(node_positions[predecessor.node_hash][0] + 3, node_positions[node.node_hash][0]), line_id) + + for node in seen_valid_nodes: + if not node.node_hash in visited: + adjust_position(node, visited) + + G.add_edges_from(edges) + + ## Rescale the y-coordinates based on the line ids from all nodes + line_id_to_peer: Dict[int, PeerId] = {} + for peer_id in lines_of_peers: + for line_id in lines_of_peers[peer_id]: + line_id_to_peer[line_id] = peer_id + + all_line_ids = sorted(set([node_positions[node.node_hash][1] for node in seen_valid_nodes]), key=lambda x: (line_id_to_peer[x], x)) + remapped_line_ids: Dict[int, int] = {} + for i in range(len(all_line_ids)): + remapped_line_ids[all_line_ids[i]] = i + + for node in seen_valid_nodes: + node_positions[node.node_hash] = (node_positions[node.node_hash][0], - 2 * remapped_line_ids[node_positions[node.node_hash][1]]) + + # Draw nodes and edges + plt.figure(figsize=(50, 10)) # Increase figure size for better visibility + + for node in seen_valid_nodes: + round_number = node.round + node_color = round_colors.get(round_number, 'gray') + nx.draw_networkx_nodes( + G, + node_positions, + nodelist=[node.node_hash], + node_size=1000, + node_shape='s', + node_color=node_color, + label=[node.peer_id] + ) + + nx.draw_networkx_edges(G, node_positions, edge_color='black', arrows=True, arrowsize=20, width=1, node_size=1000) # Ensure arrows are properly sized relative to nodes + + # Draw node labels with a border + for node in seen_valid_nodes: + x, y = node_positions[node.node_hash] + # Draw the text multiple times with slight offsets to create a border + # for dx, dy in [(-0.5, -0.5), (-0.5, 0.5), (0.5, -0.5), (0.5, 0.5)]: + # plt.text(x + dx * 0.01, y + dy * 0.01, node.node_hash, fontsize=10, ha='center', va='center', color='black') + # Draw the actual text in white + plt.text(x, y, node.label(), fontsize=5, ha='center', va='center', color='white') + + + # Draw horizontal separator lines between different peers + unique_lines = sorted(remapped_line_ids.values()) + peer_boundaries = set() + + for i in range(len(unique_lines) - 1): + peer1 = line_id_to_peer[all_line_ids[i]] + peer2 = line_id_to_peer[all_line_ids[i + 1]] + if peer1 != peer2: # If the next line belongs to a different peer, draw a separator + boundary_y = -2 * unique_lines[i] - 1 # Slightly below the last line of peer1 + peer_boundaries.add(boundary_y) + + for y in peer_boundaries: + plt.plot([min(x for x, _ in node_positions.values()), + max(x for x, _ in node_positions.values())], + [y, y], color='black', linestyle='dashed', linewidth=1) + + # plot the figure + plt.axis('off') + plt.show() + + def get_max_num_neighbors(self): + return ceil(log2(len(self.network.peers))) + + def create_genesis_node(self): + """Create a genesis/bootstrap node""" + node = Node( + peer_id=self.peer_id, + round=0, + is_witness=True, + newly_seen_txs_list=[], + self_parent_hash="", + cross_parent_hash="" + ) + assert self.verify_node_and_add_to_local_view(node) == True + return node + + # TODO: implement bootstrap node (first node refers to parents in a checkpoint after a node rejoins the network) + + def has_seen_valid_node(self, node: Node) -> bool: + if node.peer_id == self.peer_id: + return node in self.my_nodes() + else: + return (node.peer_id in self.seen_valid_nodes) and (node.node_hash in [node.node_hash for node in self.seen_valid_nodes[node.peer_id]]) + + def get_node_by_hash(self, node_hash: NodeId) -> Optional[Node]: + for node in self.my_nodes(): + if node.node_hash == node_hash: + return node + ## find in seen_valid_nodes of other peers + for peer_id in self.seen_valid_nodes: + for node in self.seen_valid_nodes[peer_id]: + if node.node_hash == node_hash: + return node + return None + + def record_transaction_receipt(self, tx_id: TransactionId, timestamp: float): + """Record when a transaction was received""" + self.tx_receive_times[tx_id] = timestamp + self.pending_txs.append((tx_id, timestamp)) + + def select_neighbors(self, all_peers: List[PeerId]): + """Randomly select neighbors from available peers""" + potential_neighbors = [p for p in all_peers if p != self.peer_id] + num_neighbors = min(self.get_max_num_neighbors(), len(all_peers) - 1) + self.neighbors = self.random_instance.sample(potential_neighbors, num_neighbors) + + def compute_seen_nodes_of_new_node(self, node: Node): + """Compute the list of seen_nodes of the new node""" + assert node.seen_nodes is None and node.equivocated_peers is None + node.seen_nodes = set() + equivocated_peers = set() + + ancestry_of_node = self.get_ancestry(node) + self_parent_set: Set[NodeId] = set() + + for cur_node in ancestry_of_node: + if not cur_node.is_genesis(): + if cur_node.self_parent_hash in self_parent_set: + equivocated_peers.add(cur_node.peer_id) + else: + self_parent_set.add(cur_node.self_parent_hash) + + for cur_node in ancestry_of_node: + if cur_node.peer_id in equivocated_peers: + continue + + node.seen_nodes.add(cur_node.node_hash) + node.equivocated_peers = equivocated_peers + + return + + def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: + """Verify a node and its transactions, and add it to the local view""" + if not self.verify_node(node): + return False + + # self.add_node_to_local_view(node) + + # add to seen_valid_nodes + if node.peer_id not in self.seen_valid_nodes: + self.seen_valid_nodes[node.peer_id] = [] + + should_add_node = node.node_hash not in [node.node_hash for node in self.seen_valid_nodes[node.peer_id]] + if should_add_node: + self.seen_valid_nodes[node.peer_id].append(node) + + for predecessor in self.get_predecessors(node): + if predecessor.node_hash not in self.local_graph: + self.local_graph[predecessor.node_hash] = set() + self.local_graph[predecessor.node_hash].add(node.node_hash) + + # compute the list of seen_nodes of the new node + self.compute_seen_nodes_of_new_node(node) + # do the cleanup if the node is created by the current peer + if node.peer_id == self.peer_id: + self.pending_txs.clear() # because all txs in the pending_txs are now in the new node + self.current_round = node.round # this makes the current round of the peer = the round of the last node in the list of its nodes + + print(f"Peer {self.peer_id} added node {node.node_hash} to its local view => new round = {self.current_round}") + + return True + + def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["Node"]: + ## check if this witness strongly sees > 2/3 of witnesses of r + ## if some witnesses are descendants of equivocated nodes, they are ignored completely + ## NOTE: we already make sure the ancestry of dest_node is verified + ## NOTE: at this point, dest_node is not called compute_seen_nodes_of_new_node() yet + + strongly_sees_threshold = 2/3 * len(self.network.peers) + strongly_seen_witnesses: list["Node"] = [] + ancestry_of_dest_node = self.get_ancestry(dest_node) + # sorted deterministically + witnesses_in_round_r = sorted([node for node in ancestry_of_dest_node if node.is_witness and node.round == r], key=lambda x: (x.peer_id, x.node_hash)) + + # itearate through all witnesses in round r and check if the dest_node can strongly see them + for witness in witnesses_in_round_r: + lineage_of_witness = self.get_lineage(witness) + crossed_peers = set() + can_conclude_strongly_seen = False + for mid_node in lineage_of_witness: + if (mid_node.peer_id in crossed_peers) or (mid_node not in ancestry_of_dest_node): + continue # can include the witness itself if satisfies the condition + + should_count_as_valid_path_from_witness_to_dest_node = (mid_node.node_hash == dest_node.node_hash) or (witness.node_hash in mid_node.seen_nodes) + if should_count_as_valid_path_from_witness_to_dest_node: + crossed_peers.add(mid_node.peer_id) + can_conclude_strongly_seen = len(crossed_peers) > strongly_sees_threshold + + if can_conclude_strongly_seen: + break + + if can_conclude_strongly_seen: + if witness.peer_id not in [node.peer_id for node in strongly_seen_witnesses]: + # at most 1 witness per peer is counted + strongly_seen_witnesses.append(witness) + + return strongly_seen_witnesses + + def check_round_number_of_non_genesis_node(self, node: Node) -> bool: + """ + if a node is of round r: + - it must not strongly sees > 2N/3 of witnesses of round r + - if its self parent is of round r, it is valid. if its self parent is of round r-1, it must strongly sees > 2N/3 of witnesses of round r-1 + """ + N = len(self.network.peers) + r = node.round + + # the node must not strongly sees > 2/3 of witnesses of round r + strongly_seen_witnesses_in_round_r = self.get_strongly_seen_valid_witnesses(node, r) + + if len(strongly_seen_witnesses_in_round_r) > 2 * N / 3: + return False + + # check non-witness node case + self_parent_node = self.get_node_by_hash(node.self_parent_hash) + assert self_parent_node is not None + if self_parent_node.round == r: + return True + + # check witness node case + strongly_seen_witnesses_in_round_r_minus_1 = self.get_strongly_seen_valid_witnesses(node, r-1) + + return len(strongly_seen_witnesses_in_round_r_minus_1) > 2 * N / 3 + + def verify_node(self, node: Node = None) -> bool: + """Verify a node and its transactions + - round number must be valid + - node hash must be valid + => This method should be called recursively for all ancestors of a node before it's verified + """ + if node is None: + return False + + if node.is_genesis(): + return True + + if not node.verify_node_hash(): + return False + + return self.check_round_number_of_non_genesis_node(node) + + def get_all_transactions(self) -> Set[TransactionId]: + """Get all transactions known to this peer""" + return self.accumulated_txs.copy() + + def get_all_seen_txs_up_to_a_verified_node(self, node: Node) -> Set[TransactionId]: + """Get all transactions seen through this node""" + res = set() + while node is not None: + res.update(node.newly_seen_txs_list) + node = self.get_node_by_hash(node.self_parent_hash) + return res + + def calculate_newly_seen_txs_list_of_new_node(self, self_parent: Node, cross_parent: Node, pending_txs: List[Tuple[TransactionId, float]]) -> list[TransactionId]: + """Calculate the newly seen transactions for a new node""" + all_seen_txs_up_to_self_parent = self.get_all_seen_txs_up_to_a_verified_node(self_parent) + all_seen_txs_up_to_cross_parent = self.get_all_seen_txs_up_to_a_verified_node(cross_parent) + + return sorted((set([txs for txs, _ in pending_txs]) | all_seen_txs_up_to_cross_parent) - all_seen_txs_up_to_self_parent) + + def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: + """A view-only method that computes a new node based on the self parent and cross parent + @return: the newly created node, if there is no new txs, return None + """ + assert self_parent is not None and cross_parent is not None + assert self_parent == self.get_my_last_node() + + # the newly seen list of txs in the new node must be not empty + # TODO: sort this list by timestamp of receipt of the transactions + newly_seen_txs_list: List[TransactionId] = list(self.calculate_newly_seen_txs_list_of_new_node(self_parent, cross_parent, self.pending_txs)) + + if len(newly_seen_txs_list) <= 0: + # can't extend the node sequence because there is no new txs, this is to save the network capacity + print(f"Peer {self.peer_id} can't extend the node sequence because there is no new txs, this is to save the network capacity, self_parent = {self_parent.node_hash}, cross_parent = {cross_parent.node_hash}") + return None + + round_num = len(self.my_nodes()) + base_hash = f"{self.peer_id}{str(round_num).zfill(3)}" + + assert self_parent.round == self.current_round + + new_node = Node( + peer_id=self.peer_id, + round=self_parent.round, + is_witness=False, + newly_seen_txs_list=newly_seen_txs_list, + self_parent_hash=self_parent.node_hash, + cross_parent_hash=cross_parent.node_hash + ) + + if not self.verify_node(new_node): + new_node = Node( + peer_id=self.peer_id, + round=self_parent.round + 1, + is_witness=True, + newly_seen_txs_list=newly_seen_txs_list, + self_parent_hash=self_parent.node_hash, + cross_parent_hash=cross_parent.node_hash + ) + assert self.verify_node(new_node) + + print(f"Peer {self.peer_id} COMPUTED NEW NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") + + return new_node + + async def gossip_push(self): + """Sync with another peer through gossip, potentially sending different views""" + # generate a random permutation of connected peers + # try to extend the node sequence and push it to the neighbors + + self_parent_node = self.get_my_last_node() + # pick a random peer with non-empty seen_valid_nodes + possible_cross_peers = [peer_id for peer_id in self.seen_valid_nodes if self.seen_valid_nodes[peer_id]] + + if len(possible_cross_peers) <= 1: + return # can't extend the node sequence because there is no cross parent for the new node + + randomness = min([self.network.random_instance.random(), self.network.random_instance.random()]) + num_nodes_to_create = min(len(possible_cross_peers), 1 + (1 if randomness < self.equivocation_prob else 0)) + new_nodes = [] # if there are more than 1 node in this list, they are equivocated nodes and that means current peer is an adversary + + # NOTE: currently, the equivocation logic is simple, an adversary basically picks the last node of the current peer as the self parent, and the latest nodes of different cross peers as the cross parents + + for _ in range(num_nodes_to_create): + found_unique_cross_parent = False + while not found_unique_cross_parent: + found_unique_cross_parent = True + + cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) + while cross_parent_peer_id == self.peer_id: + cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) + + cross_parent_node = self.seen_valid_nodes[cross_parent_peer_id][-1] + + new_node = self.compute_new_node(self_parent=self_parent_node, cross_parent=cross_parent_node) + if new_node is None: + return + + if new_node.cross_parent_hash not in [node.cross_parent_hash for node in new_nodes]: + new_nodes.append(new_node) + else: + found_unique_cross_parent = False + + assert len(new_nodes) == num_nodes_to_create + for new_node in new_nodes: + assert self.verify_node_and_add_to_local_view(new_node) + + # start gossiping to neighboring peers + for i in range(len(self.neighbors)): + other_peer_id = self.neighbors[i] + # select randomly nodes from new_nodes + node_to_send = (new_nodes[0] if i * 2 < len(self.neighbors) else new_nodes[-1]).clone() # simulate the process of serializing and deserializing the nodes in internet protocols + + # Send the selected node + success = await self.network.gossip_send_node_and_ancestry(self.peer_id, other_peer_id, node_to_send) + if not success: + print(f"peer {self.peer_id} gossiped send to {other_peer_id} node {node_to_send.node_hash} failed") + else: + print(f"peer {self.peer_id} gossiped send to {other_peer_id} node {node_to_send.node_hash} successfully") + + # TODO: finish equivocation detection + def detect_equivocation(self) -> list: + pass + +async def main(): + # Create network simulator + network = NetworkSimulator( + latency_ms_range=(50, 200), + packet_loss_prob=0.1, + random_instance=random.Random(0) + ) + + # Create peers + num_peers = 4 # next threshold for count_adversary = 2 is N = 7 + count_adversary = 0 + for i in range(num_peers): + is_adversary = (count_adversary + 1) < 1 * num_peers / 3 and network.random_instance.random() < 0.5 + + if is_adversary: + count_adversary += 1 + + peer = ConsensusPeer( + peer_id=f"P{i}", + is_adversary=is_adversary, + seed=i, + network=network + ) + network.register_peer(peer) + peers = network.peers + + # Initialize peer neighborhoods + all_peer_ids = network.get_all_peer_ids() + for peer in peers: + peer.select_neighbors(all_peer_ids) + + # Register genesis checkpoint + await network.register_genesis_nodes() + + MIN_NUM_ROUNDS = 10 + current_simluated_timestamp = 0 + + # Main consensus loop + i = 0 + while True and i < 100: + i += 1 + if i % 100 == 0: + print(f"{i}th iteration") + # Count peers that have reached MIN_NUM_ROUNDS rounds + peers_completed = sum(1 for c in peers if c.current_round >= MIN_NUM_ROUNDS) + if peers_completed > (2 * num_peers // 3): + break + + # Randomly select an action for a random peer + action = network.random_instance.random() + + if action < 0.15: # Generate new transactions + txs, peers_to_send = network.new_txs_from_user_client() + for tx in txs: + for peer in peers_to_send: + if network.random_instance.random() < 0.3: + current_simluated_timestamp += 1 # advance current timestamp + + # Record receipt time for self + peer.record_transaction_receipt(tx, current_simluated_timestamp) + else: + # pick a random set of peers to do gossip push from it to its neighbors + random_peers = network.random_instance.sample(peers, network.random_instance.randint(1, len(peers))) + for peer in random_peers: + await peer.gossip_push() + + # Create checkpoints periodically + # TODO: create network checkpoints dynamically via network.create_checkpoint() + + for peer in peers: + peer.visualize_view() + if peer.is_adversary: + print(f"Peer {peer.peer_id} is an adversary") + print(f"Neighbors of {peer.peer_id}: {peer.neighbors}") + print(f"Consensus completed with first {peers_completed} peers reaching round {MIN_NUM_ROUNDS}") + + def validate_consistency(): + graph_info_list = [peer.get_graph_info() for peer in peers] + + found_node_conflict = False + global_info_of_node: Dict[NodeId, str] = {} + + for i in range(len(graph_info_list)): + for node_id in graph_info_list[i]: + node_description = graph_info_list[i][node_id].__str__() + if node_id in global_info_of_node: + if global_info_of_node[node_id] != node_description: + found_node_conflict = True + print(f"FAILED: found conflict in node info of node {node_id}: {node_description} vs {global_info_of_node[node_id]}") + else: + global_info_of_node[node_id] = node_description + + assert found_node_conflict == False + + print("SUCCESS: There is no conflict in node info between peers") + + validate_consistency() + +# Run the simulation +asyncio.run(main()) + +### Possible attacks: +# Long-Range Attacks: If validators controlling past checkpoints sell their keys, an attacker can re-sign an alternative history, leading to checkpoint reversals. +# => Dangerous once attacker can control > 2/3 of the OLD validators +# Majority Takeover: If an attacker gains control of 2/3 of the validators (BFT threshold), they could re-finalize a new chain with different checkpoints. +# => recursive validity proof + proof of finality +# Solution: Post-Unstaking Slashing for X blocks after unstaking (but not able to withdraw before X blocks yet) + +# [] TODO: finish gossip-DAG architecture +# [] TODO: finish DAGPool's order fairness gadget \ No newline at end of file diff --git a/dagpool/schemas.py b/dagpool/schemas.py new file mode 100644 index 0000000..1262d80 --- /dev/null +++ b/dagpool/schemas.py @@ -0,0 +1,7 @@ +from typing import NewType +TransactionId = NewType("TransactionId", str) +PeerId = NewType("PeerId", str) +NodeId = NewType("NodeId", str) +NodeLabel = NewType("NodeLabel", str) +Pubkey = NewType("Pubkey", str) +Signature = NewType("Signature", str) \ No newline at end of file From d6d1be05fa8f685b9222af54dfbced78ff965bb4 Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Mon, 24 Feb 2025 12:33:32 +0700 Subject: [PATCH 02/10] add more adversary strategies, add rejection logic for nodes from equivocated peers for honest peers --- dagpool/consensus_client.py | 105 +++++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 30 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index c6452ea..9740c11 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -34,6 +34,9 @@ def clone(self): def label(self) -> NodeLabel: return f"{self.peer_id}:{self.node_hash}" + def has_computed_seen_nodes(self) -> bool: + return self.seen_nodes is not None and self.equivocated_peers is not None + @staticmethod def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: assert len(txs) > 0 @@ -143,9 +146,9 @@ async def register_genesis_nodes(self): # Gossip genesis node to neighbors success = await self.gossip_send_node_and_ancestry(peer1.peer_id, peer2.peer_id, cloned_genesis_node) if not success: - print(f"peer {peer1.peer_id} gossiped send to {peer2.peer_id} genesis node {genesis_node.node_hash} failed") + print(f"peer {peer1.peer_id} gossiped to {peer2.peer_id} genesis node {genesis_node.node_hash} failed") else: - print(f"peer {peer1.peer_id} gossiped send to {peer2.peer_id} genesis node {genesis_node.node_hash} successfully") + print(f"peer {peer1.peer_id} gossiped to {peer2.peer_id} genesis node {genesis_node.node_hash} successfully") self.genesis_checkpoint = Checkpoint( timestamp=time.time(), @@ -164,7 +167,7 @@ def new_txs_from_user_client(self) -> (list[TransactionId], list['ConsensusPeer' txs = [] for _ in range(num_txs): # 50% pick a random txs already in the global mempool - if len(mempool_txs) > 0 and self.random_instance.random() < 0.5: + if len(mempool_txs) > 0 and self.random_instance.random() < 0.1: txs.append(self.random_instance.choice(mempool_txs)) else: # 50% add a new txs @@ -301,7 +304,10 @@ async def gossip_send_node_and_ancestry(self, sender: PeerId, receiver: PeerId, for i in range(len(all_received_nodes)): current_node = all_received_nodes[i] if not receiver_peer.verify_node_and_add_to_local_view(current_node): + print(f"Peer {receiver_peer.peer_id} rejected node {current_node.node_hash} from {sender_peer.peer_id}") continue + else: + print(f"Peer {receiver_peer.peer_id} accepted node {current_node.node_hash} from {sender_peer.peer_id}") return receiver_peer.has_seen_valid_node(node1) @@ -333,7 +339,7 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw ## adversary-related data self.equivocated_peers: Set[PeerId] = set() # set of peers that current peer believes they actively create equivocated nodes self.equivocated_nodes: Set[NodeId] = set() # set of nodes that current peer believes are equivocated - self.equivocation_prob = 0.5 if is_adversary else 0.0 + self.equivocation_prob = 0.2 if is_adversary else 0.0 self.neighbors: list[PeerId] = [] # Track neighboring peers # if each peer connects to log(N) neighbors, a transaction would takes O(log(N)/log(log(N))) gossip hops to reach the whole network # for N = 10^6, it would be 7 hops @@ -616,11 +622,10 @@ def compute_seen_nodes_of_new_node(self, node: Node): def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: """Verify a node and its transactions, and add it to the local view""" + if not self.verify_node(node): return False - # self.add_node_to_local_view(node) - # add to seen_valid_nodes if node.peer_id not in self.seen_valid_nodes: self.seen_valid_nodes[node.peer_id] = [] @@ -634,8 +639,6 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: self.local_graph[predecessor.node_hash] = set() self.local_graph[predecessor.node_hash].add(node.node_hash) - # compute the list of seen_nodes of the new node - self.compute_seen_nodes_of_new_node(node) # do the cleanup if the node is created by the current peer if node.peer_id == self.peer_id: self.pending_txs.clear() # because all txs in the pending_txs are now in the new node @@ -649,7 +652,6 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No ## check if this witness strongly sees > 2/3 of witnesses of r ## if some witnesses are descendants of equivocated nodes, they are ignored completely ## NOTE: we already make sure the ancestry of dest_node is verified - ## NOTE: at this point, dest_node is not called compute_seen_nodes_of_new_node() yet strongly_sees_threshold = 2/3 * len(self.network.peers) strongly_seen_witnesses: list["Node"] = [] @@ -681,7 +683,7 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No return strongly_seen_witnesses - def check_round_number_of_non_genesis_node(self, node: Node) -> bool: + def check_round_number_of_non_genesis_node_with_valid_parents(self, node: Node) -> bool: """ if a node is of round r: - it must not strongly sees > 2N/3 of witnesses of round r @@ -712,17 +714,41 @@ def verify_node(self, node: Node = None) -> bool: - round number must be valid - node hash must be valid => This method should be called recursively for all ancestors of a node before it's verified + + If the node accepts any parents from an equivocated peer, it is invalid """ if node is None: return False - if node.is_genesis(): - return True + # an honest peer must not accept a node which itself or its parents are from equivocated peers + try: + if not node.has_computed_seen_nodes(): # in case a node is called verify_node_and_add_to_local_view() multiple times + self.compute_seen_nodes_of_new_node(node) - if not node.verify_node_hash(): + if node.is_genesis(): + return True + + if not node.verify_node_hash(): + return False + + predecessors = self.get_predecessors(node) + + # must have valid parents + if len(predecessors) < 2: + return False + + predecessors_and_current_node = predecessors + [node] + for node_to_check in predecessors_and_current_node: + if node_to_check.peer_id in node.equivocated_peers: + is_allowed_to_bypass = self.is_adversary and node_to_check.peer_id == self.peer_id # adversary don't accept invalid nodes from other adversaries + if not is_allowed_to_bypass: + return False + except Exception as e: + # adversary sending invalid nodes + print("error = ", e) return False - return self.check_round_number_of_non_genesis_node(node) + return self.check_round_number_of_non_genesis_node_with_valid_parents(node) def get_all_transactions(self) -> Set[TransactionId]: """Get all transactions known to this peer""" @@ -748,7 +774,8 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: @return: the newly created node, if there is no new txs, return None """ assert self_parent is not None and cross_parent is not None - assert self_parent == self.get_my_last_node() + if not self.is_adversary: + assert self_parent == self.get_my_last_node() # the newly seen list of txs in the new node must be not empty # TODO: sort this list by timestamp of receipt of the transactions @@ -762,7 +789,8 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: round_num = len(self.my_nodes()) base_hash = f"{self.peer_id}{str(round_num).zfill(3)}" - assert self_parent.round == self.current_round + if not self.is_adversary: + assert self_parent.round == self.current_round new_node = Node( peer_id=self.peer_id, @@ -782,7 +810,10 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: self_parent_hash=self_parent.node_hash, cross_parent_hash=cross_parent.node_hash ) - assert self.verify_node(new_node) + + if not self.verify_node(new_node): + # this new_node is invalid because either its parents are from equivocated peers + return None print(f"Peer {self.peer_id} COMPUTED NEW NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") @@ -793,7 +824,7 @@ async def gossip_push(self): # generate a random permutation of connected peers # try to extend the node sequence and push it to the neighbors - self_parent_node = self.get_my_last_node() + self_parent_node = self.get_my_last_node() if not self.is_adversary else self.random_instance.choice(self.my_nodes()) # pick a random peer with non-empty seen_valid_nodes possible_cross_peers = [peer_id for peer_id in self.seen_valid_nodes if self.seen_valid_nodes[peer_id]] @@ -807,9 +838,9 @@ async def gossip_push(self): # NOTE: currently, the equivocation logic is simple, an adversary basically picks the last node of the current peer as the self parent, and the latest nodes of different cross peers as the cross parents for _ in range(num_nodes_to_create): - found_unique_cross_parent = False - while not found_unique_cross_parent: - found_unique_cross_parent = True + max_num_retries = 10 + + for i in range(max_num_retries): cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) while cross_parent_peer_id == self.peer_id: @@ -817,16 +848,28 @@ async def gossip_push(self): cross_parent_node = self.seen_valid_nodes[cross_parent_peer_id][-1] + if cross_parent_node.node_hash in [node.cross_parent_hash for node in new_nodes]: + # duplicated cross parent + continue + new_node = self.compute_new_node(self_parent=self_parent_node, cross_parent=cross_parent_node) - if new_node is None: - return - if new_node.cross_parent_hash not in [node.cross_parent_hash for node in new_nodes]: + if new_node is not None: + # found a valid node with unique cross parent new_nodes.append(new_node) + break else: - found_unique_cross_parent = False + print(f"Peer {self.peer_id} can't compute any new nodes from {self_parent_node.node_hash} and {cross_parent_node.node_hash}") + pass + # can't construct a valid node from the current tuple of self_parent and cross_parent + + assert len(new_nodes) <= num_nodes_to_create + if len(new_nodes) > 0: + print(f"Peer {self.peer_id}, is_adversary = {self.is_adversary}, computed {len(new_nodes)} nodes, its neighbors = {self.neighbors}, its equivocated peers = {self.my_nodes()[-1].equivocated_peers}, seen_peers = {[peer_id for peer_id in self.seen_valid_nodes]}") + else: + print(f"Peer {self.peer_id}, is_adversary = {self.is_adversary}, can't compute any new nodes, its neighbors = {self.neighbors}, its equivocated peers = {self.my_nodes()[-1].equivocated_peers}, seen_peers = {[peer_id for peer_id in self.seen_valid_nodes]}") + return - assert len(new_nodes) == num_nodes_to_create for new_node in new_nodes: assert self.verify_node_and_add_to_local_view(new_node) @@ -836,12 +879,14 @@ async def gossip_push(self): # select randomly nodes from new_nodes node_to_send = (new_nodes[0] if i * 2 < len(self.neighbors) else new_nodes[-1]).clone() # simulate the process of serializing and deserializing the nodes in internet protocols + print(f"Peer {self.peer_id} try to gossip to {other_peer_id} node {node_to_send.node_hash}:") + # Send the selected node success = await self.network.gossip_send_node_and_ancestry(self.peer_id, other_peer_id, node_to_send) if not success: - print(f"peer {self.peer_id} gossiped send to {other_peer_id} node {node_to_send.node_hash} failed") + print(f"peer {self.peer_id} gossiped to {other_peer_id} node {node_to_send.node_hash} failed") else: - print(f"peer {self.peer_id} gossiped send to {other_peer_id} node {node_to_send.node_hash} successfully") + print(f"peer {self.peer_id} gossiped to {other_peer_id} node {node_to_send.node_hash} successfully") # TODO: finish equivocation detection def detect_equivocation(self) -> list: @@ -856,7 +901,7 @@ async def main(): ) # Create peers - num_peers = 4 # next threshold for count_adversary = 2 is N = 7 + num_peers = 7 # next threshold for count_adversary = 2 is N = 7 count_adversary = 0 for i in range(num_peers): is_adversary = (count_adversary + 1) < 1 * num_peers / 3 and network.random_instance.random() < 0.5 @@ -886,7 +931,7 @@ async def main(): # Main consensus loop i = 0 - while True and i < 100: + while True and i < 50: i += 1 if i % 100 == 0: print(f"{i}th iteration") From 83a49e26505a99ef4a5da7e820b0d0021ef9a5e8 Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Tue, 25 Feb 2025 13:14:32 +0700 Subject: [PATCH 03/10] implement graph utils --- dagpool/graph.py | 169 +++++++++++++++++++++++++++++++++++++++++++++ dagpool/schemas.py | 4 +- 2 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 dagpool/graph.py diff --git a/dagpool/graph.py b/dagpool/graph.py new file mode 100644 index 0000000..51780f5 --- /dev/null +++ b/dagpool/graph.py @@ -0,0 +1,169 @@ +""" +This module implements the algorithms that find a Hamiltonian path and a Hamiltonian cycle on a tournament graph. + +References: +- Hamiltonian Path in a tournament graph: + P. Hell and M. Rosenfeld, *The complexity of finding generalized paths in tournaments*, + J. Algorithms 4 (1982) 303-309. + [Link](https://www.sciencedirect.com/science/article/abs/pii/0196677483900111) + +- Hamiltonian Cycle in a strongly connected tournament graph: + Y. Manoussakis, *A Linear-Time Algorithm for Finding Hamiltonian Cycles in Tournaments*, + Discrete Appl. Math. 36, 2 (1992), 199–201. + [Link](https://www.sciencedirect.com/science/article/pii/0166218X9290233Z) +""" + +from typing import Set, Dict, List, Tuple +from schemas import OrderFairnessGraphNodeId, HashValue + +class DirectedGraph: + def __init__(self): + self.nodes: Set[OrderFairnessGraphNodeId] = {} + self.edges: Dict[OrderFairnessGraphNodeId, Set[OrderFairnessGraphNodeId]] = {} + self.is_tournament_graph = None + self.connected_components: List[Tuple[OrderFairnessGraphNodeId, HashValue]] = None + + def reset_graph_properties(self): + self.is_tournament_graph = None + self.connected_components = None + + def add_node(self, nodeId: OrderFairnessGraphNodeId): + if nodeId in self.nodes: + return + self.nodes.add(nodeId) + self.edges[nodeId] = set() + # reset the graph properties + self.reset_graph_properties() + + def add_directed_edge(self, nodeId1: OrderFairnessGraphNodeId, nodeId2: OrderFairnessGraphNodeId): + if nodeId1 not in self.nodes: + self.add_node(nodeId1) + self.edges[nodeId1].add(nodeId2) + # reset the graph properties + self.reset_graph_properties() + + def has_edge(self, nodeId1: OrderFairnessGraphNodeId, nodeId2: OrderFairnessGraphNodeId) -> bool: + return nodeId1 in self.edges and nodeId2 in self.edges[nodeId1] + + def assert_is_tournament_graph(self): + if self.is_tournament_graph is None: + self.is_tournament_graph = self.is_tournament_graph() + assert self.is_tournament_graph + + def is_tournament_graph(self) -> bool: + for node1 in self.nodes: + for node2 in self.nodes: + if node1 != node2: + count = (1 if self.has_edge(node1, node2) else 0) + (1 if self.has_edge(node2, node1) else 0) + if count != 1: + return False + return True + + def find_strongly_connected_components(self) -> List[List[OrderFairnessGraphNodeId]]: + assert self.connected_components is None + self.connected_components: List[Tuple[OrderFairnessGraphNodeId, HashValue]] = [] + + # tarjan's algorithm + index = 0 + indices = {} + lowlinks = {} + stack = [] + components: List[List[OrderFairnessGraphNodeId]] = [] + + def strongconnect(nodeId: OrderFairnessGraphNodeId) -> None: + assert nodeId not in indices + nonlocal index, indices, lowlinks, stack, on_stack, components + # Set the depth index for node + indices[nodeId] = index + lowlinks[nodeId] = index + index += 1 + stack.append(nodeId) + + # Consider successors of node + for successorId in self.edges[nodeId]: + if successorId not in indices: + # Successor has not yet been visited; recurse on it + strongconnect(successorId) + lowlinks[nodeId] = min(lowlinks[nodeId], lowlinks[successorId]) + else: + lowlinks[nodeId] = min(lowlinks[nodeId], indices[successorId]) + + # If node is a root node, pop the stack and generate an SCC + if lowlinks[nodeId] == indices[nodeId]: + component = [] + while True: + vertexId = stack.pop() + component.append(vertexId) + if vertexId == nodeId: + break + components.append([component, hash(tuple(sorted(component)))]) + + # Find SCCs for all nodes + for nodeId in self.nodes: + if nodeId not in indices: + strongconnect(nodeId) + + # Store components with their hash values + self.connected_components = components + + return components + + def assert_is_strongly_connected_component(self, hamiltonian_path: List[OrderFairnessGraphNodeId]): + assert self.connected_components is not None + hash_value = hash(tuple(sorted(hamiltonian_path))) + assert hash_value in [connected_component[1] for connected_component in self.connected_components] + + def assert_is_hamiltonian_path(self, hamiltonian_path: List[OrderFairnessGraphNodeId]): + for i in range(1, len(hamiltonian_path)): + assert self.has_edge(hamiltonian_path[i - 1], hamiltonian_path[i]) + +class TournamentGraph(DirectedGraph): + def __init__(self): + super().__init__() + + def is_tournament_graph(self) -> bool: + return super().is_tournament_graph() + + def find_hamiltonian_path(self, scc: List[OrderFairnessGraphNodeId]) -> List[OrderFairnessGraphNodeId]: + # must be a strongly connected component + self.assert_is_strongly_connected_component(scc) + + # Complexity: O(len(scc)^2) + hamiltonian_path: List[OrderFairnessGraphNodeId] = [scc[0]] + for k in range(1, len(scc)): + # find first i < k | has_edge(hamiltonian_path[i], scc[k]) & has_edge(scc[k], hamiltonian_path[i+1]) + i = 0 + while i + 1 < k and not (self.has_edge(hamiltonian_path[i], scc[k]) and self.has_edge(scc[k], hamiltonian_path[i+1])): + i += 1 + + hamiltonian_path.insert(i + 1, scc[k]) + + return hamiltonian_path + + def find_hamiltonian_cycle(self, hamiltonian_path_of_scc: List[OrderFairnessGraphNodeId]) -> List[OrderFairnessGraphNodeId]: + self.assert_is_tournament_graph() + self.assert_is_strongly_connected_component(hamiltonian_path_of_scc) + self.assert_is_hamiltonian_path(hamiltonian_path_of_scc) + + # Complexity: O(len(hamiltonian_path)^2) + accumulated_hamiltonian_cycle: List[OrderFairnessGraphNodeId] = [hamiltonian_path_of_scc[0]] + j = 1 + while j < len(hamiltonian_path_of_scc): + p = j + 1 + r = -1 + found_backward_edge = False + while not found_backward_edge and p < len(hamiltonian_path_of_scc): + nonlocal r + # check if there is a backward edge from hamiltonian_path[p] to accumulated_hamiltonian_cycle[r] + r = 0 # the out-going node of the backward edge + while r < len(accumulated_hamiltonian_cycle) and not self.has_edge(hamiltonian_path_of_scc[p], accumulated_hamiltonian_cycle[r]): + r += 1 + + if r < len(accumulated_hamiltonian_cycle): + found_backward_edge = True + + # reorder the accumulated_hamiltonian_cycle: accumulated_hamiltonian_cycle[0 -> r - 1] -> hamiltonian_path_of_scc[j+1 -> p] -> accumulated_hamiltonian_cycle[r -> ...] -> hamiltonian_path_of_scc[0] + j = p + accumulated_hamiltonian_cycle = accumulated_hamiltonian_cycle[0:r] + hamiltonian_path_of_scc[j+1:p+1] + accumulated_hamiltonian_cycle[r:] + + return accumulated_hamiltonian_cycle \ No newline at end of file diff --git a/dagpool/schemas.py b/dagpool/schemas.py index 1262d80..bda9f1c 100644 --- a/dagpool/schemas.py +++ b/dagpool/schemas.py @@ -4,4 +4,6 @@ NodeId = NewType("NodeId", str) NodeLabel = NewType("NodeLabel", str) Pubkey = NewType("Pubkey", str) -Signature = NewType("Signature", str) \ No newline at end of file +Signature = NewType("Signature", str) +OrderFairnessGraphNodeId = NewType("OrderFairnessGraphNodeId", str) +HashValue = NewType("HashValue", str) \ No newline at end of file From 1985793fc120412c6c3a0d8811671ff2c53f0598 Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Tue, 25 Feb 2025 17:23:17 +0700 Subject: [PATCH 04/10] add batch proposal construction --- dagpool/consensus_client.py | 187 ++++++++++++++++++++++++++++++------ dagpool/graph.py | 64 ++++++++---- dagpool/schemas.py | 5 +- 3 files changed, 207 insertions(+), 49 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index 9740c11..ce6e160 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -10,12 +10,72 @@ import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap import hashlib -from schemas import TransactionId, PeerId, NodeId, NodeLabel +from schemas import TransactionId, PeerId, NodeId, NodeLabel, Signature, BatchId, Pubkey, HashValue import json from math import log2, ceil # Added log2 and ceil imports +from graph import TournamentGraph + +BEACON_PACE = 4 +# the first BEACON_PACE rounds are derived directly from the list of peers +# BEACON_PACE should be chosen large enough to make sure peers have enough time to realize that they are the leader of the next rounds +BEACON_FIELD_PRIME = 28948022309329048855892746252171976963363056481941560715954676764349967630337 # equal to Pallas base field prime + +class Utils: + @staticmethod + def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: + assert len(txs) > 0 + # do the merkle tree construction using a while loop + res = [tx for tx in txs] + while len(res) > 1: + new_res = [] + for i in range(0, len(res), 2): + if i+1 < len(res): + new_res.append(hashlib.sha256(f"{res[i]}{res[i+1]}".encode()).hexdigest()) + else: + new_res.append(res[i]) + res = new_res + return res[0] + +class BatchProposal: + batch_hash: HashValue + prev_batch_hash: HashValue + final_fair_ordering: List[TransactionId] + next_beacon_randomness: HashValue # peers use this to derive the leader of round r + BEACON_PACE + + def __init__(self, prev_batch_hash: HashValue, final_fair_ordering: List[TransactionId], prev_beacon_randomness: HashValue): + self.prev_batch_hash = prev_batch_hash + self.final_fair_ordering = final_fair_ordering + self.next_beacon_randomness = self.compute_next_beacon_randomness(round, prev_beacon_randomness, final_fair_ordering) + self.batch_hash = self.compute_batch_hash() + + def compute_next_beacon_randomness(self, round_number: int, prev_beacon_randomness: HashValue, final_fair_ordering: List[TransactionId]) -> HashValue: + components = [prev_beacon_randomness, str(round_number + BEACON_PACE), Utils.merkle_root_of_transaction_list(final_fair_ordering)] + return hashlib.sha256(''.join(components).encode()).hexdigest() + + def compute_batch_hash(self) -> HashValue: + # use hashlib of [prev_batch_hash, final_fair_ordering, next_beacon_randomness] + components = [self.prev_batch_hash, Utils.merkle_root_of_transaction_list(self.final_fair_ordering), self.next_beacon_randomness] + return hashlib.sha256(''.join(components).encode()).hexdigest()[:8] + + def verify_batch_proposal_is_well_formed(self, round_number: int, prev_beacon_randomness: HashValue) -> bool: + if not self.next_beacon_randomness == self.compute_next_beacon_randomness(round_number, prev_beacon_randomness, self.final_fair_ordering): + return False + if not self.batch_hash == self.compute_batch_hash(): + return False + return True + + def clone(self): + return BatchProposal(self.prev_batch_hash, [tx for tx in self.final_fair_ordering], self.next_beacon_randomness) + +class NodeMetadata: + batch_proposal: BatchProposal # can be None if the node is not a head node (head node means the witness of the leader in its selected round) + creator_signature: Signature + + def clone(self): + return NodeMetadata(self.batch_proposal.clone() if self.batch_proposal else None, self.creator_signature) class Node: - def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs_list: list[TransactionId], self_parent_hash: NodeId, cross_parent_hash: NodeId): + def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs_list: list[TransactionId], self_parent_hash: NodeId, cross_parent_hash: NodeId, metadata: NodeMetadata): self.peer_id = peer_id self.is_witness = is_witness self.round = round @@ -27,9 +87,11 @@ def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs ## fork-related data self.equivocated_peers: Set[PeerId] = None # set of peers that current node believes are equivocated, and this node won't SEE (i.e. UNSEE) all nodes created by them. Note that this doesn't affect STRONGLY SEEING property of this node. self.seen_nodes: Set[NodeId] = None # set of nodes that current node sees + self.metadata: NodeMetadata = metadata + self.update_signature() def clone(self): - return Node(self.peer_id, self.round, self.is_witness, [txs for txs in self.newly_seen_txs_list], self.self_parent_hash, self.cross_parent_hash) + return Node(self.peer_id, self.round, self.is_witness, [txs for txs in self.newly_seen_txs_list], self.self_parent_hash, self.cross_parent_hash, self.metadata.clone()) def label(self) -> NodeLabel: return f"{self.peer_id}:{self.node_hash}" @@ -37,21 +99,22 @@ def label(self) -> NodeLabel: def has_computed_seen_nodes(self) -> bool: return self.seen_nodes is not None and self.equivocated_peers is not None - @staticmethod - def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: - assert len(txs) > 0 - # do the merkle tree construction using a while loop - res = [tx for tx in txs] - while len(res) > 1: - new_res = [] - for i in range(0, len(res), 2): - if i+1 < len(res): - new_res.append(hashlib.sha256(f"{res[i]}{res[i+1]}".encode()).hexdigest()) - else: - new_res.append(res[i]) - res = new_res - return res[0] + def compute_signature(self, _creator_private_key: PrivateKey) -> Signature: + # TODO: use real private key + if self.metadata.batch_proposal: + # signature = hash(batch_proposal.batch_hash, node_hash) + return hashlib.sha256([self.metadata.batch_proposal.batch_hash, self.node_hash].encode()).hexdigest()[:8] + else: + # signature = hash(node_hash) + return hashlib.sha256([self.node_hash].encode()).hexdigest()[:8] + def update_signature(self): + self.metadata.creator_signature = self.compute_signature() + + def verify_signature(self, creator_pubkey: Pubkey) -> bool: + # TODO: use real pubkey + return self.metadata.creator_signature == self.compute_signature() + @staticmethod def hash_node(creator: PeerId, round: int, is_witness: bool, self_parent_hash: NodeId, cross_parent_hash: NodeId, newly_seen_txs_list: list[TransactionId]) -> NodeId: """Create deterministic hash for a node""" @@ -61,12 +124,26 @@ def hash_node(creator: PeerId, round: int, is_witness: bool, self_parent_hash: N if self_parent_hash: components.append(self_parent_hash) if newly_seen_txs_list: - components.append(Node.merkle_root_of_transaction_list(newly_seen_txs_list)) + components.append(Utils.merkle_root_of_transaction_list(newly_seen_txs_list)) return hashlib.sha256(''.join(components).encode()).hexdigest()[:8] # the hash value of a node basically depends deterministically on all of its content def verify_node_hash(self) -> bool: return self.node_hash == Node.hash_node(self.peer_id, self.round, self.is_witness, self.self_parent_hash, self.cross_parent_hash, self.newly_seen_txs_list) + def validate_node_data(self, creator_pubkey: Pubkey) -> bool: + if not self.verify_node_hash(): + return False + + # TODO: validate that all the node data is well-formed (use schema validator) + + if self.metadata.batch_proposal and not self.metadata.batch_proposal.verify_batch_proposal_is_well_formed(self.round, self.metadata.batch_proposal.prev_beacon_randomness): + return False + + if not self.verify_signature(creator_pubkey): + return False + + return True + def is_genesis(self): is_genesis = self.round == 0 and self.is_witness == True and self.self_parent_hash == "" and self.cross_parent_hash == "" and self.verify_node_hash() return is_genesis @@ -117,6 +194,19 @@ def __init__(self, latency_ms_range=(50, 200), packet_loss_prob=0.1, random_inst # global mempool: simulate a global mempool of all transactions from all clients self.global_mempool = set() + def get_first_leaders(self) -> List[PeerId]: + all_peers = [p.peer_id for p in self.peers] + value_bytes = hashlib.sha256(",".join(all_peers).encode()).digest() + value = int.from_bytes(value_bytes, "big") + # choose the first BEACON_PACE leaders for the first BEACON_PACE rounds, using the first BEACON_FIELD_PRIME as modulo somehow + beacon_random = random.Random(value) + leaders = beacon_random.sample(all_peers, BEACON_PACE) + return leaders + + def get_peer_pubkey(self, peer_id: PeerId) -> Pubkey: + # TODO: use real pubkey + return "DUMMY_PUBKEY" + def register_peer(self, peer: 'ConsensusPeer'): # peer_id must be unique assert peer.peer_id not in [p.peer_id for p in self.peers] @@ -135,6 +225,7 @@ async def register_genesis_nodes(self): genesis_nodes: Dict[PeerId, Node] = {} for peer1 in self.peers: genesis_node = peer1.create_genesis_node() + peer1.construct_batch_proposal_if_needed(genesis_node) genesis_nodes[peer1.peer_id] = genesis_node for peer2 in self.peers: @@ -335,7 +426,7 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.current_round = 0 self.seen_valid_nodes: Dict[PeerId, List[Node]] = {} self.accumulated_txs = set() # Track all transactions seen by this peer - self.network = network + self.network: NetworkSimulator = network ## adversary-related data self.equivocated_peers: Set[PeerId] = set() # set of peers that current peer believes they actively create equivocated nodes self.equivocated_nodes: Set[NodeId] = set() # set of nodes that current peer believes are equivocated @@ -560,11 +651,48 @@ def create_genesis_node(self): is_witness=True, newly_seen_txs_list=[], self_parent_hash="", - cross_parent_hash="" + cross_parent_hash="", + metadata=NodeMetadata() ) assert self.verify_node_and_add_to_local_view(node) == True return node + def construct_batch_proposal_if_needed(self, node: Node): + """ + Construct a batch proposal for the given node + """ + if node.peer_id != self.peer_id or not node.is_witness: + return + + # fork-choice rule: select the previous batch using the Heaviest Observed Subtree (HOS) selection rule + heaviest_batch = 0 + while self.batch_has_children(heaviest_batch): + heaviest_batch = self.get_heaviest_child(heaviest_batch) + + prev_batch_beacon_randomness: HashValue = self.get_beacon_randomness_of_batch(heaviest_batch) + + # construct the truncated cone = the intersection between the lineage of the heaviest batch and the ancestry of the current head node + truncated_cone: Dict[NodeId, Node] = self.get_truncated_cone(heaviest_batch, node) + # gather all transactions in the truncated cone into a tournament graph + tournament_graph: TournamentGraph = self.construct_tournament_graph_of_transactions(truncated_cone) + # calculate the fair ordering of the transactions and construct the batch proposal + sccs = tournament_graph.find_strongly_connected_components() + sccs_with_fair_orderings = [[scc, tournament_graph.find_hamiltonian_cycle(scc)] for scc in sccs] + final_fair_ordering: List[TransactionId] = [] + for scc, fair_ordering in sccs_with_fair_orderings: + final_fair_ordering.extend(fair_ordering) + + batch_proposal = BatchProposal( + prev_batch_hash=heaviest_batch, + final_fair_ordering=final_fair_ordering, + prev_beacon_randomness=prev_batch_beacon_randomness + ) + + assert batch_proposal.verify_batch_proposal_is_well_formed(node.round, prev_batch_beacon_randomness) + + node.metadata = NodeMetadata(batch_proposal=batch_proposal) + node.update_signature() + # TODO: implement bootstrap node (first node refers to parents in a checkpoint after a node rejoins the network) def has_seen_valid_node(self, node: Node) -> bool: @@ -728,7 +856,7 @@ def verify_node(self, node: Node = None) -> bool: if node.is_genesis(): return True - if not node.verify_node_hash(): + if not node.validate_node_data(self.network.get_peer_pubkey(node.peer_id)): return False predecessors = self.get_predecessors(node) @@ -743,6 +871,11 @@ def verify_node(self, node: Node = None) -> bool: is_allowed_to_bypass = self.is_adversary and node_to_check.peer_id == self.peer_id # adversary don't accept invalid nodes from other adversaries if not is_allowed_to_bypass: return False + + # TODO: verify the batch proposal if it exists + # 1. check if the current node is a head witness + # 2. verify the batch proposal is constructed correctly + except Exception as e: # adversary sending invalid nodes print("error = ", e) @@ -798,8 +931,10 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: is_witness=False, newly_seen_txs_list=newly_seen_txs_list, self_parent_hash=self_parent.node_hash, - cross_parent_hash=cross_parent.node_hash + cross_parent_hash=cross_parent.node_hash, + metadata=NodeMetadata() ) + self.construct_batch_proposal_if_needed(new_node) if not self.verify_node(new_node): new_node = Node( @@ -808,8 +943,10 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: is_witness=True, newly_seen_txs_list=newly_seen_txs_list, self_parent_hash=self_parent.node_hash, - cross_parent_hash=cross_parent.node_hash + cross_parent_hash=cross_parent.node_hash, + metadata=NodeMetadata() ) + self.construct_batch_proposal_if_needed(new_node) if not self.verify_node(new_node): # this new_node is invalid because either its parents are from equivocated peers @@ -888,10 +1025,6 @@ async def gossip_push(self): else: print(f"peer {self.peer_id} gossiped to {other_peer_id} node {node_to_send.node_hash} successfully") - # TODO: finish equivocation detection - def detect_equivocation(self) -> list: - pass - async def main(): # Create network simulator network = NetworkSimulator( diff --git a/dagpool/graph.py b/dagpool/graph.py index 51780f5..dfa426b 100644 --- a/dagpool/graph.py +++ b/dagpool/graph.py @@ -14,20 +14,24 @@ """ from typing import Set, Dict, List, Tuple -from schemas import OrderFairnessGraphNodeId, HashValue +from schemas import GraphNodeId, HashValue class DirectedGraph: def __init__(self): - self.nodes: Set[OrderFairnessGraphNodeId] = {} - self.edges: Dict[OrderFairnessGraphNodeId, Set[OrderFairnessGraphNodeId]] = {} + self.nodes: Set[GraphNodeId] = {} + self.edges: Dict[GraphNodeId, Set[GraphNodeId]] = {} self.is_tournament_graph = None - self.connected_components: List[Tuple[OrderFairnessGraphNodeId, HashValue]] = None + self.connected_components: List[Tuple[GraphNodeId, HashValue]] = None + self.node_to_scc_hash: Dict[GraphNodeId, HashValue] = None + self.hash_to_scc_nodes: Dict[HashValue, List[GraphNodeId]] = None def reset_graph_properties(self): self.is_tournament_graph = None self.connected_components = None + self.node_to_scc_hash = None + self.hash_to_scc_nodes = None - def add_node(self, nodeId: OrderFairnessGraphNodeId): + def add_node(self, nodeId: GraphNodeId): if nodeId in self.nodes: return self.nodes.add(nodeId) @@ -35,14 +39,14 @@ def add_node(self, nodeId: OrderFairnessGraphNodeId): # reset the graph properties self.reset_graph_properties() - def add_directed_edge(self, nodeId1: OrderFairnessGraphNodeId, nodeId2: OrderFairnessGraphNodeId): + def add_directed_edge(self, nodeId1: GraphNodeId, nodeId2: GraphNodeId): if nodeId1 not in self.nodes: self.add_node(nodeId1) self.edges[nodeId1].add(nodeId2) # reset the graph properties self.reset_graph_properties() - def has_edge(self, nodeId1: OrderFairnessGraphNodeId, nodeId2: OrderFairnessGraphNodeId) -> bool: + def has_edge(self, nodeId1: GraphNodeId, nodeId2: GraphNodeId) -> bool: return nodeId1 in self.edges and nodeId2 in self.edges[nodeId1] def assert_is_tournament_graph(self): @@ -59,18 +63,21 @@ def is_tournament_graph(self) -> bool: return False return True - def find_strongly_connected_components(self) -> List[List[OrderFairnessGraphNodeId]]: + """ + Find the strongly connected components of the graph and return them in the topological order of the SCCs + """ + def find_strongly_connected_components(self) -> List[Tuple[List[GraphNodeId], HashValue]]: assert self.connected_components is None - self.connected_components: List[Tuple[OrderFairnessGraphNodeId, HashValue]] = [] + self.connected_components: List[Tuple[GraphNodeId, HashValue]] = [] # tarjan's algorithm index = 0 indices = {} lowlinks = {} stack = [] - components: List[List[OrderFairnessGraphNodeId]] = [] + components: List[Tuple[List[GraphNodeId], HashValue]] = [] - def strongconnect(nodeId: OrderFairnessGraphNodeId) -> None: + def strongconnect(nodeId: GraphNodeId) -> None: assert nodeId not in indices nonlocal index, indices, lowlinks, stack, on_stack, components # Set the depth index for node @@ -103,17 +110,34 @@ def strongconnect(nodeId: OrderFairnessGraphNodeId) -> None: if nodeId not in indices: strongconnect(nodeId) - # Store components with their hash values - self.connected_components = components + self.node_to_scc_hash: Dict[GraphNodeId, HashValue] = {nodeId: component[1] for component in enumerate(components) for nodeId in component[0]} + self.hash_to_scc_nodes: Dict[HashValue, List[GraphNodeId]] = {component[1]: component[0] for component in components} - return components + self.connected_components: List[Tuple[List[GraphNodeId], HashValue]] = [] + + visited_sccs = set() + def sort_sccs(comp: Tuple[List[GraphNodeId], HashValue]): + assert comp[1] not in visited_sccs + self.connected_components.append(comp) + for u in comp[0]: + if u in self.edges: + for v in self.edges[u]: + if not self.node_to_scc_hash[v] in visited_sccs: + sort_sccs(self.hash_to_scc_nodes[self.node_to_scc_hash[v]]) + visited_sccs.add(comp[1]) + + for component in components: + if component[1] not in visited_sccs: + sort_sccs(component) + + return self.connected_components - def assert_is_strongly_connected_component(self, hamiltonian_path: List[OrderFairnessGraphNodeId]): + def assert_is_strongly_connected_component(self, hamiltonian_path: List[GraphNodeId]): assert self.connected_components is not None hash_value = hash(tuple(sorted(hamiltonian_path))) assert hash_value in [connected_component[1] for connected_component in self.connected_components] - def assert_is_hamiltonian_path(self, hamiltonian_path: List[OrderFairnessGraphNodeId]): + def assert_is_hamiltonian_path(self, hamiltonian_path: List[GraphNodeId]): for i in range(1, len(hamiltonian_path)): assert self.has_edge(hamiltonian_path[i - 1], hamiltonian_path[i]) @@ -124,12 +148,12 @@ def __init__(self): def is_tournament_graph(self) -> bool: return super().is_tournament_graph() - def find_hamiltonian_path(self, scc: List[OrderFairnessGraphNodeId]) -> List[OrderFairnessGraphNodeId]: + def find_hamiltonian_path(self, scc: List[GraphNodeId]) -> List[GraphNodeId]: # must be a strongly connected component self.assert_is_strongly_connected_component(scc) # Complexity: O(len(scc)^2) - hamiltonian_path: List[OrderFairnessGraphNodeId] = [scc[0]] + hamiltonian_path: List[GraphNodeId] = [scc[0]] for k in range(1, len(scc)): # find first i < k | has_edge(hamiltonian_path[i], scc[k]) & has_edge(scc[k], hamiltonian_path[i+1]) i = 0 @@ -140,13 +164,13 @@ def find_hamiltonian_path(self, scc: List[OrderFairnessGraphNodeId]) -> List[Ord return hamiltonian_path - def find_hamiltonian_cycle(self, hamiltonian_path_of_scc: List[OrderFairnessGraphNodeId]) -> List[OrderFairnessGraphNodeId]: + def find_hamiltonian_cycle(self, hamiltonian_path_of_scc: List[GraphNodeId]) -> List[GraphNodeId]: self.assert_is_tournament_graph() self.assert_is_strongly_connected_component(hamiltonian_path_of_scc) self.assert_is_hamiltonian_path(hamiltonian_path_of_scc) # Complexity: O(len(hamiltonian_path)^2) - accumulated_hamiltonian_cycle: List[OrderFairnessGraphNodeId] = [hamiltonian_path_of_scc[0]] + accumulated_hamiltonian_cycle: List[GraphNodeId] = [hamiltonian_path_of_scc[0]] j = 1 while j < len(hamiltonian_path_of_scc): p = j + 1 diff --git a/dagpool/schemas.py b/dagpool/schemas.py index bda9f1c..e9fbf99 100644 --- a/dagpool/schemas.py +++ b/dagpool/schemas.py @@ -5,5 +5,6 @@ NodeLabel = NewType("NodeLabel", str) Pubkey = NewType("Pubkey", str) Signature = NewType("Signature", str) -OrderFairnessGraphNodeId = NewType("OrderFairnessGraphNodeId", str) -HashValue = NewType("HashValue", str) \ No newline at end of file +GraphNodeId = NewType("GraphNodeId", str) +HashValue = NewType("HashValue", str) +BatchId = NewType("BatchId", str) \ No newline at end of file From a3f6330904122b237b4dda61e527e58f09b79cfd Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Tue, 25 Feb 2025 18:35:14 +0700 Subject: [PATCH 05/10] heaviest observed subtree fork-choice rule --- dagpool/consensus_client.py | 125 ++++++++++++++++++++++++++++++------ 1 file changed, 104 insertions(+), 21 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index ce6e160..898c918 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -36,6 +36,9 @@ def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: res = new_res return res[0] +class Vote: + head_batch_hash: HashValue + class BatchProposal: batch_hash: HashValue prev_batch_hash: HashValue @@ -69,10 +72,11 @@ def clone(self): class NodeMetadata: batch_proposal: BatchProposal # can be None if the node is not a head node (head node means the witness of the leader in its selected round) + vote: Vote # BFT vote for the head batch of the network, which determines the order of transactions in the network creator_signature: Signature def clone(self): - return NodeMetadata(self.batch_proposal.clone() if self.batch_proposal else None, self.creator_signature) + return NodeMetadata(self.batch_proposal.clone() if self.batch_proposal else None, Vote(self.vote.head_batch_hash), self.creator_signature) class Node: def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs_list: list[TransactionId], self_parent_hash: NodeId, cross_parent_hash: NodeId, metadata: NodeMetadata): @@ -86,9 +90,12 @@ def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs self.node_hash = self.hash_node(peer_id, round, is_witness, self_parent_hash, cross_parent_hash, newly_seen_txs_list) ## fork-related data self.equivocated_peers: Set[PeerId] = None # set of peers that current node believes are equivocated, and this node won't SEE (i.e. UNSEE) all nodes created by them. Note that this doesn't affect STRONGLY SEEING property of this node. + self.non_equivocated_peers: Set[PeerId] = None # set of peers that current node believes are not equivocated, and this node will SEE all nodes created by them. self.seen_nodes: Set[NodeId] = None # set of nodes that current node sees + self.seen_votes_by_peers: Dict[PeerId, HashValue] = None # latest votes from each peer that current node can see self.metadata: NodeMetadata = metadata self.update_signature() + self.has_filled_node_data = False def clone(self): return Node(self.peer_id, self.round, self.is_witness, [txs for txs in self.newly_seen_txs_list], self.self_parent_hash, self.cross_parent_hash, self.metadata.clone()) @@ -96,9 +103,9 @@ def clone(self): def label(self) -> NodeLabel: return f"{self.peer_id}:{self.node_hash}" - def has_computed_seen_nodes(self) -> bool: - return self.seen_nodes is not None and self.equivocated_peers is not None - + def has_filled_node_data(self) -> bool: + return self.has_filled_node_data + def compute_signature(self, _creator_private_key: PrivateKey) -> Signature: # TODO: use real private key if self.metadata.batch_proposal: @@ -136,8 +143,12 @@ def validate_node_data(self, creator_pubkey: Pubkey) -> bool: # TODO: validate that all the node data is well-formed (use schema validator) - if self.metadata.batch_proposal and not self.metadata.batch_proposal.verify_batch_proposal_is_well_formed(self.round, self.metadata.batch_proposal.prev_beacon_randomness): - return False + if self.metadata.batch_proposal: + if not self.metadata.batch_proposal.verify_batch_proposal_is_well_formed(self.round, self.metadata.batch_proposal.prev_beacon_randomness): + return False + # it must vote for its own batch proposal + if self.metadata.vote.head_batch_hash != self.metadata.batch_proposal.batch_hash: + return False if not self.verify_signature(creator_pubkey): return False @@ -151,6 +162,9 @@ def is_genesis(self): def is_non_genesis(self): return not self.is_genesis() + def get_seen_valid_peers(self) -> Set[PeerId]: + eq + def __str__(self): return f"{"GENESIS " if self.is_genesis() else ""}Node(node_hash={self.node_hash}, peer_id={self.peer_id}, round={self.round}, is_witness={self.is_witness}, self_parent_hash={self.self_parent_hash}, cross_parent_hash={self.cross_parent_hash}, newly_seen_txs_list={self.newly_seen_txs_list})" # , equivocated_peers={self.equivocated_peers}, seen_nodes={self.seen_nodes})" @@ -437,7 +451,10 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.tx_receive_times = {} # Track when transactions were received self.pending_txs: List[Tuple[TransactionId, float]] = [] self.local_graph: Dict[NodeId, Set[NodeId]] = {} # node_hash to set of node_hashes (parent, child) - + # batch-related data + self.observed_valid_batches: List[Node] = [] + self.observed_votes_from_peers: Dict[PeerId, List[Node]] = {} + def my_nodes(self) -> List[Node]: return self.seen_valid_nodes[self.peer_id] @@ -657,17 +674,55 @@ def create_genesis_node(self): assert self.verify_node_and_add_to_local_view(node) == True return node + def get_heaviest_batch_amongst_strict_ancestors(self, node: Node) -> HashValue: + """ + Get the heaviest batch in the strict ancestors of the given node using a fork-choice rule based on the Heaviest Observed Subtree (HOS) selection rule + """ + # 1. construct a tree of batch proposals + tree: Dict[HashValue, Set[HashValue]] = {} + for node in self.observed_valid_batches: # TODO: only use the batches from the last finalized branch, so we can filter out the batches from equivocated peers + tree[node.batch_hash] = set() + parent_batch_hash = node.metadata.batch_proposal.prev_batch_hash + if parent_batch_hash not in tree: + tree[parent_batch_hash] = set() + tree[parent_batch_hash].add(node.batch_hash) + # 2. collect the latest votes of all peers that the node can see + weight_of_branch: Dict[HashValue, int] = {} + + non_equivocated_peers = node.non_equivocated_peers + for peer_id in node.seen_votes_by_peers: + voted_batch_hash = node.seen_votes_by_peers[peer_id] + assert voted_batch_hash in tree + weight_of_branch[voted_batch_hash] += 1 + # 3. accumulate weights upwards from the leaves to the root + def traverse(batch_hash: HashValue): + weight_of_branch[batch_hash] = 0 + for child_batch_hash in tree[batch_hash]: + traverse(child_batch_hash) + weight_of_branch[batch_hash] += weight_of_branch[child_batch_hash] + root = "" + traverse(root) + # 4. traverse downwards from the root to leaves using the HOS rule + heaviest_batch = "" + while heaviest_batch in tree: + heaviest_batch = max(tree[heaviest_batch], key=lambda x: weight_of_branch[x]) + return heaviest_batch + + def verify_node_is_head_node(self, node: Node) -> bool: + """ + Verify that the given node is a head node + """ + pass + def construct_batch_proposal_if_needed(self, node: Node): """ Construct a batch proposal for the given node """ - if node.peer_id != self.peer_id or not node.is_witness: + if node.peer_id != self.peer_id or not node.is_witness or not self.verify_node_is_head_node(node): return # fork-choice rule: select the previous batch using the Heaviest Observed Subtree (HOS) selection rule - heaviest_batch = 0 - while self.batch_has_children(heaviest_batch): - heaviest_batch = self.get_heaviest_child(heaviest_batch) + heaviest_batch = self.get_heaviest_batch_amongst_strict_ancestors(node) prev_batch_beacon_randomness: HashValue = self.get_beacon_randomness_of_batch(heaviest_batch) @@ -695,6 +750,27 @@ def construct_batch_proposal_if_needed(self, node: Node): # TODO: implement bootstrap node (first node refers to parents in a checkpoint after a node rejoins the network) + def construct_vote_for_node(self, node: Node): + """ + Construct a vote for the given node + """ + if self.verify_node_is_head_node(node): + node.metadata.vote = Vote(head_batch_hash=node.metadata.batch_proposal.batch_hash) + else: + node.metadata.vote = Vote(head_batch_hash=self.get_heaviest_batch_amongst_strict_ancestors(node)) + + def fill_node_data(self, node: Node) -> Node: + """ + Fill the data for the given node + """ + if not node.has_filled_node_data(): + self.compute_seen_nodes_of_new_node(node) + self.construct_batch_proposal_if_needed(node) + self.construct_vote_for_node(node) + node.has_filled_node_data = True + + return node + def has_seen_valid_node(self, node: Node) -> bool: if node.peer_id == self.peer_id: return node in self.my_nodes() @@ -728,6 +804,8 @@ def compute_seen_nodes_of_new_node(self, node: Node): assert node.seen_nodes is None and node.equivocated_peers is None node.seen_nodes = set() equivocated_peers = set() + non_equivocated_peers = set() + seen_votes_by_peers: Dict[PeerId, HashValue] = {} ancestry_of_node = self.get_ancestry(node) self_parent_set: Set[NodeId] = set() @@ -738,14 +816,22 @@ def compute_seen_nodes_of_new_node(self, node: Node): equivocated_peers.add(cur_node.peer_id) else: self_parent_set.add(cur_node.self_parent_hash) + non_equivocated_peers.add(cur_node.peer_id) for cur_node in ancestry_of_node: if cur_node.peer_id in equivocated_peers: continue + + if cur_node.node_hash not in self_parent_set: # there should be at most 1 such node for each peer + seen_votes_by_peers[cur_node.peer_id] = cur_node.metadata.vote.head_batch_hash node.seen_nodes.add(cur_node.node_hash) + + non_equivocated_peers = non_equivocated_peers.difference(equivocated_peers) + # assign the computed values to the node node.equivocated_peers = equivocated_peers - + node.non_equivocated_peers = non_equivocated_peers + node.seen_votes_by_peers = seen_votes_by_peers return def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: @@ -850,9 +936,6 @@ def verify_node(self, node: Node = None) -> bool: # an honest peer must not accept a node which itself or its parents are from equivocated peers try: - if not node.has_computed_seen_nodes(): # in case a node is called verify_node_and_add_to_local_view() multiple times - self.compute_seen_nodes_of_new_node(node) - if node.is_genesis(): return True @@ -875,6 +958,8 @@ def verify_node(self, node: Node = None) -> bool: # TODO: verify the batch proposal if it exists # 1. check if the current node is a head witness # 2. verify the batch proposal is constructed correctly + + # TODO: verify the vote except Exception as e: # adversary sending invalid nodes @@ -925,7 +1010,7 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: if not self.is_adversary: assert self_parent.round == self.current_round - new_node = Node( + new_node = self.fill_node_data(Node( peer_id=self.peer_id, round=self_parent.round, is_witness=False, @@ -933,11 +1018,10 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: self_parent_hash=self_parent.node_hash, cross_parent_hash=cross_parent.node_hash, metadata=NodeMetadata() - ) - self.construct_batch_proposal_if_needed(new_node) + )) if not self.verify_node(new_node): - new_node = Node( + new_node = self.fill_node_data(Node( peer_id=self.peer_id, round=self_parent.round + 1, is_witness=True, @@ -945,8 +1029,7 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: self_parent_hash=self_parent.node_hash, cross_parent_hash=cross_parent.node_hash, metadata=NodeMetadata() - ) - self.construct_batch_proposal_if_needed(new_node) + )) if not self.verify_node(new_node): # this new_node is invalid because either its parents are from equivocated peers From d5042fdef85e3c8907e7eec49a48826fd91b26fe Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Wed, 26 Feb 2025 17:46:42 +0700 Subject: [PATCH 06/10] update node verification --- dagpool/consensus_client.py | 268 ++++++++++++++++++++++++------------ dagpool/graph.py | 6 +- 2 files changed, 182 insertions(+), 92 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index 898c918..c0450c8 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -19,7 +19,7 @@ # the first BEACON_PACE rounds are derived directly from the list of peers # BEACON_PACE should be chosen large enough to make sure peers have enough time to realize that they are the leader of the next rounds BEACON_FIELD_PRIME = 28948022309329048855892746252171976963363056481941560715954676764349967630337 # equal to Pallas base field prime - +GENESIS_BATCH = "" class Utils: @staticmethod def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: @@ -39,6 +39,9 @@ def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: class Vote: head_batch_hash: HashValue + def __init__(self, head_batch_hash: HashValue): + self.head_batch_hash = head_batch_hash + class BatchProposal: batch_hash: HashValue prev_batch_hash: HashValue @@ -75,6 +78,11 @@ class NodeMetadata: vote: Vote # BFT vote for the head batch of the network, which determines the order of transactions in the network creator_signature: Signature + def __init__(self, batch_proposal: BatchProposal = None, vote: Vote = Vote(GENESIS_BATCH), creator_signature: Signature = ""): + self.batch_proposal = batch_proposal + self.vote = vote + self.creator_signature = creator_signature + def clone(self): return NodeMetadata(self.batch_proposal.clone() if self.batch_proposal else None, Vote(self.vote.head_batch_hash), self.creator_signature) @@ -103,17 +111,17 @@ def clone(self): def label(self) -> NodeLabel: return f"{self.peer_id}:{self.node_hash}" - def has_filled_node_data(self) -> bool: - return self.has_filled_node_data - - def compute_signature(self, _creator_private_key: PrivateKey) -> Signature: + def is_head_node(self) -> bool: + return self.is_witness and self.metadata.batch_proposal is not None + + def compute_signature(self, _creator_private_key: str = None) -> Signature: # TODO: use real private key if self.metadata.batch_proposal: # signature = hash(batch_proposal.batch_hash, node_hash) - return hashlib.sha256([self.metadata.batch_proposal.batch_hash, self.node_hash].encode()).hexdigest()[:8] + return hashlib.sha256('#'.join([self.metadata.batch_proposal.batch_hash, self.node_hash]).encode()).hexdigest()[:8] else: # signature = hash(node_hash) - return hashlib.sha256([self.node_hash].encode()).hexdigest()[:8] + return hashlib.sha256(self.node_hash.encode()).hexdigest()[:8] def update_signature(self): self.metadata.creator_signature = self.compute_signature() @@ -208,15 +216,26 @@ def __init__(self, latency_ms_range=(50, 200), packet_loss_prob=0.1, random_inst # global mempool: simulate a global mempool of all transactions from all clients self.global_mempool = set() - def get_first_leaders(self) -> List[PeerId]: + # TODO: use secure cryptographic randomness source + def get_first_beacon_randomness(self) -> HashValue: all_peers = [p.peer_id for p in self.peers] value_bytes = hashlib.sha256(",".join(all_peers).encode()).digest() value = int.from_bytes(value_bytes, "big") + return value + + def get_first_leaders(self) -> List[PeerId]: # for round 1 -> BEACON_PACE + all_peers = [p.peer_id for p in self.peers] + first_beacon_randomness = self.get_first_beacon_randomness() # choose the first BEACON_PACE leaders for the first BEACON_PACE rounds, using the first BEACON_FIELD_PRIME as modulo somehow - beacon_random = random.Random(value) + beacon_random = random.Random(first_beacon_randomness) leaders = beacon_random.sample(all_peers, BEACON_PACE) return leaders + def get_leader_after_BEACON_PACE_rounds(self, beacon_randomness: HashValue) -> PeerId: + beacon_random = random.Random(beacon_randomness) + all_peers = [p.peer_id for p in self.peers] + return beacon_random.choice(all_peers) + def get_peer_pubkey(self, peer_id: PeerId) -> Pubkey: # TODO: use real pubkey return "DUMMY_PUBKEY" @@ -442,8 +461,6 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.accumulated_txs = set() # Track all transactions seen by this peer self.network: NetworkSimulator = network ## adversary-related data - self.equivocated_peers: Set[PeerId] = set() # set of peers that current peer believes they actively create equivocated nodes - self.equivocated_nodes: Set[NodeId] = set() # set of nodes that current peer believes are equivocated self.equivocation_prob = 0.2 if is_adversary else 0.0 self.neighbors: list[PeerId] = [] # Track neighboring peers # if each peer connects to log(N) neighbors, a transaction would takes O(log(N)/log(log(N))) gossip hops to reach the whole network @@ -452,8 +469,7 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.pending_txs: List[Tuple[TransactionId, float]] = [] self.local_graph: Dict[NodeId, Set[NodeId]] = {} # node_hash to set of node_hashes (parent, child) # batch-related data - self.observed_valid_batches: List[Node] = [] - self.observed_votes_from_peers: Dict[PeerId, List[Node]] = {} + self.observed_valid_batches: Dict[HashValue, Node] = {} # map from the hash value of the batch to the node that proposes it def my_nodes(self) -> List[Node]: return self.seen_valid_nodes[self.peer_id] @@ -679,15 +695,16 @@ def get_heaviest_batch_amongst_strict_ancestors(self, node: Node) -> HashValue: Get the heaviest batch in the strict ancestors of the given node using a fork-choice rule based on the Heaviest Observed Subtree (HOS) selection rule """ # 1. construct a tree of batch proposals - tree: Dict[HashValue, Set[HashValue]] = {} - for node in self.observed_valid_batches: # TODO: only use the batches from the last finalized branch, so we can filter out the batches from equivocated peers + tree: Dict[HashValue, Set[HashValue]] = {GENESIS_BATCH: set()} + + for node in self.observed_valid_batches.values(): # TODO: only use the batches from the last finalized branch, so we can filter out the batches from equivocated peers tree[node.batch_hash] = set() parent_batch_hash = node.metadata.batch_proposal.prev_batch_hash if parent_batch_hash not in tree: tree[parent_batch_hash] = set() tree[parent_batch_hash].add(node.batch_hash) # 2. collect the latest votes of all peers that the node can see - weight_of_branch: Dict[HashValue, int] = {} + weight_of_branch: Dict[HashValue, int] = {GENESIS_BATCH: 0} non_equivocated_peers = node.non_equivocated_peers for peer_id in node.seen_votes_by_peers: @@ -700,29 +717,60 @@ def traverse(batch_hash: HashValue): for child_batch_hash in tree[batch_hash]: traverse(child_batch_hash) weight_of_branch[batch_hash] += weight_of_branch[child_batch_hash] - root = "" + root = GENESIS_BATCH traverse(root) # 4. traverse downwards from the root to leaves using the HOS rule - heaviest_batch = "" + heaviest_batch = GENESIS_BATCH while heaviest_batch in tree: + if len(tree[heaviest_batch]) == 0: + break heaviest_batch = max(tree[heaviest_batch], key=lambda x: weight_of_branch[x]) return heaviest_batch + + def get_beacon_randomness_of_batch(self, batch_hash: HashValue) -> HashValue: + """ + Get the beacon randomness of the given batch + """ + if batch_hash == GENESIS_BATCH: + return self.network.get_first_beacon_randomness() + else: + return self.observed_valid_batches[batch_hash].metadata.batch_proposal.next_beacon_randomness def verify_node_is_head_node(self, node: Node) -> bool: """ - Verify that the given node is a head node + Verify that the given node is a head node, given that the vote is valid """ - pass + return False + # TODO: turn this on + + if not node.is_witness or node.is_genesis(): # this function might be called before the batch proposal is constructed so we only needs to check whether the node is a witness + return False - def construct_batch_proposal_if_needed(self, node: Node): + # find the batch that contains the randomness that determines the leader of the current round + beacon_batch:HashValue = node.metadata.vote.head_batch_hash + steps_back = BEACON_PACE - 1 + (1 if node.metadata.batch_proposal is not None else 0) + + for i in range(steps_back): + if beacon_batch == "": + break + # find parent batch of the current beacon batch + beacon_batch: HashValue = self.observed_valid_batches[beacon_batch].metadata.batch_proposal.prev_batch_hash + + if beacon_batch == "": + # genesis beacon randomness + first_leaders: List[PeerId] = self.network.get_first_leaders() + is_head_node = len(first_leaders) == BEACON_PACE and node.round > 0 and node.round <= BEACON_PACE and node.peer_id == first_leaders[node.round - 1] + return is_head_node + else: + # non-genesis beacon randomness + return node.peer_id == self.network.get_leader_after_BEACON_PACE_rounds(self.get_beacon_randomness_of_batch(beacon_batch)) + + def compute_batch_proposal(self, node: Node) -> BatchProposal: """ - Construct a batch proposal for the given node + Compute a batch proposal for the given node """ - if node.peer_id != self.peer_id or not node.is_witness or not self.verify_node_is_head_node(node): - return - # fork-choice rule: select the previous batch using the Heaviest Observed Subtree (HOS) selection rule - heaviest_batch = self.get_heaviest_batch_amongst_strict_ancestors(node) + heaviest_batch: HashValue = self.get_heaviest_batch_amongst_strict_ancestors(node) prev_batch_beacon_randomness: HashValue = self.get_beacon_randomness_of_batch(heaviest_batch) @@ -742,34 +790,72 @@ def construct_batch_proposal_if_needed(self, node: Node): final_fair_ordering=final_fair_ordering, prev_beacon_randomness=prev_batch_beacon_randomness ) - assert batch_proposal.verify_batch_proposal_is_well_formed(node.round, prev_batch_beacon_randomness) + return batch_proposal + + def construct_batch_proposal_if_needed(self, node: Node): + """ + Construct a batch proposal for the given node. This must be called after the vote is constructed for the node. + """ + if node.peer_id != self.peer_id or not self.verify_node_is_head_node(node): + return + + batch_proposal: BatchProposal = self.compute_batch_proposal(node) + node.metadata = NodeMetadata(batch_proposal=batch_proposal) - node.update_signature() # TODO: implement bootstrap node (first node refers to parents in a checkpoint after a node rejoins the network) + def verify_batch_proposal_is_valid(self, node: Node) -> bool: + """ + Verify that the given batch proposal is valid + """ + assert node.metadata.batch_proposal is not None + + correct_batch_proposal: BatchProposal = self.compute_batch_proposal(node) + + return node.metadata.batch_proposal.batch_hash == correct_batch_proposal.batch_hash + + def compute_vote_for_node(self, node: Node) -> Vote: + """ + Compute a vote for the given node + """ + is_head_node = self.verify_node_is_head_node(node) + if is_head_node: + return Vote(head_batch_hash=node.metadata.batch_proposal.batch_hash) # vote for its own head batch + else: + return Vote(head_batch_hash=self.get_heaviest_batch_amongst_strict_ancestors(node)) # vote for the heaviest batch amongst its strict ancestors + def construct_vote_for_node(self, node: Node): """ Construct a vote for the given node """ - if self.verify_node_is_head_node(node): - node.metadata.vote = Vote(head_batch_hash=node.metadata.batch_proposal.batch_hash) - else: - node.metadata.vote = Vote(head_batch_hash=self.get_heaviest_batch_amongst_strict_ancestors(node)) + vote = self.compute_vote_for_node(node) + node.metadata.vote = vote - def fill_node_data(self, node: Node) -> Node: + def verify_vote_is_valid(self, node: Node) -> bool: """ - Fill the data for the given node + Verify that the given vote is valid """ - if not node.has_filled_node_data(): - self.compute_seen_nodes_of_new_node(node) - self.construct_batch_proposal_if_needed(node) - self.construct_vote_for_node(node) - node.has_filled_node_data = True + return node.metadata.vote.head_batch_hash == self.compute_vote_for_node(node).head_batch_hash - return node + def fill_node_data(self, node: Node) -> bool: + """ + Fill the data for the given node + """ + if not node.has_filled_node_data: + try: + self.compute_seen_nodes_of_new_node(node) + self.construct_batch_proposal_if_needed(node) + self.construct_vote_for_node(node) # this will throw errors if predecessors of the node are not available + node.has_filled_node_data = True + node.update_signature() + except Exception as e: + print(f"error = {e}") + return False + + return True def has_seen_valid_node(self, node: Node) -> bool: if node.peer_id == self.peer_id: @@ -778,9 +864,6 @@ def has_seen_valid_node(self, node: Node) -> bool: return (node.peer_id in self.seen_valid_nodes) and (node.node_hash in [node.node_hash for node in self.seen_valid_nodes[node.peer_id]]) def get_node_by_hash(self, node_hash: NodeId) -> Optional[Node]: - for node in self.my_nodes(): - if node.node_hash == node_hash: - return node ## find in seen_valid_nodes of other peers for peer_id in self.seen_valid_nodes: for node in self.seen_valid_nodes[peer_id]: @@ -807,25 +890,26 @@ def compute_seen_nodes_of_new_node(self, node: Node): non_equivocated_peers = set() seen_votes_by_peers: Dict[PeerId, HashValue] = {} - ancestry_of_node = self.get_ancestry(node) - self_parent_set: Set[NodeId] = set() + if not node.is_genesis(): + ancestry_of_node = self.get_ancestry(node) + self_parent_set: Set[NodeId] = set() - for cur_node in ancestry_of_node: - if not cur_node.is_genesis(): - if cur_node.self_parent_hash in self_parent_set: - equivocated_peers.add(cur_node.peer_id) - else: - self_parent_set.add(cur_node.self_parent_hash) - non_equivocated_peers.add(cur_node.peer_id) + for cur_node in ancestry_of_node: + if not cur_node.is_genesis(): + if cur_node.self_parent_hash in self_parent_set: + equivocated_peers.add(cur_node.peer_id) + else: + self_parent_set.add(cur_node.self_parent_hash) + non_equivocated_peers.add(cur_node.peer_id) - for cur_node in ancestry_of_node: - if cur_node.peer_id in equivocated_peers: - continue + for cur_node in ancestry_of_node: + if cur_node.peer_id in equivocated_peers: + continue - if cur_node.node_hash not in self_parent_set: # there should be at most 1 such node for each peer - seen_votes_by_peers[cur_node.peer_id] = cur_node.metadata.vote.head_batch_hash - - node.seen_nodes.add(cur_node.node_hash) + if cur_node.node_hash not in self_parent_set: # there should be at most 1 such node for each peer + seen_votes_by_peers[cur_node.peer_id] = cur_node.metadata.vote.head_batch_hash + + node.seen_nodes.add(cur_node.node_hash) non_equivocated_peers = non_equivocated_peers.difference(equivocated_peers) # assign the computed values to the node @@ -837,6 +921,8 @@ def compute_seen_nodes_of_new_node(self, node: Node): def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: """Verify a node and its transactions, and add it to the local view""" + self.fill_node_data(node) + if not self.verify_node(node): return False @@ -847,7 +933,9 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: should_add_node = node.node_hash not in [node.node_hash for node in self.seen_valid_nodes[node.peer_id]] if should_add_node: self.seen_valid_nodes[node.peer_id].append(node) - + if node.metadata.batch_proposal is not None: + self.observed_valid_batches[node.metadata.batch_proposal.batch_hash] = node + for predecessor in self.get_predecessors(node): if predecessor.node_hash not in self.local_graph: self.local_graph[predecessor.node_hash] = set() @@ -943,6 +1031,7 @@ def verify_node(self, node: Node = None) -> bool: return False predecessors = self.get_predecessors(node) + predecessors = self.get_predecessors(node) # must have valid parents if len(predecessors) < 2: @@ -955,12 +1044,17 @@ def verify_node(self, node: Node = None) -> bool: if not is_allowed_to_bypass: return False - # TODO: verify the batch proposal if it exists - # 1. check if the current node is a head witness - # 2. verify the batch proposal is constructed correctly + # verify the batch proposal if it exists + if self.verify_node_is_head_node(node): + if not self.verify_batch_proposal_is_valid(node): + return False + else: + assert node.metadata.batch_proposal is None + + # verify the vote + if not self.verify_vote_is_valid(node): + return False - # TODO: verify the vote - except Exception as e: # adversary sending invalid nodes print("error = ", e) @@ -1010,34 +1104,28 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: if not self.is_adversary: assert self_parent.round == self.current_round - new_node = self.fill_node_data(Node( - peer_id=self.peer_id, - round=self_parent.round, - is_witness=False, - newly_seen_txs_list=newly_seen_txs_list, - self_parent_hash=self_parent.node_hash, - cross_parent_hash=cross_parent.node_hash, - metadata=NodeMetadata() - )) + for round_num in range(self_parent.round, self_parent.round + 2): + print(f"Peer {self.peer_id} computing new node for round {round_num}") + new_node = Node( + peer_id=self.peer_id, + round=round_num, + is_witness=False if round_num == self_parent.round else True, + newly_seen_txs_list=newly_seen_txs_list, + self_parent_hash=self_parent.node_hash, + cross_parent_hash=cross_parent.node_hash, + metadata=NodeMetadata() + ) + assert self.fill_node_data(new_node) - if not self.verify_node(new_node): - new_node = self.fill_node_data(Node( - peer_id=self.peer_id, - round=self_parent.round + 1, - is_witness=True, - newly_seen_txs_list=newly_seen_txs_list, - self_parent_hash=self_parent.node_hash, - cross_parent_hash=cross_parent.node_hash, - metadata=NodeMetadata() - )) - - if not self.verify_node(new_node): - # this new_node is invalid because either its parents are from equivocated peers - return None + # found a valid new node + is_node_valid = self.verify_node(new_node) - print(f"Peer {self.peer_id} COMPUTED NEW NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") + if is_node_valid: + print(f"Peer {self.peer_id} COMPUTED NEW {'HEAD' if new_node.is_head_node() else "NON-HEAD"} NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") + return new_node - return new_node + # the created new nodes is invalid because either its parents are from equivocated peers + return None async def gossip_push(self): """Sync with another peer through gossip, potentially sending different views""" diff --git a/dagpool/graph.py b/dagpool/graph.py index dfa426b..14a72c1 100644 --- a/dagpool/graph.py +++ b/dagpool/graph.py @@ -79,7 +79,6 @@ def find_strongly_connected_components(self) -> List[Tuple[List[GraphNodeId], Ha def strongconnect(nodeId: GraphNodeId) -> None: assert nodeId not in indices - nonlocal index, indices, lowlinks, stack, on_stack, components # Set the depth index for node indices[nodeId] = index lowlinks[nodeId] = index @@ -177,7 +176,6 @@ def find_hamiltonian_cycle(self, hamiltonian_path_of_scc: List[GraphNodeId]) -> r = -1 found_backward_edge = False while not found_backward_edge and p < len(hamiltonian_path_of_scc): - nonlocal r # check if there is a backward edge from hamiltonian_path[p] to accumulated_hamiltonian_cycle[r] r = 0 # the out-going node of the backward edge while r < len(accumulated_hamiltonian_cycle) and not self.has_edge(hamiltonian_path_of_scc[p], accumulated_hamiltonian_cycle[r]): @@ -186,6 +184,10 @@ def find_hamiltonian_cycle(self, hamiltonian_path_of_scc: List[GraphNodeId]) -> if r < len(accumulated_hamiltonian_cycle): found_backward_edge = True + if not found_backward_edge: + # If no backward edge is found, the graph is not a tournament or not strongly connected + raise ValueError("No Hamiltonian cycle exists for the given path.") + # reorder the accumulated_hamiltonian_cycle: accumulated_hamiltonian_cycle[0 -> r - 1] -> hamiltonian_path_of_scc[j+1 -> p] -> accumulated_hamiltonian_cycle[r -> ...] -> hamiltonian_path_of_scc[0] j = p accumulated_hamiltonian_cycle = accumulated_hamiltonian_cycle[0:r] + hamiltonian_path_of_scc[j+1:p+1] + accumulated_hamiltonian_cycle[r:] From ac57b535e46395dca6c0eb74385bdafb117194c8 Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Wed, 5 Mar 2025 13:35:03 +0700 Subject: [PATCH 07/10] optimize witness checking --- dagpool/consensus_client.py | 423 +++++++++++++++++++++++++----------- dagpool/graph.py | 4 + 2 files changed, 304 insertions(+), 123 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index c0450c8..3b4754c 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -1,4 +1,5 @@ import asyncio +import sys import random from dataclasses import dataclass from enum import Enum @@ -20,10 +21,13 @@ # BEACON_PACE should be chosen large enough to make sure peers have enough time to realize that they are the leader of the next rounds BEACON_FIELD_PRIME = 28948022309329048855892746252171976963363056481941560715954676764349967630337 # equal to Pallas base field prime GENESIS_BATCH = "" +EMPTY_NODE_HASH = "" +ORDER_FAIRNESS_THRESHOLD = 4/7 + class Utils: @staticmethod def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: - assert len(txs) > 0 + len(txs) > 0 or (_ for _ in ()).throw(ValueError("txs is empty")) # do the merkle tree construction using a while loop res = [tx for tx in txs] while len(res) > 1: @@ -96,11 +100,14 @@ def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs # TODO: migrate to Sparse Merkle Tree + proof of SMT transition self.newly_seen_txs_list = newly_seen_txs_list self.node_hash = self.hash_node(peer_id, round, is_witness, self_parent_hash, cross_parent_hash, newly_seen_txs_list) - ## fork-related data + ## fork-related data, must all be None until computed self.equivocated_peers: Set[PeerId] = None # set of peers that current node believes are equivocated, and this node won't SEE (i.e. UNSEE) all nodes created by them. Note that this doesn't affect STRONGLY SEEING property of this node. self.non_equivocated_peers: Set[PeerId] = None # set of peers that current node believes are not equivocated, and this node will SEE all nodes created by them. self.seen_nodes: Set[NodeId] = None # set of nodes that current node sees + self.latest_seen_node_by_peers: Dict[PeerId, NodeId] = None # latest seen nodes from each peer that current node can see + self.latest_seen_witness_by_peers: Dict[PeerId, NodeId] = None # set of witnesses of the previous round that current node can see self.seen_votes_by_peers: Dict[PeerId, HashValue] = None # latest votes from each peer that current node can see + # metdata self.metadata: NodeMetadata = metadata self.update_signature() self.has_filled_node_data = False @@ -242,14 +249,14 @@ def get_peer_pubkey(self, peer_id: PeerId) -> Pubkey: def register_peer(self, peer: 'ConsensusPeer'): # peer_id must be unique - assert peer.peer_id not in [p.peer_id for p in self.peers] + (peer.peer_id not in [p.peer_id for p in self.peers]) or (_ for _ in ()).throw(ValueError("peer_id must be unique")) self.peers.append(peer) def get_all_peer_ids(self) -> List[PeerId]: return [p.peer_id for p in self.peers] def unregister_peer(self, peer: 'ConsensusPeer'): - assert peer in self.peers + (peer in self.peers) or (_ for _ in ()).throw(ValueError("peer must be registered")) self.peers.remove(peer) async def register_genesis_nodes(self): @@ -381,7 +388,7 @@ async def disconnect(self, peer_a: PeerId, peer_b: PeerId): async def gossip_send_node_and_ancestry(self, sender: PeerId, receiver: PeerId, node1: Node) -> bool: """Gossip a node and its ancestry to a peer""" - assert sender != receiver + sender != receiver or (_ for _ in ()).throw(ValueError("sender and receiver must be different")) for _ in range(3): if not self.is_connected(sender, receiver): @@ -414,7 +421,7 @@ async def gossip_send_node_and_ancestry(self, sender: PeerId, receiver: PeerId, if not node.is_genesis(): self_parent_node = sender_peer.get_node_by_hash(node.self_parent_hash) cross_parent_node = sender_peer.get_node_by_hash(node.cross_parent_hash) - assert self_parent_node is not None and cross_parent_node is not None + (self_parent_node is not None and cross_parent_node is not None) or (_ for _ in ()).throw(ValueError("self_parent_node and cross_parent_node must be non-None")) # use clone() to simulate the process of serializing and deserializing the nodes in internet protocols new_gossiped_list.append(self_parent_node.clone()) @@ -427,11 +434,13 @@ async def gossip_send_node_and_ancestry(self, sender: PeerId, receiver: PeerId, for i in range(len(all_received_nodes)): current_node = all_received_nodes[i] - if not receiver_peer.verify_node_and_add_to_local_view(current_node): - print(f"Peer {receiver_peer.peer_id} rejected node {current_node.node_hash} from {sender_peer.peer_id}") + try: + receiver_peer.verify_node_and_add_to_local_view(current_node) or (_ for _ in ()).throw(ValueError("failed to verify and add node")) + except Exception as e: + print(f"Peer {receiver_peer.peer_id} rejected node {current_node.node_hash} from {sender_peer.peer_id}: {e}") continue - else: - print(f"Peer {receiver_peer.peer_id} accepted node {current_node.node_hash} from {sender_peer.peer_id}") + + print(f"Peer {receiver_peer.peer_id} accepted node {current_node.node_hash} from {sender_peer.peer_id}") return receiver_peer.has_seen_valid_node(node1) @@ -458,6 +467,8 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.random_instance = random.Random(seed) self.current_round = 0 self.seen_valid_nodes: Dict[PeerId, List[Node]] = {} + self.pos_in_seen_valid_nodes: Dict[NodeId, (PeerId, int)] = {} # maps from node_hash to (peer_id, position) of the node in self.seen_valid_nodes[peer_id] + self.accumulated_txs = set() # Track all transactions seen by this peer self.network: NetworkSimulator = network ## adversary-related data @@ -493,7 +504,7 @@ def get_predecessors(self, node: Node) -> list[Node]: predecessors = [] self_parent_node = self.get_node_by_hash(node.self_parent_hash) cross_parent_node = self.get_node_by_hash(node.cross_parent_hash) - assert self_parent_node is not None and cross_parent_node is not None + (self_parent_node is not None and cross_parent_node is not None) or (_ for _ in ()).throw(ValueError("self_parent_node and cross_parent_node must be non-None")) predecessors.append(self_parent_node) predecessors.append(cross_parent_node) return predecessors @@ -505,6 +516,7 @@ def get_successors(self, node: Node) -> list[Node]: return sorted(successors, key=lambda x: x.round) def get_ancestry(self, node: Node) -> Set[Node]: + node is not None or (_ for _ in ()).throw(ValueError("node must be non-None to find its ancestry")) ancestry: Dict[NodeId, Node] = {} def dfs_backwards(node: Node): if node.node_hash in ancestry: @@ -516,6 +528,7 @@ def dfs_backwards(node: Node): return set(ancestry.values()) def get_lineage(self, node: Node) -> Set[Node]: + node is not None or (_ for _ in ()).throw(ValueError("node must be non-None to find its lineage")) lineage: Dict[NodeId, Node] = {} def dfs_forwards(node: Node): if node.node_hash in lineage: @@ -683,11 +696,11 @@ def create_genesis_node(self): round=0, is_witness=True, newly_seen_txs_list=[], - self_parent_hash="", - cross_parent_hash="", + self_parent_hash=EMPTY_NODE_HASH, + cross_parent_hash=EMPTY_NODE_HASH, metadata=NodeMetadata() ) - assert self.verify_node_and_add_to_local_view(node) == True + (self.verify_node_and_add_to_local_view(node) == True) or (_ for _ in ()).throw(ValueError("failed to verify and add genesis node")) return node def get_heaviest_batch_amongst_strict_ancestors(self, node: Node) -> HashValue: @@ -709,7 +722,7 @@ def get_heaviest_batch_amongst_strict_ancestors(self, node: Node) -> HashValue: non_equivocated_peers = node.non_equivocated_peers for peer_id in node.seen_votes_by_peers: voted_batch_hash = node.seen_votes_by_peers[peer_id] - assert voted_batch_hash in tree + (voted_batch_hash in tree) or (_ for _ in ()).throw(ValueError("voted_batch_hash must be in tree")) weight_of_branch[voted_batch_hash] += 1 # 3. accumulate weights upwards from the leaves to the root def traverse(batch_hash: HashValue): @@ -736,13 +749,20 @@ def get_beacon_randomness_of_batch(self, batch_hash: HashValue) -> HashValue: else: return self.observed_valid_batches[batch_hash].metadata.batch_proposal.next_beacon_randomness + def get_node_of_batch(self, batch_hash: HashValue) -> Node: + """ + Get the node of the given batch + """ + if batch_hash == GENESIS_BATCH: + return None + else: + return self.observed_valid_batches[batch_hash] + def verify_node_is_head_node(self, node: Node) -> bool: """ Verify that the given node is a head node, given that the vote is valid """ - return False - # TODO: turn this on - + return False if not node.is_witness or node.is_genesis(): # this function might be called before the batch proposal is constructed so we only needs to check whether the node is a witness return False @@ -765,6 +785,77 @@ def verify_node_is_head_node(self, node: Node) -> bool: # non-genesis beacon randomness return node.peer_id == self.network.get_leader_after_BEACON_PACE_rounds(self.get_beacon_randomness_of_batch(beacon_batch)) + def get_truncated_cone(self, prev_batch_hash: HashValue, head_node: Node) -> Dict[NodeId, Node]: + """ + Get the truncated cone of the given heaviest batch and node + """ + + node_of_prev_batch: Node = self.get_node_of_batch(prev_batch_hash) + + print(f"found node of prev batch = {node_of_prev_batch}") + # the difference between the ancestry of the current node and the ancestry of the node containing the previous batch + ancestry_of_cur_node = self.get_ancestry(head_node) + truncated_cone_candidates: Set[Node] = ancestry_of_cur_node.difference(set() if node_of_prev_batch is None else self.get_ancestry(node_of_prev_batch)) + truncated_cone: Set[Node] = set() + + for node in truncated_cone_candidates: + if node.node_hash in head_node.seen_nodes: # only picks the one that the head node can SEE + truncated_cone.add(node) + + res: Dict[NodeId, Node] = {node.node_hash: node for node in truncated_cone} + # TODO: we might implement the logic that allows a leader peer to limit the number of rounds of the truncated cone + return res + + def construct_tournament_graph_of_transactions(self, truncated_cone: Dict[NodeId, Node]) -> TournamentGraph: + """ + Construct the tournament graph of the transactions in the truncated cone + """ + all_peers = self.network.get_all_peer_ids() # TODO: handle the case that the list of peers changes + preference_graphs: Dict[PeerId, List[TransactionId]] = {} + for peer_id in all_peers: + preference_graphs[peer_id] = [] + + all_txs: Set[TransactionId] = set() + + visited: Set[NodeId] = set() + def constructPreferenceGraph(node: Node): + visited.add(node.node_hash) + + if node.self_parent_hash in truncated_cone: + constructPreferenceGraph(self.get_node_by_hash(node.self_parent_hash)) + else: + preference_graphs[node.peer_id].extend(node.newly_seen_txs_list) + + for node in truncated_cone.values(): + if not node.node_hash in visited: + constructPreferenceGraph(node) + + for peer_id in all_peers: + all_txs.update(preference_graphs[peer_id]) + + ## validate the preference graphs: only contains unique transactions + for peer_id in all_peers: + (len(preference_graphs[peer_id]) == len(set(preference_graphs[peer_id]))) or (_ for _ in ()).throw(ValueError("preference graphs must contain unique transactions")) + + ## construct the final tournament graph + tournament_graph: TournamentGraph = TournamentGraph() + count_edge_frequency: Dict[Tuple[TransactionId, TransactionId], int] = {} + + for peer_id in all_peers: + for i in range(len(preference_graphs[peer_id])): + for j in range(i + 1, len(preference_graphs[peer_id])): + tx1 = preference_graphs[peer_id][i] + tx2 = preference_graphs[peer_id][j] + if (tx1, tx2) not in count_edge_frequency: + count_edge_frequency[(tx1, tx2)] = 0 + count_edge_frequency[(tx1, tx2)] += 1 + + for (tx1, tx2), frequency in count_edge_frequency.items(): + if frequency > ORDER_FAIRNESS_THRESHOLD * len(all_peers): + tournament_graph.add_edge(tx1, tx2) + + return preference_graphs + def compute_batch_proposal(self, node: Node) -> BatchProposal: """ Compute a batch proposal for the given node @@ -772,10 +863,16 @@ def compute_batch_proposal(self, node: Node) -> BatchProposal: # fork-choice rule: select the previous batch using the Heaviest Observed Subtree (HOS) selection rule heaviest_batch: HashValue = self.get_heaviest_batch_amongst_strict_ancestors(node) + print(f"Before computing batch proposal for {node.node_hash}, heaviest batch = {heaviest_batch}") + prev_batch_beacon_randomness: HashValue = self.get_beacon_randomness_of_batch(heaviest_batch) - # construct the truncated cone = the intersection between the lineage of the heaviest batch and the ancestry of the current head node + print(f"Before computing truncated cone for {node.node_hash}, heaviest batch = {heaviest_batch}") + + # construct the truncated cone truncated_cone: Dict[NodeId, Node] = self.get_truncated_cone(heaviest_batch, node) + print(f"After computing truncated cone for {node.node_hash}, truncated cone = {truncated_cone}") + # gather all transactions in the truncated cone into a tournament graph tournament_graph: TournamentGraph = self.construct_tournament_graph_of_transactions(truncated_cone) # calculate the fair ordering of the transactions and construct the batch proposal @@ -790,7 +887,7 @@ def compute_batch_proposal(self, node: Node) -> BatchProposal: final_fair_ordering=final_fair_ordering, prev_beacon_randomness=prev_batch_beacon_randomness ) - assert batch_proposal.verify_batch_proposal_is_well_formed(node.round, prev_batch_beacon_randomness) + (batch_proposal.verify_batch_proposal_is_well_formed(node.round, prev_batch_beacon_randomness)) or (_ for _ in ()).throw(ValueError("batch proposal is not well-formed")) return batch_proposal @@ -811,7 +908,7 @@ def verify_batch_proposal_is_valid(self, node: Node) -> bool: """ Verify that the given batch proposal is valid """ - assert node.metadata.batch_proposal is not None + (node.metadata.batch_proposal is not None) or (_ for _ in ()).throw(ValueError("batch proposal must be non-None")) correct_batch_proposal: BatchProposal = self.compute_batch_proposal(node) @@ -840,36 +937,35 @@ def verify_vote_is_valid(self, node: Node) -> bool: """ return node.metadata.vote.head_batch_hash == self.compute_vote_for_node(node).head_batch_hash - def fill_node_data(self, node: Node) -> bool: + def fill_node_data(self, node: Node): """ Fill the data for the given node """ if not node.has_filled_node_data: try: + print(f"Peer {self.peer_id} before computing seen nodes of new node {node.node_hash}") self.compute_seen_nodes_of_new_node(node) + print(f"Peer {self.peer_id} after computing seen nodes of new node {node.node_hash} => latest_seen_node_by_peers = {node.latest_seen_node_by_peers}, non_equivocated_peers = {node.non_equivocated_peers}, equivocated_peers = {node.equivocated_peers}, seen_votes_by_peers = {node.seen_votes_by_peers}, latest_seen_witness_by_peers = {node.latest_seen_witness_by_peers}") self.construct_batch_proposal_if_needed(node) + print(f"Peer {self.peer_id} after constructing batch proposal for node {node.node_hash}") self.construct_vote_for_node(node) # this will throw errors if predecessors of the node are not available + print(f"Peer {self.peer_id} after constructing vote for node {node.node_hash}") node.has_filled_node_data = True node.update_signature() except Exception as e: - print(f"error = {e}") - return False - - return True + print(f"Peer {self.peer_id} error when filling data for node {node.node_hash} = {e}") + raise e + print(f"Peer {node.peer_id} finished filling data for node {node.node_hash}, with seen nodes = {node.latest_seen_node_by_peers}") def has_seen_valid_node(self, node: Node) -> bool: - if node.peer_id == self.peer_id: - return node in self.my_nodes() - else: - return (node.peer_id in self.seen_valid_nodes) and (node.node_hash in [node.node_hash for node in self.seen_valid_nodes[node.peer_id]]) + return node.node_hash in self.pos_in_seen_valid_nodes def get_node_by_hash(self, node_hash: NodeId) -> Optional[Node]: - ## find in seen_valid_nodes of other peers - for peer_id in self.seen_valid_nodes: - for node in self.seen_valid_nodes[peer_id]: - if node.node_hash == node_hash: - return node - return None + if node_hash in self.pos_in_seen_valid_nodes: + peer_id, pos = self.pos_in_seen_valid_nodes[node_hash] + return self.seen_valid_nodes[peer_id][pos] + else: + return None def record_transaction_receipt(self, tx_id: TransactionId, timestamp: float): """Record when a transaction was received""" @@ -882,41 +978,99 @@ def select_neighbors(self, all_peers: List[PeerId]): num_neighbors = min(self.get_max_num_neighbors(), len(all_peers) - 1) self.neighbors = self.random_instance.sample(potential_neighbors, num_neighbors) - def compute_seen_nodes_of_new_node(self, node: Node): - """Compute the list of seen_nodes of the new node""" - assert node.seen_nodes is None and node.equivocated_peers is None - node.seen_nodes = set() - equivocated_peers = set() - non_equivocated_peers = set() - seen_votes_by_peers: Dict[PeerId, HashValue] = {} - - if not node.is_genesis(): - ancestry_of_node = self.get_ancestry(node) - self_parent_set: Set[NodeId] = set() - - for cur_node in ancestry_of_node: - if not cur_node.is_genesis(): - if cur_node.self_parent_hash in self_parent_set: - equivocated_peers.add(cur_node.peer_id) + def is_valid_descendant_and_self_ancestor(self, descendant_node_hash: NodeId, self_ancestor_node_hash: NodeId) -> bool: + try: + (peer1, pos1) = self.pos_in_seen_valid_nodes[descendant_node_hash] + (peer2, pos2) = self.pos_in_seen_valid_nodes[self_ancestor_node_hash] + return peer1 == peer2 and pos1 > pos2 + except: + return False + + def get_self_descendant(self, node_hash_1: NodeId, node_hash_2: NodeId) -> Optional[NodeId]: + try: + (peer1, pos1) = self.pos_in_seen_valid_nodes[node_hash_1] + (peer2, pos2) = self.pos_in_seen_valid_nodes[node_hash_2] + if peer1 == peer2: + return node_hash_1 if pos1 > pos2 else node_hash_2 + else: + return None + except: + return None + + def compute_seen_nodes_of_new_node(self, dest_node: Node): + """Compute seen data of the new node by aggregating from parents and itself.""" + (dest_node.latest_seen_node_by_peers is None or + dest_node.non_equivocated_peers is None or + dest_node.equivocated_peers is None or + dest_node.seen_votes_by_peers is None or + dest_node.latest_seen_witness_by_peers is None) or (_ for _ in ()).throw(ValueError("Node must not have precomputed values.")) + + print(f"Peer {self.peer_id} computing seen _nodes for {dest_node.node_hash} {dest_node}") + latest_seen_node_by_peers: Dict[PeerId, NodeId] = {} + non_equivocated_peers: Set[PeerId] = set() + equivocated_peers: Set[PeerId] = set() + seen_votes_by_peers: Dict[PeerId, HashValue] = {} + latest_seen_witness_by_peers: Dict[PeerId, NodeId] = {} + + aggregated_references = [dest_node.self_parent_hash, dest_node.cross_parent_hash] + aggregated_references = [ref for ref in aggregated_references if ref is not EMPTY_NODE_HASH] + + def aggregate_latest_seen_by_peers( + dest_dict: Dict[PeerId, NodeId], + parent_dict: Dict[PeerId, NodeId], + equivocated_peers: Set[PeerId], + ): + for peer_id, seen_node_hash in parent_dict.items(): + seen_node = self.get_node_by_hash(seen_node_hash) + seen_node is not None or (_ for _ in ()).throw(ValueError("seen node must not be None")) + + existing_seen_node_hash = dest_dict.get(peer_id) + if existing_seen_node_hash is not None: + self_descendant_node_id = self.get_self_descendant(existing_seen_node_hash, seen_node_hash) + if self_descendant_node_id is not None: + dest_dict[peer_id] = self_descendant_node_id + else: + # two nodes form forks + equivocated_peers.add(peer_id) else: - self_parent_set.add(cur_node.self_parent_hash) - non_equivocated_peers.add(cur_node.peer_id) + dest_dict[peer_id] = seen_node_hash + + # Aggregate data from parents + for parent_hash in aggregated_references: + if parent_hash == EMPTY_NODE_HASH: continue + parent_node = self.get_node_by_hash(parent_hash) + latest_seen_by_peers_of_parent = parent_node.latest_seen_node_by_peers + latest_seen_witness_by_peers_of_parent = parent_node.latest_seen_witness_by_peers + + aggregate_latest_seen_by_peers( + dest_dict=latest_seen_node_by_peers, + parent_dict=latest_seen_by_peers_of_parent, + equivocated_peers=equivocated_peers + ) + aggregate_latest_seen_by_peers( + dest_dict=latest_seen_witness_by_peers, + parent_dict=latest_seen_witness_by_peers_of_parent, + equivocated_peers=equivocated_peers + ) + + print(f"Peer {self.peer_id} is merging {parent_node} with {parent_node.equivocated_peers}") + equivocated_peers.update(parent_node.equivocated_peers) + seen_votes_by_peers.update(parent_node.seen_votes_by_peers) # the vote inside the dest_node would be updated later when verifying the vote of the dest_node - for cur_node in ancestry_of_node: - if cur_node.peer_id in equivocated_peers: - continue + for peer_id in equivocated_peers: + latest_seen_node_by_peers.pop(peer_id) + latest_seen_witness_by_peers.pop(peer_id) + + for peer_id in latest_seen_node_by_peers.keys(): + non_equivocated_peers.add(peer_id) + # store the computed values - if cur_node.node_hash not in self_parent_set: # there should be at most 1 such node for each peer - seen_votes_by_peers[cur_node.peer_id] = cur_node.metadata.vote.head_batch_hash - - node.seen_nodes.add(cur_node.node_hash) - - non_equivocated_peers = non_equivocated_peers.difference(equivocated_peers) - # assign the computed values to the node - node.equivocated_peers = equivocated_peers - node.non_equivocated_peers = non_equivocated_peers - node.seen_votes_by_peers = seen_votes_by_peers - return + dest_node.non_equivocated_peers = non_equivocated_peers + dest_node.seen_votes_by_peers = seen_votes_by_peers + dest_node.equivocated_peers = equivocated_peers + # we only updates a node sees itself when it is verified and added to the local view in self.verify_node_and_add_to_local_view() + dest_node.latest_seen_node_by_peers = latest_seen_node_by_peers + dest_node.latest_seen_witness_by_peers = latest_seen_witness_by_peers def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: """Verify a node and its transactions, and add it to the local view""" @@ -926,13 +1080,19 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: if not self.verify_node(node): return False - # add to seen_valid_nodes - if node.peer_id not in self.seen_valid_nodes: - self.seen_valid_nodes[node.peer_id] = [] + should_add_node = not self.has_seen_valid_node(node) - should_add_node = node.node_hash not in [node.node_hash for node in self.seen_valid_nodes[node.peer_id]] if should_add_node: + if node.peer_id not in self.seen_valid_nodes: + self.seen_valid_nodes[node.peer_id] = [] self.seen_valid_nodes[node.peer_id].append(node) + self.pos_in_seen_valid_nodes[node.node_hash] = (node.peer_id, len(self.seen_valid_nodes[node.peer_id]) - 1) + + # make a node sees itself so its descendants can use these accumulated values + node.latest_seen_node_by_peers[node.peer_id] = node.node_hash + if node.is_witness: + node.latest_seen_witness_by_peers[node.peer_id] = node.node_hash + if node.metadata.batch_proposal is not None: self.observed_valid_batches[node.metadata.batch_proposal.batch_hash] = node @@ -955,61 +1115,77 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No ## if some witnesses are descendants of equivocated nodes, they are ignored completely ## NOTE: we already make sure the ancestry of dest_node is verified - strongly_sees_threshold = 2/3 * len(self.network.peers) - strongly_seen_witnesses: list["Node"] = [] - ancestry_of_dest_node = self.get_ancestry(dest_node) - # sorted deterministically - witnesses_in_round_r = sorted([node for node in ancestry_of_dest_node if node.is_witness and node.round == r], key=lambda x: (x.peer_id, x.node_hash)) - - # itearate through all witnesses in round r and check if the dest_node can strongly see them - for witness in witnesses_in_round_r: - lineage_of_witness = self.get_lineage(witness) - crossed_peers = set() - can_conclude_strongly_seen = False - for mid_node in lineage_of_witness: - if (mid_node.peer_id in crossed_peers) or (mid_node not in ancestry_of_dest_node): - continue # can include the witness itself if satisfies the condition - - should_count_as_valid_path_from_witness_to_dest_node = (mid_node.node_hash == dest_node.node_hash) or (witness.node_hash in mid_node.seen_nodes) - if should_count_as_valid_path_from_witness_to_dest_node: - crossed_peers.add(mid_node.peer_id) - can_conclude_strongly_seen = len(crossed_peers) > strongly_sees_threshold - - if can_conclude_strongly_seen: - break + N = len(self.network.peers) + + latest_seen_witness_by_peers: Dict[PeerId, NodeId] = dest_node.latest_seen_witness_by_peers + seen_witnesses_in_round_r = [self.get_node_by_hash(node_hash) for node_hash in latest_seen_witness_by_peers.values() if node_hash and self.get_node_by_hash(node_hash).round == r] + seen_witnesses_in_round_r = [node for node in seen_witnesses_in_round_r if node is not None] + + # keep only the strongly seen ones + latest_seen_node_by_peers: Dict[PeerId, NodeId] = dest_node.latest_seen_node_by_peers + seen_nodes_in_round_r = [self.get_node_by_hash(node_hash) for node_hash in latest_seen_node_by_peers.values() if node_hash and self.get_node_by_hash(node_hash).round == r] + seen_nodes_in_round_r = [node for node in seen_nodes_in_round_r if node is not None] + + # compute the number of times each witness is seen by each possible (latest) mid node of peers that are seen by the dest_node + count_seens: Dict[NodeId, int] = {} + + # O(N^2) where N is the number of peers + for mid_node in seen_nodes_in_round_r: + for witness in seen_witnesses_in_round_r: + # check if mid_node can strongly see witness + witness_peer_id = witness.peer_id + if witness_peer_id not in mid_node.latest_seen_node_by_peers: + continue + + latest_seen_node_of_witness_peer_id: NodeId = mid_node.latest_seen_node_by_peers[witness_peer_id] + if self.is_valid_descendant_and_self_ancestor(latest_seen_node_of_witness_peer_id, mid_node.node_hash): + count_seens[witness.node_hash] = count_seens.get(witness.node_hash, 0) + 1 - if can_conclude_strongly_seen: - if witness.peer_id not in [node.peer_id for node in strongly_seen_witnesses]: - # at most 1 witness per peer is counted - strongly_seen_witnesses.append(witness) + # the dest_node must see the witness of its own peer + found_self_peer_witness = [witness for witness in seen_witnesses_in_round_r if witness.peer_id == self.peer_id and count_seens.get(witness.node_hash, 0) > 2 * N / 3] + if len(found_self_peer_witness) == 0: + return [] + # keep only the witnesses that are seen by > 2/3 of the mid nodes + strongly_seen_witnesses: list["Node"] = [] + for witness in seen_witnesses_in_round_r: + if count_seens[witness.node_hash] > 2 * N / 3: + strongly_seen_witnesses.append(witness) + return strongly_seen_witnesses - def check_round_number_of_non_genesis_node_with_valid_parents(self, node: Node) -> bool: + def check_round_number_of_non_genesis_node_with_valid_parents(self, dest_node: Node) -> bool: """ if a node is of round r: - it must not strongly sees > 2N/3 of witnesses of round r - if its self parent is of round r, it is valid. if its self parent is of round r-1, it must strongly sees > 2N/3 of witnesses of round r-1 """ N = len(self.network.peers) - r = node.round + r = dest_node.round - # the node must not strongly sees > 2/3 of witnesses of round r - strongly_seen_witnesses_in_round_r = self.get_strongly_seen_valid_witnesses(node, r) - - if len(strongly_seen_witnesses_in_round_r) > 2 * N / 3: + self_parent_node = self.get_node_by_hash(dest_node.self_parent_hash) + (self_parent_node is not None) or (_ for _ in ()).throw(ValueError("self_parent_node must be non-None")) + if self_parent_node.round < r - 1: return False - # check non-witness node case - self_parent_node = self.get_node_by_hash(node.self_parent_hash) - assert self_parent_node is not None - if self_parent_node.round == r: - return True + if self_parent_node.round == r - 1: + # the dest_node is a witness of round r so it must strongly sees > 2/3 of witnesses of round r-1 + strongly_seen_witnesses_in_round_r_minus_1 = self.get_strongly_seen_valid_witnesses(dest_node, r-1) + is_witness_of_round_r = len(strongly_seen_witnesses_in_round_r_minus_1) > 2 * N / 3 + + print(f"Peer {self.peer_id} checking if {dest_node} is a witness of round {r}: it strongly sees {strongly_seen_witnesses_in_round_r_minus_1}") + if dest_node.node_hash == "03b372de": + sys.exit(0) + if not is_witness_of_round_r: + return False - # check witness node case - strongly_seen_witnesses_in_round_r_minus_1 = self.get_strongly_seen_valid_witnesses(node, r-1) + # the dest_node must not strongly sees > 2/3 of witnesses of round r + strongly_seen_witnesses_in_round_r = self.get_strongly_seen_valid_witnesses(dest_node, r) - return len(strongly_seen_witnesses_in_round_r_minus_1) > 2 * N / 3 + if len(strongly_seen_witnesses_in_round_r) > 2 * N / 3: + return False + + return True def verify_node(self, node: Node = None) -> bool: """Verify a node and its transactions @@ -1049,7 +1225,7 @@ def verify_node(self, node: Node = None) -> bool: if not self.verify_batch_proposal_is_valid(node): return False else: - assert node.metadata.batch_proposal is None + (node.metadata.batch_proposal is None) or (_ for _ in ()).throw(ValueError("batch proposal must be None")) # verify the vote if not self.verify_vote_is_valid(node): @@ -1085,9 +1261,9 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: """A view-only method that computes a new node based on the self parent and cross parent @return: the newly created node, if there is no new txs, return None """ - assert self_parent is not None and cross_parent is not None + (self_parent is not None and cross_parent is not None) or (_ for _ in ()).throw(ValueError("self_parent and cross_parent must be non-None")) if not self.is_adversary: - assert self_parent == self.get_my_last_node() + (self_parent == self.get_my_last_node()) or (_ for _ in ()).throw(ValueError("self_parent must be the last node of the current peer")) # the newly seen list of txs in the new node must be not empty # TODO: sort this list by timestamp of receipt of the transactions @@ -1102,9 +1278,10 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: base_hash = f"{self.peer_id}{str(round_num).zfill(3)}" if not self.is_adversary: - assert self_parent.round == self.current_round + (self_parent.round == self.current_round) or (_ for _ in ()).throw(ValueError("self_parent.round must be the current round")) - for round_num in range(self_parent.round, self_parent.round + 2): + print(f"Peer {self.peer_id} start computing new node:") + for round_num in range(self_parent.round + 1, self_parent.round - 1, -1): print(f"Peer {self.peer_id} computing new node for round {round_num}") new_node = Node( peer_id=self.peer_id, @@ -1115,7 +1292,7 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: cross_parent_hash=cross_parent.node_hash, metadata=NodeMetadata() ) - assert self.fill_node_data(new_node) + self.fill_node_data(new_node) # found a valid new node is_node_valid = self.verify_node(new_node) @@ -1171,7 +1348,7 @@ async def gossip_push(self): pass # can't construct a valid node from the current tuple of self_parent and cross_parent - assert len(new_nodes) <= num_nodes_to_create + (len(new_nodes) <= num_nodes_to_create) or (_ for _ in ()).throw(ValueError("number of new nodes must be less than or equal to num_nodes_to_create")) if len(new_nodes) > 0: print(f"Peer {self.peer_id}, is_adversary = {self.is_adversary}, computed {len(new_nodes)} nodes, its neighbors = {self.neighbors}, its equivocated peers = {self.my_nodes()[-1].equivocated_peers}, seen_peers = {[peer_id for peer_id in self.seen_valid_nodes]}") else: @@ -1179,7 +1356,7 @@ async def gossip_push(self): return for new_node in new_nodes: - assert self.verify_node_and_add_to_local_view(new_node) + (self.verify_node_and_add_to_local_view(new_node)) or (_ for _ in ()).throw(ValueError("failed to verify and add new node")) # start gossiping to neighboring peers for i in range(len(self.neighbors)): @@ -1205,7 +1382,7 @@ async def main(): ) # Create peers - num_peers = 7 # next threshold for count_adversary = 2 is N = 7 + num_peers = 4 # next threshold for count_adversary = 2 is N = 7 count_adversary = 0 for i in range(num_peers): is_adversary = (count_adversary + 1) < 1 * num_peers / 3 and network.random_instance.random() < 0.5 @@ -1235,7 +1412,7 @@ async def main(): # Main consensus loop i = 0 - while True and i < 50: + while True and i < 40: i += 1 if i % 100 == 0: print(f"{i}th iteration") @@ -1288,7 +1465,7 @@ def validate_consistency(): else: global_info_of_node[node_id] = node_description - assert found_node_conflict == False + (found_node_conflict == False) or (_ for _ in ()).throw(ValueError("found node conflict in node info between peers")) print("SUCCESS: There is no conflict in node info between peers") diff --git a/dagpool/graph.py b/dagpool/graph.py index 14a72c1..a0207dd 100644 --- a/dagpool/graph.py +++ b/dagpool/graph.py @@ -25,6 +25,9 @@ def __init__(self): self.node_to_scc_hash: Dict[GraphNodeId, HashValue] = None self.hash_to_scc_nodes: Dict[HashValue, List[GraphNodeId]] = None + def count_nodes(self) -> int: + return len(self.nodes) + def reset_graph_properties(self): self.is_tournament_graph = None self.connected_components = None @@ -148,6 +151,7 @@ def is_tournament_graph(self) -> bool: return super().is_tournament_graph() def find_hamiltonian_path(self, scc: List[GraphNodeId]) -> List[GraphNodeId]: + self.assert_is_tournament_graph() # must be a strongly connected component self.assert_is_strongly_connected_component(scc) From b24213748edc054bfc1c739d2e0a007c3f84e22f Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Thu, 6 Mar 2025 20:20:17 +0700 Subject: [PATCH 08/10] fix the witness checking --- dagpool/consensus_client.py | 238 ++++++++++++++++++++---------------- 1 file changed, 136 insertions(+), 102 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index 3b4754c..4b90ad8 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -16,6 +16,14 @@ from math import log2, ceil # Added log2 and ceil imports from graph import TournamentGraph +import builtins +original_print = builtins.print + +SHOULD_PRINT = True +def print(*args, **kwargs): + if SHOULD_PRINT: + original_print(*args, **kwargs) + BEACON_PACE = 4 # the first BEACON_PACE rounds are derived directly from the list of peers # BEACON_PACE should be chosen large enough to make sure peers have enough time to realize that they are the leader of the next rounds @@ -374,6 +382,7 @@ async def connect(self, peer_a: PeerId, peer_b: PeerId): return False def is_connected(self, peer_a: PeerId, peer_b: PeerId) -> bool: + return True conn_key = (peer_a, peer_b) conn_key_rev = (peer_b, peer_a) return self.connections.get(conn_key, ConnectionState.CLOSED) == ConnectionState.ESTABLISHED or self.connections.get(conn_key_rev, ConnectionState.CLOSED) == ConnectionState.ESTABLISHED @@ -402,40 +411,30 @@ async def gossip_send_node_and_ancestry(self, sender: PeerId, receiver: PeerId, sender_peer = [p for p in self.peers if p.peer_id == sender][0] receiver_peer = [p for p in self.peers if p.peer_id == receiver][0] + + if receiver_peer.has_seen_valid_node(node1): + return True - receiver_has_all_needed_ancestors = False - all_received_nodes = [] - current_gossiped_list = [node1] # we don't send the ancestry of the last node of the sender because node1 might be an equivocated node - while not receiver_has_all_needed_ancestors: - # receiver receives nodes from sender - receiver_has_all_needed_ancestors = True - - new_gossiped_list = [] - for node in current_gossiped_list: - if receiver_peer.has_seen_valid_node(node): - continue - else: - all_received_nodes.append(node) - receiver_has_all_needed_ancestors = False # not stop + all_received_nodes: List[Node] = [] + traced: Dict[NodeId, bool] = {} - if not node.is_genesis(): - self_parent_node = sender_peer.get_node_by_hash(node.self_parent_hash) - cross_parent_node = sender_peer.get_node_by_hash(node.cross_parent_hash) - (self_parent_node is not None and cross_parent_node is not None) or (_ for _ in ()).throw(ValueError("self_parent_node and cross_parent_node must be non-None")) + def trace(node: Node): + if node.node_hash in traced or receiver_peer.has_seen_valid_node(node): + return + traced[node.node_hash] = True - # use clone() to simulate the process of serializing and deserializing the nodes in internet protocols - new_gossiped_list.append(self_parent_node.clone()) - new_gossiped_list.append(cross_parent_node.clone()) - - current_gossiped_list = new_gossiped_list + predecessors = sender_peer.get_predecessors(node) + for predecessor in predecessors: + trace(predecessor) + + all_received_nodes.append(node) + + trace(node1) - ### reverse all_received_nodes because the early nodes in the ancestry are in the right of the list, but we want them to be in the left - all_received_nodes = all_received_nodes[::-1] - for i in range(len(all_received_nodes)): current_node = all_received_nodes[i] try: - receiver_peer.verify_node_and_add_to_local_view(current_node) or (_ for _ in ()).throw(ValueError("failed to verify and add node")) + receiver_peer.verify_node_and_add_to_local_view(current_node.clone()) or (_ for _ in ()).throw(ValueError("failed to verify and add node")) except Exception as e: print(f"Peer {receiver_peer.peer_id} rejected node {current_node.node_hash} from {sender_peer.peer_id}: {e}") continue @@ -448,8 +447,9 @@ async def _delay(self): """Simulate network latency""" delay = 0 # TODO: turn on delay (a value in the self.latency_range range) for full simulation - await asyncio.sleep(delay) - + # await asyncio.sleep(delay) + pass + def find_node_by_hash(self, node_hash: str) -> Optional[Node]: """Find a node by its hash across all peers""" if not node_hash: @@ -472,7 +472,8 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.accumulated_txs = set() # Track all transactions seen by this peer self.network: NetworkSimulator = network ## adversary-related data - self.equivocation_prob = 0.2 if is_adversary else 0.0 + self.equivocation_prob = 0.3 if is_adversary else 0.0 + self.equivocated_peers: Set[PeerId] = set() self.neighbors: list[PeerId] = [] # Track neighboring peers # if each peer connects to log(N) neighbors, a transaction would takes O(log(N)/log(log(N))) gossip hops to reach the whole network # for N = 10^6, it would be 7 hops @@ -616,7 +617,7 @@ def adjust_position(node, visited): lines_of_peers[predecessor.peer_id].append(count_lines) line_id = count_lines - node_positions[node.node_hash] = (max(node_positions[predecessor.node_hash][0] + 3, node_positions[node.node_hash][0]), line_id) + node_positions[node.node_hash] = (max(node_positions[predecessor.node_hash][0] + 10, node_positions[node.node_hash][0]), line_id) for node in seen_valid_nodes: if not node.node_hash in visited: @@ -648,13 +649,13 @@ def adjust_position(node, visited): G, node_positions, nodelist=[node.node_hash], - node_size=1000, + node_size=500, node_shape='s', node_color=node_color, label=[node.peer_id] ) - nx.draw_networkx_edges(G, node_positions, edge_color='black', arrows=True, arrowsize=20, width=1, node_size=1000) # Ensure arrows are properly sized relative to nodes + nx.draw_networkx_edges(G, node_positions, edge_color='black', arrows=True, arrowsize=20, width=1, node_size=500) # Ensure arrows are properly sized relative to nodes # Draw node labels with a border for node in seen_valid_nodes: @@ -792,7 +793,6 @@ def get_truncated_cone(self, prev_batch_hash: HashValue, head_node: Node) -> Dic node_of_prev_batch: Node = self.get_node_of_batch(prev_batch_hash) - print(f"found node of prev batch = {node_of_prev_batch}") # the difference between the ancestry of the current node and the ancestry of the node containing the previous batch ancestry_of_cur_node = self.get_ancestry(head_node) truncated_cone_candidates: Set[Node] = ancestry_of_cur_node.difference(set() if node_of_prev_batch is None else self.get_ancestry(node_of_prev_batch)) @@ -863,15 +863,10 @@ def compute_batch_proposal(self, node: Node) -> BatchProposal: # fork-choice rule: select the previous batch using the Heaviest Observed Subtree (HOS) selection rule heaviest_batch: HashValue = self.get_heaviest_batch_amongst_strict_ancestors(node) - print(f"Before computing batch proposal for {node.node_hash}, heaviest batch = {heaviest_batch}") - prev_batch_beacon_randomness: HashValue = self.get_beacon_randomness_of_batch(heaviest_batch) - print(f"Before computing truncated cone for {node.node_hash}, heaviest batch = {heaviest_batch}") - # construct the truncated cone truncated_cone: Dict[NodeId, Node] = self.get_truncated_cone(heaviest_batch, node) - print(f"After computing truncated cone for {node.node_hash}, truncated cone = {truncated_cone}") # gather all transactions in the truncated cone into a tournament graph tournament_graph: TournamentGraph = self.construct_tournament_graph_of_transactions(truncated_cone) @@ -943,20 +938,16 @@ def fill_node_data(self, node: Node): """ if not node.has_filled_node_data: try: - print(f"Peer {self.peer_id} before computing seen nodes of new node {node.node_hash}") self.compute_seen_nodes_of_new_node(node) - print(f"Peer {self.peer_id} after computing seen nodes of new node {node.node_hash} => latest_seen_node_by_peers = {node.latest_seen_node_by_peers}, non_equivocated_peers = {node.non_equivocated_peers}, equivocated_peers = {node.equivocated_peers}, seen_votes_by_peers = {node.seen_votes_by_peers}, latest_seen_witness_by_peers = {node.latest_seen_witness_by_peers}") self.construct_batch_proposal_if_needed(node) - print(f"Peer {self.peer_id} after constructing batch proposal for node {node.node_hash}") self.construct_vote_for_node(node) # this will throw errors if predecessors of the node are not available - print(f"Peer {self.peer_id} after constructing vote for node {node.node_hash}") node.has_filled_node_data = True node.update_signature() except Exception as e: print(f"Peer {self.peer_id} error when filling data for node {node.node_hash} = {e}") raise e - print(f"Peer {node.peer_id} finished filling data for node {node.node_hash}, with seen nodes = {node.latest_seen_node_by_peers}") + print(f"Peer {self.peer_id} finished filling data for node {node.node_hash}, with seen nodes = {node.latest_seen_node_by_peers}") def has_seen_valid_node(self, node: Node) -> bool: return node.node_hash in self.pos_in_seen_valid_nodes @@ -976,7 +967,7 @@ def select_neighbors(self, all_peers: List[PeerId]): """Randomly select neighbors from available peers""" potential_neighbors = [p for p in all_peers if p != self.peer_id] num_neighbors = min(self.get_max_num_neighbors(), len(all_peers) - 1) - self.neighbors = self.random_instance.sample(potential_neighbors, num_neighbors) + self.neighbors = [peer_id for peer_id in self.random_instance.sample(potential_neighbors, num_neighbors) if peer_id not in self.equivocated_peers] def is_valid_descendant_and_self_ancestor(self, descendant_node_hash: NodeId, self_ancestor_node_hash: NodeId) -> bool: try: @@ -1070,46 +1061,66 @@ def aggregate_latest_seen_by_peers( dest_node.equivocated_peers = equivocated_peers # we only updates a node sees itself when it is verified and added to the local view in self.verify_node_and_add_to_local_view() dest_node.latest_seen_node_by_peers = latest_seen_node_by_peers + dest_node.latest_seen_witness_by_peers = latest_seen_witness_by_peers + ### + self.equivocated_peers.update(dest_node.equivocated_peers) + def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: """Verify a node and its transactions, and add it to the local view""" + if self.has_seen_valid_node(node): + return True + self.fill_node_data(node) if not self.verify_node(node): return False - should_add_node = not self.has_seen_valid_node(node) - - if should_add_node: - if node.peer_id not in self.seen_valid_nodes: - self.seen_valid_nodes[node.peer_id] = [] - self.seen_valid_nodes[node.peer_id].append(node) - self.pos_in_seen_valid_nodes[node.node_hash] = (node.peer_id, len(self.seen_valid_nodes[node.peer_id]) - 1) + if node.peer_id not in self.seen_valid_nodes: + self.seen_valid_nodes[node.peer_id] = [] + self.seen_valid_nodes[node.peer_id].append(node) + self.pos_in_seen_valid_nodes[node.node_hash] = (node.peer_id, len(self.seen_valid_nodes[node.peer_id]) - 1) - # make a node sees itself so its descendants can use these accumulated values - node.latest_seen_node_by_peers[node.peer_id] = node.node_hash - if node.is_witness: - node.latest_seen_witness_by_peers[node.peer_id] = node.node_hash + # make a node sees itself so its descendants can use these accumulated values + node.latest_seen_node_by_peers[node.peer_id] = node.node_hash + if node.is_witness: + node.latest_seen_witness_by_peers[node.peer_id] = node.node_hash - if node.metadata.batch_proposal is not None: - self.observed_valid_batches[node.metadata.batch_proposal.batch_hash] = node - - for predecessor in self.get_predecessors(node): - if predecessor.node_hash not in self.local_graph: - self.local_graph[predecessor.node_hash] = set() - self.local_graph[predecessor.node_hash].add(node.node_hash) - - # do the cleanup if the node is created by the current peer - if node.peer_id == self.peer_id: - self.pending_txs.clear() # because all txs in the pending_txs are now in the new node - self.current_round = node.round # this makes the current round of the peer = the round of the last node in the list of its nodes + if node.metadata.batch_proposal is not None: + self.observed_valid_batches[node.metadata.batch_proposal.batch_hash] = node + + for predecessor in self.get_predecessors(node): + if predecessor.node_hash not in self.local_graph: + self.local_graph[predecessor.node_hash] = set() + self.local_graph[predecessor.node_hash].add(node.node_hash) + + # do the cleanup if the node is created by the current peer + if node.peer_id == self.peer_id: + self.pending_txs.clear() # because all txs in the pending_txs are now in the new node + self.current_round = node.round # this makes the current round of the peer = the round of the last node in the list of its nodes print(f"Peer {self.peer_id} added node {node.node_hash} to its local view => new round = {self.current_round}") return True + def find_prev_witness_at_round(self, cur_witness: Node, r: int) -> Optional[Node]: + while cur_witness.round > r: + parent_node = self.get_node_by_hash(cur_witness.self_parent_hash) + if parent_node is None: + return None + + # jump to prev witness + prev_witness = self.get_node_by_hash(parent_node.latest_seen_witness_by_peers[cur_witness.peer_id]) + + if prev_witness is None: + return None + + cur_witness = prev_witness + + return cur_witness + def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["Node"]: ## check if this witness strongly sees > 2/3 of witnesses of r ## if some witnesses are descendants of equivocated nodes, they are ignored completely @@ -1118,38 +1129,51 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No N = len(self.network.peers) latest_seen_witness_by_peers: Dict[PeerId, NodeId] = dest_node.latest_seen_witness_by_peers - seen_witnesses_in_round_r = [self.get_node_by_hash(node_hash) for node_hash in latest_seen_witness_by_peers.values() if node_hash and self.get_node_by_hash(node_hash).round == r] - seen_witnesses_in_round_r = [node for node in seen_witnesses_in_round_r if node is not None] - + seen_witnesses_in_round_gte_r = [self.get_node_by_hash(node_hash) for node_hash in latest_seen_witness_by_peers.values() if node_hash and self.get_node_by_hash(node_hash).round >= r] + seen_witnesses_in_round_r = [self.find_prev_witness_at_round(node, r) for node in seen_witnesses_in_round_gte_r if node is not None] + # keep only the strongly seen ones latest_seen_node_by_peers: Dict[PeerId, NodeId] = dest_node.latest_seen_node_by_peers - seen_nodes_in_round_r = [self.get_node_by_hash(node_hash) for node_hash in latest_seen_node_by_peers.values() if node_hash and self.get_node_by_hash(node_hash).round == r] - seen_nodes_in_round_r = [node for node in seen_nodes_in_round_r if node is not None] + + seen_nodes_in_round_gte_r = [self.get_node_by_hash(node_hash) for node_hash in latest_seen_node_by_peers.values() if node_hash and self.get_node_by_hash(node_hash).round >= r] + seen_nodes_in_round_gte_r = [node for node in seen_nodes_in_round_gte_r if node is not None] # compute the number of times each witness is seen by each possible (latest) mid node of peers that are seen by the dest_node count_seens: Dict[NodeId, int] = {} - # O(N^2) where N is the number of peers - for mid_node in seen_nodes_in_round_r: + for mid_node in seen_nodes_in_round_gte_r: for witness in seen_witnesses_in_round_r: # check if mid_node can strongly see witness witness_peer_id = witness.peer_id + + # if dest_node.node_hash == "03b372de": + # if witness.node_hash == "858a63d3": + # print(f"Peer {self.peer_id} checking if {mid_node} can strongly see {witness}, and seen nodes of {mid_node} = {mid_node.latest_seen_node_by_peers}") + if witness_peer_id not in mid_node.latest_seen_node_by_peers: continue latest_seen_node_of_witness_peer_id: NodeId = mid_node.latest_seen_node_by_peers[witness_peer_id] - if self.is_valid_descendant_and_self_ancestor(latest_seen_node_of_witness_peer_id, mid_node.node_hash): - count_seens[witness.node_hash] = count_seens.get(witness.node_hash, 0) + 1 + # if dest_node.node_hash == "03b372de": + # if witness.node_hash == "858a63d3": + # print(f"Let's validate {latest_seen_node_of_witness_peer_id} is a descendant of {witness.node_hash} => {self.is_valid_descendant_and_self_ancestor(latest_seen_node_of_witness_peer_id, witness.node_hash)}") + if self.is_valid_descendant_and_self_ancestor(latest_seen_node_of_witness_peer_id, witness.node_hash): + count_seens[witness.node_hash] = count_seens.get(witness.node_hash, 0) + 1 + # the dest_node must see the witness of its own peer - found_self_peer_witness = [witness for witness in seen_witnesses_in_round_r if witness.peer_id == self.peer_id and count_seens.get(witness.node_hash, 0) > 2 * N / 3] + found_self_peer_witness = [witness for witness in seen_witnesses_in_round_r if witness.peer_id == self.peer_id and count_seens.get(witness.node_hash, 0) > 0] + # if dest_node.node_hash == "03b372de": + # print(f"Valid seen witnesses in round {r} = {[node.node_hash for node in seen_witnesses_in_round_r]}") + # print(f"Valid seen nodes in round {r} = {[node.node_hash for node in seen_nodes_in_round_r]}") + # print(f"Peer {self.peer_id} found_self_peer_witness = {found_self_peer_witness}, count_seens = {count_seens}") if len(found_self_peer_witness) == 0: return [] # keep only the witnesses that are seen by > 2/3 of the mid nodes strongly_seen_witnesses: list["Node"] = [] for witness in seen_witnesses_in_round_r: - if count_seens[witness.node_hash] > 2 * N / 3: + if count_seens.get(witness.node_hash, 0) > 2 * N / 3: strongly_seen_witnesses.append(witness) return strongly_seen_witnesses @@ -1165,20 +1189,15 @@ def check_round_number_of_non_genesis_node_with_valid_parents(self, dest_node: N self_parent_node = self.get_node_by_hash(dest_node.self_parent_hash) (self_parent_node is not None) or (_ for _ in ()).throw(ValueError("self_parent_node must be non-None")) - if self_parent_node.round < r - 1: + if self_parent_node.round < r - 1 or self_parent_node.round > r: return False - if self_parent_node.round == r - 1: # the dest_node is a witness of round r so it must strongly sees > 2/3 of witnesses of round r-1 strongly_seen_witnesses_in_round_r_minus_1 = self.get_strongly_seen_valid_witnesses(dest_node, r-1) is_witness_of_round_r = len(strongly_seen_witnesses_in_round_r_minus_1) > 2 * N / 3 - print(f"Peer {self.peer_id} checking if {dest_node} is a witness of round {r}: it strongly sees {strongly_seen_witnesses_in_round_r_minus_1}") - if dest_node.node_hash == "03b372de": - sys.exit(0) if not is_witness_of_round_r: return False - # the dest_node must not strongly sees > 2/3 of witnesses of round r strongly_seen_witnesses_in_round_r = self.get_strongly_seen_valid_witnesses(dest_node, r) @@ -1198,17 +1217,23 @@ def verify_node(self, node: Node = None) -> bool: if node is None: return False + if node.peer_id in self.equivocated_peers: + return False + # an honest peer must not accept a node which itself or its parents are from equivocated peers try: if node.is_genesis(): return True + parent_node = self.get_node_by_hash(node.self_parent_hash) + # TODO: verify newly_seen_txs_list of the node + (parent_node is not None) or (_ for _ in ()).throw(ValueError("parent_node must be non-None")) + (parent_node.node_hash == self.seen_valid_nodes[parent_node.peer_id][-1].node_hash or (self.is_adversary and node.peer_id == self.peer_id)) or (_ for _ in ()).throw(ValueError("parent_node must be the last node of the sender")) if not node.validate_node_data(self.network.get_peer_pubkey(node.peer_id)): return False predecessors = self.get_predecessors(node) predecessors = self.get_predecessors(node) - # must have valid parents if len(predecessors) < 2: return False @@ -1219,23 +1244,19 @@ def verify_node(self, node: Node = None) -> bool: is_allowed_to_bypass = self.is_adversary and node_to_check.peer_id == self.peer_id # adversary don't accept invalid nodes from other adversaries if not is_allowed_to_bypass: return False - # verify the batch proposal if it exists if self.verify_node_is_head_node(node): if not self.verify_batch_proposal_is_valid(node): return False else: (node.metadata.batch_proposal is None) or (_ for _ in ()).throw(ValueError("batch proposal must be None")) - # verify the vote if not self.verify_vote_is_valid(node): return False - except Exception as e: # adversary sending invalid nodes print("error = ", e) return False - return self.check_round_number_of_non_genesis_node_with_valid_parents(node) def get_all_transactions(self) -> Set[TransactionId]: @@ -1309,7 +1330,8 @@ async def gossip_push(self): # generate a random permutation of connected peers # try to extend the node sequence and push it to the neighbors - self_parent_node = self.get_my_last_node() if not self.is_adversary else self.random_instance.choice(self.my_nodes()) + should_equivocate = self.is_adversary and self.network.random_instance.random() < self.equivocation_prob + self_parent_node = self.get_my_last_node() if not should_equivocate else self.random_instance.choice(self.my_nodes()[-2:]) # pick a random peer with non-empty seen_valid_nodes possible_cross_peers = [peer_id for peer_id in self.seen_valid_nodes if self.seen_valid_nodes[peer_id]] @@ -1317,7 +1339,7 @@ async def gossip_push(self): return # can't extend the node sequence because there is no cross parent for the new node randomness = min([self.network.random_instance.random(), self.network.random_instance.random()]) - num_nodes_to_create = min(len(possible_cross_peers), 1 + (1 if randomness < self.equivocation_prob else 0)) + num_nodes_to_create = min(len(possible_cross_peers), 1 + (1 if should_equivocate else 0)) new_nodes = [] # if there are more than 1 node in this list, they are equivocated nodes and that means current peer is an adversary # NOTE: currently, the equivocation logic is simple, an adversary basically picks the last node of the current peer as the self parent, and the latest nodes of different cross peers as the cross parents @@ -1328,8 +1350,11 @@ async def gossip_push(self): for i in range(max_num_retries): cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) - while cross_parent_peer_id == self.peer_id: - cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) + for j in range(10000): + if cross_parent_peer_id == self.peer_id or cross_parent_peer_id in self.equivocated_peers: + cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) + else: + break cross_parent_node = self.seen_valid_nodes[cross_parent_peer_id][-1] @@ -1356,7 +1381,14 @@ async def gossip_push(self): return for new_node in new_nodes: - (self.verify_node_and_add_to_local_view(new_node)) or (_ for _ in ()).throw(ValueError("failed to verify and add new node")) + res = self.verify_node_and_add_to_local_view(new_node) + if res: + print(f"Peer {self.peer_id} successfully added self-node {new_node.node_hash} to its local view") + else: + print(f"Peer {self.peer_id} failed to add self-node {new_node.node_hash} to its local view") + (res or self.is_adversary) or (_ for _ in ()).throw(ValueError("failed to verify and add new node")) + + self.select_neighbors(self.network.get_all_peer_ids()) # re-select neighbors # start gossiping to neighboring peers for i in range(len(self.neighbors)): @@ -1364,8 +1396,6 @@ async def gossip_push(self): # select randomly nodes from new_nodes node_to_send = (new_nodes[0] if i * 2 < len(self.neighbors) else new_nodes[-1]).clone() # simulate the process of serializing and deserializing the nodes in internet protocols - print(f"Peer {self.peer_id} try to gossip to {other_peer_id} node {node_to_send.node_hash}:") - # Send the selected node success = await self.network.gossip_send_node_and_ancestry(self.peer_id, other_peer_id, node_to_send) if not success: @@ -1389,7 +1419,6 @@ async def main(): if is_adversary: count_adversary += 1 - peer = ConsensusPeer( peer_id=f"P{i}", is_adversary=is_adversary, @@ -1412,7 +1441,7 @@ async def main(): # Main consensus loop i = 0 - while True and i < 40: + while True and i < 1000: i += 1 if i % 100 == 0: print(f"{i}th iteration") @@ -1442,7 +1471,7 @@ async def main(): # Create checkpoints periodically # TODO: create network checkpoints dynamically via network.create_checkpoint() - for peer in peers: + for i, peer in enumerate(peers): peer.visualize_view() if peer.is_adversary: print(f"Peer {peer.peer_id} is an adversary") @@ -1471,7 +1500,13 @@ def validate_consistency(): validate_consistency() -# Run the simulation + print(f"Total number of transactions = {len(network.global_mempool)}") + for peer in peers: + print(f"Peer {peer.peer_id} has {sum([len(peer.seen_valid_nodes[peer_id]) for peer_id in peer.seen_valid_nodes])} nodes") + print(f"Peer {peer.peer_id} has {len(peer.neighbors)} neighbors") + print(f"Peer {peer.peer_id} sees that {peer.equivocated_peers} peers are equivocated") + print(f"Peer {peer.peer_id} is at round {peer.current_round}") + asyncio.run(main()) ### Possible attacks: @@ -1481,5 +1516,4 @@ def validate_consistency(): # => recursive validity proof + proof of finality # Solution: Post-Unstaking Slashing for X blocks after unstaking (but not able to withdraw before X blocks yet) -# [] TODO: finish gossip-DAG architecture # [] TODO: finish DAGPool's order fairness gadget \ No newline at end of file From 206deba067060ca262eedaf13fb7d2a9e88ae8d8 Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Fri, 14 Mar 2025 12:44:20 +0700 Subject: [PATCH 09/10] feat: add 3-round deferred ordering and fair order finalization --- dagpool/consensus_client.py | 614 +++++++++++++++++++++++------------- dagpool/graph.py | 133 +++++--- 2 files changed, 491 insertions(+), 256 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index 4b90ad8..2ac6c5d 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -1,3 +1,4 @@ +# self_parent_hash=fc7983b5, cross_parent_hash=223e2386 import asyncio import sys import random @@ -14,7 +15,7 @@ from schemas import TransactionId, PeerId, NodeId, NodeLabel, Signature, BatchId, Pubkey, HashValue import json from math import log2, ceil # Added log2 and ceil imports -from graph import TournamentGraph +from graph import SemiCompleteDiGraph import builtins original_print = builtins.print @@ -24,18 +25,32 @@ def print(*args, **kwargs): if SHOULD_PRINT: original_print(*args, **kwargs) -BEACON_PACE = 4 +printed = set() + +def ddebug(peer_id: PeerId, *args, **kwargs): + if SHOULD_PRINT: + if len(printed) <= 0: + with open("debug.txt", "w") as f: + f.write("") + # write to file debug.txt + with open("debug.txt", "a") as f: + line = f"{peer_id}: {' '.join([str(arg) for arg in args])}\n" + if line not in printed: + f.write(line) + printed.add(line) + +BEACON_PACE = 10 # the first BEACON_PACE rounds are derived directly from the list of peers # BEACON_PACE should be chosen large enough to make sure peers have enough time to realize that they are the leader of the next rounds BEACON_FIELD_PRIME = 28948022309329048855892746252171976963363056481941560715954676764349967630337 # equal to Pallas base field prime GENESIS_BATCH = "" EMPTY_NODE_HASH = "" -ORDER_FAIRNESS_THRESHOLD = 4/7 +ORDER_FAIRNESS_THRESHOLD = 0.51 class Utils: @staticmethod - def merkle_root_of_transaction_list(txs: list[TransactionId]) -> NodeId: - len(txs) > 0 or (_ for _ in ()).throw(ValueError("txs is empty")) + def merkle_root_of_transaction_list(txs: list[TransactionId]) -> HashValue: + if len(txs) == 0: return EMPTY_NODE_HASH # do the merkle tree construction using a while loop res = [tx for tx in txs] while len(res) > 1: @@ -54,36 +69,56 @@ class Vote: def __init__(self, head_batch_hash: HashValue): self.head_batch_hash = head_batch_hash + def __str__(self): + return f"Vote(head_batch_hash={self.head_batch_hash})" + class BatchProposal: batch_hash: HashValue prev_batch_hash: HashValue - final_fair_ordering: List[TransactionId] + deferred_ordering: List[TransactionId] # the deferred ordering of the third last head witness (if exists) from current witness backwards next_beacon_randomness: HashValue # peers use this to derive the leader of round r + BEACON_PACE + solid_transactions: Set[TransactionId] # transactions that are received by n - f peers at this batch - def __init__(self, prev_batch_hash: HashValue, final_fair_ordering: List[TransactionId], prev_beacon_randomness: HashValue): + def __init__(self, round_number: int, prev_batch_hash: HashValue, deferred_ordering: List[TransactionId], prev_beacon_randomness: HashValue, solid_transactions: Set[TransactionId]): self.prev_batch_hash = prev_batch_hash - self.final_fair_ordering = final_fair_ordering - self.next_beacon_randomness = self.compute_next_beacon_randomness(round, prev_beacon_randomness, final_fair_ordering) + self.deferred_ordering = deferred_ordering + self.solid_transactions = solid_transactions + self.round_number = round_number + self.prev_beacon_randomness = prev_beacon_randomness + self.next_beacon_randomness = self.compute_next_beacon_randomness() self.batch_hash = self.compute_batch_hash() + self.non_solid_txs: List[TransactionId] = [] + + def store_non_solid_txs(self, txs_list: List[TransactionId]): + self.non_solid_txs = txs_list # store all the non solid transactions that aren't included in all the batch up to it + + def merkle_root_of_deferred_ordering(self) -> HashValue: + return Utils.merkle_root_of_transaction_list(self.deferred_ordering) + + def merkle_root_of_solid_transactions(self) -> HashValue: + return Utils.merkle_root_of_transaction_list(sorted(list(self.solid_transactions))) - def compute_next_beacon_randomness(self, round_number: int, prev_beacon_randomness: HashValue, final_fair_ordering: List[TransactionId]) -> HashValue: - components = [prev_beacon_randomness, str(round_number + BEACON_PACE), Utils.merkle_root_of_transaction_list(final_fair_ordering)] + def compute_next_beacon_randomness(self) -> HashValue: + components = [self.prev_beacon_randomness, str(self.round_number + BEACON_PACE), self.prev_batch_hash, self.merkle_root_of_deferred_ordering(), self.merkle_root_of_solid_transactions()] return hashlib.sha256(''.join(components).encode()).hexdigest() def compute_batch_hash(self) -> HashValue: - # use hashlib of [prev_batch_hash, final_fair_ordering, next_beacon_randomness] - components = [self.prev_batch_hash, Utils.merkle_root_of_transaction_list(self.final_fair_ordering), self.next_beacon_randomness] + # use hashlib of [prev_batch_hash, deferred_ordering, next_beacon_randomness] + components = [str(self.round_number), self.prev_batch_hash, self.merkle_root_of_deferred_ordering(), self.merkle_root_of_solid_transactions(), self.next_beacon_randomness] return hashlib.sha256(''.join(components).encode()).hexdigest()[:8] - def verify_batch_proposal_is_well_formed(self, round_number: int, prev_beacon_randomness: HashValue) -> bool: - if not self.next_beacon_randomness == self.compute_next_beacon_randomness(round_number, prev_beacon_randomness, self.final_fair_ordering): + def verify_batch_proposal_is_well_formed(self) -> bool: + if not self.next_beacon_randomness == self.compute_next_beacon_randomness(): return False if not self.batch_hash == self.compute_batch_hash(): return False return True def clone(self): - return BatchProposal(self.prev_batch_hash, [tx for tx in self.final_fair_ordering], self.next_beacon_randomness) + return BatchProposal(round_number=self.round_number, prev_batch_hash=self.prev_batch_hash, deferred_ordering=[tx for tx in self.deferred_ordering], prev_beacon_randomness=self.prev_beacon_randomness, solid_transactions=[tx for tx in self.solid_transactions]) + + def __str__(self): + return f"BatchProposal(batch_hash={self.batch_hash}, prev_batch_hash={self.prev_batch_hash}, prev_beacon_randomness={self.prev_beacon_randomness}, deferred_ordering={self.deferred_ordering}, next_beacon_randomness={self.next_beacon_randomness}, solid_transactions={self.solid_transactions})" class NodeMetadata: batch_proposal: BatchProposal # can be None if the node is not a head node (head node means the witness of the leader in its selected round) @@ -167,7 +202,7 @@ def validate_node_data(self, creator_pubkey: Pubkey) -> bool: # TODO: validate that all the node data is well-formed (use schema validator) if self.metadata.batch_proposal: - if not self.metadata.batch_proposal.verify_batch_proposal_is_well_formed(self.round, self.metadata.batch_proposal.prev_beacon_randomness): + if not self.metadata.batch_proposal.verify_batch_proposal_is_well_formed(): return False # it must vote for its own batch proposal if self.metadata.vote.head_batch_hash != self.metadata.batch_proposal.batch_hash: @@ -189,7 +224,9 @@ def get_seen_valid_peers(self) -> Set[PeerId]: eq def __str__(self): - return f"{"GENESIS " if self.is_genesis() else ""}Node(node_hash={self.node_hash}, peer_id={self.peer_id}, round={self.round}, is_witness={self.is_witness}, self_parent_hash={self.self_parent_hash}, cross_parent_hash={self.cross_parent_hash}, newly_seen_txs_list={self.newly_seen_txs_list})" # , equivocated_peers={self.equivocated_peers}, seen_nodes={self.seen_nodes})" + batch_info = ("batch_hash=" + self.metadata.batch_proposal.__str__() if self.metadata.batch_proposal else "") + vote_info = ("vote=" + self.metadata.vote.__str__() if self.metadata.vote else "none") + return f"{"GENESIS " if self.is_genesis() else ""}Node(node_hash={self.node_hash}, peer_id={self.peer_id}, round={self.round}, is_witness={self.is_witness}, self_parent_hash={self.self_parent_hash}, cross_parent_hash={self.cross_parent_hash}, newly_seen_txs_list={self.newly_seen_txs_list}, {batch_info}, {vote_info})" # , equivocated_peers={self.equivocated_peers}, seen_nodes={self.seen_nodes})" class ConnectionState(Enum): CLOSED = 0 @@ -234,16 +271,16 @@ def __init__(self, latency_ms_range=(50, 200), packet_loss_prob=0.1, random_inst # TODO: use secure cryptographic randomness source def get_first_beacon_randomness(self) -> HashValue: all_peers = [p.peer_id for p in self.peers] - value_bytes = hashlib.sha256(",".join(all_peers).encode()).digest() - value = int.from_bytes(value_bytes, "big") - return value + return hashlib.sha256(",".join(all_peers).encode()).hexdigest() def get_first_leaders(self) -> List[PeerId]: # for round 1 -> BEACON_PACE all_peers = [p.peer_id for p in self.peers] first_beacon_randomness = self.get_first_beacon_randomness() # choose the first BEACON_PACE leaders for the first BEACON_PACE rounds, using the first BEACON_FIELD_PRIME as modulo somehow beacon_random = random.Random(first_beacon_randomness) - leaders = beacon_random.sample(all_peers, BEACON_PACE) + leaders = [] + for i in range(BEACON_PACE): + leaders.append(beacon_random.choice(all_peers)) return leaders def get_leader_after_BEACON_PACE_rounds(self, beacon_randomness: HashValue) -> PeerId: @@ -322,33 +359,6 @@ def new_txs_from_user_client(self) -> (list[TransactionId], list['ConsensusPeer' return txs, [peer for peer in self.peers if peer.peer_id in peer_ids_to_send] - def get_accumulated_txs_until_node(self, node: Node) -> Set[TransactionId]: - """Get all transactions accumulated up to a specific node""" - # First, find the latest checkpoint before this node - latest_applicable_checkpoint = None - for checkpoint in reversed(self.checkpoints): - if checkpoint.round < node.round: - latest_applicable_checkpoint = checkpoint - break - - accumulated_txs = set() - if latest_applicable_checkpoint: - accumulated_txs.update(latest_applicable_checkpoint.accumulated_txs) - - # Add transactions from the node and its ancestors back to the checkpoint - def collect_txs(current_node): - if not current_node or (latest_applicable_checkpoint and - current_node.round <= latest_applicable_checkpoint.round): - return - accumulated_txs.update(current_node.newly_seen_txs_list) - if current_node.self_parent_hash: - collect_txs(current_node.self_parent_hash) - if current_node.cross_parent_hash: - collect_txs(current_node.cross_parent_hash) - - collect_txs(node) - return accumulated_txs - async def connect(self, peer_a: PeerId, peer_b: PeerId): """Simulate TCP three-way handshake""" conn_key = (peer_a, peer_b) @@ -436,10 +446,10 @@ def trace(node: Node): try: receiver_peer.verify_node_and_add_to_local_view(current_node.clone()) or (_ for _ in ()).throw(ValueError("failed to verify and add node")) except Exception as e: - print(f"Peer {receiver_peer.peer_id} rejected node {current_node.node_hash} from {sender_peer.peer_id}: {e}") + # print(f"Peer {receiver_peer.peer_id} rejected node {current_node.node_hash} from {sender_peer.peer_id} created by {current_node.peer_id}: {e}") continue - print(f"Peer {receiver_peer.peer_id} accepted node {current_node.node_hash} from {sender_peer.peer_id}") + # print(f"Peer {receiver_peer.peer_id} accepted node {current_node.node_hash} from {sender_peer.peer_id} created by {current_node.peer_id}") return receiver_peer.has_seen_valid_node(node1) @@ -482,6 +492,8 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.local_graph: Dict[NodeId, Set[NodeId]] = {} # node_hash to set of node_hashes (parent, child) # batch-related data self.observed_valid_batches: Dict[HashValue, Node] = {} # map from the hash value of the batch to the node that proposes it + self.first_inclusion_of_txs_at_peer: Dict[PeerId, Dict[TransactionId, NodeId]] = {} # map from peer_id to map from tx_id to the first time it is included at that peer + self.cached_heaviest_batch_amongst_strict_ancestors: Dict[NodeId, HashValue] = {} def my_nodes(self) -> List[Node]: return self.seen_valid_nodes[self.peer_id] @@ -516,30 +528,6 @@ def get_successors(self, node: Node) -> list[Node]: successors.append(self.get_node_by_hash(successor)) return sorted(successors, key=lambda x: x.round) - def get_ancestry(self, node: Node) -> Set[Node]: - node is not None or (_ for _ in ()).throw(ValueError("node must be non-None to find its ancestry")) - ancestry: Dict[NodeId, Node] = {} - def dfs_backwards(node: Node): - if node.node_hash in ancestry: - return - ancestry[node.node_hash] = node - for predecessor in self.get_predecessors(node): - dfs_backwards(predecessor) - dfs_backwards(node) - return set(ancestry.values()) - - def get_lineage(self, node: Node) -> Set[Node]: - node is not None or (_ for _ in ()).throw(ValueError("node must be non-None to find its lineage")) - lineage: Dict[NodeId, Node] = {} - def dfs_forwards(node: Node): - if node.node_hash in lineage: - return - lineage[node.node_hash] = node - for successor in self.get_successors(node): - dfs_forwards(successor) - dfs_forwards(node) - return set(lineage.values()) - def get_graph_info(self) -> Dict[NodeId, Node]: res: Dict[NodeId, Node] = {} for peer_id in self.seen_valid_nodes: @@ -549,10 +537,10 @@ def get_graph_info(self) -> Dict[NodeId, Node]: return res def visualize_view(self): - print(f"peer {self.peer_id} sees the DAG:") - for peer_id in self.seen_valid_nodes: - for node in self.seen_valid_nodes[peer_id]: - print(f"{node}") + # print(f"peer {self.peer_id} sees the DAG:") + # for peer_id in self.seen_valid_nodes: + # for node in self.seen_valid_nodes[peer_id]: + # print(f"{node}") # pos = {} round_colors = { @@ -708,15 +696,21 @@ def get_heaviest_batch_amongst_strict_ancestors(self, node: Node) -> HashValue: """ Get the heaviest batch in the strict ancestors of the given node using a fork-choice rule based on the Heaviest Observed Subtree (HOS) selection rule """ + if node.node_hash in self.cached_heaviest_batch_amongst_strict_ancestors: + return self.cached_heaviest_batch_amongst_strict_ancestors[node.node_hash] + # 1. construct a tree of batch proposals tree: Dict[HashValue, Set[HashValue]] = {GENESIS_BATCH: set()} - - for node in self.observed_valid_batches.values(): # TODO: only use the batches from the last finalized branch, so we can filter out the batches from equivocated peers - tree[node.batch_hash] = set() + + valid_batches = self.observed_valid_batches.values() + valid_batches = [cur_node for cur_node in valid_batches if self.is_seen_by(cur_node.node_hash, node)] + + for node in valid_batches: # TODO: only use the batches from the last finalized branch, so we can filter out the batches from equivocated peers + tree[node.metadata.batch_proposal.batch_hash] = set() parent_batch_hash = node.metadata.batch_proposal.prev_batch_hash if parent_batch_hash not in tree: tree[parent_batch_hash] = set() - tree[parent_batch_hash].add(node.batch_hash) + tree[parent_batch_hash].add(node.metadata.batch_proposal.batch_hash) # 2. collect the latest votes of all peers that the node can see weight_of_branch: Dict[HashValue, int] = {GENESIS_BATCH: 0} @@ -739,8 +733,10 @@ def traverse(batch_hash: HashValue): if len(tree[heaviest_batch]) == 0: break heaviest_batch = max(tree[heaviest_batch], key=lambda x: weight_of_branch[x]) + + self.cached_heaviest_batch_amongst_strict_ancestors[node.node_hash] = heaviest_batch return heaviest_batch - + def get_beacon_randomness_of_batch(self, batch_hash: HashValue) -> HashValue: """ Get the beacon randomness of the given batch @@ -763,126 +759,286 @@ def verify_node_is_head_node(self, node: Node) -> bool: """ Verify that the given node is a head node, given that the vote is valid """ - return False if not node.is_witness or node.is_genesis(): # this function might be called before the batch proposal is constructed so we only needs to check whether the node is a witness return False + heaviest_batch_amongst_strict_ancestors = self.get_heaviest_batch_amongst_strict_ancestors(node) + # find the batch that contains the randomness that determines the leader of the current round - beacon_batch:HashValue = node.metadata.vote.head_batch_hash - steps_back = BEACON_PACE - 1 + (1 if node.metadata.batch_proposal is not None else 0) + beacon_batch:HashValue = heaviest_batch_amongst_strict_ancestors + steps_back = BEACON_PACE - for i in range(steps_back): - if beacon_batch == "": + for i in range(steps_back - 1): + if beacon_batch == GENESIS_BATCH: break # find parent batch of the current beacon batch - beacon_batch: HashValue = self.observed_valid_batches[beacon_batch].metadata.batch_proposal.prev_batch_hash + beacon_batch = self.observed_valid_batches[beacon_batch].metadata.batch_proposal.prev_batch_hash + + beacon_randomness = str(0 if beacon_batch == GENESIS_BATCH else self.get_beacon_randomness_of_batch(beacon_batch)) + ":" + str(node.round) + beacon_randomer = random.Random(beacon_randomness) + all_peers = [p.peer_id for p in self.network.peers] + return beacon_randomer.choice(all_peers) == node.peer_id + + # compute the leader of the current round based on the beacon randomness - if beacon_batch == "": - # genesis beacon randomness - first_leaders: List[PeerId] = self.network.get_first_leaders() - is_head_node = len(first_leaders) == BEACON_PACE and node.round > 0 and node.round <= BEACON_PACE and node.peer_id == first_leaders[node.round - 1] - return is_head_node - else: - # non-genesis beacon randomness - return node.peer_id == self.network.get_leader_after_BEACON_PACE_rounds(self.get_beacon_randomness_of_batch(beacon_batch)) + # if beacon_batch == GENESIS_BATCH: + # first_leaders: List[PeerId] = self.network.get_first_leaders() + # is_head_node = len(first_leaders) == BEACON_PACE and node.round > 0 and node.round <= BEACON_PACE and node.peer_id == first_leaders[node.round - 1] + # return is_head_node + # else: + # # non-genesis beacon randomness + # return node.peer_id == self.network.get_leader_after_BEACON_PACE_rounds(self.get_beacon_randomness_of_batch(beacon_batch)) + + def is_seen_by(self, node_id: NodeId, dest_node: Node) -> bool: + if dest_node == None: + return False + + if dest_node.node_hash == node_id: + return True + + for mid_node_hash in dest_node.latest_seen_node_by_peers.values(): + if self.is_valid_descendant_and_self_ancestor(mid_node_hash, node_id): + return True + + return False - def get_truncated_cone(self, prev_batch_hash: HashValue, head_node: Node) -> Dict[NodeId, Node]: + def check_transaction_included_up_to_batch(self, tx: TransactionId, batch_hash: HashValue) -> bool: + # TODO: use efficient lookup strategy + while batch_hash != GENESIS_BATCH: + if tx in self.observed_valid_batches[batch_hash].metadata.batch_proposal.solid_transactions: + return True + batch_hash = self.observed_valid_batches[batch_hash].metadata.batch_proposal.prev_batch_hash + return False + + + def map_to_cone_region(self, first_inclusion_location: NodeId, left_boundary_node: Node, head_witness: Node) -> NodeId: """ - Get the truncated cone of the given heaviest batch and node + Map the first inclusion location of a transaction to the cone region of the given head witness """ + if first_inclusion_location == EMPTY_NODE_HASH: + return EMPTY_NODE_HASH + + # check if the first inclusion location is seen by head_witness but not left_boundary_node + if self.is_seen_by(first_inclusion_location, head_witness) and not self.is_seen_by(first_inclusion_location, left_boundary_node): + return first_inclusion_location + + return EMPTY_NODE_HASH + + def print_final_transaction_order(self) -> List[TransactionId]: + res: List[TransactionId] = [] + heaviest_batch = self.get_heaviest_batch_amongst_strict_ancestors(self.get_my_last_node()) + + while heaviest_batch != GENESIS_BATCH: + cur_witness: Node = self.get_node_of_batch(heaviest_batch) + res.extend(cur_witness.metadata.batch_proposal.deferred_ordering[::-1]) + heaviest_batch = self.observed_valid_batches[heaviest_batch].metadata.batch_proposal.prev_batch_hash + return res[::-1] - node_of_prev_batch: Node = self.get_node_of_batch(prev_batch_hash) + def construct_semi_complete_digraph_of_transactions(self, last_3_head_witnesses: List[Node]) -> SemiCompleteDiGraph: + (len(last_3_head_witnesses) == 3) or (_ for _ in ()).throw(ValueError("last_3_head_witnesses must have length 3")) - # the difference between the ancestry of the current node and the ancestry of the node containing the previous batch - ancestry_of_cur_node = self.get_ancestry(head_node) - truncated_cone_candidates: Set[Node] = ancestry_of_cur_node.difference(set() if node_of_prev_batch is None else self.get_ancestry(node_of_prev_batch)) - truncated_cone: Set[Node] = set() + if any(head_witness is None for head_witness in last_3_head_witnesses): + return SemiCompleteDiGraph() - for node in truncated_cone_candidates: - if node.node_hash in head_node.seen_nodes: # only picks the one that the head node can SEE - truncated_cone.add(node) + dest_node = last_3_head_witnesses[2] + rounds: List[int] = [node.round for node in last_3_head_witnesses] + ordered_head_witness: Node = last_3_head_witnesses[0] + left_boundary_node: Node = self.get_node_of_batch(ordered_head_witness.metadata.batch_proposal.prev_batch_hash) + solid_transactions: Set[TransactionId] = ordered_head_witness.metadata.batch_proposal.solid_transactions + dependency_graph: SemiCompleteDiGraph = SemiCompleteDiGraph() + for u in solid_transactions: + dependency_graph.add_node(u) + + traced: Dict[NodeId, bool] = {} + nodes_by_peer_id: Dict[PeerId, Node] = {} + + all_legitimate_peers = [peer_id for peer_id in self.network.get_all_peer_ids() if peer_id not in dest_node.equivocated_peers] # TODO: handle the case that the list of peers changes + + preference_graphs: Dict[PeerId, List[TransactionId]] = {} + for peer_id in all_legitimate_peers: + preference_graphs[peer_id] = [] + + solid_edge: Set[Tuple[TransactionId, TransactionId]] = set() + for tx1 in solid_transactions: + for tx2 in solid_transactions: + if tx1 < tx2: + count_receipts_of_edge_at_batches: List[int] = [0] * len(last_3_head_witnesses) + count_opposite_receipts_of_edge_at_batches: List[int] = [0] * len(last_3_head_witnesses) + + for i in range(len(last_3_head_witnesses)): + head_witness: Node = last_3_head_witnesses[i] + for peer_id in all_legitimate_peers: + if peer_id not in self.first_inclusion_of_txs_at_peer: continue + first_inclusion_of_tx1_at_peer = self.first_inclusion_of_txs_at_peer[peer_id].get(tx1, EMPTY_NODE_HASH) + first_inclusion_of_tx2_at_peer = self.first_inclusion_of_txs_at_peer[peer_id].get(tx2, EMPTY_NODE_HASH) + # map the first inclusion locations into the cone region of the current head witness + mapped_first_inclusion_of_tx1_at_peer = self.map_to_cone_region(first_inclusion_of_tx1_at_peer, left_boundary_node, head_witness) + mapped_first_inclusion_of_tx2_at_peer = self.map_to_cone_region(first_inclusion_of_tx2_at_peer, left_boundary_node, head_witness) + + is_left_seen = mapped_first_inclusion_of_tx1_at_peer != EMPTY_NODE_HASH + is_right_seen = mapped_first_inclusion_of_tx2_at_peer != EMPTY_NODE_HASH + + if is_left_seen: + if not is_right_seen: # (tx1 -> tx2) + count_opposite_receipts_of_edge_at_batches[i] += 1 + else: + if is_right_seen: # (tx2 -> tx1) + count_receipts_of_edge_at_batches[i] += 1 + + if is_left_seen and is_right_seen: + if self.is_seen_by(mapped_first_inclusion_of_tx1_at_peer, self.get_node_by_hash(mapped_first_inclusion_of_tx2_at_peer)): # (tx1 -> tx2) + count_receipts_of_edge_at_batches[i] += 1 + else: # (tx2 -> tx1) + count_opposite_receipts_of_edge_at_batches[i] += 1 + + # solid edge + if count_receipts_of_edge_at_batches[2] >= ORDER_FAIRNESS_THRESHOLD * len(all_legitimate_peers): + dependency_graph.add_directed_edge(tx1, tx2) + solid_edge.add((tx1, tx2)) + elif count_opposite_receipts_of_edge_at_batches[2] >= ORDER_FAIRNESS_THRESHOLD * len(all_legitimate_peers): + dependency_graph.add_directed_edge(tx2, tx1) + else: + # soft edge + for i in range(len(last_3_head_witnesses)): + if count_receipts_of_edge_at_batches[i] >= count_opposite_receipts_of_edge_at_batches[i]: + dependency_graph.add_directed_edge(tx1, tx2) + elif count_receipts_of_edge_at_batches[i] <= count_opposite_receipts_of_edge_at_batches[i]: + dependency_graph.add_directed_edge(tx2, tx1) + + dependency_graph.assert_is_semi_complete_digraph() # there must be at least one edge between any two transactions - res: Dict[NodeId, Node] = {node.node_hash: node for node in truncated_cone} - # TODO: we might implement the logic that allows a leader peer to limit the number of rounds of the truncated cone - return res + return dependency_graph - def construct_tournament_graph_of_transactions(self, truncated_cone: Dict[NodeId, Node]) -> TournamentGraph: + def get_solid_transactions_and_non_solid_transactions_in_cone(self, head_node: Node, prev_batch_hash: HashValue) -> Tuple[List[TransactionId], List[TransactionId]]: """ - Construct the tournament graph of the transactions in the truncated cone + Get the solid transactions in the truncated cone of the given head node and previous batch hash """ - all_peers = self.network.get_all_peer_ids() # TODO: handle the case that the list of peers changes - preference_graphs: Dict[PeerId, List[TransactionId]] = {} - for peer_id in all_peers: - preference_graphs[peer_id] = [] + traced: Dict[NodeId, bool] = {} + current_cone: List[Node] = [] - all_txs: Set[TransactionId] = set() + # TODO: gather transactions from previous cones that aren't included as solid transactions + # => Not only transactions in the current truncated cone - visited: Set[NodeId] = set() - def constructPreferenceGraph(node: Node): - visited.add(node.node_hash) + prev_head_witness = self.get_node_of_batch(prev_batch_hash) - if node.self_parent_hash in truncated_cone: - constructPreferenceGraph(self.get_node_by_hash(node.self_parent_hash)) - else: - preference_graphs[node.peer_id].extend(node.newly_seen_txs_list) - - for node in truncated_cone.values(): - if not node.node_hash in visited: - constructPreferenceGraph(node) + unincluded_solid_txs_from_prev_batch: Set[TransactionId] = set() + + if prev_head_witness: + unincluded_solid_txs_from_prev_batch.update(prev_head_witness.metadata.batch_proposal.non_solid_txs) # it now becomes solid due to strongly-seeing property (i.e. n - f peers each of these transactions) - for peer_id in all_peers: - all_txs.update(preference_graphs[peer_id]) + def find_node_in_current_cone(node: Node): + if node.node_hash in traced: + return + + in_prev_cone_ancestry: bool = False if not prev_head_witness else self.is_seen_by(node.node_hash, prev_head_witness) + in_current_cone: bool = node.node_hash == head_node.node_hash or (self.is_seen_by(node.node_hash, head_node) and not in_prev_cone_ancestry) + + if not in_current_cone: + traced[node.node_hash] = False + return + + traced[node.node_hash] = True - ## validate the preference graphs: only contains unique transactions - for peer_id in all_peers: - (len(preference_graphs[peer_id]) == len(set(preference_graphs[peer_id]))) or (_ for _ in ()).throw(ValueError("preference graphs must contain unique transactions")) + predecessors = self.get_predecessors(node) + for predecessor in predecessors: + find_node_in_current_cone(predecessor) - ## construct the final tournament graph - tournament_graph: TournamentGraph = TournamentGraph() - count_edge_frequency: Dict[Tuple[TransactionId, TransactionId], int] = {} + current_cone.append(node) - for peer_id in all_peers: - for i in range(len(preference_graphs[peer_id])): - for j in range(i + 1, len(preference_graphs[peer_id])): - tx1 = preference_graphs[peer_id][i] - tx2 = preference_graphs[peer_id][j] - if (tx1, tx2) not in count_edge_frequency: - count_edge_frequency[(tx1, tx2)] = 0 - count_edge_frequency[(tx1, tx2)] += 1 + find_node_in_current_cone(head_node) + + # filter & keep only the transactions that aren't included in previous batches and received by n - f legitimate peers + count_legitimate_receipts_of_tx: Dict[TransactionId, int] = {} + for node in current_cone: + for tx in node.newly_seen_txs_list: + if node.peer_id not in head_node.equivocated_peers: + count_legitimate_receipts_of_tx[tx] = count_legitimate_receipts_of_tx.get(tx, 0) + 1 + + n = len(self.network.get_all_peer_ids()) + f = (n - 1)/3 + + new_candidate_solid_transactions: Set[TransactionId] = set() + new_pending_txs: List[TransactionId] = [] + + # TODO: use more lightweight data structure here + all_prev_solid_txs: Set[TransactionId] = set() + cur_batch = prev_batch_hash + while cur_batch != GENESIS_BATCH: + cur_head_witness = self.observed_valid_batches[cur_batch] + all_prev_solid_txs.update(cur_head_witness.metadata.batch_proposal.solid_transactions) + cur_batch = cur_head_witness.metadata.batch_proposal.prev_batch_hash + + for tx in count_legitimate_receipts_of_tx: + if tx not in all_prev_solid_txs and tx not in unincluded_solid_txs_from_prev_batch: + if count_legitimate_receipts_of_tx[tx] >= n - f: + new_candidate_solid_transactions.add(tx) + else: + new_pending_txs.append(tx) + + solid_transactions: List[TransactionId] = sorted(list(new_candidate_solid_transactions.union(unincluded_solid_txs_from_prev_batch))) + pending_txs: List[TransactionId] = sorted(new_pending_txs) + + return solid_transactions, pending_txs + + def get_deferred_ordering_of_transactions(self, last_3_head_witnesses: List[Node]) -> List[TransactionId]: + """ + Get the deferred ordering of the transactions in the truncated cone of the given head node and previous batch hash + """ + # gather all transactions in the truncated cone into a tournament graph + dependency_graph: SemiCompleteDiGraph = self.construct_semi_complete_digraph_of_transactions(last_3_head_witnesses) + # calculate the fair ordering of the transactions and construct the batch proposal + sccs = dependency_graph.find_strongly_connected_components() + sccs_with_fair_orderings = [[scc, dependency_graph.find_hamiltonian_path(scc[0])] for scc in sccs] + deferred_ordering: List[TransactionId] = [] + for scc, fair_ordering in sccs_with_fair_orderings: + deferred_ordering.extend([TransactionId(tx) for tx in fair_ordering]) - for (tx1, tx2), frequency in count_edge_frequency.items(): - if frequency > ORDER_FAIRNESS_THRESHOLD * len(all_peers): - tournament_graph.add_edge(tx1, tx2) + assert len(deferred_ordering) == len(dependency_graph.nodes) + print(f"Peer {self.peer_id} got deferred ordering for node_hash = {last_3_head_witnesses[2].node_hash} = {deferred_ordering}") - return preference_graphs + return deferred_ordering - def compute_batch_proposal(self, node: Node) -> BatchProposal: + def compute_batch_proposal(self, dest_node: Node) -> BatchProposal: """ Compute a batch proposal for the given node """ # fork-choice rule: select the previous batch using the Heaviest Observed Subtree (HOS) selection rule - heaviest_batch: HashValue = self.get_heaviest_batch_amongst_strict_ancestors(node) - prev_batch_beacon_randomness: HashValue = self.get_beacon_randomness_of_batch(heaviest_batch) + + heaviest_batch: HashValue = self.get_heaviest_batch_amongst_strict_ancestors(dest_node) + prev_beacon_randomness: HashValue = self.get_beacon_randomness_of_batch(heaviest_batch) + + ## find the last 3 head witnesses + last_3_head_witnesses: List[Node] = [dest_node] + cur_batch_hash = heaviest_batch + for i in range(2): + if cur_batch_hash == GENESIS_BATCH: + last_3_head_witnesses.append(None) + else: + prev_head_witness = self.observed_valid_batches[cur_batch_hash] if cur_batch_hash in self.observed_valid_batches else None + (prev_head_witness is not None) or (_ for _ in ()).throw(ValueError("prev_head_witness must be non-None")) + last_3_head_witnesses.append(prev_head_witness) + cur_batch_hash = prev_head_witness.metadata.batch_proposal.prev_batch_hash + last_3_head_witnesses.reverse() - # construct the truncated cone - truncated_cone: Dict[NodeId, Node] = self.get_truncated_cone(heaviest_batch, node) + # construct the truncated cone of round (r) + deferred_ordering: List[TransactionId] = [] + solid_transactions, non_solid_transactions = self.get_solid_transactions_and_non_solid_transactions_in_cone(dest_node, heaviest_batch) - # gather all transactions in the truncated cone into a tournament graph - tournament_graph: TournamentGraph = self.construct_tournament_graph_of_transactions(truncated_cone) - # calculate the fair ordering of the transactions and construct the batch proposal - sccs = tournament_graph.find_strongly_connected_components() - sccs_with_fair_orderings = [[scc, tournament_graph.find_hamiltonian_cycle(scc)] for scc in sccs] - final_fair_ordering: List[TransactionId] = [] - for scc, fair_ordering in sccs_with_fair_orderings: - final_fair_ordering.extend(fair_ordering) + if len(last_3_head_witnesses) >= 3: + deferred_ordering = self.get_deferred_ordering_of_transactions(last_3_head_witnesses) batch_proposal = BatchProposal( + round_number=dest_node.round, prev_batch_hash=heaviest_batch, - final_fair_ordering=final_fair_ordering, - prev_beacon_randomness=prev_batch_beacon_randomness + deferred_ordering=deferred_ordering, + prev_beacon_randomness=prev_beacon_randomness, + solid_transactions=solid_transactions ) - (batch_proposal.verify_batch_proposal_is_well_formed(node.round, prev_batch_beacon_randomness)) or (_ for _ in ()).throw(ValueError("batch proposal is not well-formed")) + + batch_proposal.store_non_solid_txs(non_solid_transactions) + + (batch_proposal.verify_batch_proposal_is_well_formed()) or (_ for _ in ()).throw(ValueError("batch proposal is not well-formed")) return batch_proposal @@ -890,9 +1046,9 @@ def construct_batch_proposal_if_needed(self, node: Node): """ Construct a batch proposal for the given node. This must be called after the vote is constructed for the node. """ - if node.peer_id != self.peer_id or not self.verify_node_is_head_node(node): + if not self.verify_node_is_head_node(node): return - + batch_proposal: BatchProposal = self.compute_batch_proposal(node) node.metadata = NodeMetadata(batch_proposal=batch_proposal) @@ -906,7 +1062,7 @@ def verify_batch_proposal_is_valid(self, node: Node) -> bool: (node.metadata.batch_proposal is not None) or (_ for _ in ()).throw(ValueError("batch proposal must be non-None")) correct_batch_proposal: BatchProposal = self.compute_batch_proposal(node) - + return node.metadata.batch_proposal.batch_hash == correct_batch_proposal.batch_hash def compute_vote_for_node(self, node: Node) -> Vote: @@ -914,6 +1070,7 @@ def compute_vote_for_node(self, node: Node) -> Vote: Compute a vote for the given node """ is_head_node = self.verify_node_is_head_node(node) + if is_head_node: return Vote(head_batch_hash=node.metadata.batch_proposal.batch_hash) # vote for its own head batch else: @@ -944,10 +1101,9 @@ def fill_node_data(self, node: Node): node.has_filled_node_data = True node.update_signature() except Exception as e: - print(f"Peer {self.peer_id} error when filling data for node {node.node_hash} = {e}") + # error when filling data for the node raise e - print(f"Peer {self.peer_id} finished filling data for node {node.node_hash}, with seen nodes = {node.latest_seen_node_by_peers}") def has_seen_valid_node(self, node: Node) -> bool: return node.node_hash in self.pos_in_seen_valid_nodes @@ -973,8 +1129,8 @@ def is_valid_descendant_and_self_ancestor(self, descendant_node_hash: NodeId, se try: (peer1, pos1) = self.pos_in_seen_valid_nodes[descendant_node_hash] (peer2, pos2) = self.pos_in_seen_valid_nodes[self_ancestor_node_hash] - return peer1 == peer2 and pos1 > pos2 - except: + return peer1 == peer2 and pos1 >= pos2 + except Exception as e: return False def get_self_descendant(self, node_hash_1: NodeId, node_hash_2: NodeId) -> Optional[NodeId]: @@ -996,7 +1152,6 @@ def compute_seen_nodes_of_new_node(self, dest_node: Node): dest_node.seen_votes_by_peers is None or dest_node.latest_seen_witness_by_peers is None) or (_ for _ in ()).throw(ValueError("Node must not have precomputed values.")) - print(f"Peer {self.peer_id} computing seen _nodes for {dest_node.node_hash} {dest_node}") latest_seen_node_by_peers: Dict[PeerId, NodeId] = {} non_equivocated_peers: Set[PeerId] = set() equivocated_peers: Set[PeerId] = set() @@ -1025,11 +1180,14 @@ def aggregate_latest_seen_by_peers( equivocated_peers.add(peer_id) else: dest_dict[peer_id] = seen_node_hash - + # Aggregate data from parents for parent_hash in aggregated_references: if parent_hash == EMPTY_NODE_HASH: continue parent_node = self.get_node_by_hash(parent_hash) + if parent_node is None: + raise ValueError(f"Parent node {parent_hash} is None") + latest_seen_by_peers_of_parent = parent_node.latest_seen_node_by_peers latest_seen_witness_by_peers_of_parent = parent_node.latest_seen_witness_by_peers @@ -1044,7 +1202,7 @@ def aggregate_latest_seen_by_peers( equivocated_peers=equivocated_peers ) - print(f"Peer {self.peer_id} is merging {parent_node} with {parent_node.equivocated_peers}") + # print(f"Peer {self.peer_id} is merging {parent_node} with {parent_node.equivocated_peers}") equivocated_peers.update(parent_node.equivocated_peers) seen_votes_by_peers.update(parent_node.seen_votes_by_peers) # the vote inside the dest_node would be updated later when verifying the vote of the dest_node @@ -1063,7 +1221,6 @@ def aggregate_latest_seen_by_peers( dest_node.latest_seen_node_by_peers = latest_seen_node_by_peers dest_node.latest_seen_witness_by_peers = latest_seen_witness_by_peers - ### self.equivocated_peers.update(dest_node.equivocated_peers) @@ -1078,11 +1235,23 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: if not self.verify_node(node): return False + if node.is_witness and node.metadata.batch_proposal is not None: + print(f"Peer {self.peer_id} confirms node ({node.node_hash}, {node.peer_id}) is head node of round {node.round}") + if node.peer_id not in self.seen_valid_nodes: self.seen_valid_nodes[node.peer_id] = [] self.seen_valid_nodes[node.peer_id].append(node) self.pos_in_seen_valid_nodes[node.node_hash] = (node.peer_id, len(self.seen_valid_nodes[node.peer_id]) - 1) + # store the first inclusion of the transactions at the peer + if node.peer_id not in self.first_inclusion_of_txs_at_peer: + self.first_inclusion_of_txs_at_peer[node.peer_id] = {} + + for tx in node.newly_seen_txs_list: + if tx in self.first_inclusion_of_txs_at_peer[node.peer_id]: # node.peer_id must be equivocated + return False + self.first_inclusion_of_txs_at_peer[node.peer_id][tx] = node.node_hash + # make a node sees itself so its descendants can use these accumulated values node.latest_seen_node_by_peers[node.peer_id] = node.node_hash if node.is_witness: @@ -1090,7 +1259,6 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: if node.metadata.batch_proposal is not None: self.observed_valid_batches[node.metadata.batch_proposal.batch_hash] = node - for predecessor in self.get_predecessors(node): if predecessor.node_hash not in self.local_graph: self.local_graph[predecessor.node_hash] = set() @@ -1101,7 +1269,7 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: self.pending_txs.clear() # because all txs in the pending_txs are now in the new node self.current_round = node.round # this makes the current round of the peer = the round of the last node in the list of its nodes - print(f"Peer {self.peer_id} added node {node.node_hash} to its local view => new round = {self.current_round}") + # print(f"Peer {self.peer_id} added node {node.node_hash} to its local view => new round = {self.current_round}") return True @@ -1146,27 +1314,16 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No # check if mid_node can strongly see witness witness_peer_id = witness.peer_id - # if dest_node.node_hash == "03b372de": - # if witness.node_hash == "858a63d3": - # print(f"Peer {self.peer_id} checking if {mid_node} can strongly see {witness}, and seen nodes of {mid_node} = {mid_node.latest_seen_node_by_peers}") - if witness_peer_id not in mid_node.latest_seen_node_by_peers: continue latest_seen_node_of_witness_peer_id: NodeId = mid_node.latest_seen_node_by_peers[witness_peer_id] - # if dest_node.node_hash == "03b372de": - # if witness.node_hash == "858a63d3": - # print(f"Let's validate {latest_seen_node_of_witness_peer_id} is a descendant of {witness.node_hash} => {self.is_valid_descendant_and_self_ancestor(latest_seen_node_of_witness_peer_id, witness.node_hash)}") if self.is_valid_descendant_and_self_ancestor(latest_seen_node_of_witness_peer_id, witness.node_hash): count_seens[witness.node_hash] = count_seens.get(witness.node_hash, 0) + 1 # the dest_node must see the witness of its own peer found_self_peer_witness = [witness for witness in seen_witnesses_in_round_r if witness.peer_id == self.peer_id and count_seens.get(witness.node_hash, 0) > 0] - # if dest_node.node_hash == "03b372de": - # print(f"Valid seen witnesses in round {r} = {[node.node_hash for node in seen_witnesses_in_round_r]}") - # print(f"Valid seen nodes in round {r} = {[node.node_hash for node in seen_nodes_in_round_r]}") - # print(f"Peer {self.peer_id} found_self_peer_witness = {found_self_peer_witness}, count_seens = {count_seens}") if len(found_self_peer_witness) == 0: return [] @@ -1224,7 +1381,6 @@ def verify_node(self, node: Node = None) -> bool: try: if node.is_genesis(): return True - parent_node = self.get_node_by_hash(node.self_parent_hash) # TODO: verify newly_seen_txs_list of the node (parent_node is not None) or (_ for _ in ()).throw(ValueError("parent_node must be non-None")) @@ -1245,11 +1401,15 @@ def verify_node(self, node: Node = None) -> bool: if not is_allowed_to_bypass: return False # verify the batch proposal if it exists - if self.verify_node_is_head_node(node): - if not self.verify_batch_proposal_is_valid(node): - return False - else: - (node.metadata.batch_proposal is None) or (_ for _ in ()).throw(ValueError("batch proposal must be None")) + + if node.is_witness: + is_head_node = self.verify_node_is_head_node(node) + + if is_head_node: + if not self.verify_batch_proposal_is_valid(node): + return False + else: + (node.metadata.batch_proposal is None) or (_ for _ in ()).throw(ValueError("batch proposal must be None")) # verify the vote if not self.verify_vote_is_valid(node): return False @@ -1286,15 +1446,16 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: if not self.is_adversary: (self_parent == self.get_my_last_node()) or (_ for _ in ()).throw(ValueError("self_parent must be the last node of the current peer")) + (self.get_node_by_hash(self_parent.node_hash) is not None and self.get_node_by_hash(cross_parent.node_hash) is not None) or (_ for _ in ()).throw(ValueError("self_parent and cross_parent must be valid nodes")) + # the newly seen list of txs in the new node must be not empty # TODO: sort this list by timestamp of receipt of the transactions newly_seen_txs_list: List[TransactionId] = list(self.calculate_newly_seen_txs_list_of_new_node(self_parent, cross_parent, self.pending_txs)) if len(newly_seen_txs_list) <= 0: # can't extend the node sequence because there is no new txs, this is to save the network capacity - print(f"Peer {self.peer_id} can't extend the node sequence because there is no new txs, this is to save the network capacity, self_parent = {self_parent.node_hash}, cross_parent = {cross_parent.node_hash}") return None - + round_num = len(self.my_nodes()) base_hash = f"{self.peer_id}{str(round_num).zfill(3)}" @@ -1319,9 +1480,11 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: is_node_valid = self.verify_node(new_node) if is_node_valid: - print(f"Peer {self.peer_id} COMPUTED NEW {'HEAD' if new_node.is_head_node() else "NON-HEAD"} NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") + # print(f"Peer {self.peer_id} COMPUTED NEW {'HEAD' if new_node.is_head_node() else "NON-HEAD"} NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") return new_node - + else: + # print(f"Peer {self.peer_id} computed invalid node {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") + pass # the created new nodes is invalid because either its parents are from equivocated peers return None @@ -1369,23 +1532,19 @@ async def gossip_push(self): new_nodes.append(new_node) break else: - print(f"Peer {self.peer_id} can't compute any new nodes from {self_parent_node.node_hash} and {cross_parent_node.node_hash}") + # Peer can't compute any new nodes from the current tuple of self_parent and cross_parent pass - # can't construct a valid node from the current tuple of self_parent and cross_parent (len(new_nodes) <= num_nodes_to_create) or (_ for _ in ()).throw(ValueError("number of new nodes must be less than or equal to num_nodes_to_create")) - if len(new_nodes) > 0: - print(f"Peer {self.peer_id}, is_adversary = {self.is_adversary}, computed {len(new_nodes)} nodes, its neighbors = {self.neighbors}, its equivocated peers = {self.my_nodes()[-1].equivocated_peers}, seen_peers = {[peer_id for peer_id in self.seen_valid_nodes]}") - else: - print(f"Peer {self.peer_id}, is_adversary = {self.is_adversary}, can't compute any new nodes, its neighbors = {self.neighbors}, its equivocated peers = {self.my_nodes()[-1].equivocated_peers}, seen_peers = {[peer_id for peer_id in self.seen_valid_nodes]}") + if len(new_nodes) <= 0: return for new_node in new_nodes: res = self.verify_node_and_add_to_local_view(new_node) if res: - print(f"Peer {self.peer_id} successfully added self-node {new_node.node_hash} to its local view") + print(f"Peer {self.peer_id} successfully added self-node {new_node.node_hash} to its local view: {new_node}") else: - print(f"Peer {self.peer_id} failed to add self-node {new_node.node_hash} to its local view") + print(f"Peer {self.peer_id} failed to add self-node {new_node.node_hash} to its local view: {new_node}") (res or self.is_adversary) or (_ for _ in ()).throw(ValueError("failed to verify and add new node")) self.select_neighbors(self.network.get_all_peer_ids()) # re-select neighbors @@ -1403,7 +1562,7 @@ async def gossip_push(self): else: print(f"peer {self.peer_id} gossiped to {other_peer_id} node {node_to_send.node_hash} successfully") -async def main(): +async def main(num_peers, MIN_NUM_ROUNDS_OF_HONEST_PEERS): # Create network simulator network = NetworkSimulator( latency_ms_range=(50, 200), @@ -1412,10 +1571,9 @@ async def main(): ) # Create peers - num_peers = 4 # next threshold for count_adversary = 2 is N = 7 count_adversary = 0 for i in range(num_peers): - is_adversary = (count_adversary + 1) < 1 * num_peers / 3 and network.random_instance.random() < 0.5 + is_adversary = (count_adversary + 1) * 3 + 1 <= num_peers and network.random_instance.random() < 0.5 if is_adversary: count_adversary += 1 @@ -1436,17 +1594,16 @@ async def main(): # Register genesis checkpoint await network.register_genesis_nodes() - MIN_NUM_ROUNDS = 10 current_simluated_timestamp = 0 # Main consensus loop i = 0 - while True and i < 1000: + while True: i += 1 if i % 100 == 0: print(f"{i}th iteration") - # Count peers that have reached MIN_NUM_ROUNDS rounds - peers_completed = sum(1 for c in peers if c.current_round >= MIN_NUM_ROUNDS) + # Count peers that have reached MIN_NUM_ROUNDS_OF_HONEST_PEERS rounds + peers_completed = sum(1 for c in peers if c.current_round >= MIN_NUM_ROUNDS_OF_HONEST_PEERS) if peers_completed > (2 * num_peers // 3): break @@ -1472,11 +1629,11 @@ async def main(): # TODO: create network checkpoints dynamically via network.create_checkpoint() for i, peer in enumerate(peers): - peer.visualize_view() + # peer.visualize_view() if peer.is_adversary: print(f"Peer {peer.peer_id} is an adversary") print(f"Neighbors of {peer.peer_id}: {peer.neighbors}") - print(f"Consensus completed with first {peers_completed} peers reaching round {MIN_NUM_ROUNDS}") + print(f"Consensus completed with first {peers_completed} peers reaching round {MIN_NUM_ROUNDS_OF_HONEST_PEERS}") def validate_consistency(): graph_info_list = [peer.get_graph_info() for peer in peers] @@ -1507,13 +1664,30 @@ def validate_consistency(): print(f"Peer {peer.peer_id} sees that {peer.equivocated_peers} peers are equivocated") print(f"Peer {peer.peer_id} is at round {peer.current_round}") -asyncio.run(main()) + final_fair_order = None + for peer in peers: + peer_final_order = peer.print_final_transaction_order() + peer_final_order_hash = hashlib.sha256(str(peer_final_order).encode()).hexdigest() + print(f"Peer {peer.peer_id}'s final txs order: {peer_final_order_hash} : {peer_final_order}, cur round = {peer.current_round}") + + if not peer.is_adversary: + if not final_fair_order: final_fair_order = peer_final_order + else: + # find the first index where final_order and final_fair_order differ + for i in range(min(len(peer_final_order), len(final_fair_order))): + if peer_final_order[i] != final_fair_order[i]: + final_fair_order = peer_final_order[:i] + break + + print(f"Global mempool size: {len(network.global_mempool)}") + print(f"Final fair order: {final_fair_order}") + print(f"Ordered txs: {len(final_fair_order)}/{len(network.global_mempool)} after {MIN_NUM_ROUNDS_OF_HONEST_PEERS} rounds at honest peers") + +asyncio.run(main(num_peers=4, MIN_NUM_ROUNDS_OF_HONEST_PEERS=20)) ### Possible attacks: # Long-Range Attacks: If validators controlling past checkpoints sell their keys, an attacker can re-sign an alternative history, leading to checkpoint reversals. # => Dangerous once attacker can control > 2/3 of the OLD validators # Majority Takeover: If an attacker gains control of 2/3 of the validators (BFT threshold), they could re-finalize a new chain with different checkpoints. # => recursive validity proof + proof of finality -# Solution: Post-Unstaking Slashing for X blocks after unstaking (but not able to withdraw before X blocks yet) - -# [] TODO: finish DAGPool's order fairness gadget \ No newline at end of file +# Solution: Post-Unstaking Slashing for X blocks after unstaking (but not able to withdraw before X blocks yet) \ No newline at end of file diff --git a/dagpool/graph.py b/dagpool/graph.py index a0207dd..8a229b3 100644 --- a/dagpool/graph.py +++ b/dagpool/graph.py @@ -15,12 +15,15 @@ from typing import Set, Dict, List, Tuple from schemas import GraphNodeId, HashValue +import hashlib +import random class DirectedGraph: def __init__(self): - self.nodes: Set[GraphNodeId] = {} + self.nodes: Set[GraphNodeId] = set() self.edges: Dict[GraphNodeId, Set[GraphNodeId]] = {} - self.is_tournament_graph = None + self.flag_is_tournament_graph = None + self.flag_is_semi_complete_digraph = None self.connected_components: List[Tuple[GraphNodeId, HashValue]] = None self.node_to_scc_hash: Dict[GraphNodeId, HashValue] = None self.hash_to_scc_nodes: Dict[HashValue, List[GraphNodeId]] = None @@ -29,7 +32,8 @@ def count_nodes(self) -> int: return len(self.nodes) def reset_graph_properties(self): - self.is_tournament_graph = None + self.flag_is_tournament_graph = None + self.flag_is_semi_complete_digraph = None self.connected_components = None self.node_to_scc_hash = None self.hash_to_scc_nodes = None @@ -53,10 +57,23 @@ def has_edge(self, nodeId1: GraphNodeId, nodeId2: GraphNodeId) -> bool: return nodeId1 in self.edges and nodeId2 in self.edges[nodeId1] def assert_is_tournament_graph(self): - if self.is_tournament_graph is None: - self.is_tournament_graph = self.is_tournament_graph() - assert self.is_tournament_graph + if self.flag_is_tournament_graph is None: + self.flag_is_tournament_graph = self.is_tournament_graph() + assert self.flag_is_tournament_graph + def assert_is_semi_complete_digraph(self): + if self.flag_is_semi_complete_digraph is None: + self.flag_is_semi_complete_digraph = self.is_semi_complete_digraph() + assert self.flag_is_semi_complete_digraph + + def is_semi_complete_digraph(self) -> bool: + for node1 in self.nodes: + for node2 in self.nodes: + if node1 != node2: + if not self.has_edge(node1, node2) and not self.has_edge(node2, node1): + return False + return True + def is_tournament_graph(self) -> bool: for node1 in self.nodes: for node2 in self.nodes: @@ -66,6 +83,14 @@ def is_tournament_graph(self) -> bool: return False return True + def print_all_edges(self): + list_edges: List[Tuple[GraphNodeId, GraphNodeId]] = [] + for node1 in self.nodes: + for node2 in self.nodes: + if node1 != node2: + if self.has_edge(node1, node2): + list_edges.append((node1, node2)) + """ Find the strongly connected components of the graph and return them in the topological order of the SCCs """ @@ -77,10 +102,11 @@ def find_strongly_connected_components(self) -> List[Tuple[List[GraphNodeId], Ha index = 0 indices = {} lowlinks = {} - stack = [] + stack: List[GraphNodeId] = [] components: List[Tuple[List[GraphNodeId], HashValue]] = [] def strongconnect(nodeId: GraphNodeId) -> None: + nonlocal index assert nodeId not in indices # Set the depth index for node indices[nodeId] = index @@ -89,7 +115,7 @@ def strongconnect(nodeId: GraphNodeId) -> None: stack.append(nodeId) # Consider successors of node - for successorId in self.edges[nodeId]: + for successorId in sorted(list(self.edges[nodeId])): if successorId not in indices: # Successor has not yet been visited; recurse on it strongconnect(successorId) @@ -103,40 +129,45 @@ def strongconnect(nodeId: GraphNodeId) -> None: while True: vertexId = stack.pop() component.append(vertexId) + indices[vertexId] = float('inf') + lowlinks[vertexId] = float('inf') if vertexId == nodeId: break - components.append([component, hash(tuple(sorted(component)))]) + components.append([component, hashlib.sha256(str(sorted(component)).encode()).hexdigest()[:8]]) # Find SCCs for all nodes for nodeId in self.nodes: if nodeId not in indices: strongconnect(nodeId) - - self.node_to_scc_hash: Dict[GraphNodeId, HashValue] = {nodeId: component[1] for component in enumerate(components) for nodeId in component[0]} - self.hash_to_scc_nodes: Dict[HashValue, List[GraphNodeId]] = {component[1]: component[0] for component in components} + + self.node_to_scc_hash: Dict[GraphNodeId, HashValue] = {nodeId: component[1] for i, component in enumerate(components) for nodeId in component[0]} + self.hash_to_scc_nodes: Dict[HashValue, List[GraphNodeId]] = {component[1]: component for component in components} self.connected_components: List[Tuple[List[GraphNodeId], HashValue]] = [] visited_sccs = set() def sort_sccs(comp: Tuple[List[GraphNodeId], HashValue]): - assert comp[1] not in visited_sccs - self.connected_components.append(comp) + if comp[1] in visited_sccs: + return + visited_sccs.add(comp[1]) for u in comp[0]: if u in self.edges: - for v in self.edges[u]: + for v in sorted(list(self.edges[u])): if not self.node_to_scc_hash[v] in visited_sccs: sort_sccs(self.hash_to_scc_nodes[self.node_to_scc_hash[v]]) - visited_sccs.add(comp[1]) + self.connected_components.append(comp) for component in components: if component[1] not in visited_sccs: sort_sccs(component) - + + self.connected_components.reverse() + return self.connected_components def assert_is_strongly_connected_component(self, hamiltonian_path: List[GraphNodeId]): assert self.connected_components is not None - hash_value = hash(tuple(sorted(hamiltonian_path))) + hash_value = hashlib.sha256(str(sorted(hamiltonian_path)).encode()).hexdigest()[:8] assert hash_value in [connected_component[1] for connected_component in self.connected_components] def assert_is_hamiltonian_path(self, hamiltonian_path: List[GraphNodeId]): @@ -147,36 +178,51 @@ class TournamentGraph(DirectedGraph): def __init__(self): super().__init__() - def is_tournament_graph(self) -> bool: - return super().is_tournament_graph() + def assert_is_valid_graph(self): + assert self.is_tournament_graph() def find_hamiltonian_path(self, scc: List[GraphNodeId]) -> List[GraphNodeId]: - self.assert_is_tournament_graph() + self.assert_is_valid_graph() # must be a strongly connected component self.assert_is_strongly_connected_component(scc) + # deterministic shuffling to prevent censorship + seed = hashlib.sha256(str(sorted(scc)).encode()).hexdigest() + randomer = random.Random(seed) + shuffled_scc = scc.copy() + randomer.shuffle(shuffled_scc) + # Complexity: O(len(scc)^2) - hamiltonian_path: List[GraphNodeId] = [scc[0]] - for k in range(1, len(scc)): - # find first i < k | has_edge(hamiltonian_path[i], scc[k]) & has_edge(scc[k], hamiltonian_path[i+1]) - i = 0 - while i + 1 < k and not (self.has_edge(hamiltonian_path[i], scc[k]) and self.has_edge(scc[k], hamiltonian_path[i+1])): - i += 1 - - hamiltonian_path.insert(i + 1, scc[k]) + hamiltonian_path: List[GraphNodeId] = [shuffled_scc[0]] + for k in range(1, len(shuffled_scc)): + # find first i < k | has_edge(hamiltonian_path[i], shuffled_scc[k]) & has_edge(shuffled_scc[k], hamiltonian_path[i+1]) + if self.has_edge(shuffled_scc[k], hamiltonian_path[0]): + hamiltonian_path.insert(0, shuffled_scc[k]) + elif self.has_edge(hamiltonian_path[-1], shuffled_scc[k]): + hamiltonian_path.append(shuffled_scc[k]) + else: + i = 0 + while i + 1 < k and not (self.has_edge(hamiltonian_path[i], shuffled_scc[k]) and self.has_edge(shuffled_scc[k], hamiltonian_path[i+1])): + i += 1 + + assert i + 1 < k # must found such a position + hamiltonian_path.insert(i + 1, shuffled_scc[k]) + + for i in range(len(hamiltonian_path) - 1): + assert self.has_edge(hamiltonian_path[i], hamiltonian_path[i + 1]) return hamiltonian_path def find_hamiltonian_cycle(self, hamiltonian_path_of_scc: List[GraphNodeId]) -> List[GraphNodeId]: - self.assert_is_tournament_graph() + self.assert_is_valid_graph() self.assert_is_strongly_connected_component(hamiltonian_path_of_scc) self.assert_is_hamiltonian_path(hamiltonian_path_of_scc) # Complexity: O(len(hamiltonian_path)^2) accumulated_hamiltonian_cycle: List[GraphNodeId] = [hamiltonian_path_of_scc[0]] - j = 1 + j = 1 # next element to be added to the cycle while j < len(hamiltonian_path_of_scc): - p = j + 1 + p = j r = -1 found_backward_edge = False while not found_backward_edge and p < len(hamiltonian_path_of_scc): @@ -187,13 +233,28 @@ def find_hamiltonian_cycle(self, hamiltonian_path_of_scc: List[GraphNodeId]) -> if r < len(accumulated_hamiltonian_cycle): found_backward_edge = True + else: + p += 1 if not found_backward_edge: # If no backward edge is found, the graph is not a tournament or not strongly connected raise ValueError("No Hamiltonian cycle exists for the given path.") - # reorder the accumulated_hamiltonian_cycle: accumulated_hamiltonian_cycle[0 -> r - 1] -> hamiltonian_path_of_scc[j+1 -> p] -> accumulated_hamiltonian_cycle[r -> ...] -> hamiltonian_path_of_scc[0] - j = p - accumulated_hamiltonian_cycle = accumulated_hamiltonian_cycle[0:r] + hamiltonian_path_of_scc[j+1:p+1] + accumulated_hamiltonian_cycle[r:] + accumulated_hamiltonian_cycle = accumulated_hamiltonian_cycle[0:r] + hamiltonian_path_of_scc[j:p+1] + accumulated_hamiltonian_cycle[r:] + k = j + j = p + 1 + + cycle_size = len(accumulated_hamiltonian_cycle) + for i in range(cycle_size): + # if not self.has_edge(accumulated_hamiltonian_cycle[i], accumulated_hamiltonian_cycle[(i + 1) % cycle_size]): + # print(f"DEBUG: {accumulated_hamiltonian_cycle} vs {hamiltonian_path_of_scc}") + assert self.has_edge(accumulated_hamiltonian_cycle[i], accumulated_hamiltonian_cycle[(i + 1) % cycle_size]) + + return accumulated_hamiltonian_cycle + +class SemiCompleteDiGraph(TournamentGraph): + def __init__(self): + super().__init__() - return accumulated_hamiltonian_cycle \ No newline at end of file + def assert_is_valid_graph(self): + assert self.is_semi_complete_digraph() \ No newline at end of file From 45129e6c91a8cdd168a8f42b868e52dcf1fb9175 Mon Sep 17 00:00:00 2001 From: Galin Chung Nguyen Date: Sat, 15 Mar 2025 09:05:37 +0700 Subject: [PATCH 10/10] feat: update consistency logic when nodes see forks but still absorb them --- dagpool/consensus_client.py | 344 +++++++++++++++++++++++++----------- 1 file changed, 236 insertions(+), 108 deletions(-) diff --git a/dagpool/consensus_client.py b/dagpool/consensus_client.py index 2ac6c5d..61a6170 100644 --- a/dagpool/consensus_client.py +++ b/dagpool/consensus_client.py @@ -1,4 +1,3 @@ -# self_parent_hash=fc7983b5, cross_parent_hash=223e2386 import asyncio import sys import random @@ -21,23 +20,23 @@ original_print = builtins.print SHOULD_PRINT = True -def print(*args, **kwargs): - if SHOULD_PRINT: - original_print(*args, **kwargs) +def print(msg, force=False): + if SHOULD_PRINT or force: + original_print(msg) printed = set() +with open("debug.txt", "w") as f: + f.write("") + def ddebug(peer_id: PeerId, *args, **kwargs): - if SHOULD_PRINT: - if len(printed) <= 0: - with open("debug.txt", "w") as f: - f.write("") + # if SHOULD_PRINT: # write to file debug.txt with open("debug.txt", "a") as f: line = f"{peer_id}: {' '.join([str(arg) for arg in args])}\n" if line not in printed: f.write(line) - printed.add(line) + # printed.add(line) BEACON_PACE = 10 # the first BEACON_PACE rounds are derived directly from the list of peers @@ -87,7 +86,12 @@ def __init__(self, round_number: int, prev_batch_hash: HashValue, deferred_order self.prev_beacon_randomness = prev_beacon_randomness self.next_beacon_randomness = self.compute_next_beacon_randomness() self.batch_hash = self.compute_batch_hash() + # supporting data, can be locally inferred and don't have to be transferred over the network self.non_solid_txs: List[TransactionId] = [] + self.nodes_in_cone: Set[NodeId] = set() + + def store_nodes_in_cone(self, nodes: Set[NodeId]): + self.nodes_in_cone = nodes def store_non_solid_txs(self, txs_list: List[TransactionId]): self.non_solid_txs = txs_list # store all the non solid transactions that aren't included in all the batch up to it @@ -118,7 +122,7 @@ def clone(self): return BatchProposal(round_number=self.round_number, prev_batch_hash=self.prev_batch_hash, deferred_ordering=[tx for tx in self.deferred_ordering], prev_beacon_randomness=self.prev_beacon_randomness, solid_transactions=[tx for tx in self.solid_transactions]) def __str__(self): - return f"BatchProposal(batch_hash={self.batch_hash}, prev_batch_hash={self.prev_batch_hash}, prev_beacon_randomness={self.prev_beacon_randomness}, deferred_ordering={self.deferred_ordering}, next_beacon_randomness={self.next_beacon_randomness}, solid_transactions={self.solid_transactions})" + return f"BatchProposal(batch_hash={self.batch_hash}, prev_batch_hash={self.prev_batch_hash}, prev_beacon_randomness={self.prev_beacon_randomness}, deferred_ordering={self.deferred_ordering}, next_beacon_randomness={self.next_beacon_randomness}, solid_transactions={self.solid_transactions}, non_solid_txs={self.non_solid_txs})" class NodeMetadata: batch_proposal: BatchProposal # can be None if the node is not a head node (head node means the witness of the leader in its selected round) @@ -134,15 +138,16 @@ def clone(self): return NodeMetadata(self.batch_proposal.clone() if self.batch_proposal else None, Vote(self.vote.head_batch_hash), self.creator_signature) class Node: - def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs_list: list[TransactionId], self_parent_hash: NodeId, cross_parent_hash: NodeId, metadata: NodeMetadata): + def __init__(self, peer_id: PeerId, height: int, round: int, is_witness: bool, newly_seen_txs_list: list[TransactionId], self_parent_hash: NodeId, cross_parent_hash: NodeId, metadata: NodeMetadata): self.peer_id = peer_id + self.height = height self.is_witness = is_witness self.round = round self.self_parent_hash = self_parent_hash self.cross_parent_hash = cross_parent_hash # TODO: migrate to Sparse Merkle Tree + proof of SMT transition self.newly_seen_txs_list = newly_seen_txs_list - self.node_hash = self.hash_node(peer_id, round, is_witness, self_parent_hash, cross_parent_hash, newly_seen_txs_list) + self.node_hash = self.hash_node(peer_id, height, round, is_witness, self_parent_hash, cross_parent_hash, newly_seen_txs_list) ## fork-related data, must all be None until computed self.equivocated_peers: Set[PeerId] = None # set of peers that current node believes are equivocated, and this node won't SEE (i.e. UNSEE) all nodes created by them. Note that this doesn't affect STRONGLY SEEING property of this node. self.non_equivocated_peers: Set[PeerId] = None # set of peers that current node believes are not equivocated, and this node will SEE all nodes created by them. @@ -156,7 +161,7 @@ def __init__(self, peer_id: PeerId, round: int, is_witness: bool, newly_seen_txs self.has_filled_node_data = False def clone(self): - return Node(self.peer_id, self.round, self.is_witness, [txs for txs in self.newly_seen_txs_list], self.self_parent_hash, self.cross_parent_hash, self.metadata.clone()) + return Node(self.peer_id, self.height, self.round, self.is_witness, [txs for txs in self.newly_seen_txs_list], self.self_parent_hash, self.cross_parent_hash, self.metadata.clone()) def label(self) -> NodeLabel: return f"{self.peer_id}:{self.node_hash}" @@ -181,9 +186,9 @@ def verify_signature(self, creator_pubkey: Pubkey) -> bool: return self.metadata.creator_signature == self.compute_signature() @staticmethod - def hash_node(creator: PeerId, round: int, is_witness: bool, self_parent_hash: NodeId, cross_parent_hash: NodeId, newly_seen_txs_list: list[TransactionId]) -> NodeId: + def hash_node(creator: PeerId, height: int, round: int, is_witness: bool, self_parent_hash: NodeId, cross_parent_hash: NodeId, newly_seen_txs_list: list[TransactionId]) -> NodeId: """Create deterministic hash for a node""" - components = [creator, str(round), str(is_witness)] + components = [creator, str(height), str(round), str(is_witness)] if cross_parent_hash: components.append(cross_parent_hash) if self_parent_hash: @@ -193,7 +198,7 @@ def hash_node(creator: PeerId, round: int, is_witness: bool, self_parent_hash: N return hashlib.sha256(''.join(components).encode()).hexdigest()[:8] # the hash value of a node basically depends deterministically on all of its content def verify_node_hash(self) -> bool: - return self.node_hash == Node.hash_node(self.peer_id, self.round, self.is_witness, self.self_parent_hash, self.cross_parent_hash, self.newly_seen_txs_list) + return self.node_hash == Node.hash_node(self.peer_id, self.height, self.round, self.is_witness, self.self_parent_hash, self.cross_parent_hash, self.newly_seen_txs_list) def validate_node_data(self, creator_pubkey: Pubkey) -> bool: if not self.verify_node_hash(): @@ -226,7 +231,7 @@ def get_seen_valid_peers(self) -> Set[PeerId]: def __str__(self): batch_info = ("batch_hash=" + self.metadata.batch_proposal.__str__() if self.metadata.batch_proposal else "") vote_info = ("vote=" + self.metadata.vote.__str__() if self.metadata.vote else "none") - return f"{"GENESIS " if self.is_genesis() else ""}Node(node_hash={self.node_hash}, peer_id={self.peer_id}, round={self.round}, is_witness={self.is_witness}, self_parent_hash={self.self_parent_hash}, cross_parent_hash={self.cross_parent_hash}, newly_seen_txs_list={self.newly_seen_txs_list}, {batch_info}, {vote_info})" # , equivocated_peers={self.equivocated_peers}, seen_nodes={self.seen_nodes})" + return f"{"GENESIS " if self.is_genesis() else ""}Node(node_hash={self.node_hash}, peer_id={self.peer_id}, height={self.height}, round={self.round}, is_witness={self.is_witness}, self_parent_hash={self.self_parent_hash}, cross_parent_hash={self.cross_parent_hash}, newly_seen_txs_list={self.newly_seen_txs_list}, {batch_info}, {vote_info})" # , equivocated_peers={self.equivocated_peers}, seen_nodes={self.seen_nodes})" class ConnectionState(Enum): CLOSED = 0 @@ -267,6 +272,16 @@ def __init__(self, latency_ms_range=(50, 200), packet_loss_prob=0.1, random_inst # global mempool: simulate a global mempool of all transactions from all clients self.global_mempool = set() + self.should_exit = False + + def N(self) -> int: + return len(self.peers) + + def f(self) -> int: + return ((self.N() - 1) // 3) + + def safe_threshold(self) -> int: + return (self.N() - self.f()) # TODO: use secure cryptographic randomness source def get_first_beacon_randomness(self) -> HashValue: @@ -314,7 +329,7 @@ async def register_genesis_nodes(self): genesis_nodes[peer1.peer_id] = genesis_node for peer2 in self.peers: - if peer2.peer_id == peer1.peer_id: + if peer2.peer_id == peer1.peer_id or peer1.peer_id in peer2.equivocated_peers or peer2.peer_id in peer1.equivocated_peers: continue cloned_genesis_node = genesis_node.clone() # simulate the process of serializing and deserializing the nodes in internet protocols @@ -422,6 +437,9 @@ async def gossip_send_node_and_ancestry(self, sender: PeerId, receiver: PeerId, sender_peer = [p for p in self.peers if p.peer_id == sender][0] receiver_peer = [p for p in self.peers if p.peer_id == receiver][0] + if sender in receiver_peer.equivocated_peers: + return False + if receiver_peer.has_seen_valid_node(node1): return True @@ -444,12 +462,20 @@ def trace(node: Node): for i in range(len(all_received_nodes)): current_node = all_received_nodes[i] try: - receiver_peer.verify_node_and_add_to_local_view(current_node.clone()) or (_ for _ in ()).throw(ValueError("failed to verify and add node")) + is_success = receiver_peer.verify_node_and_add_to_local_view(current_node.clone(), sender=sender_peer.peer_id) or (_ for _ in ()).throw(ValueError("failed to verify and add node")) + (is_success == True) or (_ for _ in ()).throw(ValueError("failed to verify and add node")) except Exception as e: - # print(f"Peer {receiver_peer.peer_id} rejected node {current_node.node_hash} from {sender_peer.peer_id} created by {current_node.peer_id}: {e}") + # if sender_peer.peer_id not in receiver_peer.equivocated_peers: + ddebug(receiver_peer.peer_id, f"(R={receiver_peer.peer_id},S={sender_peer.peer_id}) rejected node {current_node.node_hash} created by {current_node.peer_id}: {e}") + isWrong = not sender_peer.is_adversary and not receiver_peer.is_adversary + if isWrong: + ddebug(receiver_peer.peer_id, f"e = {e}") + ddebug(receiver_peer.peer_id, f"INVALID REJECTION of {current_node.node_hash} with (sender={sender_peer.peer_id}, receiver={receiver_peer.peer_id})") + self.should_exit = True + break continue - # print(f"Peer {receiver_peer.peer_id} accepted node {current_node.node_hash} from {sender_peer.peer_id} created by {current_node.peer_id}") + print(f"(R={receiver_peer.peer_id},S={sender_peer.peer_id}) accepted node {current_node.node_hash} created by {current_node.peer_id}, {current_node}") return receiver_peer.has_seen_valid_node(node1) @@ -494,6 +520,7 @@ def __init__(self, peer_id: PeerId, is_adversary: bool, seed: int, network: Netw self.observed_valid_batches: Dict[HashValue, Node] = {} # map from the hash value of the batch to the node that proposes it self.first_inclusion_of_txs_at_peer: Dict[PeerId, Dict[TransactionId, NodeId]] = {} # map from peer_id to map from tx_id to the first time it is included at that peer self.cached_heaviest_batch_amongst_strict_ancestors: Dict[NodeId, HashValue] = {} + self.cached_head_node: Dict[NodeId, bool] = {} def my_nodes(self) -> List[Node]: return self.seen_valid_nodes[self.peer_id] @@ -682,6 +709,7 @@ def create_genesis_node(self): """Create a genesis/bootstrap node""" node = Node( peer_id=self.peer_id, + height=0, round=0, is_witness=True, newly_seen_txs_list=[], @@ -689,7 +717,7 @@ def create_genesis_node(self): cross_parent_hash=EMPTY_NODE_HASH, metadata=NodeMetadata() ) - (self.verify_node_and_add_to_local_view(node) == True) or (_ for _ in ()).throw(ValueError("failed to verify and add genesis node")) + (self.verify_node_and_add_to_local_view(node, sender=self.peer_id) == True) or (_ for _ in ()).throw(ValueError("failed to verify and add genesis node")) return node def get_heaviest_batch_amongst_strict_ancestors(self, node: Node) -> HashValue: @@ -759,6 +787,10 @@ def verify_node_is_head_node(self, node: Node) -> bool: """ Verify that the given node is a head node, given that the vote is valid """ + + if node.node_hash in self.cached_head_node: + return self.cached_head_node[node.node_hash] + if not node.is_witness or node.is_genesis(): # this function might be called before the batch proposal is constructed so we only needs to check whether the node is a witness return False @@ -774,10 +806,20 @@ def verify_node_is_head_node(self, node: Node) -> bool: # find parent batch of the current beacon batch beacon_batch = self.observed_valid_batches[beacon_batch].metadata.batch_proposal.prev_batch_hash - beacon_randomness = str(0 if beacon_batch == GENESIS_BATCH else self.get_beacon_randomness_of_batch(beacon_batch)) + ":" + str(node.round) - beacon_randomer = random.Random(beacon_randomness) - all_peers = [p.peer_id for p in self.network.peers] - return beacon_randomer.choice(all_peers) == node.peer_id + + beacon_randomness = f"{0 if beacon_batch == GENESIS_BATCH else self.get_beacon_randomness_of_batch(beacon_batch)}:{node.round}" + hash_to_big_int = int(hashlib.sha256(beacon_randomness.encode()).hexdigest(), 16) + + all_peers = sorted([p.peer_id for p in self.network.peers]) + selected_peer = all_peers[hash_to_big_int % len(all_peers)] + + if node.round >= 1: + if selected_peer == node.peer_id: + if self.peer_id == node.peer_id: + ddebug(self.peer_id, f" {node.node_hash} found leader of round {node.round}: beacon_batch = {beacon_batch}, beacon_randomness = {beacon_randomness} => {selected_peer}") + + self.cached_head_node[node.node_hash] = selected_peer == node.peer_id + return self.cached_head_node[node.node_hash] # compute the leader of the current round based on the beacon randomness @@ -795,7 +837,7 @@ def is_seen_by(self, node_id: NodeId, dest_node: Node) -> bool: if dest_node.node_hash == node_id: return True - + for mid_node_hash in dest_node.latest_seen_node_by_peers.values(): if self.is_valid_descendant_and_self_ancestor(mid_node_hash, node_id): return True @@ -909,16 +951,13 @@ def construct_semi_complete_digraph_of_transactions(self, last_3_head_witnesses: return dependency_graph - def get_solid_transactions_and_non_solid_transactions_in_cone(self, head_node: Node, prev_batch_hash: HashValue) -> Tuple[List[TransactionId], List[TransactionId]]: + def get_solid_transactions_and_non_solid_transactions_in_cone(self, head_node: Node, prev_batch_hash: HashValue) -> Tuple[List[TransactionId], List[TransactionId], List[NodeId]]: """ Get the solid transactions in the truncated cone of the given head node and previous batch hash """ traced: Dict[NodeId, bool] = {} current_cone: List[Node] = [] - # TODO: gather transactions from previous cones that aren't included as solid transactions - # => Not only transactions in the current truncated cone - prev_head_witness = self.get_node_of_batch(prev_batch_hash) unincluded_solid_txs_from_prev_batch: Set[TransactionId] = set() @@ -926,17 +965,24 @@ def get_solid_transactions_and_non_solid_transactions_in_cone(self, head_node: N if prev_head_witness: unincluded_solid_txs_from_prev_batch.update(prev_head_witness.metadata.batch_proposal.non_solid_txs) # it now becomes solid due to strongly-seeing property (i.e. n - f peers each of these transactions) + # TODO: use more lightweight data structure here + all_prev_solid_txs: Set[TransactionId] = set() + all_prev_cone_nodes: Set[NodeId] = set() + + cur_batch = prev_batch_hash + while cur_batch != GENESIS_BATCH: + cur_head_witness = self.observed_valid_batches[cur_batch] + all_prev_solid_txs.update(cur_head_witness.metadata.batch_proposal.solid_transactions) + all_prev_cone_nodes.update(cur_head_witness.metadata.batch_proposal.nodes_in_cone) + cur_batch = cur_head_witness.metadata.batch_proposal.prev_batch_hash + def find_node_in_current_cone(node: Node): if node.node_hash in traced: return - in_prev_cone_ancestry: bool = False if not prev_head_witness else self.is_seen_by(node.node_hash, prev_head_witness) - in_current_cone: bool = node.node_hash == head_node.node_hash or (self.is_seen_by(node.node_hash, head_node) and not in_prev_cone_ancestry) - - if not in_current_cone: - traced[node.node_hash] = False + if node.node_hash in all_prev_cone_nodes: return - + traced[node.node_hash] = True predecessors = self.get_predecessors(node) @@ -946,13 +992,16 @@ def find_node_in_current_cone(node: Node): current_cone.append(node) find_node_in_current_cone(head_node) - + # filter & keep only the transactions that aren't included in previous batches and received by n - f legitimate peers count_legitimate_receipts_of_tx: Dict[TransactionId, int] = {} for node in current_cone: + if node.peer_id in head_node.equivocated_peers: # ignore opinions of equivocated peers + # TODO: only ignore the suffix from which forks are detected + continue + for tx in node.newly_seen_txs_list: - if node.peer_id not in head_node.equivocated_peers: - count_legitimate_receipts_of_tx[tx] = count_legitimate_receipts_of_tx.get(tx, 0) + 1 + count_legitimate_receipts_of_tx[tx] = count_legitimate_receipts_of_tx.get(tx, 0) + 1 n = len(self.network.get_all_peer_ids()) f = (n - 1)/3 @@ -960,14 +1009,6 @@ def find_node_in_current_cone(node: Node): new_candidate_solid_transactions: Set[TransactionId] = set() new_pending_txs: List[TransactionId] = [] - # TODO: use more lightweight data structure here - all_prev_solid_txs: Set[TransactionId] = set() - cur_batch = prev_batch_hash - while cur_batch != GENESIS_BATCH: - cur_head_witness = self.observed_valid_batches[cur_batch] - all_prev_solid_txs.update(cur_head_witness.metadata.batch_proposal.solid_transactions) - cur_batch = cur_head_witness.metadata.batch_proposal.prev_batch_hash - for tx in count_legitimate_receipts_of_tx: if tx not in all_prev_solid_txs and tx not in unincluded_solid_txs_from_prev_batch: if count_legitimate_receipts_of_tx[tx] >= n - f: @@ -978,7 +1019,7 @@ def find_node_in_current_cone(node: Node): solid_transactions: List[TransactionId] = sorted(list(new_candidate_solid_transactions.union(unincluded_solid_txs_from_prev_batch))) pending_txs: List[TransactionId] = sorted(new_pending_txs) - return solid_transactions, pending_txs + return solid_transactions, pending_txs, [node.node_hash for node in current_cone] def get_deferred_ordering_of_transactions(self, last_3_head_witnesses: List[Node]) -> List[TransactionId]: """ @@ -1023,7 +1064,7 @@ def compute_batch_proposal(self, dest_node: Node) -> BatchProposal: # construct the truncated cone of round (r) deferred_ordering: List[TransactionId] = [] - solid_transactions, non_solid_transactions = self.get_solid_transactions_and_non_solid_transactions_in_cone(dest_node, heaviest_batch) + solid_transactions, non_solid_transactions, nodes_in_cone = self.get_solid_transactions_and_non_solid_transactions_in_cone(dest_node, heaviest_batch) if len(last_3_head_witnesses) >= 3: deferred_ordering = self.get_deferred_ordering_of_transactions(last_3_head_witnesses) @@ -1037,6 +1078,7 @@ def compute_batch_proposal(self, dest_node: Node) -> BatchProposal: ) batch_proposal.store_non_solid_txs(non_solid_transactions) + batch_proposal.store_nodes_in_cone(nodes_in_cone) (batch_proposal.verify_batch_proposal_is_well_formed()) or (_ for _ in ()).throw(ValueError("batch proposal is not well-formed")) @@ -1126,19 +1168,40 @@ def select_neighbors(self, all_peers: List[PeerId]): self.neighbors = [peer_id for peer_id in self.random_instance.sample(potential_neighbors, num_neighbors) if peer_id not in self.equivocated_peers] def is_valid_descendant_and_self_ancestor(self, descendant_node_hash: NodeId, self_ancestor_node_hash: NodeId) -> bool: - try: - (peer1, pos1) = self.pos_in_seen_valid_nodes[descendant_node_hash] - (peer2, pos2) = self.pos_in_seen_valid_nodes[self_ancestor_node_hash] - return peer1 == peer2 and pos1 >= pos2 - except Exception as e: - return False + ancestor_node = self.get_node_by_hash(self_ancestor_node_hash) + descendant_node = self.get_node_by_hash(descendant_node_hash) + if ancestor_node is None or descendant_node is None: + return False + + # TODO: optimize this to jump bigger steps for faster descendant check + while descendant_node is not None and descendant_node.round >= ancestor_node.round: + if descendant_node.node_hash == ancestor_node.node_hash: + return True + descendant_node = self.get_node_by_hash(descendant_node.self_parent_hash) + return False def get_self_descendant(self, node_hash_1: NodeId, node_hash_2: NodeId) -> Optional[NodeId]: try: - (peer1, pos1) = self.pos_in_seen_valid_nodes[node_hash_1] - (peer2, pos2) = self.pos_in_seen_valid_nodes[node_hash_2] - if peer1 == peer2: - return node_hash_1 if pos1 > pos2 else node_hash_2 + node1 = self.get_node_by_hash(node_hash_1) + node2 = self.get_node_by_hash(node_hash_2) + + if not node1 or not node2 or node1.peer_id != node2.peer_id: + return None + + res = node1 if node1.height > node2.height else node2 + + while node1.height > node2.height: + node1 = self.get_node_by_hash(node1.self_parent_hash) + if not node1: + return None + + while node2.height > node1.height: + node2 = self.get_node_by_hash(node2.self_parent_hash) + if not node2: + return None + + if node1.node_hash == node2.node_hash: + return res.node_hash else: return None except: @@ -1207,8 +1270,10 @@ def aggregate_latest_seen_by_peers( seen_votes_by_peers.update(parent_node.seen_votes_by_peers) # the vote inside the dest_node would be updated later when verifying the vote of the dest_node for peer_id in equivocated_peers: - latest_seen_node_by_peers.pop(peer_id) - latest_seen_witness_by_peers.pop(peer_id) + if peer_id in latest_seen_node_by_peers: + latest_seen_node_by_peers.pop(peer_id) + if peer_id in latest_seen_witness_by_peers: + latest_seen_witness_by_peers.pop(peer_id) for peer_id in latest_seen_node_by_peers.keys(): non_equivocated_peers.add(peer_id) @@ -1224,20 +1289,35 @@ def aggregate_latest_seen_by_peers( ### self.equivocated_peers.update(dest_node.equivocated_peers) - def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: + def verify_node_and_add_to_local_view(self, node: Node, sender: PeerId = None) -> bool: """Verify a node and its transactions, and add it to the local view""" + ### Verification + if sender in self.equivocated_peers and sender != self.peer_id: + return False + + assert node is not None if self.has_seen_valid_node(node): return True self.fill_node_data(node) + + is_valid = self.verify_node(node) - if not self.verify_node(node): + if not is_valid: return False - + if node.is_witness and node.metadata.batch_proposal is not None: print(f"Peer {self.peer_id} confirms node ({node.node_hash}, {node.peer_id}) is head node of round {node.round}") + for tx in node.newly_seen_txs_list: + if tx in self.first_inclusion_of_txs_at_peer[node.peer_id]: # node.peer_id must be equivocated + self.equivocated_peers.add(node.peer_id) + if sender == node.peer_id and sender != self.peer_id: + return False + + ##################################################################################################################### + ### Update local view if node.peer_id not in self.seen_valid_nodes: self.seen_valid_nodes[node.peer_id] = [] self.seen_valid_nodes[node.peer_id].append(node) @@ -1247,11 +1327,6 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: if node.peer_id not in self.first_inclusion_of_txs_at_peer: self.first_inclusion_of_txs_at_peer[node.peer_id] = {} - for tx in node.newly_seen_txs_list: - if tx in self.first_inclusion_of_txs_at_peer[node.peer_id]: # node.peer_id must be equivocated - return False - self.first_inclusion_of_txs_at_peer[node.peer_id][tx] = node.node_hash - # make a node sees itself so its descendants can use these accumulated values node.latest_seen_node_by_peers[node.peer_id] = node.node_hash if node.is_witness: @@ -1271,6 +1346,24 @@ def verify_node_and_add_to_local_view(self, node: Node = None) -> bool: # print(f"Peer {self.peer_id} added node {node.node_hash} to its local view => new round = {self.current_round}") + for tx in node.newly_seen_txs_list: + self.first_inclusion_of_txs_at_peer[node.peer_id][tx] = node.node_hash + + # # check consistency + # for other_peer in self.network.peers: + # if other_peer.peer_id == self.peer_id: + # continue + + # print(f"start cons check for {other_peer.peer_id} vs {self.peer_id}") + + # if other_peer.has_seen_valid_node(node): + # l1 = sorted(other_peer.get_node_by_hash(node.node_hash).latest_seen_node_by_peers.values()) + # l2 = sorted(self.get_node_by_hash(node.node_hash).latest_seen_node_by_peers.values()) + # if l1 != l2: + # ddebug(self.peer_id, f"## CONSISTENCY CHECK FAILED for node {node.node_hash}: {other_peer.peer_id} has {l1} while {self.peer_id} has {l2}") + # ddebug(self.peer_id, f"{self.peer_id} => {self.get_node_by_hash(node.node_hash)}, vs {other_peer.peer_id} => {other_peer.get_node_by_hash(node.node_hash)}") + # self.network.should_exit = True + return True def find_prev_witness_at_round(self, cur_witness: Node, r: int) -> Optional[Node]: @@ -1290,16 +1383,19 @@ def find_prev_witness_at_round(self, cur_witness: Node, r: int) -> Optional[Node return cur_witness def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["Node"]: - ## check if this witness strongly sees > 2/3 of witnesses of r + ## check if this witness strongly sees >= N - f of witnesses of r ## if some witnesses are descendants of equivocated nodes, they are ignored completely ## NOTE: we already make sure the ancestry of dest_node is verified N = len(self.network.peers) - + latest_seen_witness_by_peers: Dict[PeerId, NodeId] = dest_node.latest_seen_witness_by_peers seen_witnesses_in_round_gte_r = [self.get_node_by_hash(node_hash) for node_hash in latest_seen_witness_by_peers.values() if node_hash and self.get_node_by_hash(node_hash).round >= r] seen_witnesses_in_round_r = [self.find_prev_witness_at_round(node, r) for node in seen_witnesses_in_round_gte_r if node is not None] + for witness in seen_witnesses_in_round_r: + assert witness.round == r + # keep only the strongly seen ones latest_seen_node_by_peers: Dict[PeerId, NodeId] = dest_node.latest_seen_node_by_peers @@ -1310,7 +1406,10 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No count_seens: Dict[NodeId, int] = {} # O(N^2) where N is the number of peers for mid_node in seen_nodes_in_round_gte_r: + if mid_node.peer_id in dest_node.equivocated_peers: + continue for witness in seen_witnesses_in_round_r: + # we still accept witnesses from equivocated peers but don't count opinions from them # check if mid_node can strongly see witness witness_peer_id = witness.peer_id @@ -1327,10 +1426,10 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No if len(found_self_peer_witness) == 0: return [] - # keep only the witnesses that are seen by > 2/3 of the mid nodes + # keep only the witnesses that are seen by N - f of the mid nodes strongly_seen_witnesses: list["Node"] = [] for witness in seen_witnesses_in_round_r: - if count_seens.get(witness.node_hash, 0) > 2 * N / 3: + if count_seens.get(witness.node_hash, 0) >= self.network.safe_threshold(): strongly_seen_witnesses.append(witness) return strongly_seen_witnesses @@ -1338,8 +1437,8 @@ def get_strongly_seen_valid_witnesses(self, dest_node: Node, r: int) -> list["No def check_round_number_of_non_genesis_node_with_valid_parents(self, dest_node: Node) -> bool: """ if a node is of round r: - - it must not strongly sees > 2N/3 of witnesses of round r - - if its self parent is of round r, it is valid. if its self parent is of round r-1, it must strongly sees > 2N/3 of witnesses of round r-1 + - it must not strongly sees >= N - f of witnesses of round r + - if its self parent is of round r, it is valid. if its self parent is of round r-1, it must strongly sees >= N - f of witnesses of round r-1 """ N = len(self.network.peers) r = dest_node.round @@ -1348,17 +1447,19 @@ def check_round_number_of_non_genesis_node_with_valid_parents(self, dest_node: N (self_parent_node is not None) or (_ for _ in ()).throw(ValueError("self_parent_node must be non-None")) if self_parent_node.round < r - 1 or self_parent_node.round > r: return False + if self_parent_node.round == r - 1: - # the dest_node is a witness of round r so it must strongly sees > 2/3 of witnesses of round r-1 + # the dest_node is a witness of round r so it must strongly sees N - f of witnesses of round r-1 strongly_seen_witnesses_in_round_r_minus_1 = self.get_strongly_seen_valid_witnesses(dest_node, r-1) - is_witness_of_round_r = len(strongly_seen_witnesses_in_round_r_minus_1) > 2 * N / 3 + is_witness_of_round_r = len(strongly_seen_witnesses_in_round_r_minus_1) >= self.network.safe_threshold() if not is_witness_of_round_r: return False - # the dest_node must not strongly sees > 2/3 of witnesses of round r + + # the dest_node must not strongly sees >= N - f of witnesses of round r strongly_seen_witnesses_in_round_r = self.get_strongly_seen_valid_witnesses(dest_node, r) - if len(strongly_seen_witnesses_in_round_r) > 2 * N / 3: + if len(strongly_seen_witnesses_in_round_r) >= self.network.safe_threshold(): return False return True @@ -1374,17 +1475,24 @@ def verify_node(self, node: Node = None) -> bool: if node is None: return False - if node.peer_id in self.equivocated_peers: - return False + # if node.peer_id in self.equivocated_peers: + # return False + # still receives node from equivocated peers in case the sender is an honest peer # an honest peer must not accept a node which itself or its parents are from equivocated peers + try: if node.is_genesis(): return True parent_node = self.get_node_by_hash(node.self_parent_hash) # TODO: verify newly_seen_txs_list of the node (parent_node is not None) or (_ for _ in ()).throw(ValueError("parent_node must be non-None")) - (parent_node.node_hash == self.seen_valid_nodes[parent_node.peer_id][-1].node_hash or (self.is_adversary and node.peer_id == self.peer_id)) or (_ for _ in ()).throw(ValueError("parent_node must be the last node of the sender")) + + valid_node_chain_extension = parent_node.node_hash == self.seen_valid_nodes[parent_node.peer_id][-1].node_hash + if not valid_node_chain_extension: + self.equivocated_peers.add(node.peer_id) + # TODO: log the equivocation activity here + if not node.validate_node_data(self.network.get_peer_pubkey(node.peer_id)): return False @@ -1404,7 +1512,6 @@ def verify_node(self, node: Node = None) -> bool: if node.is_witness: is_head_node = self.verify_node_is_head_node(node) - if is_head_node: if not self.verify_batch_proposal_is_valid(node): return False @@ -1417,7 +1524,10 @@ def verify_node(self, node: Node = None) -> bool: # adversary sending invalid nodes print("error = ", e) return False - return self.check_round_number_of_non_genesis_node_with_valid_parents(node) + + is_correct_round_number = self.check_round_number_of_non_genesis_node_with_valid_parents(node) + + return is_correct_round_number def get_all_transactions(self) -> Set[TransactionId]: """Get all transactions known to this peer""" @@ -1451,10 +1561,10 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: # the newly seen list of txs in the new node must be not empty # TODO: sort this list by timestamp of receipt of the transactions newly_seen_txs_list: List[TransactionId] = list(self.calculate_newly_seen_txs_list_of_new_node(self_parent, cross_parent, self.pending_txs)) - - if len(newly_seen_txs_list) <= 0: - # can't extend the node sequence because there is no new txs, this is to save the network capacity - return None + + # if len(newly_seen_txs_list) <= 0: + # # can't extend the node sequence because there is no new txs, this is to save the network capacity + # return None round_num = len(self.my_nodes()) base_hash = f"{self.peer_id}{str(round_num).zfill(3)}" @@ -1467,6 +1577,7 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: print(f"Peer {self.peer_id} computing new node for round {round_num}") new_node = Node( peer_id=self.peer_id, + height=self_parent.height + 1, round=round_num, is_witness=False if round_num == self_parent.round else True, newly_seen_txs_list=newly_seen_txs_list, @@ -1480,11 +1591,10 @@ def compute_new_node(self, self_parent: Node, cross_parent: Node) -> Node: is_node_valid = self.verify_node(new_node) if is_node_valid: - # print(f"Peer {self.peer_id} COMPUTED NEW {'HEAD' if new_node.is_head_node() else "NON-HEAD"} NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") + print(f"Peer {self.peer_id} COMPUTED NEW {'HEAD' if new_node.is_head_node() else "NON-HEAD"} NODE {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") return new_node else: - # print(f"Peer {self.peer_id} computed invalid node {new_node.node_hash} from {self_parent.node_hash} and {cross_parent.node_hash}") - pass + print(f"Peer {self.peer_id} computed invalid node {new_node.node_hash} (r={new_node.round}) from {self_parent.node_hash} and {cross_parent.node_hash}") # the created new nodes is invalid because either its parents are from equivocated peers return None @@ -1508,19 +1618,25 @@ async def gossip_push(self): # NOTE: currently, the equivocation logic is simple, an adversary basically picks the last node of the current peer as the self parent, and the latest nodes of different cross peers as the cross parents for _ in range(num_nodes_to_create): - max_num_retries = 10 + max_num_retries = 3 for i in range(max_num_retries): cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) - for j in range(10000): + for j in range(10): if cross_parent_peer_id == self.peer_id or cross_parent_peer_id in self.equivocated_peers: cross_parent_peer_id = self.random_instance.choice(possible_cross_peers) else: break - + + if cross_parent_peer_id == self.peer_id or cross_parent_peer_id in self.equivocated_peers or cross_parent_peer_id not in self.seen_valid_nodes or len(self.seen_valid_nodes[cross_parent_peer_id]) <= 0: + continue + cross_parent_node = self.seen_valid_nodes[cross_parent_peer_id][-1] + if self.peer_id == "P7" and self.current_round == 9: # now about to extend to 10 + ddebug(self.peer_id, f"## P7 should_equivocate = {should_equivocate} about to create for round 10: self peer id = {self.peer_id}, self_parent = {self_parent_node.node_hash}, cross peer id = {cross_parent_peer_id}, cross_parent_node = {cross_parent_node.node_hash}") + if cross_parent_node.node_hash in [node.cross_parent_hash for node in new_nodes]: # duplicated cross parent continue @@ -1533,6 +1649,7 @@ async def gossip_push(self): break else: # Peer can't compute any new nodes from the current tuple of self_parent and cross_parent + print(f"Peer {self.peer_id} can't compute any new nodes from ({self_parent_node.node_hash}, {cross_parent_node.node_hash})") pass (len(new_nodes) <= num_nodes_to_create) or (_ for _ in ()).throw(ValueError("number of new nodes must be less than or equal to num_nodes_to_create")) @@ -1540,7 +1657,7 @@ async def gossip_push(self): return for new_node in new_nodes: - res = self.verify_node_and_add_to_local_view(new_node) + res = self.verify_node_and_add_to_local_view(new_node, sender=self.peer_id) if res: print(f"Peer {self.peer_id} successfully added self-node {new_node.node_hash} to its local view: {new_node}") else: @@ -1575,17 +1692,20 @@ async def main(num_peers, MIN_NUM_ROUNDS_OF_HONEST_PEERS): for i in range(num_peers): is_adversary = (count_adversary + 1) * 3 + 1 <= num_peers and network.random_instance.random() < 0.5 - if is_adversary: - count_adversary += 1 peer = ConsensusPeer( peer_id=f"P{i}", is_adversary=is_adversary, seed=i, network=network ) + if is_adversary: + count_adversary += 1 + print(f"Peer {peer.peer_id} is an adversary") network.register_peer(peer) peers = network.peers + print("Safe threshold = ", network.safe_threshold()) + # Initialize peer neighborhoods all_peer_ids = network.get_all_peer_ids() for peer in peers: @@ -1596,15 +1716,18 @@ async def main(num_peers, MIN_NUM_ROUNDS_OF_HONEST_PEERS): current_simluated_timestamp = 0 + time_start = time.time() # Main consensus loop i = 0 - while True: + while not network.should_exit: i += 1 - if i % 100 == 0: - print(f"{i}th iteration") + if i % 50 == 0: + print(f"{i}th iteration at time {time.time() - time_start}, global mempool size = {len(network.global_mempool)}", force=True) + for peer in peers: + print(f"Peer {peer.peer_id} (is_adversary={peer.is_adversary}) is at round {peer.current_round} with {len(peer.my_nodes())} nodes", force=True) # Count peers that have reached MIN_NUM_ROUNDS_OF_HONEST_PEERS rounds peers_completed = sum(1 for c in peers if c.current_round >= MIN_NUM_ROUNDS_OF_HONEST_PEERS) - if peers_completed > (2 * num_peers // 3): + if peers_completed >= network.safe_threshold(): break # Randomly select an action for a random peer @@ -1679,15 +1802,20 @@ def validate_consistency(): final_fair_order = peer_final_order[:i] break - print(f"Global mempool size: {len(network.global_mempool)}") - print(f"Final fair order: {final_fair_order}") - print(f"Ordered txs: {len(final_fair_order)}/{len(network.global_mempool)} after {MIN_NUM_ROUNDS_OF_HONEST_PEERS} rounds at honest peers") + if len(peer_final_order) < len(final_fair_order): + final_fair_order = peer_final_order[:] + + print(f"Global mempool size: {len(network.global_mempool)}", force=True) + print(f"Final fair order: {final_fair_order}", force=True) + print(f"Ordered txs: {len(final_fair_order)}/{len(network.global_mempool)}={float(len(final_fair_order)) / len(network.global_mempool)} after {MIN_NUM_ROUNDS_OF_HONEST_PEERS} rounds at honest peers", force=True) -asyncio.run(main(num_peers=4, MIN_NUM_ROUNDS_OF_HONEST_PEERS=20)) +time_start = time.time() +asyncio.run(main(num_peers=7, MIN_NUM_ROUNDS_OF_HONEST_PEERS=30)) +print(f"Time taken: {time.time() - time_start} seconds") ### Possible attacks: # Long-Range Attacks: If validators controlling past checkpoints sell their keys, an attacker can re-sign an alternative history, leading to checkpoint reversals. -# => Dangerous once attacker can control > 2/3 of the OLD validators -# Majority Takeover: If an attacker gains control of 2/3 of the validators (BFT threshold), they could re-finalize a new chain with different checkpoints. +# => Dangerous once attacker can control N - f of the OLD validators +# Majority Takeover: If an attacker gains control of N - f of the validators (BFT threshold), they could re-finalize a new chain with different checkpoints. # => recursive validity proof + proof of finality # Solution: Post-Unstaking Slashing for X blocks after unstaking (but not able to withdraw before X blocks yet) \ No newline at end of file