diff --git a/api.go b/api.go index df17a173..85b4246a 100644 --- a/api.go +++ b/api.go @@ -57,7 +57,6 @@ type Storage interface { } type Communication interface { - // Nodes returns all nodes that participate in the epoch. Nodes() []NodeID diff --git a/block_scheduler_test.go b/block_scheduler_test.go index ec8fd6aa..d9ef385c 100644 --- a/block_scheduler_test.go +++ b/block_scheduler_test.go @@ -18,24 +18,6 @@ const ( defaultWaitDuration = 500 * time.Millisecond ) -func waitNoReceive(t *testing.T, ch <-chan struct{}) { - select { - case <-ch: - t.Fatal("channel unexpectedly signaled") - case <-time.After(defaultWaitDuration): - // good - } -} - -func waitReceive(t *testing.T, ch <-chan struct{}) { - select { - case <-ch: - // good - case <-time.After(defaultWaitDuration): - t.Fatal("timed out waiting for signal") - } -} - func TestBlockVerificationScheduler(t *testing.T) { t.Run("Schedules immediately when no dependencies", func(t *testing.T) { scheduler := simplex.NewScheduler(testutil.MakeLogger(t), defaultMaxDeps) diff --git a/epoch.go b/epoch.go index c61daff3..ca9679ce 100644 --- a/epoch.go +++ b/epoch.go @@ -6,12 +6,10 @@ package simplex import ( "bytes" "context" - "crypto/rand" "encoding/binary" "errors" "fmt" "math" - "math/big" "slices" "sync" "sync/atomic" @@ -25,9 +23,8 @@ import ( var ErrAlreadyStarted = errors.New("epoch already started") const ( - DefaultMaxRoundWindow = 10 - DefaultProcessingBlocks = 500 - + DefaultMaxRoundWindow = 10 + DefaultProcessingBlocks = 500 DefaultMaxProposalWaitTime = 5 * time.Second DefaultReplicationRequestTimeout = 5 * time.Second DefaultEmptyVoteRebroadcastTimeout = 5 * time.Second @@ -103,7 +100,7 @@ type Epoch struct { monitor *Monitor haltedError error cancelWaitForBlockNotarization context.CancelFunc - timeoutHandler *TimeoutHandler + timeoutHandler *TimeoutHandler[string] replicationState *ReplicationState timedOutRounds map[uint16]uint64 // NodeIndex -> round redeemedRounds map[uint16]uint64 // NodeIndex -> round @@ -166,9 +163,9 @@ func (e *Epoch) HandleMessage(msg *Message, from NodeID) error { return e.handleFinalizeVoteMessage(msg.FinalizeVote, from) case msg.Finalization != nil: return e.handleFinalizationMessage(msg.Finalization, from) - case msg.ReplicationResponse != nil: + case msg.ReplicationResponse != nil && e.ReplicationEnabled: return e.handleReplicationResponse(msg.ReplicationResponse, from) - case msg.ReplicationRequest != nil: + case msg.ReplicationRequest != nil && e.ReplicationEnabled: return e.handleReplicationRequest(msg.ReplicationRequest, from) default: e.Logger.Debug("Invalid message type", zap.Stringer("from", from)) @@ -197,7 +194,7 @@ func (e *Epoch) init() error { e.eligibleNodeIDs = make(map[string]struct{}, len(e.nodes)) e.futureMessages = make(messagesFromNode, len(e.nodes)) e.replicationState = NewReplicationState(e.Logger, e.Comm, e.ID, e.maxRoundWindow, e.ReplicationEnabled, e.StartTime, &e.lock) - e.timeoutHandler = NewTimeoutHandler(e.Logger, e.StartTime, e.nodes) + e.timeoutHandler = NewTimeoutHandler(e.Logger, "emptyVoteRebroadcast", e.StartTime, e.MaxRebroadcastWait, e.emptyVoteTimeoutTaskRunner) for _, node := range e.nodes { e.futureMessages[string(node)] = make(map[uint64]*messagesForRound) @@ -602,7 +599,7 @@ func (e *Epoch) handleFinalizationMessage(message *Finalization, from NodeID) er func (e *Epoch) handleFinalizationForPendingOrFutureRound(message *Finalization, round uint64, nextSeqToCommit uint64) { if round == e.round { // delay collecting future finalization if we are verifying the proposal for that round - // and the finalization is for the current round + // and the finalization is for rounds we have for _, msgs := range e.futureMessages { msgForRound, exists := msgs[round] if exists && msgForRound.proposalBeingProcessed { @@ -613,12 +610,13 @@ func (e *Epoch) handleFinalizationForPendingOrFutureRound(message *Finalization, } // TODO: delay requesting future finalizations and blocks, since blocks could be in transit - e.Logger.Debug("Received finalization for a future round", zap.Uint64("round", round)) + e.Logger.Debug("Received finalization for a pending or future round, and we don't have the block", zap.Uint64("round", round), zap.Uint64("our round", e.round)) if LeaderForRound(e.nodes, e.round).Equals(e.ID) { e.Logger.Debug("We are the leader of this round, but a higher round has been finalized. Aborting block building.") e.blockBuilderCancelFunc() } - e.replicationState.replicateBlocks(message, nextSeqToCommit) + + e.replicationState.ReceivedFutureFinalization(message, nextSeqToCommit) } func (e *Epoch) handleFinalizeVoteMessage(message *FinalizeVote, from NodeID) error { @@ -700,7 +698,7 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { vote := message.Vote e.Logger.Verbo("Received empty vote message", - zap.Stringer("from", from), zap.Uint64("round", vote.Round)) + zap.Stringer("from", from), zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) // Only process point to point empty votes. // A node will never need to forward to us someone else's vote. @@ -714,6 +712,12 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { e.Logger.Debug("Got empty vote from a past round", zap.Uint64("round", vote.Round), zap.Uint64("my round", e.round), zap.Stringer("from", from)) + // if this node has sent us an empty vote for a past round, it may be behind + // send it both the latest finalization and the highest round to help it catch up and initiate the replication process + e.sendLatestFinalization(from) + e.sendHighestRound(from) + + // also send the notarization or finalization for this round as well e.maybeSendNotarizationOrFinalization(from, vote.Round) return nil } @@ -727,7 +731,6 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { } // Else, this is an empty vote for current round - e.Logger.Debug("Received an empty vote for the current round", zap.Uint64("round", vote.Round), zap.Stringer("from", from)) @@ -747,10 +750,60 @@ func (e *Epoch) handleEmptyVoteMessage(message *EmptyVote, from NodeID) error { return e.maybeAssembleEmptyNotarization() } +func (e *Epoch) sendLatestFinalization(to NodeID) { + if e.lastBlock == nil { + e.Logger.Debug("No blocks committed yet, cannot send latest block", zap.Stringer("to", to)) + return + } + + msg := &Message{ + Finalization: &e.lastBlock.Finalization, + } + e.Logger.Debug("Node appears behind, sending them the latest finalization", zap.Stringer("to", to), zap.Uint64("round", e.lastBlock.Finalization.Finalization.Round), zap.Uint64("sequence", e.lastBlock.Finalization.Finalization.Seq)) + e.Comm.Send(msg, to) +} + +func (e *Epoch) sendHighestRound(to NodeID) { + latestQR := e.getLatestVerifiedQuorumRound() + + if latestQR == nil { + e.Logger.Debug("Cannot send latest round because there is none", zap.Stringer("to", to)) + return + } + + if latestQR.Notarization != nil { + msg := &Message{ + Notarization: latestQR.Notarization, + } + e.Logger.Debug("Node appears behind, sending them the highest round", zap.Stringer("to", to), zap.Uint64("round", latestQR.Notarization.Vote.Round)) + e.Comm.Send(msg, to) + return + } + + if latestQR.EmptyNotarization != nil { + msg := &Message{ + EmptyNotarization: latestQR.EmptyNotarization, + } + e.Logger.Debug("Node appears behind, sending them the highest empty notarized round", zap.Stringer("to", to), zap.Uint64("round", latestQR.EmptyNotarization.Vote.Round)) + e.Comm.Send(msg, to) + return + } +} + func (e *Epoch) maybeSendNotarizationOrFinalization(to NodeID, round uint64) { r, ok := e.rounds[round] if !ok { + // round could be an empty notarized round + evs, ok := e.emptyVotes[round] + if ok && evs.emptyNotarization != nil { + msg := &Message{ + EmptyNotarization: evs.emptyNotarization, + } + e.Logger.Debug("Node appears behind, sending them an empty notarization", zap.Stringer("to", to), zap.Uint64("round", round)) + e.Comm.Send(msg, to) + } + return } @@ -995,7 +1048,7 @@ func (e *Epoch) persistFinalization(finalization Finalization) error { // we receive a finalization for a future round e.Logger.Debug("Received a finalization for a future sequence", zap.Uint64("seq", finalization.Finalization.Seq), zap.Uint64("nextSeqToCommit", nextSeqToCommit)) - e.replicationState.replicateBlocks(&finalization, nextSeqToCommit) + e.replicationState.ReceivedFutureFinalization(&finalization, nextSeqToCommit) if err := e.rebroadcastPastFinalizeVotes(); err != nil { return err @@ -1222,7 +1275,7 @@ func (e *Epoch) persistEmptyNotarization(emptyVotes *EmptyVoteSet, shouldBroadca } e.blockVerificationScheduler.ExecuteEmptyRoundDependents(emptyNotarization.Vote.Round) - + e.replicationState.DeleteRound(emptyNotarization.Vote.Round) // don't increase the round if this is a empty notarization for a past round if e.round != emptyNotarization.Vote.Round { return nil @@ -1239,6 +1292,7 @@ func (e *Epoch) persistEmptyNotarization(emptyVotes *EmptyVoteSet, shouldBroadca } func (e *Epoch) maybeMarkLeaderAsTimedOutForFutureBlacklisting(emptyNotarization *EmptyNotarization) error { + e.Logger.Debug("Marking the leader as timed out", zap.Uint64("round", emptyNotarization.Vote.Round), zap.Stringer("leader", LeaderForRound(e.nodes, emptyNotarization.Vote.Round))) var blacklist Blacklist if e.lastBlock != nil { if e.lastBlock.VerifiedBlock == nil { @@ -1327,7 +1381,9 @@ func (e *Epoch) persistNotarization(notarization Notarization) error { } } - e.increaseRound() + if notarization.Vote.Round == e.round && r.finalization == nil { + e.increaseRound() + } return nil } @@ -1354,15 +1410,10 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *EmptyNotarizat e.Logger.Verbo("Received empty notarization message", zap.Uint64("round", vote.Round)) - if e.isRoundTooFarAhead(vote.Round) { - e.Logger.Debug("Received an empty notarization for a too high round", - zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) - return nil - } - if e.isVoteForFinalizedRound(vote.Round) { e.Logger.Debug("Received an empty notarization for a too low round", zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) + return nil } @@ -1371,14 +1422,23 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *EmptyNotarizat return nil } - emptyVotes := e.getOrCreateEmptyVoteSetForRound(vote.Round) - emptyVotes.emptyNotarization = emptyNotarization if vote.Round > e.round { - e.Logger.Debug("Received empty notarization for a future round", + e.Logger.Debug("Received an empty notarization for a higher round", zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) + + e.replicationState.ReceivedFutureRound(vote.Round, 0, e.round, emptyNotarization.QC.Signers()) + + // store in future state if within max round window + if e.isWithinMaxRoundWindow(vote.Round) { + emptyVotes := e.getOrCreateEmptyVoteSetForRound(vote.Round) + emptyVotes.emptyNotarization = emptyNotarization + } return nil } + emptyVotes := e.getOrCreateEmptyVoteSetForRound(vote.Round) + emptyVotes.emptyNotarization = emptyNotarization + // The empty notarization is for this round, so store it but don't broadcast it, as we've received it via a broadcast. return e.persistEmptyNotarization(emptyVotes, false) } @@ -1410,7 +1470,7 @@ func (e *Epoch) handleNotarizationMessage(message *Notarization, from NodeID) er e.Logger.Verbo("Received notarization message", zap.Stringer("from", from), zap.Uint64("round", vote.Round)) - if !e.isVoteRoundValid(vote.Round) { + if e.isVoteForFinalizedRound(vote.Round) { return nil } @@ -1418,6 +1478,17 @@ func (e *Epoch) handleNotarizationMessage(message *Notarization, from NodeID) er return nil } + if vote.Round > e.round { + e.Logger.Debug("Received a notarization for a future round", + zap.Uint64("round", vote.Round), zap.Uint64("our round", e.round)) + e.replicationState.ReceivedFutureRound(vote.Round, vote.Seq, e.round, message.QC.Signers()) + if e.isWithinMaxRoundWindow(vote.Round) { + e.storeFutureNotarization(message, from, vote.Round) + } + + return nil + } + // Can we handle this notarization right away or should we handle it later? round, exists := e.rounds[vote.Round] // If we have already notarized the round, no need to continue @@ -1425,11 +1496,14 @@ func (e *Epoch) handleNotarizationMessage(message *Notarization, from NodeID) er e.Logger.Debug("Received a notarization for an already notarized round") return nil } + // If this notarization is for a round we are currently processing its proposal, // or for a future round, then store it for later use. - if !exists || e.round < vote.Round { - e.Logger.Debug("Received a notarization for a future round", zap.Uint64("round", vote.Round)) + if !exists { + e.Logger.Info("Received a notarization for this round, but we don't have a block for it yet", zap.Uint64("round", vote.Round), zap.Uint64("epoch round", e.round)) e.storeFutureNotarization(message, from, vote.Round) + + // TODO: we need to request the block. return nil } @@ -1560,10 +1634,11 @@ func (e *Epoch) blockDependencies(bh BlockHeader) (*Digest, []uint64) { prevBlock, notarizationOrFinalization, found := e.locateBlock(bh.Seq-1, bh.Prev[:]) if !found { // should never happen since we check this when we verify the proposal metadata - e.Logger.Error("Could not find predecessor block for proposal scheduling", + e.Logger.Info("Could not find predecessor block for proposal scheduling", zap.Uint64("seq", bh.Seq-1), zap.Stringer("prev", bh.Prev)) + // TODO: if not found we need to not schedule right away and wait to get the round of the parent so we know the empty round deps return &bh.Prev, nil } @@ -1587,7 +1662,9 @@ func (e *Epoch) blockDependencies(bh BlockHeader) (*Digest, []uint64) { // processFinalizedBlocks processes a block that has a finalization. // if the block has already been verified, it will index the finalization, // otherwise it will verify the block first. -func (e *Epoch) processFinalizedBlock(block Block, finalization Finalization) error { +func (e *Epoch) processFinalizedBlock(block Block, finalization *Finalization) error { + e.Logger.Debug("Processing finalized block during replication", zap.Uint64("round", finalization.Finalization.Round), zap.Uint64("sequence", finalization.Finalization.Seq)) + round, exists := e.rounds[finalization.Finalization.Round] // dont create a block verification task if the block is already in the rounds map if exists { @@ -1600,7 +1677,7 @@ func (e *Epoch) processFinalizedBlock(block Block, finalization Finalization) er delete(e.rounds, round.num) return e.processFinalizedBlock(block, finalization) } - round.finalization = &finalization + round.finalization = finalization if err := e.indexFinalizations(round.num); err != nil { e.Logger.Error("Failed to index finalization", zap.Error(err)) return err @@ -1609,11 +1686,20 @@ func (e *Epoch) processFinalizedBlock(block Block, finalization Finalization) er return e.processReplicationState() } + blockDependency, missingRounds := e.blockDependencies(block.BlockHeader()) + // because its finalized we don't care about empty rounds + if blockDependency != nil { + e.Logger.Error( + "Received a finalization for nextSeqToCommit that breaks our chain", + zap.Stringer("block digest", block.BlockHeader().Digest), + zap.Stringer("expected digest", blockDependency), + zap.Uint64s("missing rounds", missingRounds), + ) + } + // Create a task that will verify the block in the future, after its predecessors have also been verified. task := e.createFinalizedBlockVerificationTask(e.oneTimeVerifier.Wrap(block), finalization) - - // TODO: in a future PR, we need to handle collecting any potential dependencies for finalized blocks - e.blockVerificationScheduler.ScheduleTaskWithDependencies(task, block.BlockHeader().Seq, nil, []uint64{}) + e.blockVerificationScheduler.ScheduleTaskWithDependencies(task, block.BlockHeader().Seq, blockDependency, []uint64{}) return nil } @@ -1622,10 +1708,11 @@ func (e *Epoch) processFinalizedBlock(block Block, finalization Finalization) er // if the block has already been verified, it will persist the notarization, // otherwise it will verify the block first. func (e *Epoch) processNotarizedBlock(block Block, notarization *Notarization) error { + e.Logger.Debug("Processing notarized block during replication", zap.Uint64("round", notarization.Vote.Round), zap.Uint64("sequence", notarization.Vote.Seq)) md := block.BlockHeader() round, exists := e.rounds[md.Round] - // dont create a block verification task if the block is already in the rounds map + // don't create a block verification task if the block is already in the rounds map if exists { // We could have a block in the rounds map, as well as an empty notarization. // its important to not create a conflicting notarization for that round. @@ -1666,8 +1753,8 @@ func (e *Epoch) processNotarizedBlock(block Block, notarization *Notarization) e task := e.createNotarizedBlockVerificationTask(e.oneTimeVerifier.Wrap(block), *notarization) blockDependency, missingRounds := e.blockDependencies(md) - // TODO: if we have dependencies during replication, we are stuck since this means a node - // must have sent us an empty notarization for a round this block depends on. + e.replicationState.CreateDependencyTasks(blockDependency, md.Seq-1, missingRounds) + e.blockVerificationScheduler.ScheduleTaskWithDependencies(task, md.Seq, blockDependency, missingRounds) return nil @@ -1762,7 +1849,7 @@ func (e *Epoch) createBlockVerificationTask(block Block, from NodeID, vote Vote) } } -func (e *Epoch) createFinalizedBlockVerificationTask(block Block, finalization Finalization) func() Digest { +func (e *Epoch) createFinalizedBlockVerificationTask(block Block, finalization *Finalization) func() Digest { return func() Digest { md := block.BlockHeader() @@ -1777,15 +1864,11 @@ func (e *Epoch) createFinalizedBlockVerificationTask(block Block, finalization F if err != nil { e.Logger.Debug("Failed verifying block", zap.Error(err)) // if we fail to verify the block, we re-add to request timeout - numSigners := int64(len(finalization.QC.Signers())) - index, err := rand.Int(rand.Reader, big.NewInt(numSigners)) + err = e.replicationState.ResendFinalizationRequest(md.Seq, finalization.QC.Signers()) if err != nil { e.haltedError = err - e.Logger.Debug("Failed to generate random index", zap.Error(err)) - return md.Digest + e.Logger.Debug("Failed to resend finalization", zap.Error(err)) } - - e.replicationState.sendRequestToNode(md.Seq, md.Seq, finalization.QC.Signers(), int(index.Int64())) return md.Digest } @@ -1801,7 +1884,7 @@ func (e *Epoch) createFinalizedBlockVerificationTask(block Block, finalization F return md.Digest } - if err := e.indexFinalization(verifiedBlock, finalization); err != nil { + if err := e.indexFinalization(verifiedBlock, *finalization); err != nil { e.haltedError = err e.Logger.Error("Failed to index finalization", zap.Error(err)) return md.Digest @@ -1832,6 +1915,7 @@ func (e *Epoch) createNotarizedBlockVerificationTask(block Block, notarization N verifiedBlock, err := block.Verify(context.Background()) if err != nil { e.Logger.Debug("Failed verifying block", zap.Error(err)) + // TODO: if we fail to verify the block, we should re-request it from the replication state return md.Digest } @@ -1841,8 +1925,7 @@ func (e *Epoch) createNotarizedBlockVerificationTask(block Block, notarization N // we started verifying the block when we didn't have a notarization, however its // possible we received a notarization or empty notarization for this block in the meantime. round, ok := e.rounds[md.Round] - emptyVote, emptyOk := e.emptyVotes[md.Round] - if (ok && round.notarization != nil) || (emptyOk && emptyVote.emptyNotarization != nil) { + if ok && round.notarization != nil { e.Logger.Debug("Verifying notarized block that already has a notarization for the round", zap.Uint64("round", md.Round)) return md.Digest @@ -1854,14 +1937,10 @@ func (e *Epoch) createNotarizedBlockVerificationTask(block Block, notarization N return md.Digest } - round, ok = e.rounds[block.BlockHeader().Round] - if !ok { - e.Logger.Warn("Unable to get proposed block for the round", zap.Uint64("round", md.Round)) - return md.Digest - } - if err := e.persistNotarization(notarization); err != nil { e.haltedError = err + e.Logger.Error("Failed to persist notarization", zap.Error(err)) + return md.Digest } err = e.processReplicationState() @@ -2232,7 +2311,7 @@ func (e *Epoch) triggerEmptyBlockNotarization(round uint64) { e.Comm.Broadcast(&Message{EmptyVoteMessage: &signedEV}) - e.addEmptyVoteRebroadcastTimeout(&signedEV) + e.addEmptyVoteRebroadcastTimeout() if err := e.maybeAssembleEmptyNotarization(); err != nil { e.Logger.Error("Failed assembling empty notarization", zap.Error(err)) @@ -2240,23 +2319,32 @@ func (e *Epoch) triggerEmptyBlockNotarization(round uint64) { } } -func (e *Epoch) addEmptyVoteRebroadcastTimeout(vote *EmptyVote) { - task := &TimeoutTask{ - NodeID: e.ID, - TaskID: EmptyVoteTimeoutID, - Deadline: e.timeoutHandler.GetTime().Add(e.EpochConfig.MaxRebroadcastWait), - Task: func() { - e.Logger.Debug("Rebroadcasting empty vote because round has not advanced", zap.Uint64("round", vote.Vote.Round)) - e.Comm.Broadcast(&Message{EmptyVoteMessage: vote}) - e.addEmptyVoteRebroadcastTimeout(vote) - }, +func (e *Epoch) emptyVoteTimeoutTaskRunner(_ []string) { + e.lock.Lock() + roundVotes, ok := e.emptyVotes[e.round] + e.lock.Unlock() + + if !ok { + e.Logger.Debug("No empty vote set found to rebroadcast, yet expected to rebroadcast", zap.Uint64("round", e.round)) + return + } + + ourVote, voted := roundVotes.votes[string(e.ID)] + if !voted { + e.Logger.Debug("Our empty vote not found in the set to rebroadcast, yet expected to rebroadcast", zap.Uint64("round", e.round)) + return } - e.timeoutHandler.AddTask(task) + e.Logger.Debug("Rebroadcasting empty vote because round has not advanced", zap.Uint64("round", ourVote.Vote.Round)) + e.Comm.Broadcast(&Message{EmptyVoteMessage: ourVote}) +} + +func (e *Epoch) addEmptyVoteRebroadcastTimeout() { + e.timeoutHandler.AddTask(EmptyVoteTimeoutID) } func (e *Epoch) monitorProgress(round uint64) { - e.Logger.Debug("Monitoring progress", zap.Uint64("round", round)) + e.Logger.Debug("Monitoring progress", zap.Uint64("round", round), zap.Uint64("currentRound", e.round)) ctx, cancelContext := context.WithCancel(context.Background()) noop := func() {} @@ -2468,8 +2556,7 @@ func (e *Epoch) increaseRound() { e.blockBuilderCancelFunc() // remove the rebroadcast empty vote task - e.timeoutHandler.RemoveTask(e.ID, EmptyVoteTimeoutID) - + e.timeoutHandler.RemoveTask(EmptyVoteTimeoutID) prevLeader := LeaderForRound(e.nodes, e.round) nextLeader := LeaderForRound(e.nodes, e.round+1) @@ -2630,34 +2717,68 @@ func (e *Epoch) storeProposal(block VerifiedBlock) bool { // HandleRequest processes a request and returns a response. It also sends a response to the sender. func (e *Epoch) handleReplicationRequest(req *ReplicationRequest, from NodeID) error { - e.Logger.Debug("Received replication request", zap.Stringer("from", from), zap.Int("num seqs", len(req.Seqs)), zap.Uint64("latest round", req.LatestRound)) + e.Logger.Debug("Received replication request", zap.Stringer("from", from), zap.Int("num seqs", len(req.Seqs)), zap.Int("num rounds", len(req.Rounds)), zap.Uint64("latest round", req.LatestRound)) if !e.ReplicationEnabled { return nil } response := &VerifiedReplicationResponse{} - latestRound := e.getLatestVerifiedQuorumRound() + if len(req.Seqs) > int(e.maxRoundWindow) && len(req.Rounds) > int(e.maxRoundWindow) { + e.Logger.Info("Replication request exceeds maximum allowed seqs and rounds", + zap.Stringer("from", from), + zap.Int("num seqs", len(req.Seqs)), + zap.Int("num rounds", len(req.Rounds)), + zap.Uint64("max round window", e.maxRoundWindow)) + return nil + } - if latestRound != nil && latestRound.GetRound() > req.LatestRound { - response.LatestRound = latestRound + if req.LatestRound > 0 { + latestRound := e.getLatestVerifiedQuorumRound() + if latestRound != nil && latestRound.GetRound() > req.LatestRound { + response.LatestRound = latestRound + } + } + if req.LatestFinalizedSeq > 0 { + if e.lastBlock != nil && e.lastBlock.Finalization.Finalization.Seq > req.LatestFinalizedSeq { + response.LatestFinalizedSeq = &VerifiedQuorumRound{ + VerifiedBlock: e.lastBlock.VerifiedBlock, + Finalization: &e.lastBlock.Finalization, + } + } } seqs := req.Seqs slices.Sort(seqs) - data := make([]VerifiedQuorumRound, len(seqs)) + seqData := make([]VerifiedQuorumRound, len(seqs)) for i, seq := range seqs { quorumRound := e.locateQuorumRecord(seq) if quorumRound == nil { // since we are sorted, we can break early - data = data[:i] + seqData = seqData[:i] break } - data[i] = *quorumRound + seqData[i] = *quorumRound } + rounds := req.Rounds + roundData := make([]VerifiedQuorumRound, 0, len(rounds)) + slices.Sort(rounds) + for _, roundNum := range rounds { + quorumRound := e.locateQuorumRecordByRound(roundNum) + if quorumRound == nil { + // we cannot break early since empty votes may + continue + } + roundData = append(roundData, *quorumRound) + } + + data := make([]VerifiedQuorumRound, 0, len(seqData)+len(roundData)) + data = append(data, seqData...) + data = append(data, roundData...) response.Data = data - if len(data) == 0 && response.LatestRound == nil { + + if len(data) == 0 && response.LatestRound == nil && response.LatestFinalizedSeq == nil { e.Logger.Debug("No data found for replication request", zap.Stringer("from", from)) return nil } @@ -2714,6 +2835,38 @@ func (e *Epoch) locateQuorumRecord(seq uint64) *VerifiedQuorumRound { } } +func (e *Epoch) locateQuorumRecordByRound(targetRound uint64) *VerifiedQuorumRound { + var qr *VerifiedQuorumRound + + for _, round := range e.rounds { + blockRound := round.block.BlockHeader().Round + if blockRound == targetRound { + if round.finalization != nil || round.notarization != nil { + qr = &VerifiedQuorumRound{ + VerifiedBlock: round.block, + Finalization: round.finalization, + Notarization: round.notarization, + } + break + } + } + } + + // check if the round is empty notarized + emptyVoteForRound, exists := e.emptyVotes[targetRound] + if exists && emptyVoteForRound.emptyNotarization != nil { + if qr != nil { + qr.EmptyNotarization = emptyVoteForRound.emptyNotarization + return qr + } + qr = &VerifiedQuorumRound{ + EmptyNotarization: emptyVoteForRound.emptyNotarization, + } + } + + return qr +} + func (e *Epoch) haveNotFinalizedNotarizedRound() (uint64, bool) { e.lock.Lock() defer e.lock.Unlock() @@ -2737,42 +2890,35 @@ func (e *Epoch) handleReplicationResponse(resp *ReplicationResponse, from NodeID return nil } - e.Logger.Debug("Received replication response", zap.Stringer("from", from), zap.Int("num seqs", len(resp.Data)), zap.Stringer("latest round", resp.LatestRound)) + e.Logger.Debug("Received replication response", zap.Stringer("from", from), zap.Int("num seqs", len(resp.Data)), zap.Stringer("latest round", resp.LatestRound), zap.Stringer("latest seq", resp.LatestSeq)) nextSeqToCommit := e.nextSeqToCommit() - validRounds := make([]QuorumRound, 0, len(resp.Data)) for _, data := range resp.Data { - if err := data.IsWellFormed(); err != nil { - e.Logger.Debug("Malformed Quorum Round Received", zap.Error(err)) - continue - } - - if data.EmptyNotarization == nil && nextSeqToCommit > data.GetSequence() { - e.Logger.Debug("Received quorum round for a seq that is too far behind", zap.Uint64("seq", data.GetSequence())) - continue - } - - if data.GetSequence() > nextSeqToCommit+e.maxRoundWindow { + if data.Finalization != nil && data.GetSequence() > nextSeqToCommit+e.maxRoundWindow { e.Logger.Debug("Received quorum round for a seq that is too far ahead", zap.Uint64("seq", data.GetSequence())) // we are too far behind, we should ignore this message continue } - if err := e.verifyQuorumRound(data, from); err != nil { - e.Logger.Debug("Received invalid quorum round", zap.Uint64("seq", data.GetSequence()), zap.Stringer("from", from)) + // We may be really far behind, so we shouldn't process sequences unless they are the nextSequenceToCommit + if data.GetRound() > e.round+e.maxRoundWindow && data.GetSequence() != nextSeqToCommit { + e.Logger.Debug("Received quorum round for a round that is too far ahead", zap.Uint64("round", data.GetRound())) + // we are too far behind, we should ignore this message continue } - validRounds = append(validRounds, data) - e.replicationState.StoreQuorumRound(data) + if err := e.processQuorumRound(&data, from); err != nil { + e.Logger.Debug("Failed processing quorum round", zap.Error(err)) + } } - if err := e.processLatestRoundReceived(resp.LatestRound, from); err != nil { + if err := e.processQuorumRound(resp.LatestRound, from); err != nil { e.Logger.Debug("Failed processing latest round", zap.Error(err)) - return nil } - e.replicationState.receivedReplicationResponse(validRounds, from) + if err := e.processQuorumRound(resp.LatestSeq, from); err != nil { + e.Logger.Debug("Failed processing latest seq", zap.Error(err)) + } return e.processReplicationState() } @@ -2807,6 +2953,7 @@ func (e *Epoch) verifyQuorumRound(q QuorumRound, from NodeID) error { } func (e *Epoch) processEmptyNotarization(emptyNotarization *EmptyNotarization) error { + e.Logger.Debug("Processing empty notarization due to replication", zap.Uint64("round", emptyNotarization.Vote.Round), zap.Uint64("our round", e.round)) emptyVotes := e.getOrCreateEmptyVoteSetForRound(emptyNotarization.Vote.Round) emptyVotes.emptyNotarization = emptyNotarization @@ -2818,72 +2965,77 @@ func (e *Epoch) processEmptyNotarization(emptyNotarization *EmptyNotarization) e return e.processReplicationState() } -func (e *Epoch) processLatestRoundReceived(latestRound *QuorumRound, from NodeID) error { - if latestRound == nil { +// processQuorumRound processes a quorum round received from another node. +// It verifies the quorum round and stores it in the replication state if valid. +func (e *Epoch) processQuorumRound(round *QuorumRound, from NodeID) error { + if round == nil { return nil } - // make sure the latest round is well formed - if err := latestRound.IsWellFormed(); err != nil { - e.Logger.Debug("Received invalid latest round", zap.Error(err)) - return err + // make sure the round is well formed + if err := round.IsWellFormed(); err != nil { + return fmt.Errorf("received malformed latest round: %w", err) } - if err := e.verifyQuorumRound(*latestRound, from); err != nil { - e.Logger.Debug("Received invalid latest round", zap.Error(err)) - return err + if round.Finalization == nil && e.isVoteForFinalizedRound(round.GetRound()) { + return fmt.Errorf("received a quorum round for a round that has been finalized. round: %d; seq: %d", round.GetRound(), round.GetSequence()) } - e.replicationState.StoreQuorumRound(*latestRound) + if round.Finalization != nil && e.lastBlock != nil && e.lastBlock.VerifiedBlock.BlockHeader().Seq > round.Finalization.Finalization.Seq { + return fmt.Errorf("received a finalized round for a committed sequence. round: %d; seq: %d", round.GetRound(), round.GetSequence()) + } + + if err := e.verifyQuorumRound(*round, from); err != nil { + return fmt.Errorf("failed verifying latest round: %w", err) + } + + e.replicationState.StoreQuorumRound(round) return nil } func (e *Epoch) processReplicationState() error { nextSeqToCommit := e.nextSeqToCommit() - // check if we are done replicating and should start a new round - if e.replicationState.isReplicationComplete(nextSeqToCommit, e.round) { - // TODO: an adversarial node can send multiple empty replication responses, causing us - // to call start round multiple times. This is potentially bad if we are the leader, since we will - // propose multiple blocks for the same round. - return e.startRound() - } - - e.replicationState.maybeCollectFutureSequences(e.nextSeqToCommit()) + // We might have advanced the rounds from non-replicating paths such as future messages. Advance replication state accordingly. + e.replicationState.MaybeAdvancedState(nextSeqToCommit, e.round) // first we check if we can commit the next sequence, it is ok to try and commit the next sequence // directly, since if there are any empty notarizations, `indexFinalization` will // increment the round properly. block, finalization, exists := e.replicationState.GetFinalizedBlockForSequence(nextSeqToCommit) if exists { - delete(e.replicationState.receivedQuorumRounds, block.BlockHeader().Round) + e.replicationState.DeleteSeq(nextSeqToCommit) return e.processFinalizedBlock(block, finalization) } - qRound, ok := e.replicationState.receivedQuorumRounds[e.round] - if ok && qRound.Notarization != nil { - if qRound.Finalization != nil { - e.Logger.Debug("Delaying processing a QuorumRound that has an Finalization != NextSeqToCommit", zap.Stringer("QuorumRound", &qRound)) - return nil + // process the lowest round in our replication state(up to and including the current epoch round) + lowestRound := e.replicationState.GetLowestRound() + if lowestRound != nil && lowestRound.GetRound() <= e.round { + e.Logger.Debug("Process replication state", zap.Stringer("lowest round", lowestRound), zap.Uint64("Our round", e.round)) + + // Remove before processing to avoid infinite recursion + e.replicationState.DeleteRound(lowestRound.GetRound()) + if lowestRound.Notarization != nil { + if err := e.processNotarizedBlock(lowestRound.Block, lowestRound.Notarization); err != nil { + return err + } + } + // we can also have an empty notarization + if lowestRound.EmptyNotarization != nil { + if err := e.processEmptyNotarization(lowestRound.EmptyNotarization); err != nil { + return err + } } - delete(e.replicationState.receivedQuorumRounds, e.round) - return e.processNotarizedBlock(qRound.Block, qRound.Notarization) - } - // the current round is an empty notarization - if ok && qRound.EmptyNotarization != nil { - delete(e.replicationState.receivedQuorumRounds, qRound.GetRound()) - return e.processEmptyNotarization(qRound.EmptyNotarization) } - roundAdvanced, err := e.maybeAdvanceRoundFromEmptyNotarizations() + // maybe there are no replication rounds < our round but we can still advance from empty notarizations + err := e.maybeAdvanceRoundFromEmptyNotarizations() if err != nil { return err } - if roundAdvanced { - return e.processReplicationState() - } + e.Logger.Debug("Nothing to process in replication state", zap.Uint64("Our round", e.round), zap.Uint64("NextSeqToCommit", nextSeqToCommit), zap.Stringer("lowest replication round", lowestRound)) return nil } @@ -2896,22 +3048,24 @@ func (e *Epoch) processReplicationState() error { // QRound2 { round 8, seq 1 } // // in this case we can infer there was 8-1 empty notarizations during rounds [2, 8]. -func (e *Epoch) maybeAdvanceRoundFromEmptyNotarizations() (bool, error) { +func (e *Epoch) maybeAdvanceRoundFromEmptyNotarizations() error { round := e.round expectedSeq := e.metadata().Seq - nextSeqQuorum := e.replicationState.GetQuorumRoundWithSeq(expectedSeq) - if nextSeqQuorum != nil { + block := e.replicationState.GetBlockWithSeq(expectedSeq) + if block != nil { + bh := block.BlockHeader() // num empty notarizations - if round < nextSeqQuorum.GetRound() { - for range nextSeqQuorum.GetRound() - round { + if round < bh.Round { + e.Logger.Debug("Advancing round from a gap in empty notarizations", zap.Uint64("epoch round", round), zap.Uint64("block round", bh.Round)) + for range bh.Round - round { e.increaseRound() } - return true, nil + return e.processReplicationState() } } - return false, nil + return nil } // getHighestRound returns the highest round that has either a notarization or finalization @@ -2953,7 +3107,6 @@ func (e *Epoch) getLatestVerifiedQuorumRound() *VerifiedQuorumRound { return GetLatestVerifiedQuorumRound( e.getHighestRound(), e.getHighestEmptyNotarization(), - e.lastBlock, ) } diff --git a/epoch_multinode_test.go b/epoch_multinode_test.go index 2ace9d53..432616da 100644 --- a/epoch_multinode_test.go +++ b/epoch_multinode_test.go @@ -4,6 +4,7 @@ package simplex_test import ( + "fmt" "sync" "sync/atomic" "testing" @@ -155,10 +156,6 @@ func TestSimplexMultiNodeBlacklist(t *testing.T) { testutil.NewSimplexNode(t, nodes[2], net, testEpochConfig) testutil.NewSimplexNode(t, nodes[3], net, testEpochConfig) - for _, n := range net.Instances[:3] { - n.Silence() - } - net.StartInstances() // Advance to the fourth node's turn by building three blocks @@ -353,6 +350,9 @@ func TestSplitVotes(t *testing.T) { } } + // allow outstanding messages to be dropped + time.Sleep(100 * time.Millisecond) + net.SetAllNodesMessageFilter(testutil.AllowAllMessages) time2 := splitNode2.E.StartTime @@ -374,7 +374,7 @@ func TestSplitVotes(t *testing.T) { splitNode3.WAL.AssertNotarization(0) for _, n := range net.Instances { - require.Equal(t, uint64(0), n.Storage.NumBlocks()) + require.Equal(t, uint64(0), n.Storage.NumBlocks(), fmt.Sprintf("node %s should not have", n.E.ID)) require.Equal(t, uint64(1), n.E.Metadata().Round) require.Equal(t, uint64(1), n.E.Metadata().Seq) } diff --git a/epoch_test.go b/epoch_test.go index 2acb3d2f..64576351 100644 --- a/epoch_test.go +++ b/epoch_test.go @@ -1263,14 +1263,12 @@ func TestEpochVotesForEquivocatedVotes(t *testing.T) { equivocatedBlock.Data = []byte{1, 2, 3} equivocatedBlock.ComputeDigest() testutil.InjectTestVote(t, e, equivocatedBlock, nodes[1]) - eqbh := equivocatedBlock.BlockHeader() // We should not have sent a notarization yet, since we have not received enough votes for the block we received from the leader require.Never(t, func() bool { select { case msg := <-recordedMessages: if msg.Notarization != nil { - fmt.Println(msg.Notarization.Vote.BlockHeader.Equals(&eqbh)) return true } default: @@ -1301,6 +1299,52 @@ func TestEpochVotesForEquivocatedVotes(t *testing.T) { } } +// Ensures we don't double increment the round on persisting a notarization +func TestDoubleIncrementOnPersistNotarization(t *testing.T) { + // add an empty notarization, then a notarization for a previous round + bb := &testutil.TestBlockBuilder{Out: make(chan *testutil.TestBlock, 1)} + nodes := []NodeID{{1}, {2}, {3}, {4}} + conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[3], testutil.NewNoopComm(nodes), bb) + conf.ReplicationEnabled = true + + e, err := NewEpoch(conf) + require.NoError(t, err) + + require.NoError(t, e.Start()) + + advanceRoundFromEmpty(t, e) + require.Equal(t, uint64(1), e.Metadata().Round) + + // create a notarization for round 0 + md := ProtocolMetadata{ + Epoch: 0, + Round: 0, + Seq: 0, + } + _, ok := bb.BuildBlock(context.Background(), md, emptyBlacklist) + require.True(t, ok) + + block := <-bb.Out + notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, nodes) + require.NoError(t, err) + + err = e.HandleMessage(&Message{ + ReplicationResponse: &ReplicationResponse{ + Data: []QuorumRound{ + { + Block: block, + Notarization: ¬arization, + }, + }, + }, + }, nodes[0]) + require.NoError(t, err) + + wal.AssertWALSize(2) + // ensure the round is still 1 + require.Equal(t, uint64(1), e.Metadata().Round) +} + // ListnerComm is a comm that listens for incoming messages // and sends them to the [in] channel type listenerComm struct { diff --git a/msg.go b/msg.go index 385e0ccd..9f40f288 100644 --- a/msg.go +++ b/msg.go @@ -223,18 +223,22 @@ type QuorumCertificate interface { } type ReplicationRequest struct { - Seqs []uint64 // sequences we are requesting - LatestRound uint64 // latest round that we are aware of + Seqs []uint64 // sequences we are requesting + Rounds []uint64 // rounds we are requesting + LatestRound uint64 // latest round that we are aware of + LatestFinalizedSeq uint64 // latest finalized sequence that we are aware of } type ReplicationResponse struct { Data []QuorumRound LatestRound *QuorumRound + LatestSeq *QuorumRound } type VerifiedReplicationResponse struct { - Data []VerifiedQuorumRound - LatestRound *VerifiedQuorumRound + Data []VerifiedQuorumRound + LatestRound *VerifiedQuorumRound + LatestFinalizedSeq *VerifiedQuorumRound } // QuorumRound represents a round that has achieved quorum on either @@ -250,13 +254,15 @@ type QuorumRound struct { // (block, notarization) or (block, finalization) or // (empty notarization) func (q *QuorumRound) IsWellFormed() error { - if q.EmptyNotarization != nil && q.Block == nil { - return nil - } else if q.Block != nil && (q.Notarization != nil || q.Finalization != nil) { - return nil + if q.Block == nil && q.EmptyNotarization == nil { + return fmt.Errorf("malformed QuorumRound, empty block and notarization fields") + } + + if q.Block != nil && (q.Notarization == nil && q.Finalization == nil) { + return fmt.Errorf("malformed QuorumRound, block but no notarization or finalization") } - return fmt.Errorf("malformed QuorumRound") + return nil } func (q *QuorumRound) GetRound() uint64 { @@ -284,10 +290,15 @@ func (q *QuorumRound) VerifyQCConsistentWithBlock() error { return err } - if q.EmptyNotarization != nil { + if q.Block == nil { return nil } + // if an empty notarization is included, ensure the round is equal to the block round + if q.EmptyNotarization != nil && q.EmptyNotarization.Vote.Round != q.Block.BlockHeader().Round { + return fmt.Errorf("empty round does not match block round") + } + // ensure the finalization or notarization we get relates to the block blockDigest := q.Block.BlockHeader().Digest @@ -314,7 +325,7 @@ func (q *QuorumRound) String() string { if err != nil { return fmt.Sprintf("QuorumRound{Error: %s}", err) } else { - return fmt.Sprintf("QuorumRound{Round: %d, Seq: %d}", q.GetRound(), q.GetSequence()) + return fmt.Sprintf("QuorumRound{Round: %d, Seq: %d, Finalized: %t}", q.GetRound(), q.GetSequence(), q.Finalization != nil) } } diff --git a/msg_test.go b/msg_test.go index 25f6cac0..123f9a4b 100644 --- a/msg_test.go +++ b/msg_test.go @@ -82,6 +82,24 @@ func TestQuorumRoundMalformed(t *testing.T) { }, expectedErr: true, }, + { + name: "block and notarization and empty notarization", + qr: simplex.QuorumRound{ + Block: &testutil.TestBlock{}, + Notarization: &simplex.Notarization{}, + EmptyNotarization: &simplex.EmptyNotarization{}, + }, + expectedErr: false, + }, + { + name: "block and finalization and empty notarization", + qr: simplex.QuorumRound{ + Block: &testutil.TestBlock{}, + Finalization: &simplex.Finalization{}, + EmptyNotarization: &simplex.EmptyNotarization{}, + }, + expectedErr: false, + }, } for _, test := range tests { diff --git a/pos_test.go b/pos_test.go index 78650ae4..65e3345c 100644 --- a/pos_test.go +++ b/pos_test.go @@ -5,12 +5,12 @@ package simplex_test import ( "bytes" - "fmt" + "testing" + "time" + "github.com/ava-labs/simplex" "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" - "testing" - "time" ) func TestPoS(t *testing.T) { @@ -133,8 +133,6 @@ func TestPoS(t *testing.T) { testutil.WaitToEnterRound(t, n.E, 15) } - fmt.Println(simplex.LeaderForRound(nodes, 15)) - // Now, disconnect the node with the highest stake (node 3) and observe the network is stuck net.Disconnect(nodes[2]) net.TriggerLeaderBlockBuilder(15) diff --git a/replication.go b/replication.go deleted file mode 100644 index fb5d2d34..00000000 --- a/replication.go +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package simplex - -import ( - "fmt" - "math" - "slices" - "sync" - "time" - - "go.uber.org/zap" -) - -// signedSequence is a sequence that has been signed by a quorum certificate. -// it essentially is a quorum round without the enforcement of needing a block with a -// finalization or notarization. -type signedSequence struct { - seq uint64 - signers NodeIDs -} - -func newSignedSequenceFromRound(round QuorumRound) (*signedSequence, error) { - ss := &signedSequence{} - switch { - case round.Finalization != nil: - ss.signers = round.Finalization.QC.Signers() - ss.seq = round.Finalization.Finalization.Seq - case round.Notarization != nil: - ss.signers = round.Notarization.QC.Signers() - ss.seq = round.Notarization.Vote.Seq - case round.EmptyNotarization != nil: - return nil, fmt.Errorf("should not create signed sequence from empty notarization") - default: - return nil, fmt.Errorf("round does not contain a finalization, empty notarization, or notarization") - } - - return ss, nil -} - -type ReplicationState struct { - lock *sync.Mutex - logger Logger - enabled bool - maxRoundWindow uint64 - comm Communication - id NodeID - - // latest seq requested - lastSequenceRequested uint64 - - // highest sequence we have received - highestSequenceObserved *signedSequence - - // receivedQuorumRounds maps rounds to quorum rounds - receivedQuorumRounds map[uint64]QuorumRound - - // request iterator - requestIterator int - - timeoutHandler *TimeoutHandler -} - -func NewReplicationState(logger Logger, comm Communication, id NodeID, maxRoundWindow uint64, enabled bool, start time.Time, lock *sync.Mutex) *ReplicationState { - return &ReplicationState{ - lock: lock, - logger: logger, - enabled: enabled, - comm: comm, - id: id, - maxRoundWindow: maxRoundWindow, - receivedQuorumRounds: make(map[uint64]QuorumRound), - timeoutHandler: NewTimeoutHandler(logger, start, comm.Nodes()), - } -} - -func (r *ReplicationState) AdvanceTime(now time.Time) { - r.timeoutHandler.Tick(now) -} - -// isReplicationComplete returns true if we have finished the replication process. -// The process is considered finished once [currentRound] has caught up to the highest round received. -func (r *ReplicationState) isReplicationComplete(nextSeqToCommit uint64, currentRound uint64) bool { - if r.highestSequenceObserved == nil { - return true - } - - return nextSeqToCommit > r.highestSequenceObserved.seq && currentRound > r.highestKnownRound() -} - -func (r *ReplicationState) collectMissingSequences(observedSignedSeq *signedSequence, nextSeqToCommit uint64) { - observedSeq := observedSignedSeq.seq - // Node is behind, but we've already sent messages to collect future finalizations - if r.lastSequenceRequested >= observedSeq && r.highestSequenceObserved != nil { - return - } - - if r.highestSequenceObserved == nil || observedSeq > r.highestSequenceObserved.seq { - r.highestSequenceObserved = observedSignedSeq - } - - startSeq := math.Max(float64(nextSeqToCommit), float64(r.lastSequenceRequested)) - // Don't exceed the max round window - endSeq := math.Min(float64(observedSeq), float64(r.maxRoundWindow+nextSeqToCommit)) - - r.logger.Debug("Node is behind, requesting missing finalizations", zap.Uint64("seq", observedSeq), zap.Uint64("startSeq", uint64(startSeq)), zap.Uint64("endSeq", uint64(endSeq))) - r.sendReplicationRequests(uint64(startSeq), uint64(endSeq)) -} - -// sendReplicationRequests sends requests for missing sequences for the -// range of sequences [start, end] <- inclusive. It does so by splitting the -// range of sequences equally amount the nodes that have signed [highestSequenceObserved]. -func (r *ReplicationState) sendReplicationRequests(start uint64, end uint64) { - // it's possible our node has signed [highestSequenceObserved]. - // For example this may happen if our node has sent a finalization - // for [highestSequenceObserved] and has not received the - // finalization from the network. - nodes := r.highestSequenceObserved.signers.Remove(r.id) - numNodes := len(nodes) - - seqRequests := DistributeSequenceRequests(start, end, numNodes) - - r.logger.Debug("Distributing replication requests", zap.Uint64("start", start), zap.Uint64("end", end), zap.Stringer("nodes", NodeIDs(nodes))) - for i, seqs := range seqRequests { - index := (i + r.requestIterator) % numNodes - r.sendRequestToNode(seqs.Start, seqs.End, nodes, index) - } - - r.lastSequenceRequested = end - // next time we send requests, we start with a different permutation - r.requestIterator++ -} - -// sendRequestToNode requests the sequences [start, end] from nodes[index]. -// In case the nodes[index] does not respond, we create a timeout that will -// re-send the request. -func (r *ReplicationState) sendRequestToNode(start uint64, end uint64, nodes []NodeID, index int) { - r.logger.Debug("Requesting missing finalizations ", - zap.Stringer("from", nodes[index]), - zap.Uint64("start", start), - zap.Uint64("end", end)) - seqs := make([]uint64, (end+1)-start) - for i := start; i <= end; i++ { - seqs[i-start] = i - } - request := &ReplicationRequest{ - Seqs: seqs, - LatestRound: r.highestSequenceObserved.seq, - } - msg := &Message{ReplicationRequest: request} - - task := r.createReplicationTimeoutTask(start, end, nodes, index) - - r.timeoutHandler.AddTask(task) - - r.comm.Send(msg, nodes[index]) -} - -func (r *ReplicationState) createReplicationTimeoutTask(start, end uint64, nodes []NodeID, index int) *TimeoutTask { - taskFunc := func() { - r.lock.Lock() - defer r.lock.Unlock() - r.sendRequestToNode(start, end, nodes, (index+1)%len(nodes)) - } - timeoutTask := &TimeoutTask{ - Start: start, - End: end, - NodeID: nodes[index], - TaskID: getTimeoutID(start, end), - Task: taskFunc, - Deadline: r.timeoutHandler.GetTime().Add(DefaultReplicationRequestTimeout), - } - - return timeoutTask -} - -// receivedReplicationResponse notifies the task handler a response was received. If the response -// was incomplete(meaning our timeout expected more seqs), then we will create a new timeout -// for the missing sequences and send the request to a different node. -func (r *ReplicationState) receivedReplicationResponse(data []QuorumRound, node NodeID) { - seqs := make([]uint64, 0, len(data)) - - // remove all sequences where we expect a finalization but only received a notarization - highestSeq := r.highestSequenceObserved.seq - for _, qr := range data { - if qr.GetSequence() <= highestSeq && qr.Finalization == nil && qr.Notarization != nil { - r.logger.Debug("Received notarization without finalization, skipping", zap.Stringer("from", node), zap.Uint64("seq", qr.GetSequence())) - continue - } - - seqs = append(seqs, qr.GetSequence()) - } - - slices.Sort(seqs) - - task := FindReplicationTask(r.timeoutHandler, node, seqs) - if task == nil { - r.logger.Debug("Could not find a timeout task associated with the replication response", zap.Stringer("from", node), zap.Any("seqs", seqs)) - return - } - r.timeoutHandler.RemoveTask(node, task.TaskID) - - // we found the timeout, now make sure all seqs were returned - missing := findMissingNumbersInRange(task.Start, task.End, seqs) - if len(missing) == 0 { - return - } - - // if not all sequences were returned, create new timeouts - r.logger.Debug("Received missing sequences in the replication response", zap.Stringer("from", node), zap.Any("missing", missing)) - nodes := r.highestSequenceObserved.signers.Remove(r.id) - numNodes := len(nodes) - segments := CompressSequences(missing) - for i, seqs := range segments { - index := i % numNodes - newTask := r.createReplicationTimeoutTask(seqs.Start, seqs.End, nodes, index) - r.timeoutHandler.AddTask(newTask) - } -} - -// findMissingNumbersInRange finds numbers in an array constructed by [start...end] that are not in [nums] -// ex. (3, 10, [1,2,3,4,5,6]) -> [7,8,9,10] -func findMissingNumbersInRange(start, end uint64, nums []uint64) []uint64 { - numMap := make(map[uint64]struct{}) - for _, num := range nums { - numMap[num] = struct{}{} - } - - var result []uint64 - - for i := start; i <= end; i++ { - if _, exists := numMap[i]; !exists { - result = append(result, i) - } - } - - return result -} - -func (r *ReplicationState) replicateBlocks(finalization *Finalization, nextSeqToCommit uint64) { - if !r.enabled { - return - } - - signedSequence := &signedSequence{ - seq: finalization.Finalization.Seq, - signers: finalization.QC.Signers(), - } - - r.collectMissingSequences(signedSequence, nextSeqToCommit) -} - -// maybeCollectFutureSequences attempts to collect future sequences if -// there are more to be collected and the round has caught up for us to send the request. -func (r *ReplicationState) maybeCollectFutureSequences(nextSequenceToCommit uint64) { - if !r.enabled { - return - } - - if r.lastSequenceRequested >= r.highestSequenceObserved.seq { - return - } - - // we send out more requests once our seq has caught up to 1/2 of the maxRoundWindow - if nextSequenceToCommit+r.maxRoundWindow/2 > r.lastSequenceRequested { - r.collectMissingSequences(r.highestSequenceObserved, nextSequenceToCommit) - } -} - -func (r *ReplicationState) StoreQuorumRound(round QuorumRound) { - if _, ok := r.receivedQuorumRounds[round.GetRound()]; ok { - // maybe this quorum round was behind - if r.receivedQuorumRounds[round.GetRound()].Finalization == nil && round.Finalization != nil { - r.receivedQuorumRounds[round.GetRound()] = round - } - return - } - - if round.EmptyNotarization == nil && round.GetSequence() > r.highestSequenceObserved.seq { - signedSeq, err := newSignedSequenceFromRound(round) - if err != nil { - // should never be here since we already checked the QuorumRound was valid - r.logger.Error("Error creating signed sequence from round", zap.Error(err)) - return - } - - r.highestSequenceObserved = signedSeq - } - - r.logger.Debug("Stored quorum round ", zap.Stringer("qr", &round)) - r.receivedQuorumRounds[round.GetRound()] = round -} - -func (r *ReplicationState) GetFinalizedBlockForSequence(seq uint64) (Block, Finalization, bool) { - for _, round := range r.receivedQuorumRounds { - if round.GetSequence() == seq { - if round.Block == nil || round.Finalization == nil { - // this could be an empty notarization - continue - } - return round.Block, *round.Finalization, true - } - } - return nil, Finalization{}, false -} - -func (r *ReplicationState) highestKnownRound() uint64 { - var highestRound uint64 - for round := range r.receivedQuorumRounds { - if round > highestRound { - highestRound = round - } - } - return highestRound -} - -func (r *ReplicationState) GetQuorumRoundWithSeq(seq uint64) *QuorumRound { - for _, round := range r.receivedQuorumRounds { - if round.GetSequence() == seq { - return &round - } - } - return nil -} - -// FindReplicationTask returns a TimeoutTask assigned to [node] that contains the lowest sequence in [seqs]. -// A sequence is considered "contained" if it falls between a task's Start (inclusive) and End (inclusive). -func FindReplicationTask(t *TimeoutHandler, node NodeID, seqs []uint64) *TimeoutTask { - var lowestTask *TimeoutTask - - t.forEach(string(node), func(tt *TimeoutTask) { - for _, seq := range seqs { - if seq >= tt.Start && seq <= tt.End { - if lowestTask == nil { - lowestTask = tt - } else if seq < lowestTask.Start { - lowestTask = tt - } - } - } - }) - - return lowestTask -} diff --git a/replication_request_test.go b/replication_request_test.go index db271c8b..57aab2ac 100644 --- a/replication_request_test.go +++ b/replication_request_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "testing" + "time" "github.com/ava-labs/simplex" "github.com/ava-labs/simplex/testutil" @@ -43,6 +44,7 @@ func TestReplicationRequestIndexedBlocks(t *testing.T) { msg := <-comm.in resp := msg.VerifiedReplicationResponse require.Nil(t, resp.LatestRound) + require.Nil(t, resp.LatestFinalizedSeq) require.Equal(t, len(sequences), len(resp.Data)) for i, data := range resp.Data { @@ -60,9 +62,7 @@ func TestReplicationRequestIndexedBlocks(t *testing.T) { err = e.HandleMessage(req, nodes[1]) require.NoError(t, err) - msg = <-comm.in - resp = msg.VerifiedReplicationResponse - require.Zero(t, len(resp.Data)) + require.Never(t, func() bool { return len(comm.in) > 0 }, 5*time.Second, 100*time.Millisecond) } // TestReplicationRequestNotarizations tests replication requests for notarized blocks. @@ -97,8 +97,9 @@ func TestReplicationRequestNotarizations(t *testing.T) { } req := &simplex.Message{ ReplicationRequest: &simplex.ReplicationRequest{ - Seqs: seqs, - LatestRound: 0, + Seqs: seqs, + LatestRound: 1, + LatestFinalizedSeq: 0, }, } @@ -109,7 +110,30 @@ func TestReplicationRequestNotarizations(t *testing.T) { resp := msg.VerifiedReplicationResponse require.NoError(t, err) require.NotNil(t, resp) + require.NotNil(t, resp.LatestRound) + require.Nil(t, resp.LatestFinalizedSeq) require.Equal(t, *resp.LatestRound, rounds[numBlocks-1]) + + for _, round := range resp.Data { + require.Nil(t, round.EmptyNotarization) + notarizedBlock, ok := rounds[round.VerifiedBlock.BlockHeader().Round] + require.True(t, ok) + require.Equal(t, notarizedBlock.VerifiedBlock, round.VerifiedBlock) + require.Equal(t, notarizedBlock.Notarization, round.Notarization) + } + + // now ask for the notarizations as rounds + req = &simplex.Message{ + ReplicationRequest: &simplex.ReplicationRequest{ + Rounds: seqs, + }, + } + + err = e.HandleMessage(req, nodes[1]) + require.NoError(t, err) + + msg = <-comm.in + resp = msg.VerifiedReplicationResponse for _, round := range resp.Data { require.Nil(t, round.EmptyNotarization) notarizedBlock, ok := rounds[round.VerifiedBlock.BlockHeader().Round] @@ -134,8 +158,11 @@ func TestReplicationRequestMixed(t *testing.T) { numBlocks := uint64(8) rounds := make(map[uint64]simplex.VerifiedQuorumRound) + + numExpectedRounds := 0 // only produce a notarization for blocks we are the leader, otherwise produce an empty notarization for i := range numBlocks { + numExpectedRounds++ leaderForRound := bytes.Equal(simplex.LeaderForRound(nodes, uint64(i)), e.ID) emptyBlock := !leaderForRound if emptyBlock { @@ -155,18 +182,74 @@ func TestReplicationRequestMixed(t *testing.T) { VerifiedBlock: block, Notarization: notarization, } + } require.Equal(t, uint64(numBlocks), e.Metadata().Round) - seqs := make([]uint64, 0, len(rounds)) + roundsRequested := make([]uint64, 0, len(rounds)) for k := range rounds { - seqs = append(seqs, k) + roundsRequested = append(roundsRequested, k) + } + + req := &simplex.Message{ + ReplicationRequest: &simplex.ReplicationRequest{ + Rounds: roundsRequested, + LatestRound: 1, + }, + } + + err = e.HandleMessage(req, nodes[1]) + require.NoError(t, err) + + msg := <-comm.in + resp := msg.VerifiedReplicationResponse + require.Equal(t, *resp.LatestRound, rounds[numBlocks-1]) + require.Equal(t, numExpectedRounds, len(resp.Data)) + + for _, round := range resp.Data { + notarizedBlock, ok := rounds[round.GetRound()] + require.True(t, ok) + require.Equal(t, notarizedBlock.VerifiedBlock, round.VerifiedBlock) + require.Equal(t, notarizedBlock.Notarization, round.Notarization) + require.Equal(t, notarizedBlock.EmptyNotarization, round.EmptyNotarization) + } +} + +func TestReplicationRequestTailingEmptyNotarizations(t *testing.T) { + bb := &testutil.TestBlockBuilder{Out: make(chan *testutil.TestBlock, 1)} + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + require.NoError(t, e.Start()) + + numBlocks := uint64(8) + rounds := make(map[uint64]simplex.VerifiedQuorumRound) + // only produce a notarization for blocks we are the leader, otherwise produce an empty notarization + for i := range numBlocks { + emptyNotarization := testutil.NewEmptyNotarization(nodes, uint64(i)) + e.HandleMessage(&simplex.Message{ + EmptyNotarization: emptyNotarization, + }, nodes[1]) + wal.AssertNotarization(uint64(i)) + rounds[i] = simplex.VerifiedQuorumRound{ + EmptyNotarization: emptyNotarization, + } + } + + require.Equal(t, uint64(numBlocks), e.Metadata().Round) + roundsRequested := make([]uint64, 0, len(rounds)) + for k := range rounds { + roundsRequested = append(roundsRequested, k) } req := &simplex.Message{ ReplicationRequest: &simplex.ReplicationRequest{ - Seqs: seqs, - LatestRound: 0, + Rounds: roundsRequested, + LatestRound: 1, }, } @@ -177,6 +260,7 @@ func TestReplicationRequestMixed(t *testing.T) { resp := msg.VerifiedReplicationResponse require.Equal(t, *resp.LatestRound, rounds[numBlocks-1]) + require.Equal(t, len(roundsRequested), len(resp.Data)) for _, round := range resp.Data { notarizedBlock, ok := rounds[round.GetRound()] require.True(t, ok) @@ -186,6 +270,31 @@ func TestReplicationRequestMixed(t *testing.T) { } } +func TestReplicationRequestUnknownSeqsAndRounds(t *testing.T) { + bb := &testutil.TestBlockBuilder{Out: make(chan *testutil.TestBlock, 1)} + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} + comm := NewListenerComm(nodes) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) + conf.ReplicationEnabled = true + + e, err := simplex.NewEpoch(conf) + require.NoError(t, err) + require.NoError(t, e.Start()) + + req := &simplex.Message{ + ReplicationRequest: &simplex.ReplicationRequest{ + Rounds: []uint64{100, 101, 102}, + Seqs: []uint64{200, 201, 202}, + LatestRound: 1, + }, + } + + err = e.HandleMessage(req, nodes[1]) + require.NoError(t, err) + + require.Never(t, func() bool { return len(comm.in) > 0 }, 5*time.Second, 100*time.Millisecond) +} + func TestNilReplicationResponse(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} net := testutil.NewInMemNetwork(t, nodes) @@ -202,7 +311,7 @@ func TestNilReplicationResponse(t *testing.T) { } // TestMalformedReplicationResponse tests that a malformed replication response is handled correctly. -// This replication response is malformeds since it must also include a notarization or +// This replication response is malformed since it must also include a notarization or // finalization. func TestMalformedReplicationResponse(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} diff --git a/replication_state.go b/replication_state.go new file mode 100644 index 00000000..0a9ce107 --- /dev/null +++ b/replication_state.go @@ -0,0 +1,326 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package simplex + +import ( + "crypto/rand" + "math/big" + "sync" + "time" + + "go.uber.org/zap" +) + +type finalizedQuorumRound struct { + block Block + finalization *Finalization +} + +type ReplicationState struct { + enabled bool + logger Logger + myNodeID NodeID + + // seqs maps sequences to a block and its associated finalization + seqs map[uint64]*finalizedQuorumRound + + // rounds maps round numbers to QuorumRounds + rounds map[uint64]*QuorumRound + + // digestTimeouts handles timeouts for fetching missing block digests. + // When a notarization depends on a block we haven’t received yet, + // it means a prior notarization for that block exists but is missing. + // Since we may not know which round that dependency belongs to, + // digestTimeouts ensures we re-request the missing digest until it arrives. + digestTimeouts *TimeoutHandler[Digest] + + // emptyRoundTimeouts handles timeouts for fetching missing empty round notarizations. + // When replication encounters a notarized block that depends on an empty round we haven't received, + // emptyRoundTimeouts ensures we re-request those empty rounds until they are received. + emptyRoundTimeouts *TimeoutHandler[uint64] + + roundRequestor *requestor + finalizationRequestor *requestor +} + +func NewReplicationState(logger Logger, comm Communication, myNodeID NodeID, maxRoundWindow uint64, enabled bool, start time.Time, lock *sync.Mutex) *ReplicationState { + if !enabled { + return &ReplicationState{ + enabled: enabled, + logger: logger, + } + } + + r := &ReplicationState{ + enabled: enabled, + myNodeID: myNodeID, + logger: logger, + + // seq replication + seqs: make(map[uint64]*finalizedQuorumRound), + finalizationRequestor: newRequestor(logger, start, lock, maxRoundWindow, comm, true), + + // round replication + rounds: make(map[uint64]*QuorumRound), + roundRequestor: newRequestor(logger, start, lock, maxRoundWindow, comm, false), + } + + r.digestTimeouts = NewTimeoutHandler(logger, "digest", start, DefaultReplicationRequestTimeout, r.requestDigests) + r.emptyRoundTimeouts = NewTimeoutHandler(logger, "empty", start, DefaultReplicationRequestTimeout, r.requestEmptyRounds) + + return r +} + +func (r *ReplicationState) AdvanceTime(now time.Time) { + if !r.enabled { + return + } + + r.finalizationRequestor.advanceTime(now) + r.roundRequestor.advanceTime(now) + r.digestTimeouts.Tick(now) + r.emptyRoundTimeouts.Tick(now) +} + +// deleteOldRounds cleans up the replication state for round replication after receiving a finalized round. +func (r *ReplicationState) deleteOldRounds(finalizedRound uint64) { + for round := range r.rounds { + if round <= finalizedRound { + r.logger.Debug("Replication State Deleting Old Round", zap.Uint64("round", round)) + delete(r.rounds, round) + } + } + + r.roundRequestor.removeOldTasks(finalizedRound) + r.emptyRoundTimeouts.RemoveOldTasks(func(r uint64, _ struct{}) bool { + return r <= finalizedRound + }) +} + +// storeSequence stores a block and finalization into the replication state +func (r *ReplicationState) storeSequence(block Block, finalization *Finalization) { + if _, exists := r.seqs[finalization.Finalization.Seq]; exists { + return + } + + r.seqs[finalization.Finalization.Seq] = &finalizedQuorumRound{ + block: block, + finalization: finalization, + } + + r.finalizationRequestor.removeTask(finalization.Finalization.Seq) + r.digestTimeouts.RemoveTask(block.BlockHeader().Digest) +} + +// storeRound adds or updates a quorum round in the replication state. +// If the round already exists, it merges any missing notarizations or empty notarizations +// from the provided quorum round. Otherwise, it stores the new round as is. +func (r *ReplicationState) storeRound(qr *QuorumRound) { + if qr.Block != nil { + r.digestTimeouts.RemoveTask(qr.Block.BlockHeader().Digest) + } + + existing, exists := r.rounds[qr.GetRound()] + if !exists { + r.rounds[qr.GetRound()] = qr + return + } + + if qr.EmptyNotarization != nil && existing.EmptyNotarization == nil { + existing.EmptyNotarization = qr.EmptyNotarization + } + + if (qr.Notarization != nil && qr.Block != nil) && existing.Block == nil { + existing.Notarization = qr.Notarization + existing.Block = qr.Block + } +} + +// StoreQuorumRound stores the quorum round into the replication state. +func (r *ReplicationState) StoreQuorumRound(round *QuorumRound) { + if !r.enabled { + return + } + + r.logger.Debug("Replication State Storing Quorum Round", zap.Stringer("QR", round)) + + if round.Finalization != nil { + r.storeSequence(round.Block, round.Finalization) + r.finalizationRequestor.receivedSignedQuorum(newSignedQuorum(round, r.myNodeID)) + r.deleteOldRounds(round.Finalization.Finalization.Round) + return + } + + // otherwise we are storing a round without finalization + // don't bother storing rounds that are older than the highest finalized round we know + if observed := r.finalizationRequestor.getHighestObserved(); observed != nil && observed.round >= round.GetRound() { + r.logger.Debug("Replication State received a finalized quorum round for a round we know is finalized.") + return + } + + r.storeRound(round) + r.roundRequestor.receivedSignedQuorum(newSignedQuorum(round, r.myNodeID)) + if round.EmptyNotarization != nil { + r.emptyRoundTimeouts.RemoveTask(round.GetRound()) + } +} + +// receivedFutureFinalization notifies the replication state a finalization was created in a future round. +func (r *ReplicationState) ReceivedFutureFinalization(finalization *Finalization, nextSeqToCommit uint64) { + if !r.enabled { + return + } + + signedSequence := newSignedQuorumFromFinalization(finalization, r.myNodeID) + + // maybe this finalization was for a round that we initially thought only had notarizations + // remove from the round replicator since we now have a finalization for this round + r.deleteOldRounds(finalization.Finalization.BlockHeader.Round) + + // potentially send out requests for blocks/finalizations in between + r.finalizationRequestor.observedSignedQuorum(signedSequence, nextSeqToCommit) +} + +// receivedFutureRound notifies the replication state of a future round. +func (r *ReplicationState) ReceivedFutureRound(round, seq, currentRound uint64, signers []NodeID) { + if !r.enabled { + return + } + + if observed := r.finalizationRequestor.getHighestObserved(); observed != nil && observed.round >= round { + r.logger.Debug("Ignoring round replication for a future round since we have a finalization for a higher round", zap.Uint64("round", round)) + return + } + + sq := newSignedQuorumFromRound(round, seq, signers, r.myNodeID) + r.roundRequestor.observedSignedQuorum(sq, currentRound) +} + +// ResendFinalizationRequest notifies the replication state that `seq` should be re-requested. +func (r *ReplicationState) ResendFinalizationRequest(seq uint64, signers []NodeID) error { + if !r.enabled { + return nil + } + + signers = NodeIDs(signers).Remove(r.myNodeID) + numSigners := int64(len(signers)) + index, err := rand.Int(rand.Reader, big.NewInt(numSigners)) + if err != nil { + return err + } + + // because we are resending because the block failed to verify, we should remove the stored quorum round + // so that we can try to get a new block & finalization + delete(r.seqs, seq) + r.finalizationRequestor.sendRequestToNode(seq, seq, signers[index.Int64()]) + return nil +} + +// CreateDependencyTasks creates tasks to refetch the given parent digest and empty rounds. If there are no +// dependencies, no tasks are created. +// TODO: in a future PR, these requests will be sent as specific digest requests. +func (r *ReplicationState) CreateDependencyTasks(parent *Digest, parentSeq uint64, emptyRounds []uint64) { + if parent != nil { + r.digestTimeouts.AddTask(*parent) + } + + if len(emptyRounds) > 0 { + for _, round := range emptyRounds { + r.emptyRoundTimeouts.AddTask(round) + } + } +} + +func (r *ReplicationState) clearDependencyTasks(parent *Digest) { + // TODO: for a future PR +} + +// maybeSendFutureRequests attempts to collect future sequences if +// there are more to be collected and the round has caught up for us to send the request. +func (r *ReplicationState) MaybeAdvancedState(nextSequenceToCommit uint64, currentRound uint64) { + if !r.enabled { + return + } + + if nextSequenceToCommit > 0 { + r.deleteOldRounds(nextSequenceToCommit - 1) + } + + // update the requestors in case they need to send more requests + r.finalizationRequestor.updateState(nextSequenceToCommit) + r.roundRequestor.updateState(currentRound) +} + +func (r *ReplicationState) GetFinalizedBlockForSequence(seq uint64) (Block, *Finalization, bool) { + blockWithFinalization, exists := r.seqs[seq] + if !exists { + return nil, nil, false + } + + return blockWithFinalization.block, blockWithFinalization.finalization, true +} + +func (r *ReplicationState) GetLowestRound() *QuorumRound { + var lowestRound *QuorumRound + + for round, qr := range r.rounds { + if lowestRound == nil { + lowestRound = qr + continue + } + + if lowestRound.GetRound() > round { + lowestRound = qr + } + } + + return lowestRound +} + +func (r *ReplicationState) GetBlockWithSeq(seq uint64) Block { + block, _, _ := r.GetFinalizedBlockForSequence(seq) + if block != nil { + return block + } + + // check rounds replicator. + // note: this is not deterministic since we can have multiple blocks notarized with the same seq + // its fine to return since the caller can still optimistically advance the round + for _, round := range r.rounds { + if round.GetSequence() == seq && round.Block != nil { + return round.Block + } + } + + return nil +} + +func (r *ReplicationState) requestDigests(digests []Digest) { + // TODO: In a future PR, I will add a message that requests a specific digest. + r.logger.Debug("Not implemented yet", zap.Stringers("Digests", digests)) +} + +func (r *ReplicationState) requestEmptyRounds(emptyRounds []uint64) { + r.logger.Debug("Replication State requesting empty rounds", zap.Uint64s("empty rounds", emptyRounds)) + r.roundRequestor.resendReplicationRequests(emptyRounds) +} + +func (r *ReplicationState) DeleteRound(round uint64) { + if !r.enabled { + return + } + + r.logger.Debug("Replication State Removing Round", zap.Uint64("round", round)) + r.emptyRoundTimeouts.RemoveTask(round) + r.roundRequestor.removeTask(round) + delete(r.rounds, round) +} + +func (r *ReplicationState) DeleteSeq(seq uint64) { + if !r.enabled { + return + } + + delete(r.seqs, seq) +} diff --git a/replication_test.go b/replication_test.go index c8f3c72e..d1ed4168 100644 --- a/replication_test.go +++ b/replication_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/record" . "github.com/ava-labs/simplex/testutil" "go.uber.org/zap/zapcore" @@ -70,7 +71,6 @@ func testReplication(t *testing.T, startSeq uint64, nodes []simplex.NodeID) { for _, n := range net.Instances { n.Storage.WaitForBlockCommit(startSeq) } - assertEqualLedgers(t, net) } // TestReplicationAdversarialNode tests the replication process of a node that @@ -247,6 +247,7 @@ func TestRebroadcastingWithReplication(t *testing.T) { // TestReplicationEmptyNotarizations ensures a lagging node will properly replicate // many empty notarizations in a row. +// This test sometimes takes > 30 sec func TestReplicationEmptyNotarizations(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}, {5}, {6}} @@ -322,6 +323,17 @@ func testReplicationEmptyNotarizations(t *testing.T, nodes []simplex.NodeID, end net.Connect(laggingNode.E.ID) net.TriggerLeaderBlockBuilder(endRound) for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + // maybe lagging node has requested finalizations to a node without it, we may need to resend the request + for { + if n.Storage.NumBlocks() == 2 { + break + } + time.Sleep(10 * time.Millisecond) + n.AdvanceTime(2 * simplex.DefaultMaxProposalWaitTime) + } + continue + } n.Storage.WaitForBlockCommit(1) } @@ -574,15 +586,11 @@ func sendVotesToOneNode(filteredInNode simplex.NodeID) MessageFilter { func TestReplicationStuckInProposingBlock(t *testing.T) { var aboutToBuildBlock sync.WaitGroup - aboutToBuildBlock.Add(2) - - var cancelBlockBuilding sync.WaitGroup - cancelBlockBuilding.Add(1) + aboutToBuildBlock.Add(1) tbb := &TestBlockBuilder{Out: make(chan *TestBlock, 1), BlockShouldBeBuilt: make(chan struct{}, 1), In: make(chan *TestBlock, 1)} bb := NewTestControlledBlockBuilder(t) bb.TestBlockBuilder = *tbb - storage := NewInMemStorage() nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} blocks := createBlocks(t, nodes, 5) @@ -600,9 +608,6 @@ func TestReplicationStuckInProposingBlock(t *testing.T) { if strings.Contains(entry.Message, "Scheduling block building") { aboutToBuildBlock.Done() } - if strings.Contains(entry.Message, "We are the leader of this round, but a higher round has been finalized. Aborting block building.") { - cancelBlockBuilding.Done() - } return nil }) @@ -696,9 +701,6 @@ func TestReplicationStuckInProposingBlock(t *testing.T) { }, nodes[1]) storage.WaitForBlockCommit(4) - - // Just for sanity, ensure that the block building was cancelled - cancelBlockBuilding.Wait() } // TestReplicationNodeDiverges tests that a node replicates blocks even if they @@ -1085,3 +1087,331 @@ func TestReplicationVerifyEmptyNotarization(t *testing.T) { return wal.ContainsEmptyNotarization(0) }, time.Millisecond*500, time.Millisecond*10, "Did not expect an empty notarization with a corrupt QC to be written to the WAL") } + +// almostFinalizeBlocks is a message filter that allows all messages except for finalized votes +// and finalizations, unless the message is from node 1. This way each node will have 2 finalized votes, +// which is one short from quorum. +func almostFinalizeBlocks(msg *simplex.Message, from, _ simplex.NodeID) bool { + // block finalized votes and finalizations + if msg.Finalization != nil || msg.FinalizeVote != nil { + return from.Equals(simplex.NodeID{1}) + } + return true +} + +// TestReplicationVotesForNotarizations tests that a lagging node will replicate +// finalizations and notarizations. It ensures the node sends finalized votes for rounds +// without finalizations. +func TestReplicationVotesForNotarizations(t *testing.T) { + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} + + numFinalizedBlocks := uint64(5) + // number of notarized blocks after the finalized blocks + numNotarizedBlocks := uint64(11) + net := NewInMemNetwork(t, nodes) + + storageData := createBlocks(t, nodes, numFinalizedBlocks) + nodeConfig := func(from simplex.NodeID) *TestNodeConfig { + comm := NewTestComm(from, net, almostFinalizeBlocks) + return &TestNodeConfig{ + InitialStorage: storageData, + Comm: comm, + ReplicationEnabled: true, + } + } + + n1 := NewSimplexNode(t, nodes[0], net, nodeConfig(nodes[0])) + n2 := NewSimplexNode(t, nodes[1], net, nodeConfig(nodes[1])) + adversary := NewSimplexNode(t, nodes[2], net, nodeConfig(nodes[2])) + laggingNode := NewSimplexNode(t, nodes[3], net, &TestNodeConfig{ + ReplicationEnabled: true, + }) + + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + continue + } + require.Equal(t, numFinalizedBlocks, n.Storage.NumBlocks()) + } + + // lagging node should be disconnected while nodes create notarizations without finalizations + net.Disconnect(laggingNode.E.ID) + + net.StartInstances() + + missedSeqs := uint64(0) + // normal nodes continue to make progress + for round := numFinalizedBlocks; round < numFinalizedBlocks+numNotarizedBlocks; round++ { + emptyRound := bytes.Equal(simplex.LeaderForRound(nodes, round), laggingNode.E.ID) + if emptyRound { + missedSeqs++ + net.AdvanceWithoutLeader(round, laggingNode.E.ID) + } else { + net.TriggerLeaderBlockBuilder(round) + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + continue + } + n.WAL.AssertNotarization(round) + } + } + } + + // all nodes should be on round [numFinalizedBlocks + numNotarizedBlocks - 1] + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + require.Equal(t, uint64(0), n.E.Metadata().Round) + continue + } + require.Equal(t, numFinalizedBlocks, n.Storage.NumBlocks()) + require.Equal(t, numFinalizedBlocks+numNotarizedBlocks, n.E.Metadata().Round) + } + + // at this point in time, the adversarial node will disconnect + // since each node has sent 2 finalized votes, which is one short of a quorum + // the lagging node will need to replicate the finalizations, and then send votes for notarizations + net.Disconnect(adversary.E.ID) + net.Connect(laggingNode.E.ID) + net.SetAllNodesMessageFilter(AllowAllMessages) + + // the adversary should not be the leader(to simplify test) + isAdversaryLeader := bytes.Equal(simplex.LeaderForRound(nodes, numFinalizedBlocks+numNotarizedBlocks), adversary.E.ID) + require.False(t, isAdversaryLeader) + + // lagging node should not be leader + isLaggingNodeLeader := bytes.Equal(simplex.LeaderForRound(nodes, numFinalizedBlocks+numNotarizedBlocks), laggingNode.E.ID) + require.False(t, isLaggingNodeLeader) + + // trigger block building, but we only have 2 connected nodes so the nodes will time out + net.TriggerLeaderBlockBuilder(numFinalizedBlocks + numNotarizedBlocks) + + // ensure time out on required nodes + n1.TimeoutOnRound(numFinalizedBlocks + numNotarizedBlocks) + n2.TimeoutOnRound(numFinalizedBlocks + numNotarizedBlocks) + require.Equal(t, uint64(0), laggingNode.E.Metadata().Round) + + // when the lagging node times out, it will broadcast an empty vote. The other online nodes will reply with their latest tip which should kickstart replication + laggingNode.TimeoutOnRound(0) + + expectedNumBlocks := numFinalizedBlocks + numNotarizedBlocks - missedSeqs + // because the adversarial node is offline , we may need to send replication requests many times + for { + time.Sleep(time.Millisecond * 100) + if laggingNode.Storage.NumBlocks() == expectedNumBlocks { + break + } + + laggingNode.AdvanceTime(simplex.DefaultReplicationRequestTimeout) + } + + for _, n := range net.Instances { + if n.E.ID.Equals(adversary.E.ID) { + continue + } + n.Storage.WaitForBlockCommit(expectedNumBlocks - 1) // subtract -1 because seq starts at 0 + } + + laggingNode.TimeoutOnRound(numFinalizedBlocks + numNotarizedBlocks) + for _, n := range net.Instances { + if n.E.ID.Equals(adversary.E.ID) { + continue + } + WaitToEnterRound(t, n.E, numFinalizedBlocks+numNotarizedBlocks+1) + require.True(t, n.WAL.ContainsEmptyNotarization(numFinalizedBlocks+numNotarizedBlocks)) + } +} + +// TestReplicationEmptyNotarizations ensures a lagging node will properly replicate +// a tail of empty notarizations. +func TestReplicationEmptyNotarizationsTail(t *testing.T) { + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}, {5}, {6}} + + for endRound := uint64(2); endRound <= 2*simplex.DefaultMaxRoundWindow; endRound++ { + isLaggingNodeLeader := bytes.Equal(simplex.LeaderForRound(nodes, endRound), nodes[5]) + if isLaggingNodeLeader { + continue + } + + testName := fmt.Sprintf("Empty_notarizations_end_round%d", endRound) + t.Run(testName, func(t *testing.T) { + t.Parallel() + testReplicationEmptyNotarizationsTail(t, nodes, endRound) + }) + } +} + +func testReplicationEmptyNotarizationsTail(t *testing.T, nodes []simplex.NodeID, endRound uint64) { + net := NewInMemNetwork(t, nodes) + newNodeConfig := func(from simplex.NodeID) *TestNodeConfig { + comm := NewTestComm(from, net, AllowAllMessages) + return &TestNodeConfig{ + Comm: comm, + ReplicationEnabled: true, + } + } + + NewSimplexNode(t, nodes[0], net, newNodeConfig(nodes[0])) + NewSimplexNode(t, nodes[1], net, newNodeConfig(nodes[1])) + NewSimplexNode(t, nodes[2], net, newNodeConfig(nodes[2])) + NewSimplexNode(t, nodes[3], net, newNodeConfig(nodes[3])) + NewSimplexNode(t, nodes[4], net, newNodeConfig(nodes[4])) + laggingNode := NewSimplexNode(t, nodes[5], net, newNodeConfig(nodes[5])) + + net.StartInstances() + + net.Disconnect(laggingNode.E.ID) + net.SetAllNodesMessageFilter(onlyAllowEmptyRoundMessages) + + // normal nodes continue to make progress + for i := uint64(0); i < endRound; i++ { + leader := simplex.LeaderForRound(nodes, i) + if !leader.Equals(laggingNode.E.ID) { + net.TriggerLeaderBlockBuilder(i) + } + + net.AdvanceWithoutLeader(i, laggingNode.E.ID) + } + + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + require.Equal(t, uint64(0), n.E.Metadata().Round) + continue + } + + // assert metadata + require.Equal(t, uint64(endRound), n.E.Metadata().Round) + require.Equal(t, uint64(0), n.E.Metadata().Seq) + require.Equal(t, uint64(0), n.E.Storage.NumBlocks()) + } + + net.Connect(laggingNode.E.ID) + net.SetAllNodesMessageFilter(AllowAllMessages) + + // have the lagging node timeout to trigger replication + laggingNode.E.AdvanceTime(time.Now().Add(laggingNode.E.MaxProposalWait)) + + for _, n := range net.Instances { + WaitToEnterRound(t, n.E, endRound) + require.Equal(t, uint64(endRound), n.E.Metadata().Round) + } +} + +func sendEmptyNotarizationQuorumRounds(emptyNotes map[uint64]*simplex.EmptyNotarization) MessageFilter { + return func(msg *simplex.Message, from, to simplex.NodeID) bool { + if msg.VerifiedReplicationResponse != nil { + newData := make([]simplex.VerifiedQuorumRound, 0, len(msg.VerifiedReplicationResponse.Data)) + for _, qr := range msg.VerifiedReplicationResponse.Data { + newQR := simplex.VerifiedQuorumRound{ + EmptyNotarization: qr.EmptyNotarization, + } + if qr.EmptyNotarization == nil { + newQR.EmptyNotarization = emptyNotes[qr.GetRound()] + } + newData = append(newData, newQR) + } + msg.VerifiedReplicationResponse.Data = newData + } + + return allowFinalizeVotes(msg, from, to) + } +} + +func allowFinalizeVotes(msg *simplex.Message, from, to simplex.NodeID) bool { + if msg.Finalization != nil || msg.FinalizeVote != nil { + if to.Equals(simplex.NodeID{3}) || from.Equals(simplex.NodeID{3}) { + return false + } + } + return true +} + +// TestReplicationChain tests that a node can both empty notarizations and notarizations for the same round. +func TestReplicationChain(t *testing.T) { + // Digest message requests are needed for this test + t.Skip() + nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} + net := NewInMemNetwork(t, nodes) + + newNodeConfig := func(from simplex.NodeID) *TestNodeConfig { + comm := NewTestComm(from, net, allowFinalizeVotes) + return &TestNodeConfig{ + Comm: comm, + ReplicationEnabled: true, + } + } + + // full nodes operate normally + fullNode1 := NewSimplexNode(t, nodes[0], net, newNodeConfig(nodes[0])) + fullNode2 := NewSimplexNode(t, nodes[1], net, newNodeConfig(nodes[1])) + + fullNode1.Silence() + fullNode2.Silence() + // node 3 will not receive finalize votes & finalizations + blockFinalize3 := NewSimplexNode(t, nodes[2], net, newNodeConfig(nodes[2])) + blockFinalize3.Silence() + // lagging node is disconnected initially. It initially receives only empty notarizations + // but then later receives notarizations and must send finalize votes for them + laggingNode := NewSimplexNode(t, nodes[3], net, newNodeConfig(nodes[3])) + + net.StartInstances() + net.Disconnect(laggingNode.E.ID) + + emptyNotarizations := make(map[uint64]*simplex.EmptyNotarization) + numNotarizations := uint64(8) + missedNotarizations := uint64(0) + for i := range numNotarizations { + // every round has an empty notarization(possible due to timeouts) + emptyNotarization := NewEmptyNotarization(nodes, i) + emptyNotarizations[i] = emptyNotarization + + leader := simplex.LeaderForRound(nodes, i) + if !leader.Equals(laggingNode.E.ID) { + net.TriggerLeaderBlockBuilder(i) + continue + } + + net.AdvanceWithoutLeader(i, laggingNode.E.ID) + missedNotarizations++ + } + + for _, n := range net.Instances { + if n.E.ID.Equals(laggingNode.E.ID) { + require.Equal(t, uint64(0), n.Storage.NumBlocks()) + require.Equal(t, uint64(0), n.E.Metadata().Round) + continue + } + + n.WAL.AssertNotarization(numNotarizations - 1) + // assert metadata + require.Equal(t, numNotarizations, n.E.Metadata().Round) + require.Equal(t, numNotarizations-missedNotarizations, n.E.Metadata().Seq) + require.Equal(t, uint64(0), n.E.Storage.NumBlocks()) + } + + // lagging node should not be the leader after reconnect + leader := simplex.LeaderForRound(nodes, numNotarizations) + require.NotEqual(t, laggingNode.E.ID, leader) + + net.SetAllNodesMessageFilter(sendEmptyNotarizationQuorumRounds(emptyNotarizations)) + net.Connect(laggingNode.E.ID) + net.TriggerLeaderBlockBuilder(numNotarizations) + + // now the lagging node should catch up on empty notarizations + empty := laggingNode.WAL.AssertNotarization(numNotarizations - 1) + require.Equal(t, empty, record.EmptyNotarizationRecordType) + + // seq should be 0, since filter gave empty notarizations + require.Equal(t, numNotarizations, laggingNode.E.Metadata().Round) + require.Equal(t, uint64(0), laggingNode.E.Metadata().Seq) + + // nodes send a block but can't notarize so they time out + net.SetAllNodesMessageFilter(allowFinalizeVotes) + + for _, n := range net.Instances { + n.TickUntilRoundAdvanced(numNotarizations+1, simplex.DefaultReplicationRequestTimeout) + require.Equal(t, numNotarizations+1-missedNotarizations, n.Storage.NumBlocks()) + } +} diff --git a/replication_timeout_test.go b/replication_timeout_test.go index 1acebf50..8646da5f 100644 --- a/replication_timeout_test.go +++ b/replication_timeout_test.go @@ -65,7 +65,8 @@ func TestReplicationRequestTimeout(t *testing.T) { // assert the lagging node has not received any replication requests require.Equal(t, uint64(0), laggingNode.Storage.NumBlocks()) - + // allow the replication state to cancel the request before setting filter + time.Sleep(100 * time.Millisecond) // after the timeout, the nodes should respond and the lagging node will replicate net.SetAllNodesMessageFilter(testutil.AllowAllMessages) laggingNode.E.AdvanceTime(laggingNode.E.StartTime.Add(simplex.DefaultReplicationRequestTimeout / 2)) @@ -122,16 +123,17 @@ func TestReplicationRequestTimeoutCancels(t *testing.T) { // all blocks except the lagging node start at round 8, seq 8. // lagging node starts at round 0, seq 0. // this asserts that the lagging node catches up to the latest round - for i := 0; i <= int(startSeq); i++ { - for _, n := range net.Instances { - n.Storage.WaitForBlockCommit(uint64(startSeq)) - } + for _, n := range net.Instances { + n.Storage.WaitForBlockCommit(startSeq) } // ensure lagging node doesn't resend requests mf := &testTimeoutMessageFilter{ t: t, } + + // allow the replication state to cancel the request before setting filter + time.Sleep(100 * time.Millisecond) laggingNode.E.Comm.(*testutil.TestComm).SetFilter(mf.failOnReplicationRequest) laggingNode.E.AdvanceTime(laggingNode.E.StartTime.Add(simplex.DefaultReplicationRequestTimeout * 2)) @@ -368,7 +370,10 @@ func (c *collectNotarizationComm) removeFinalizationsFromReplicationResponses(ms newData = append(newData, qr) } msg.VerifiedReplicationResponse.Data = newData - c.replicationResponses <- struct{}{} + select { + case c.replicationResponses <- struct{}{}: + default: + } } if msg.Finalization != nil && msg.Finalization.Finalization.Round == 0 { diff --git a/requestor.go b/requestor.go new file mode 100644 index 00000000..6a60829b --- /dev/null +++ b/requestor.go @@ -0,0 +1,262 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package simplex + +import ( + "math" + "sync" + "time" + + "go.uber.org/zap" +) + +// signedQuorum is a round that has been signed by a quorum certificate. +// if the round was empty notarized, seq is set to 0. +type signedQuorum struct { + round uint64 + seq uint64 + signers NodeIDs +} + +func newSignedQuorum(qr *QuorumRound, myNodeID NodeID) *signedQuorum { + // it's possible our node has signed this quorum. + // For example this may happen if our node has sent a finalized vote + // for this round and has not received the + // finalization from the network. + switch { + case qr.EmptyNotarization != nil: + return &signedQuorum{ + signers: NodeIDs(qr.EmptyNotarization.QC.Signers()).Remove(myNodeID), + round: qr.EmptyNotarization.Vote.Round, + } + case qr.Finalization != nil: + return &signedQuorum{ + signers: NodeIDs(qr.Finalization.QC.Signers()).Remove(myNodeID), + round: qr.Finalization.Finalization.Round, + seq: qr.Finalization.Finalization.Seq, + } + case qr.Notarization != nil: + return &signedQuorum{ + signers: NodeIDs(qr.Notarization.QC.Signers()).Remove(myNodeID), + round: qr.Notarization.Vote.Round, + seq: qr.Notarization.Vote.Seq, + } + default: + return nil + } +} + +func newSignedQuorumFromFinalization(finalization *Finalization, nodeID NodeID) *signedQuorum { + return newSignedQuorum(&QuorumRound{ + Finalization: finalization, + }, nodeID) +} + +func newSignedQuorumFromRound(round, seq uint64, signers []NodeID, myNodeID NodeID) *signedQuorum { + return &signedQuorum{ + round: round, + seq: seq, + signers: NodeIDs(signers).Remove(myNodeID), + } +} + +type sender interface { + // Send sends a message to the given destination node + Send(msg *Message, destination NodeID) +} + +// requestor fetches quorum rounds up to [highestObserved] from the network, +// allowing up to [maxRoundWindow] concurrent requests to limit memory use. +// Ensures all rounds/sequences are eventually received. +type requestor struct { + epochLock *sync.Mutex + + // highestSequenceRequested prevents duplicates and limits outstanding requests. + highestRequested uint64 + + // the requestor stops requesting once all sequences/rounds up to an including `highestObserved` have been received. + highestObserved *signedQuorum + + // Handles timeouts and retries for missing sequences/rounds. + timeoutHandler *TimeoutHandler[uint64] + + logger Logger + + // maxRoundWindow is the maximum number of requests we can request past highestRequested. + maxRoundWindow uint64 + + sender sender + + // requestIterator is an iterator over NodeIDs in order to request quorum rounds + requestIterator int + + // replicateSeqs is set true if this requestor is for replicating sequences, and false if for rounds. + replicateSeqs bool +} + +func newRequestor(logger Logger, start time.Time, lock *sync.Mutex, maxRoundWindow uint64, sender sender, replicateSeqs bool) *requestor { + r := &requestor{ + logger: logger, + epochLock: lock, + maxRoundWindow: maxRoundWindow, + sender: sender, + replicateSeqs: replicateSeqs, + } + name := "seq-timeout-handler" + if !replicateSeqs { + name = "round-timeout-handler" + } + r.timeoutHandler = NewTimeoutHandler(logger, name, start, DefaultReplicationRequestTimeout, r.resendReplicationRequests) + return r +} + +func (r *requestor) advanceTime(now time.Time) { + r.timeoutHandler.Tick(now) +} + +func (r *requestor) resendReplicationRequests(missingIds []uint64) { + // we call this function in the timeout handler goroutine, so we need to + // ensure we don't have concurrent access to highestObserved + r.epochLock.Lock() + defer r.epochLock.Unlock() + + segments := CompressSequences(missingIds) + + r.sendSegments(segments) + + r.requestIterator++ +} + +// observedSignedQuorum is called when we observe a signed quorum for a future round/sequence. +// we do not mix sequences and rounds because we have separate instances of requestor for each. +func (r *requestor) observedSignedQuorum(observed *signedQuorum, currentSeqOrRound uint64) { + observedSeqOrRound := r.getSeqOrRound(observed) + + // we've observed something we've already requested + if r.highestRequested >= observedSeqOrRound && r.highestObserved != nil { + r.logger.Debug("Already requested observed value, skipping", zap.Uint64("value", observedSeqOrRound), zap.Bool("Seq Replication", r.replicateSeqs)) + return + } + + // if this is the highest observed sequence, update our state + if r.highestObserved == nil || observedSeqOrRound > r.highestObserved.seq { + r.highestObserved = observed + } + + r.sendMoreReplicationRequests(observedSeqOrRound, currentSeqOrRound) +} + +// maybeSendMoreReplicationRequests checks if we need to send more replication requests given an observed quorum. +// it limits the amount of outstanding requests to be at most [maxRoundWindow] ahead of [currentSeqOrRound]. +func (r *requestor) sendMoreReplicationRequests(observedSeqOrRound, currentSeqOrRound uint64) { + start := math.Max(float64(currentSeqOrRound), float64(r.highestRequested)) + // we limit the number of outstanding requests to be at most maxRoundWindow ahead of nextSeqToCommit + end := math.Min(float64(observedSeqOrRound), float64(r.maxRoundWindow+currentSeqOrRound)) + + r.logger.Debug("Node is behind, attempting to request missing values", zap.Uint64("value", observedSeqOrRound), zap.Uint64("start", uint64(start)), zap.Uint64("end", uint64(end)), zap.Bool("seq requestor", r.replicateSeqs)) + r.sendReplicationRequests(uint64(start), uint64(end)) +} + +// sendReplicationRequests sends requests for missing sequences for the +// range of sequences [start, end] <- inclusive. It does so by splitting the +// range of sequences equally amount the nodes that have signed [highestObserved]. +func (r *requestor) sendReplicationRequests(start uint64, end uint64) { + nodes := r.highestObserved.signers + numNodes := len(nodes) + + seqRequests := DistributeSequenceRequests(start, end, numNodes) + r.logger.Debug("Distributing replication requests", zap.Uint64("start", start), zap.Uint64("end", end), zap.Stringer("nodes", NodeIDs(nodes))) + + r.sendSegments(seqRequests) + + // next time we send requests, we start with a different permutation + r.requestIterator++ +} + +func (r *requestor) sendSegments(segments []Segment) { + numNodes := len(r.highestObserved.signers) + for i, seqsOrRounds := range segments { + index := (i + r.requestIterator) % numNodes + r.sendRequestToNode(seqsOrRounds.Start, seqsOrRounds.End, r.highestObserved.signers[index]) + } +} + +// sendRequestToNode requests [start, end] from nodes[index]. +// In case the nodes[index] does not respond, we create a timeout that will +// re-send the request. +func (r *requestor) sendRequestToNode(start uint64, end uint64, node NodeID) { + seqsOrRound := make([]uint64, (end+1)-start) + for i := start; i <= end; i++ { + seqsOrRound[i-start] = i + // ensure we set a timeout for this sequence + r.timeoutHandler.AddTask(i) + } + + if r.highestRequested < end { + r.highestRequested = end + } + + request := &ReplicationRequest{} + if r.replicateSeqs { + request.LatestFinalizedSeq = r.highestObserved.seq + request.Seqs = seqsOrRound + } else { + request.LatestRound = r.highestObserved.round + request.Rounds = seqsOrRound + } + + msg := &Message{ReplicationRequest: request} + + r.logger.Debug("Requesting missing rounds/sequences ", + zap.Stringer("from", node), + zap.Uint64("start", start), + zap.Uint64("end", end), + zap.Uint64("latestSeq", request.LatestFinalizedSeq), + zap.Uint64("latestRound", request.LatestRound), + ) + r.sender.Send(msg, node) +} + +func (r *requestor) receivedSignedQuorum(signedQuorum *signedQuorum) { + seqOrRound := r.getSeqOrRound(signedQuorum) + + // check if this is the highest round or seq we have seen + if r.highestObserved == nil || seqOrRound > r.getSeqOrRound(r.highestObserved) { + r.highestObserved = signedQuorum + } + + // we received this sequence, remove the timeout task + r.timeoutHandler.RemoveTask(seqOrRound) + r.logger.Debug("Received future quorum round", zap.Uint64("seq or round", seqOrRound), zap.Bool("is finalization", r.replicateSeqs)) +} + +func (r *requestor) updateState(currentRoundOrNextSeq uint64) { + // we send out more requests once our seq has caught up to 1/2 of the maxRoundWindow + if currentRoundOrNextSeq+r.maxRoundWindow/2 > r.highestRequested && r.highestObserved != nil { + r.observedSignedQuorum(r.highestObserved, currentRoundOrNextSeq) + } +} + +func (r *requestor) getHighestObserved() *signedQuorum { + return r.highestObserved +} + +func (r *requestor) getSeqOrRound(signedQuorum *signedQuorum) uint64 { + if r.replicateSeqs { + return signedQuorum.seq + } + + return signedQuorum.round +} + +// removes all tasks less or equal to the targetSeqOrRound +func (r *requestor) removeOldTasks(targetSeqOrRound uint64) { + r.timeoutHandler.RemoveOldTasks(func(seqOrRound uint64, _ struct{}) bool { + return seqOrRound <= targetSeqOrRound + }) +} + +func (r *requestor) removeTask(seqOrRound uint64) { + r.timeoutHandler.RemoveTask(seqOrRound) +} diff --git a/testutil/comm.go b/testutil/comm.go index da0d38af..f9225c72 100644 --- a/testutil/comm.go +++ b/testutil/comm.go @@ -95,34 +95,22 @@ func (c *TestComm) maybeTranslateOutoingToIncomingMessageTypes(msg *simplex.Mess quorumRound := simplex.QuorumRound{} if verifiedQuorumRound.EmptyNotarization != nil { quorumRound.EmptyNotarization = verifiedQuorumRound.EmptyNotarization - } else { + } + if verifiedQuorumRound.VerifiedBlock != nil { quorumRound.Block = verifiedQuorumRound.VerifiedBlock.(simplex.Block) - if verifiedQuorumRound.Notarization != nil { - quorumRound.Notarization = verifiedQuorumRound.Notarization - } - if verifiedQuorumRound.Finalization != nil { - quorumRound.Finalization = verifiedQuorumRound.Finalization - } + } + if verifiedQuorumRound.Notarization != nil { + quorumRound.Notarization = verifiedQuorumRound.Notarization + } + if verifiedQuorumRound.Finalization != nil { + quorumRound.Finalization = verifiedQuorumRound.Finalization } data = append(data, quorumRound) } - var latestRound *simplex.QuorumRound - if msg.VerifiedReplicationResponse.LatestRound != nil { - if msg.VerifiedReplicationResponse.LatestRound.EmptyNotarization != nil { - latestRound = &simplex.QuorumRound{ - EmptyNotarization: msg.VerifiedReplicationResponse.LatestRound.EmptyNotarization, - } - } else { - latestRound = &simplex.QuorumRound{ - Block: msg.VerifiedReplicationResponse.LatestRound.VerifiedBlock.(simplex.Block), - Notarization: msg.VerifiedReplicationResponse.LatestRound.Notarization, - Finalization: msg.VerifiedReplicationResponse.LatestRound.Finalization, - EmptyNotarization: msg.VerifiedReplicationResponse.LatestRound.EmptyNotarization, - } - } - } + latestRound := verifiedQRtoQR(msg.VerifiedReplicationResponse.LatestRound) + latestSeq := verifiedQRtoQR(msg.VerifiedReplicationResponse.LatestFinalizedSeq) require.Nil( c.net.t, @@ -133,6 +121,7 @@ func (c *TestComm) maybeTranslateOutoingToIncomingMessageTypes(msg *simplex.Mess msg.ReplicationResponse = &simplex.ReplicationResponse{ Data: data, LatestRound: latestRound, + LatestSeq: latestSeq, } } @@ -145,6 +134,24 @@ func (c *TestComm) maybeTranslateOutoingToIncomingMessageTypes(msg *simplex.Mess } } +func verifiedQRtoQR(vqr *simplex.VerifiedQuorumRound) *simplex.QuorumRound { + if vqr == nil { + return nil + } + + qr := &simplex.QuorumRound{ + Notarization: vqr.Notarization, + Finalization: vqr.Finalization, + EmptyNotarization: vqr.EmptyNotarization, + } + + if vqr.VerifiedBlock != nil { + qr.Block = vqr.VerifiedBlock.(simplex.Block) + } + + return qr +} + func (c *TestComm) isMessagePermitted(msg *simplex.Message, destination simplex.NodeID) bool { c.lock.RLock() defer c.lock.RUnlock() diff --git a/testutil/network.go b/testutil/network.go index 5970939a..83dd257e 100644 --- a/testutil/network.go +++ b/testutil/network.go @@ -86,6 +86,19 @@ type TestNetworkCommunication interface { SetFilter(filter MessageFilter) } +func (n *InMemNetwork) SetNodeMessageFilter(node simplex.NodeID, filter MessageFilter) { + for _, instance := range n.Instances { + if !instance.E.ID.Equals(node) { + continue + } + comm, ok := instance.E.Comm.(TestNetworkCommunication) + if !ok { + continue + } + comm.SetFilter(filter) + } +} + func (n *InMemNetwork) SetAllNodesMessageFilter(filter MessageFilter) { for _, instance := range n.Instances { comm, ok := instance.E.Comm.(TestNetworkCommunication) @@ -150,16 +163,3 @@ func (n *InMemNetwork) AdvanceWithoutLeader(round uint64, laggingNodeId simplex. require.Equal(n.t, record.EmptyNotarizationRecordType, recordType) } } - -func (n *InMemNetwork) SetNodeMessageFilter(node simplex.NodeID, filter MessageFilter) { - for _, instance := range n.Instances { - if !instance.E.ID.Equals(node) { - continue - } - comm, ok := instance.E.Comm.(TestNetworkCommunication) - if !ok { - continue - } - comm.SetFilter(filter) - } -} diff --git a/testutil/node.go b/testutil/node.go index 4decc8b8..95706353 100644 --- a/testutil/node.go +++ b/testutil/node.go @@ -122,3 +122,43 @@ func (t *TestNode) handleMessages() { } } } + +// TimeoutOnRound advances time until the node times out of the given round. +func (t *TestNode) TimeoutOnRound(round uint64) { + for { + currentRound := t.E.Metadata().Round + if currentRound > round { + return + } + if len(t.BB.BlockShouldBeBuilt) == 0 { + t.BB.BlockShouldBeBuilt <- struct{}{} + } + t.AdvanceTime(t.E.MaxProposalWait) + + // check the wal for an empty vote for that round + if hasVote := t.WAL.ContainsEmptyVote(round); hasVote { + return + } + + time.Sleep(50 * time.Millisecond) + } +} + +func (t *TestNode) TickUntilRoundAdvanced(round uint64, tick time.Duration) { + timeout := time.NewTimer(time.Minute) + defer timeout.Stop() + + for { + if t.E.Metadata().Round >= round { + return + } + + select { + case <-time.After(time.Millisecond * 10): + t.AdvanceTime(tick) + continue + case <-timeout.C: + require.Fail(t.t, "timed out waiting to enter round", "current round %d, waiting for round %d", t.E.Metadata().Round, round) + } + } +} diff --git a/testutil/util.go b/testutil/util.go index bbe6d2d1..6ef8bb5c 100644 --- a/testutil/util.go +++ b/testutil/util.go @@ -5,6 +5,7 @@ package testutil import ( "encoding/asn1" + "fmt" "testing" "time" @@ -206,7 +207,7 @@ func WaitForBlockProposerTimeout(t *testing.T, e *simplex.Epoch, startTime *time } // if we are expected to time out for this round, we should not have a notarization - require.False(t, e.WAL.(*TestWAL).ContainsNotarization(startRound)) + require.False(t, e.WAL.(*TestWAL).ContainsNotarization(startRound), fmt.Sprintf("should not have notarized %d for node %s", startRound, e.ID)) *startTime = startTime.Add(e.EpochConfig.MaxProposalWait / 5) e.AdvanceTime(*startTime) diff --git a/timeout_handler.go b/timeout_handler.go index bf6d4de6..bef67d49 100644 --- a/timeout_handler.go +++ b/timeout_handler.go @@ -4,35 +4,27 @@ package simplex import ( - "container/heap" - "fmt" + "maps" "sync" "time" "go.uber.org/zap" ) -type TimeoutTask struct { - NodeID NodeID - TaskID string - Task func() - Deadline time.Time - - // for replication tasks - Start uint64 - End uint64 - - index int // for heap to work more efficiently -} - -type TimeoutHandler struct { - lock sync.Mutex - - ticks chan time.Time - close chan struct{} - // nodeids -> range -> task - tasks map[string]map[string]*TimeoutTask - heap TaskHeap +type timeoutRunner[T comparable] func(ids []T) +type shouldRemoveFunc[T comparable] func(a T, b T) bool +type TimeoutHandler[T comparable] struct { + // helpful for logging + name string + // how often to run through the tasks + runInterval time.Duration + // function to run tasks + taskRunner timeoutRunner[T] + lock sync.Mutex + ticks chan time.Time + close chan struct{} + // maps id to a task + tasks map[T]struct{} now time.Time log Logger @@ -40,36 +32,35 @@ type TimeoutHandler struct { // NewTimeoutHandler returns a TimeoutHandler and starts a new goroutine that // listens for ticks and executes TimeoutTasks. -func NewTimeoutHandler(log Logger, startTime time.Time, nodes []NodeID) *TimeoutHandler { - tasks := make(map[string]map[string]*TimeoutTask) - for _, node := range nodes { - tasks[string(node)] = make(map[string]*TimeoutTask) - } - - t := &TimeoutHandler{ - now: startTime, - tasks: tasks, - ticks: make(chan time.Time, 1), - close: make(chan struct{}), - log: log, +func NewTimeoutHandler[T comparable](log Logger, name string, startTime time.Time, runInterval time.Duration, taskRunner timeoutRunner[T]) *TimeoutHandler[T] { + t := &TimeoutHandler[T]{ + name: name, + now: startTime, + tasks: make(map[T]struct{}), + ticks: make(chan time.Time, 1), + close: make(chan struct{}), + runInterval: runInterval, + taskRunner: taskRunner, + log: log, } - go t.run() + go t.run(startTime) return t } -func (t *TimeoutHandler) GetTime() time.Time { - t.lock.Lock() - defer t.lock.Unlock() - - return t.now -} +func (t *TimeoutHandler[T]) run(startTime time.Time) { + lastTickTime := startTime -func (t *TimeoutHandler) run() { for t.shouldRun() { select { case now := <-t.ticks: + if now.Sub(lastTickTime) < t.runInterval { + continue + } + lastTickTime = now + + // update the current time t.lock.Lock() t.now = now t.lock.Unlock() @@ -81,30 +72,24 @@ func (t *TimeoutHandler) run() { } } -func (t *TimeoutHandler) maybeRunTasks() { - // go through the heap executing relevant tasks - for { - t.lock.Lock() - if t.heap.Len() == 0 { - t.lock.Unlock() - break - } +func (t *TimeoutHandler[T]) maybeRunTasks() { + ids := make([]T, 0, len(t.tasks)) - next := t.heap[0] - if next.Deadline.After(t.now) { - t.lock.Unlock() - break - } + t.lock.Lock() + for id := range t.tasks { + ids = append(ids, id) + } + t.lock.Unlock() - heap.Pop(&t.heap) - delete(t.tasks[string(next.NodeID)], next.TaskID) - t.lock.Unlock() - t.log.Debug("Executing timeout task", zap.String("taskid", next.TaskID)) - next.Task() + if len(ids) == 0 { + return } + + t.log.Debug("Running task ids", zap.Any("task ids", ids), zap.String("name", t.name)) + t.taskRunner(ids) } -func (t *TimeoutHandler) shouldRun() bool { +func (t *TimeoutHandler[T]) shouldRun() bool { select { case <-t.close: return false @@ -113,72 +98,58 @@ func (t *TimeoutHandler) shouldRun() bool { } } -func (t *TimeoutHandler) Tick(now time.Time) { +func (t *TimeoutHandler[T]) Tick(now time.Time) { select { case t.ticks <- now: t.lock.Lock() t.now = now t.lock.Unlock() default: - t.log.Debug("Dropping tick in timeouthandler") + t.log.Debug("Dropping tick in timeouthandler", zap.String("name", t.name)) } } -func (t *TimeoutHandler) AddTask(task *TimeoutTask) { +func (t *TimeoutHandler[T]) AddTask(id T) { t.lock.Lock() defer t.lock.Unlock() - if _, ok := t.tasks[string(task.NodeID)]; !ok { - t.log.Debug("Attempting to add a task for an unknown node", zap.Stringer("from", task.NodeID)) - return - } - - // adds a task to the heap and the tasks map - if _, ok := t.tasks[string(task.NodeID)][task.TaskID]; ok { - t.log.Debug("Trying to add an already included task", zap.Stringer("from", task.NodeID), zap.String("Task ID", task.TaskID)) - return - } - - t.tasks[string(task.NodeID)][task.TaskID] = task - t.log.Debug("Adding timeout task", zap.Stringer("from", task.NodeID), zap.String("taskid", task.TaskID)) - heap.Push(&t.heap, task) + t.tasks[id] = struct{}{} + t.log.Debug("Adding timeout task", zap.Any("id", id), zap.String("name", t.name)) } -func (t *TimeoutHandler) RemoveTask(nodeID NodeID, ID string) { +func (t *TimeoutHandler[T]) RemoveTask(ID T) { t.lock.Lock() defer t.lock.Unlock() - if _, ok := t.tasks[string(nodeID)]; !ok { - t.log.Debug("Attempting to remove a task for an unknown node", zap.Stringer("from", nodeID)) - return - } - - if _, ok := t.tasks[string(nodeID)][ID]; !ok { + if _, ok := t.tasks[ID]; !ok { return } - // find the task using the task map - // remove it from the heap using the index - t.log.Debug("Removing timeout task", zap.Stringer("from", nodeID), zap.String("taskid", ID)) - heap.Remove(&t.heap, t.tasks[string(nodeID)][ID].index) - delete(t.tasks[string(nodeID)], ID) + t.log.Debug("Removing timeout task", zap.Any("id", ID), zap.String("name", t.name)) + delete(t.tasks, ID) } -func (t *TimeoutHandler) forEach(nodeID string, f func(tt *TimeoutTask)) { +// RemoveOldTasks removes all tasks where shouldRemove(id, task) is true. +// func (t *TimeoutHandler[T]) RemoveOldTasks(task T) { +// t.lock.Lock() +// defer t.lock.Unlock() + +// for id := range t.tasks { +// if t.shouldRemove(id, task) { +// // t.log.Debug("Removing old timeout task", zap.Any("id", id), zap.String("name", t.name)) +// delete(t.tasks, id) +// } +// } +// } + +func (t *TimeoutHandler[T]) RemoveOldTasks(shouldRemove func(id T, _ struct{}) bool) { t.lock.Lock() defer t.lock.Unlock() - tasks, exists := t.tasks[nodeID] - if !exists { - return - } - - for _, task := range tasks { - f(task) - } + maps.DeleteFunc(t.tasks, shouldRemove) } -func (t *TimeoutHandler) Close() { +func (t *TimeoutHandler[T]) Close() { select { case <-t.close: return @@ -187,39 +158,6 @@ func (t *TimeoutHandler) Close() { } } -const delimiter = "_" - -func getTimeoutID(start, end uint64) string { - return fmt.Sprintf("%d%s%d", start, delimiter, end) -} - -// ---------------------------------------------------------------------- -type TaskHeap []*TimeoutTask - -func (h *TaskHeap) Len() int { return len(*h) } - -// Less returns if the task at index [i] has a lower timeout than the task at index [j] -func (h *TaskHeap) Less(i, j int) bool { return (*h)[i].Deadline.Before((*h)[j].Deadline) } - -// Swap swaps the values at index [i] and [j] -func (h *TaskHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] - (*h)[i].index = i - (*h)[j].index = j -} - -func (h *TaskHeap) Push(x any) { - task := x.(*TimeoutTask) - task.index = h.Len() - *h = append(*h, task) -} - -func (h *TaskHeap) Pop() any { - old := *h - len := h.Len() - task := old[len-1] - old[len-1] = nil - *h = old[0 : len-1] - task.index = -1 - return task +func alwaysFalseRemover[T comparable](a T, b T) bool { + return false } diff --git a/timeout_handler_test.go b/timeout_handler_test.go index 4c56126f..7cd4decb 100644 --- a/timeout_handler_test.go +++ b/timeout_handler_test.go @@ -4,8 +4,7 @@ package simplex_test import ( - "sync" - "sync/atomic" + "slices" "testing" "time" @@ -15,317 +14,212 @@ import ( "github.com/stretchr/testify/require" ) -func TestAddAndRunTask(t *testing.T) { +const testName = "test" +const testRunInterval = 200 * time.Millisecond + +func waitNoReceive[T any](t *testing.T, ch <-chan T) { + select { + case <-ch: + t.Fatal("channel unexpectedly signaled") + case <-time.After(defaultWaitDuration): + // good + } +} + +func waitReceive[T any](t *testing.T, ch <-chan T) T { + select { + case v := <-ch: + return v + case <-time.After(defaultWaitDuration): + t.Fatal("timed out waiting for signal") + return *new(T) + } +} + +func lessUint(a, b uint64) bool { return a < b } + +// Ensures tasks only run when at/after the runInterval boundary, and that +// "too-early" ticks are ignored by the runner loop. +func TestTimeoutHandlerRunsOnlyOnInterval(t *testing.T) { start := time.Now() - l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) - defer handler.Close() - - sent := make(chan struct{}, 1) - var count atomic.Int64 - - task := &simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "simplerun", - Deadline: start.Add(5 * time.Second), - Task: func() { - sent <- struct{}{} - count.Add(1) - }, + log := testutil.MakeLogger(t) + + ran := make(chan []uint64, 4) + runner := func(ids []uint64) { + // copy since ids is reused by caller + cp := append([]uint64(nil), ids...) + ran <- cp } - handler.AddTask(task) - handler.Tick(start.Add(2 * time.Second)) - time.Sleep(10 * time.Millisecond) + h := simplex.NewTimeoutHandler(log, testName, start, testRunInterval, runner) + defer h.Close() + + h.AddTask(1) + + // Too early: < runInterval since last tick -> should not run + h.Tick(start.Add(testRunInterval / 2)) + waitNoReceive(t, ran) - require.Zero(t, len(sent)) - handler.Tick(start.Add(6 * time.Second)) - <-sent - require.Equal(t, int64(1), count.Load()) + // Exactly at interval: should run once with id=1 + h.Tick(start.Add(testRunInterval)) + batch := waitReceive(t, ran) + require.Equal(t, []uint64{1}, sorted(batch)) - // test we only execute task once - handler.Tick(start.Add(12 * time.Second)) - time.Sleep(10 * time.Millisecond) - require.Equal(t, int64(1), count.Load()) + h.AddTask(2) + // Another tick but less than interval since the lastTickTime: should not run + h.Tick(start.Add(testRunInterval + testRunInterval/2)) + waitNoReceive(t, ran) + + // At 2*interval: should run again + h.Tick(start.Add(2 * testRunInterval)) + batch = waitReceive(t, ran) + require.Equal(t, []uint64{1, 2}, sorted(batch)) } -func TestRemoveTask(t *testing.T) { +// Add & Remove single task, verifying it stops running after removal. +func TestTimeoutHandler_AddThenRemoveTask(t *testing.T) { start := time.Now() - l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) - defer handler.Close() - - var ran bool - task := &simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "task2", - Deadline: start.Add(1 * time.Second), - Task: func() { - ran = true - }, - } + log := testutil.MakeLogger(t) + + ran := make(chan []uint64, 2) + runner := func(ids []uint64) { ran <- append([]uint64(nil), ids...) } - handler.AddTask(task) - handler.RemoveTask(nodes[0], "task2") - handler.Tick(start.Add(2 * time.Second)) - require.False(t, ran) + h := simplex.NewTimeoutHandler(log, testName, start, testRunInterval, runner) + defer h.Close() - // ensure no panic - handler.RemoveTask(nodes[1], "task-doesn't-exist") + h.AddTask(7) + h.Tick(start.Add(testRunInterval)) + require.Equal(t, []uint64{7}, sorted(waitReceive(t, ran))) + + // Remove then tick again; nothing should run + h.RemoveTask(7) + h.Tick(start.Add(2 * testRunInterval)) + waitNoReceive(t, ran) } -func TestTaskOrder(t *testing.T) { +// Multiple tasks get batched and delivered together; removing one leaves the rest. +func TestTimeoutHandler_MultipleTasksBatchAndPersist(t *testing.T) { start := time.Now() - l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) - defer handler.Close() - - finished := make(chan struct{}) - - var mu sync.Mutex - var results []string - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "first", - Deadline: start.Add(1 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "first") - finished <- struct{}{} - mu.Unlock() - }, - }) + log := testutil.MakeLogger(t) - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[1], - TaskID: "second", - Deadline: start.Add(2 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "second") - finished <- struct{}{} - mu.Unlock() - }, - }) + ran := make(chan []uint64, 2) + runner := func(ids []uint64) { ran <- append([]uint64(nil), ids...) } - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "noruntask", - Deadline: start.Add(4 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "norun") - mu.Unlock() - }, - }) + h := simplex.NewTimeoutHandler(log, testName, start, testRunInterval, runner) + defer h.Close() + + h.AddTask(1) + h.AddTask(2) + h.AddTask(3) - handler.Tick(start.Add(3 * time.Second)) + // First run: should see all three (order not guaranteed) + h.Tick(start.Add(testRunInterval)) + got := sorted(waitReceive(t, ran)) + require.Equal(t, []uint64{1, 2, 3}, got) - <-finished - <-finished + // Remove one; remaining should continue to run on next valid tick + h.RemoveTask(2) + h.Tick(start.Add(2 * testRunInterval)) + got = sorted(waitReceive(t, ran)) + require.Equal(t, []uint64{1, 3}, got) +} + +// Adding the same task twice should not duplicate it in the batch (set semantics). +func TestTimeoutHandler_AddDuplicateTaskIsIdempotent(t *testing.T) { + start := time.Now() + log := testutil.MakeLogger(t) - mu.Lock() - defer mu.Unlock() + ran := make(chan []uint64, 1) + runner := func(ids []uint64) { ran <- append([]uint64(nil), ids...) } - require.Equal(t, 2, len(results)) - require.Equal(t, results[0], "first") - require.Equal(t, results[1], "second") + h := simplex.NewTimeoutHandler(log, testName, start, testRunInterval, runner) + defer h.Close() + + h.AddTask(42) + h.AddTask(42) // duplicate + + h.Tick(start.Add(testRunInterval)) + got := sorted(waitReceive(t, ran)) + require.Equal(t, []uint64{42}, got) } -func TestAddTasksOutOfOrder(t *testing.T) { +// RemoveOldTasks should drop tasks where shouldRemove(id, cutoff) == true. +// With lessUint(a,b), that means id < cutoff. +func TestTimeoutHandler_RemoveOldTasks(t *testing.T) { start := time.Now() - l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - handler := simplex.NewTimeoutHandler(l, start, nodes) - defer handler.Close() - - finished := make(chan struct{}) - var mu sync.Mutex - var results []string - - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "third", - Deadline: start.Add(3 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "third") - finished <- struct{}{} - mu.Unlock() - }, - }) + log := testutil.MakeLogger(t) - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "second", - Deadline: start.Add(2 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "second") - finished <- struct{}{} - mu.Unlock() - }, - }) + ran := make(chan []uint64, 2) + runner := func(ids []uint64) { ran <- append([]uint64(nil), ids...) } + + h := simplex.NewTimeoutHandler(log, testName, start, testRunInterval, runner) + defer h.Close() - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[1], - TaskID: "fourth", - Deadline: start.Add(4 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "fourth") - finished <- struct{}{} - mu.Unlock() - }, + for _, id := range []uint64{1, 2, 3, 4, 5} { + h.AddTask(id) + } + + // Remove everything with id < 3 -> removes 1,2 ; keeps 3,4,5 + h.RemoveOldTasks(func(id uint64, _ struct{}) bool { + return id < 3 }) + h.Tick(start.Add(testRunInterval)) + got := sorted(waitReceive(t, ran)) + require.Equal(t, []uint64{3, 4, 5}, got) - handler.AddTask(&simplex.TimeoutTask{ - NodeID: nodes[0], - TaskID: "first", - Deadline: start.Add(1 * time.Second), - Task: func() { - mu.Lock() - results = append(results, "first") - finished <- struct{}{} - mu.Unlock() - }, + // Now remove everything with id < 5 -> removes 3,4 ; keeps 5 + h.RemoveOldTasks(func(id uint64, _ struct{}) bool { + return id < 5 }) + h.Tick(start.Add(2 * testRunInterval)) + got = sorted(waitReceive(t, ran)) + require.Equal(t, []uint64{5}, got) +} + +// After Close, the goroutine stops and further ticks should not cause runs. +func TestTimeoutHandler_CloseStopsRunner(t *testing.T) { + start := time.Now() + log := testutil.MakeLogger(t) + + ran := make(chan []uint64, 1) + runner := func(ids []uint64) { ran <- append([]uint64(nil), ids...) } - handler.Tick(start.Add(1 * time.Second)) - <-finished - mu.Lock() - require.Equal(t, 1, len(results)) - require.Equal(t, results[0], "first") - mu.Unlock() - - handler.Tick(start.Add(3 * time.Second)) - <-finished - <-finished - mu.Lock() - require.Equal(t, 3, len(results)) - require.Equal(t, results[1], "second") - require.Equal(t, results[2], "third") - mu.Unlock() - - handler.Tick(start.Add(4 * time.Second)) - <-finished - mu.Lock() - require.Equal(t, 4, len(results)) - require.Equal(t, results[3], "fourth") - mu.Unlock() + h := simplex.NewTimeoutHandler(log, testName, start, testRunInterval, runner) + + h.Close() + // Calls after Close should be safe and no-ops for scheduling. + h.AddTask(9) + h.Tick(start.Add(testRunInterval)) + waitNoReceive(t, ran) } -func TestFindTask(t *testing.T) { - // Setup a mock logger - l := testutil.MakeLogger(t, 1) - nodes := []simplex.NodeID{{1}, {2}} - startTime := time.Now() - - handler := simplex.NewTimeoutHandler(l, startTime, nodes) - defer handler.Close() - - // Create some test tasks - task1 := &simplex.TimeoutTask{ - TaskID: "task1", - NodeID: nodes[0], - Start: 5, - End: 10, - } +// If ticks come in faster than runInterval, the second "too-soon" tick should +// not trigger execution (interval gating). This overlaps with RunsOnlyOnInterval, +// but emphasizes back-to-back ticks. +func TestTimeoutHandler_BackToBackTicksUnderIntervalDontRun(t *testing.T) { + start := time.Now() + log := testutil.MakeLogger(t) - taskSameRangeDiffNode := &simplex.TimeoutTask{ - TaskID: "taskSameDiff", - NodeID: nodes[1], - Start: 5, - End: 10, - } + ran := make(chan []uint64, 2) + runner := func(ids []uint64) { ran <- append([]uint64(nil), ids...) } - task3 := &simplex.TimeoutTask{ - TaskID: "task3", - NodeID: nodes[1], - Start: 25, - End: 30, - } + h := simplex.NewTimeoutHandler(log, testName, start, testRunInterval, runner) + defer h.Close() - task4 := &simplex.TimeoutTask{ - TaskID: "task4", - NodeID: nodes[1], - Start: 31, - End: 36, - } + h.AddTask(100) - // Add tasks to handler - handler.AddTask(task1) - handler.AddTask(taskSameRangeDiffNode) - handler.AddTask(task3) - handler.AddTask(task4) - - tests := []struct { - name string - node simplex.NodeID - seqs []uint64 - expected *simplex.TimeoutTask - }{ - { - name: "Find task with sequence in middle of range", - node: nodes[0], - seqs: []uint64{7, 8, 9}, - expected: task1, - }, - { - name: "Find task with sequence at boundary (inclusive)", - node: nodes[0], - seqs: []uint64{5, 7}, - expected: task1, - }, - { - name: "Find task with mixed sequences (first valid sequence)", - node: nodes[0], - seqs: []uint64{3, 4, 5, 11}, - expected: task1, // 5 is in range - }, - { - name: "Same sequences, but different node", - node: nodes[1], - seqs: []uint64{7, 8, 9}, - expected: taskSameRangeDiffNode, - }, - { - name: "No sequences in range", - node: nodes[0], - seqs: []uint64{1, 2, 3, 4, 11, 12, 13, 14}, - expected: nil, - }, - { - name: "Span across many tasks", - node: nodes[1], - seqs: []uint64{26, 27, 30, 31, 33}, - expected: task3, - }, - { - name: "Unknown node", - node: simplex.NodeID("unknown"), - seqs: []uint64{5, 15, 25}, - expected: nil, - }, - { - name: "Empty sequence list", - node: nodes[1], - seqs: []uint64{}, - expected: nil, - }, - } + h.Tick(start.Add(testRunInterval)) // should run + require.Equal(t, []uint64{100}, sorted(waitReceive(t, ran))) + h.Tick(start.Add(testRunInterval + time.Millisecond)) // < interval since lastTickTime -> no run + waitNoReceive(t, ran) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := simplex.FindReplicationTask(handler, tt.node, tt.seqs) - if tt.expected != result { - require.Fail(t, "not equal") - } - require.Equal(t, tt.expected, result) - }) - } + // Next valid interval boundary -> runs again + h.Tick(start.Add(2 * testRunInterval)) + require.Equal(t, []uint64{100}, sorted(waitReceive(t, ran))) +} + +func sorted(v []uint64) []uint64 { + sorted := append([]uint64(nil), v...) + slices.Sort(sorted) + return sorted } diff --git a/util.go b/util.go index b4a3aca2..f47fb0a3 100644 --- a/util.go +++ b/util.go @@ -6,6 +6,7 @@ package simplex import ( "context" "fmt" + "slices" "strings" "sync" "time" @@ -85,13 +86,13 @@ func VerifyQC(qc QuorumCertificate, logger Logger, messageType string, isQuorum } // GetLatestVerifiedQuorumRound returns the latest verified quorum round given -// a round, empty notarization, and last block. If all are nil, it returns nil. -func GetLatestVerifiedQuorumRound(round *Round, emptyNotarization *EmptyNotarization, lastBlock *VerifiedFinalizedBlock) *VerifiedQuorumRound { +// a round and empty notarization. If both are nil, it returns nil. +func GetLatestVerifiedQuorumRound(round *Round, emptyNotarization *EmptyNotarization) *VerifiedQuorumRound { var verifiedQuorumRound *VerifiedQuorumRound var highestRound uint64 var exists bool - if round != nil { + if round != nil && (round.finalization != nil || round.notarization != nil) { highestRound = round.num verifiedQuorumRound = &VerifiedQuorumRound{ VerifiedBlock: round.block, @@ -107,15 +108,6 @@ func GetLatestVerifiedQuorumRound(round *Round, emptyNotarization *EmptyNotariza verifiedQuorumRound = &VerifiedQuorumRound{ EmptyNotarization: emptyNotarization, } - highestRound = emptyNotarization.Vote.Round - exists = true - } - } - - if lastBlock != nil && (lastBlock.VerifiedBlock.BlockHeader().Round > highestRound || !exists) { - verifiedQuorumRound = &VerifiedQuorumRound{ - VerifiedBlock: lastBlock.VerifiedBlock, - Finalization: &lastBlock.Finalization, } } @@ -203,10 +195,11 @@ type Segment struct { End uint64 } -// compressSequences takes a sorted slice of uint64 values representing +// compressSequences takes a slice of uint64 values representing // missing sequence numbers and compresses consecutive numbers into segments. // Each segment represents a continuous block of missing sequence numbers. func CompressSequences(missingSeqs []uint64) []Segment { + slices.Sort(missingSeqs) var segments []Segment if len(missingSeqs) == 0 { diff --git a/util_test.go b/util_test.go index 8bf081b1..a9340f71 100644 --- a/util_test.go +++ b/util_test.go @@ -188,7 +188,6 @@ func TestGetHighestQuorumRound(t *testing.T) { block10 := testutil.NewTestBlock(ProtocolMetadata{Seq: 10, Round: 10}, emptyBlacklist) notarization10, err := testutil.NewNotarization(l, signatureAggregator, block10, nodes) require.NoError(t, err) - finalization10, _ := testutil.NewFinalizationRecord(t, l, signatureAggregator, block10, nodes) tests := []struct { name string @@ -205,18 +204,11 @@ func TestGetHighestQuorumRound(t *testing.T) { }, }, { - name: "only last block", - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block1, - Finalization: finalization1, - }, - expectedQr: &VerifiedQuorumRound{ - VerifiedBlock: block1, - Finalization: &finalization1, - }, + name: "nothing", + expectedQr: nil, }, { - name: "round", + name: "round with finalization", round: SetRound(block1, nil, &finalization1), expectedQr: &VerifiedQuorumRound{ VerifiedBlock: block1, @@ -232,36 +224,17 @@ func TestGetHighestQuorumRound(t *testing.T) { }, }, { - name: "higher notarized round than indexed", + name: "higher round than empty notarization", round: SetRound(block10, ¬arization10, nil), - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block1, - Finalization: finalization1, - }, + eNote: testutil.NewEmptyNotarization(nodes, 1), expectedQr: &VerifiedQuorumRound{ VerifiedBlock: block10, Notarization: ¬arization10, }, }, - { - name: "higher indexed than in round", - round: SetRound(block1, ¬arization1, nil), - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block10, - Finalization: finalization10, - }, - expectedQr: &VerifiedQuorumRound{ - VerifiedBlock: block10, - Finalization: &finalization10, - }, - }, { name: "higher empty notarization", eNote: testutil.NewEmptyNotarization(nodes, 100), - lastBlock: &VerifiedFinalizedBlock{ - VerifiedBlock: block1, - Finalization: finalization1, - }, round: SetRound(block10, ¬arization10, nil), expectedQr: &VerifiedQuorumRound{ EmptyNotarization: testutil.NewEmptyNotarization(nodes, 100), @@ -271,7 +244,7 @@ func TestGetHighestQuorumRound(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - qr := GetLatestVerifiedQuorumRound(tt.round, tt.eNote, tt.lastBlock) + qr := GetLatestVerifiedQuorumRound(tt.round, tt.eNote) require.Equal(t, tt.expectedQr, qr) }) }