diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c30e59359..f6739bfd1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -374,6 +374,8 @@ jobs: - 'raft-kv-memstore-opendal-snapshot-data' - 'raft-kv-memstore-singlethreaded' - 'raft-kv-rocksdb' + - 'multi-raft-kv' + - 'multi-raft-sharding' steps: - uses: actions/checkout@v4 @@ -401,13 +403,17 @@ jobs: - name: Test demo script of examples/${{ matrix.example }} # The script is not meant for testing. Just to ensure it works but do not - # rely on it. + # rely on it. Skip if test-cluster.sh doesn't exist. if: ${{ matrix.toolchain == 'stable' }} shell: bash run: | cd examples/${{ matrix.example }} - ./test-cluster.sh + if [ -f test-cluster.sh ]; then + ./test-cluster.sh + else + echo "No test-cluster.sh found, skipping" + fi - name: Format # clippy/format produces different result with stable and nightly. Only with nightly. diff --git a/Cargo.toml b/Cargo.toml index b6978964a..cbef7a088 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,10 @@ exclude = [ "examples/raft-kv-memstore-opendal-snapshot-data", "examples/raft-kv-rocksdb", + "examples/multi-raft-kv", + "examples/multi-raft-sharding", + "rt-monoio", - "rt-compio" + "rt-compio", + "multiraft" ] diff --git a/Makefile b/Makefile index 365249665..7210e5c6c 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,8 @@ test: cargo test --features serde # only crate `tests` has single-term-leader feature cargo test --features single-term-leader -p tests + # multiraft crate tests + cargo test --manifest-path multiraft/Cargo.toml $(MAKE) test-examples check-parallel: @@ -40,6 +42,8 @@ test-examples: cargo test --manifest-path examples/raft-kv-memstore-singlethreaded/Cargo.toml cargo test --manifest-path examples/raft-kv-rocksdb/Cargo.toml cargo test --manifest-path examples/rocksstore/Cargo.toml + cargo test --manifest-path examples/multi-raft-kv/Cargo.toml + cargo test --manifest-path examples/multi-raft-sharding/Cargo.toml bench: cargo bench --features bench @@ -75,19 +79,29 @@ guide: lint: cargo fmt + cargo fmt --manifest-path multiraft/Cargo.toml + cargo fmt --manifest-path rt-compio/Cargo.toml + cargo fmt --manifest-path rt-monoio/Cargo.toml cargo fmt --manifest-path examples/mem-log/Cargo.toml cargo fmt --manifest-path examples/raft-kv-memstore-network-v2/Cargo.toml cargo fmt --manifest-path examples/raft-kv-memstore-opendal-snapshot-data/Cargo.toml cargo fmt --manifest-path examples/raft-kv-memstore-singlethreaded/Cargo.toml cargo fmt --manifest-path examples/raft-kv-memstore/Cargo.toml cargo fmt --manifest-path examples/raft-kv-rocksdb/Cargo.toml + cargo fmt --manifest-path examples/multi-raft-kv/Cargo.toml + cargo fmt --manifest-path examples/multi-raft-sharding/Cargo.toml cargo clippy --no-deps --all-targets -- -D warnings - cargo clippy --no-deps --manifest-path examples/mem-log/Cargo.toml --all-targets -- -D warnings + cargo clippy --no-deps --manifest-path multiraft/Cargo.toml --all-targets -- -D warnings + cargo clippy --no-deps --manifest-path rt-compio/Cargo.toml --all-targets -- -D warnings + cargo clippy --no-deps --manifest-path rt-monoio/Cargo.toml --all-targets -- -D warnings + cargo clippy --no-deps --manifest-path examples/mem-log/Cargo.toml --all-targets -- -D warnings cargo clippy --no-deps --manifest-path examples/raft-kv-memstore-network-v2/Cargo.toml --all-targets -- -D warnings cargo clippy --no-deps --manifest-path examples/raft-kv-memstore-opendal-snapshot-data/Cargo.toml --all-targets -- -D warnings cargo clippy --no-deps --manifest-path examples/raft-kv-memstore-singlethreaded/Cargo.toml --all-targets -- -D warnings cargo clippy --no-deps --manifest-path examples/raft-kv-memstore/Cargo.toml --all-targets -- -D warnings cargo clippy --no-deps --manifest-path examples/raft-kv-rocksdb/Cargo.toml --all-targets -- -D warnings + cargo clippy --no-deps --manifest-path examples/multi-raft-kv/Cargo.toml --all-targets -- -D warnings + cargo clippy --no-deps --manifest-path examples/multi-raft-sharding/Cargo.toml --all-targets -- -D warnings # Bug: clippy --all-targets reports false warning about unused dep in # `[dev-dependencies]`: # https://github.com/rust-lang/rust/issues/72686#issuecomment-635539688 diff --git a/examples/mem-log/Cargo.toml b/examples/mem-log/Cargo.toml index f1bd9e665..d795910df 100644 --- a/examples/mem-log/Cargo.toml +++ b/examples/mem-log/Cargo.toml @@ -15,7 +15,7 @@ license = "MIT OR Apache-2.0" repository = "https://github.com/databendlabs/openraft" [dependencies] -openraft = { path = "../../openraft", features = ["type-alias"] } +openraft = { path = "../../openraft", default-features = false, features = ["type-alias", "tokio-rt"] } tokio = { version = "1.0", default-features = false, features = ["sync"] } diff --git a/examples/multi-raft-kv/Cargo.toml b/examples/multi-raft-kv/Cargo.toml new file mode 100644 index 000000000..91c05ed5d --- /dev/null +++ b/examples/multi-raft-kv/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "multi-raft-kv" +version = "0.1.0" +readme = "README.md" + +edition = "2021" +authors = [ + "AriesDevil ", +] +categories = ["algorithms", "asynchronous", "data-structures"] +description = "An example Multi-Raft distributed key-value store with 3 groups built upon `openraft`." +homepage = "https://github.com/databendlabs/openraft" +keywords = ["raft", "consensus", "multi-raft"] +license = "MIT OR Apache-2.0" +repository = "https://github.com/databendlabs/openraft" + +[dependencies] +mem-log = { path = "../mem-log", features = [] } +openraft = { path = "../../openraft", default-features = false, features = ["serde", "type-alias", "tokio-rt"] } +openraft-multi = { path = "../../multiraft" } + +futures = { version = "0.3" } +serde = { version = "1", features = ["derive"] } +serde_json = { version = "1" } +tokio = { version = "1", default-features = false, features = ["sync"] } +tracing = { version = "0.1" } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + + +[features] + +[package.metadata.docs.rs] +all-features = true + diff --git a/examples/multi-raft-kv/README.md b/examples/multi-raft-kv/README.md new file mode 100644 index 000000000..8661feef5 --- /dev/null +++ b/examples/multi-raft-kv/README.md @@ -0,0 +1,85 @@ +# Multi-Raft KV Store Example + +This example demonstrates how to use OpenRaft's Multi-Raft support to run multiple independent Raft consensus groups within a single process. + +## Overview + +The example creates a distributed key-value store with **3 Raft groups**: +- **users** - Stores user data +- **orders** - Stores order data +- **products** - Stores product data + +Each group runs its own independent Raft consensus, but they share the same network infrastructure. + +## Architecture + +``` ++-----------------------------------------------------------------------+ +| Node 1 | +| +-------------------+ +-------------------+ +-------------------+ | +| | Group "users" | | Group "orders" | | Group "products" | | +| | (Raft Instance) | | (Raft Instance) | | (Raft Instance) | | +| +-------------------+ +-------------------+ +-------------------+ | +| | | | | +| +----------------------+----------------------+ | +| | | +| +--------+--------+ | +| | Router | | +| | (shared network)| | +| +-----------------+ | ++------------------------------------------------------------------------+ + | + Network Connection + | ++------------------------------------------------------------------------+ +| Node 2 | +| +-------------------+ +-------------------+ +-------------------+ | +| | Group "users" | | Group "orders" | | Group "products" | | +| | (Raft Instance) | | (Raft Instance) | | (Raft Instance) | | +| +-------------------+ +-------------------+ +-------------------+ | ++------------------------------------------------------------------------+ +``` + +## Key Concepts + +### GroupId +A string identifier that uniquely identifies each Raft group (e.g., "users", "orders", "products"). + +### Shared Network +Multiple Raft groups share the same network infrastructure (`Router`), reducing connection overhead. Messages are routed to the correct group using the `group_id`. + +### Independent Consensus +Each group runs its own Raft consensus independently: +- Separate log storage +- Separate state machine +- Separate leader election +- Separate membership + +## Running the Test + +```bash +# Run the integration test +cargo test -p multi-raft-kv test_multi_raft_cluster -- --nocapture + +# With debug logging +RUST_LOG=debug cargo test -p multi-raft-kv test_multi_raft_cluster -- --nocapture +``` + +## Code Structure + +``` +multi-raft-kv/ +├── Cargo.toml +├── README.md +├── src/ +│ ├── lib.rs # Type definitions and group constants +│ ├── app.rs # Application handler for each group +│ ├── api.rs # API handlers (read, write, raft operations) +│ ├── network.rs # Network implementation with group routing +│ ├── router.rs # Message router for (node_id, group_id) +│ └── store.rs # State machine storage +└── tests/ + └── cluster/ + ├── main.rs + └── test_cluster.rs # Integration tests +``` diff --git a/examples/multi-raft-kv/src/api.rs b/examples/multi-raft-kv/src/api.rs new file mode 100644 index 000000000..f0d9098e5 --- /dev/null +++ b/examples/multi-raft-kv/src/api.rs @@ -0,0 +1,111 @@ +use std::collections::BTreeMap; +use std::collections::BTreeSet; + +use openraft::raft::TransferLeaderRequest; +use openraft::BasicNode; +use openraft::ReadPolicy; + +use crate::app::GroupApp; +use crate::decode; +use crate::encode; +use crate::typ::*; +use crate::NodeId; + +/// Write a key-value pair to the group's state machine +pub async fn write(app: &mut GroupApp, req: String) -> String { + let res = app.raft.client_write(decode(&req)).await; + encode(res) +} + +/// Read a value from the group's state machine using linearizable read +pub async fn read(app: &mut GroupApp, req: String) -> String { + let key: String = decode(&req); + + let ret = app.raft.get_read_linearizer(ReadPolicy::ReadIndex).await; + + let res = match ret { + Ok(linearizer) => { + linearizer.await_ready(&app.raft).await.unwrap(); + + let state_machine = app.state_machine.state_machine.lock().await; + let value = state_machine.data.get(&key).cloned(); + + let res: Result> = Ok(value.unwrap_or_default()); + res + } + Err(e) => Err(e), + }; + encode(res) +} + +// ============================================================================ +// Raft Protocol API +// ============================================================================ + +/// Handle vote request +pub async fn vote(app: &mut GroupApp, req: String) -> String { + let res = app.raft.vote(decode(&req)).await; + encode(res) +} + +/// Handle append entries request +pub async fn append(app: &mut GroupApp, req: String) -> String { + let res = app.raft.append_entries(decode(&req)).await; + encode(res) +} + +/// Receive a snapshot and install it +pub async fn snapshot(app: &mut GroupApp, req: String) -> String { + let (vote, snapshot_meta, snapshot_data): (Vote, SnapshotMeta, SnapshotData) = decode(&req); + let snapshot = Snapshot { + meta: snapshot_meta, + snapshot: snapshot_data, + }; + let res = app.raft.install_full_snapshot(vote, snapshot).await.map_err(RaftError::::Fatal); + encode(res) +} + +/// Handle transfer leader request +pub async fn transfer_leader(app: &mut GroupApp, req: String) -> String { + let transfer_req: TransferLeaderRequest = decode(&req); + let res = app.raft.handle_transfer_leader(transfer_req).await; + encode(res) +} + +// ============================================================================ +// Management API +// ============================================================================ + +/// Add a node as **Learner** to this group. +/// +/// This should be done before adding a node as a member into the cluster +/// (by calling `change-membership`) +pub async fn add_learner(app: &mut GroupApp, req: String) -> String { + let node_id: NodeId = decode(&req); + let node = BasicNode { addr: "".to_string() }; + let res = app.raft.add_learner(node_id, node, true).await; + encode(res) +} + +/// Changes specified learners to members, or remove members from this group. +pub async fn change_membership(app: &mut GroupApp, req: String) -> String { + let node_ids: BTreeSet = decode(&req); + let res = app.raft.change_membership(node_ids, false).await; + encode(res) +} + +/// Initialize a single-node cluster for this group. +pub async fn init(app: &mut GroupApp) -> String { + let mut nodes = BTreeMap::new(); + nodes.insert(app.node_id, BasicNode { addr: "".to_string() }); + let res = app.raft.initialize(nodes).await; + encode(res) +} + +/// Get the latest metrics of this Raft group +pub async fn metrics(app: &mut GroupApp) -> String { + let metrics = app.raft.metrics().borrow().clone(); + + let res: Result = Ok(metrics); + encode(res) +} diff --git a/examples/multi-raft-kv/src/app.rs b/examples/multi-raft-kv/src/app.rs new file mode 100644 index 000000000..51571b899 --- /dev/null +++ b/examples/multi-raft-kv/src/app.rs @@ -0,0 +1,120 @@ +use std::collections::BTreeMap; +use std::sync::Arc; + +use tokio::sync::mpsc; + +use crate::api; +use crate::encode; +use crate::router::NodeMessage; +use crate::router::NodeRx; +use crate::router::NodeTx; +use crate::router::Router; +use crate::typ; +use crate::GroupId; +use crate::NodeId; +use crate::StateMachineStore; + +/// A Node manages multiple Raft groups on the same physical node. +/// +/// All groups share ONE connection to this node. +/// The Node dispatches incoming messages to the correct group based on group_id. +pub struct Node { + pub node_id: NodeId, + pub groups: BTreeMap, + pub rx: NodeRx, + pub router: Router, +} + +impl Node { + pub fn new(node_id: NodeId, router: Router) -> (Self, NodeTx) { + let (tx, rx) = mpsc::unbounded_channel(); + + // Register this node's shared connection + router.register_node(node_id, tx.clone()); + + let node = Self { + node_id, + groups: BTreeMap::new(), + rx, + router, + }; + + (node, tx) + } + + /// Add a Raft group to this node. + pub fn add_group(&mut self, group_id: GroupId, raft: typ::Raft, state_machine: Arc) { + let app = GroupApp { + node_id: self.node_id, + group_id: group_id.clone(), + raft, + state_machine, + }; + self.groups.insert(group_id, app); + } + + /// Get a Raft instance by group_id. + pub fn get_raft(&self, group_id: &GroupId) -> Option<&typ::Raft> { + self.groups.get(group_id).map(|g| &g.raft) + } + + /// Run the node message dispatcher. + /// Routes incoming messages to the correct group based on group_id. + pub async fn run(mut self) -> Option<()> { + loop { + let msg = self.rx.recv().await?; + + let NodeMessage { + group_id, + path, + payload, + response_tx, + } = msg; + + // Find the target group + let group = match self.groups.get_mut(&group_id) { + Some(g) => g, + None => { + let _ = response_tx.send(encode::>(Err(typ::RaftError::Fatal( + openraft::error::Fatal::Stopped, + )))); + continue; + } + }; + + // Dispatch to the group + let res = match path.as_str() { + // Application API + "/app/write" => api::write(group, payload).await, + "/app/read" => api::read(group, payload).await, + + // Raft API + "/raft/append" => api::append(group, payload).await, + "/raft/snapshot" => api::snapshot(group, payload).await, + "/raft/vote" => api::vote(group, payload).await, + "/raft/transfer_leader" => api::transfer_leader(group, payload).await, + + // Management API + "/mng/add-learner" => api::add_learner(group, payload).await, + "/mng/change-membership" => api::change_membership(group, payload).await, + "/mng/init" => api::init(group).await, + "/mng/metrics" => api::metrics(group).await, + + _ => { + tracing::warn!("unknown path: {}", path); + encode::>(Err(typ::RaftError::Fatal(openraft::error::Fatal::Stopped))) + } + }; + + let _ = response_tx.send(res); + } + } +} + +/// A single Raft group's application context. +pub struct GroupApp { + pub node_id: NodeId, + pub group_id: GroupId, + pub raft: typ::Raft, + pub state_machine: Arc, +} diff --git a/examples/multi-raft-kv/src/lib.rs b/examples/multi-raft-kv/src/lib.rs new file mode 100644 index 000000000..a730e2935 --- /dev/null +++ b/examples/multi-raft-kv/src/lib.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use openraft::Config; + +use crate::app::Node; +use crate::router::Router; +use crate::store::Request; +use crate::store::Response; +use crate::store::StateMachineData; + +pub mod router; + +pub mod api; +pub mod app; +pub mod network; +pub mod store; + +/// Node ID type - identifies a node in the cluster +pub type NodeId = u64; + +/// Group ID type - identifies a Raft group +pub type GroupId = String; + +openraft::declare_raft_types!( + /// Declare the type configuration for Multi-Raft K/V store. + pub TypeConfig: + D = Request, + R = Response, + // In this example, snapshot is just a copy of the state machine. + SnapshotData = StateMachineData, +); + +pub type LogStore = store::LogStore; +pub type StateMachineStore = store::StateMachineStore; + +/// Define all Raft-related type aliases +#[path = "../../utils/declare_types.rs"] +pub mod typ; + +pub mod groups { + pub const USERS: &str = "users"; + pub const ORDERS: &str = "orders"; + pub const PRODUCTS: &str = "products"; + + pub fn all() -> Vec { + vec![USERS.to_string(), ORDERS.to_string(), PRODUCTS.to_string()] + } +} + +pub fn encode(t: T) -> String { + serde_json::to_string(&t).unwrap() +} + +pub fn decode(s: &str) -> T { + serde_json::from_str(s).unwrap() +} + +/// Create a Node with multiple Raft groups. +/// +/// - One Node has ONE connection (shared by all groups) +/// - Each group has its own Raft instance +pub async fn create_node(node_id: NodeId, group_ids: &[GroupId], router: Router) -> Node { + let (mut node, _tx) = Node::new(node_id, router.clone()); + + for group_id in group_ids { + let config = Config { + heartbeat_interval: 500, + election_timeout_min: 1500, + election_timeout_max: 3000, + max_in_snapshot_log_to_keep: 0, + ..Default::default() + }; + + let config = Arc::new(config.validate().unwrap()); + let log_store = LogStore::default(); + let state_machine_store = Arc::new(StateMachineStore::default()); + + let network = network::NetworkFactory::new(router.clone(), group_id.clone()); + + let raft = openraft::Raft::new(node_id, config, network, log_store, state_machine_store.clone()).await.unwrap(); + + node.add_group(group_id.clone(), raft, state_machine_store); + } + + node +} diff --git a/examples/multi-raft-kv/src/network.rs b/examples/multi-raft-kv/src/network.rs new file mode 100644 index 000000000..421ea74db --- /dev/null +++ b/examples/multi-raft-kv/src/network.rs @@ -0,0 +1,92 @@ +use std::future::Future; + +use openraft::error::RPCError; +use openraft::error::ReplicationClosed; +use openraft::error::StreamingError; +use openraft::network::Backoff; +use openraft::network::RPCOption; +use openraft::network::RaftNetworkFactory; +use openraft::raft::AppendEntriesRequest; +use openraft::raft::AppendEntriesResponse; +use openraft::raft::SnapshotResponse; +use openraft::raft::TransferLeaderRequest; +use openraft::raft::VoteRequest; +use openraft::raft::VoteResponse; +use openraft::storage::Snapshot; +use openraft::OptionalSend; +use openraft_multi::GroupNetworkAdapter; +use openraft_multi::GroupNetworkFactory; +use openraft_multi::GroupRouter; + +use crate::router::Router; +use crate::typ; +use crate::GroupId; +use crate::NodeId; +use crate::TypeConfig; + +impl GroupRouter for Router { + async fn send_append_entries( + &self, + target: NodeId, + group_id: GroupId, + rpc: AppendEntriesRequest, + _option: RPCOption, + ) -> Result, RPCError> { + self.send(target, &group_id, "/raft/append", rpc).await.map_err(RPCError::Unreachable) + } + + async fn send_vote( + &self, + target: NodeId, + group_id: GroupId, + rpc: VoteRequest, + _option: RPCOption, + ) -> Result, RPCError> { + self.send(target, &group_id, "/raft/vote", rpc).await.map_err(RPCError::Unreachable) + } + + async fn send_snapshot( + &self, + target: NodeId, + group_id: GroupId, + vote: typ::Vote, + snapshot: Snapshot, + _cancel: impl Future + OptionalSend + 'static, + _option: RPCOption, + ) -> Result, StreamingError> { + self.send( + target, + &group_id, + "/raft/snapshot", + (vote, snapshot.meta, snapshot.snapshot), + ) + .await + .map_err(StreamingError::Unreachable) + } + + async fn send_transfer_leader( + &self, + target: NodeId, + group_id: GroupId, + req: TransferLeaderRequest, + _option: RPCOption, + ) -> Result<(), RPCError> { + self.send(target, &group_id, "/raft/transfer_leader", req).await.map_err(RPCError::Unreachable) + } + + fn backoff(&self) -> Backoff { + Backoff::new(std::iter::repeat(std::time::Duration::from_millis(500))) + } +} + +/// Network factory that creates `GroupNetworkAdapter` instances. +pub type NetworkFactory = GroupNetworkFactory; + +impl RaftNetworkFactory for NetworkFactory { + /// The network type is `GroupNetworkAdapter` binding (Router, target, group_id). + type Network = GroupNetworkAdapter; + + async fn new_client(&mut self, target: NodeId, _node: &openraft::BasicNode) -> Self::Network { + GroupNetworkAdapter::new(self.factory.clone(), target, self.group_id.clone()) + } +} diff --git a/examples/multi-raft-kv/src/router.rs b/examples/multi-raft-kv/src/router.rs new file mode 100644 index 000000000..9002a8b49 --- /dev/null +++ b/examples/multi-raft-kv/src/router.rs @@ -0,0 +1,121 @@ +use std::collections::BTreeMap; +use std::fmt; +use std::sync::Arc; +use std::sync::Mutex; + +use openraft::error::Unreachable; +use tokio::sync::oneshot; + +use crate::decode; +use crate::encode; +use crate::typ::RaftError; +use crate::GroupId; +use crate::NodeId; + +pub type NodeTx = tokio::sync::mpsc::UnboundedSender; +pub type NodeRx = tokio::sync::mpsc::UnboundedReceiver; + +#[derive(Debug)] +pub struct RouterError(pub String); + +impl fmt::Display for RouterError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for RouterError {} + +/// Message sent through a node connection, containing group_id for routing. +pub struct NodeMessage { + pub group_id: GroupId, + pub path: String, + pub payload: String, + pub response_tx: oneshot::Sender, +} + +/// Multi-Raft Router with per-node connection sharing. +#[derive(Debug, Clone, Default)] +pub struct Router { + /// Map from node_id to node connection. + /// All groups on the same node share this connection. + pub nodes: Arc>>, +} + +impl Router { + pub fn new() -> Self { + Self { + nodes: Arc::new(Mutex::new(BTreeMap::new())), + } + } + + /// Register a node connection. All groups on this node will use this connection. + pub fn register_node(&self, node_id: NodeId, tx: NodeTx) { + let mut nodes = self.nodes.lock().unwrap(); + nodes.insert(node_id, tx); + } + + /// Unregister a node connection. + pub fn unregister_node(&self, node_id: NodeId) -> Option { + let mut nodes = self.nodes.lock().unwrap(); + nodes.remove(&node_id) + } + + /// Send a request to a specific (node, group). + pub async fn send( + &self, + to_node: NodeId, + to_group: &GroupId, + path: &str, + req: Req, + ) -> Result + where + Req: serde::Serialize, + Result: serde::de::DeserializeOwned, + { + let (resp_tx, resp_rx) = oneshot::channel(); + + let encoded_req = encode(&req); + tracing::debug!( + "send to: node={}, group={}, path={}, req={}", + to_node, + to_group, + path, + encoded_req + ); + + // Send through the shared node connection + { + let nodes = self.nodes.lock().unwrap(); + let tx = nodes + .get(&to_node) + .ok_or_else(|| Unreachable::new(&RouterError(format!("node {} not connected", to_node))))?; + + let msg = NodeMessage { + group_id: to_group.clone(), + path: path.to_string(), + payload: encoded_req, + response_tx: resp_tx, + }; + + tx.send(msg).map_err(|e| Unreachable::new(&RouterError(e.to_string())))?; + } + + let resp_str = resp_rx.await.map_err(|e| Unreachable::new(&RouterError(e.to_string())))?; + tracing::debug!( + "resp from: node={}, group={}, path={}, resp={}", + to_node, + to_group, + path, + resp_str + ); + + let res = decode::>(&resp_str); + res.map_err(|e| Unreachable::new(&RouterError(e.to_string()))) + } + + pub fn has_node(&self, node_id: NodeId) -> bool { + let nodes = self.nodes.lock().unwrap(); + nodes.contains_key(&node_id) + } +} diff --git a/examples/multi-raft-kv/src/store.rs b/examples/multi-raft-kv/src/store.rs new file mode 100644 index 000000000..4f122fcb5 --- /dev/null +++ b/examples/multi-raft-kv/src/store.rs @@ -0,0 +1,213 @@ +use std::collections::BTreeMap; +use std::fmt; +use std::fmt::Debug; +use std::io; +use std::sync::Arc; +use std::sync::Mutex; + +use futures::Stream; +use futures::TryStreamExt; +use openraft::storage::EntryResponder; +use openraft::storage::RaftStateMachine; +use openraft::EntryPayload; +use openraft::OptionalSend; +use openraft::RaftSnapshotBuilder; +use serde::Deserialize; +use serde::Serialize; + +use crate::typ::*; +use crate::TypeConfig; + +pub type LogStore = mem_log::LogStore; + +/// Application request type +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum Request { + Set { key: String, value: String }, +} + +impl fmt::Display for Request { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Request::Set { key, value } => write!(f, "Set {{ key: {}, value: {} }}", key, value), + } + } +} + +impl Request { + pub fn set(key: impl ToString, value: impl ToString) -> Self { + Self::Set { + key: key.to_string(), + value: value.to_string(), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Response { + pub value: Option, +} + +#[derive(Debug)] +pub struct StoredSnapshot { + pub meta: SnapshotMeta, + + /// The data of the state machine at the time of this snapshot. + pub data: SnapshotData, +} + +#[derive(Serialize, Deserialize, Debug, Default, Clone)] +pub struct StateMachineData { + pub last_applied: Option, + + pub last_membership: StoredMembership, + + /// Application data - the key-value store + pub data: BTreeMap, +} + +/// Each group may have its own independent state machine instance. +#[derive(Debug, Default)] +pub struct StateMachineStore { + pub state_machine: tokio::sync::Mutex, + + snapshot_idx: Mutex, + + /// The last received snapshot. + current_snapshot: Mutex>, +} + +impl RaftSnapshotBuilder for Arc { + #[tracing::instrument(level = "trace", skip(self))] + async fn build_snapshot(&mut self) -> Result { + let data; + let last_applied_log; + let last_membership; + + { + // Serialize the data of the state machine. + let state_machine = self.state_machine.lock().await.clone(); + + last_applied_log = state_machine.last_applied; + last_membership = state_machine.last_membership.clone(); + data = state_machine; + } + + let snapshot_idx = { + let mut l = self.snapshot_idx.lock().unwrap(); + *l += 1; + *l + }; + + let snapshot_id = if let Some(last) = last_applied_log { + format!("{}-{}-{}", last.committed_leader_id(), last.index(), snapshot_idx) + } else { + format!("--{}", snapshot_idx) + }; + + let meta = SnapshotMeta { + last_log_id: last_applied_log, + last_membership, + snapshot_id, + }; + + let snapshot = StoredSnapshot { + meta: meta.clone(), + data: data.clone(), + }; + + { + let mut current_snapshot = self.current_snapshot.lock().unwrap(); + *current_snapshot = Some(snapshot); + } + + Ok(Snapshot { meta, snapshot: data }) + } +} + +impl RaftStateMachine for Arc { + type SnapshotBuilder = Self; + + async fn applied_state(&mut self) -> Result<(Option, StoredMembership), io::Error> { + let state_machine = self.state_machine.lock().await; + Ok((state_machine.last_applied, state_machine.last_membership.clone())) + } + + #[tracing::instrument(level = "trace", skip(self, entries))] + async fn apply(&mut self, mut entries: Strm) -> Result<(), io::Error> + where Strm: Stream, io::Error>> + Unpin + OptionalSend { + let mut sm = self.state_machine.lock().await; + + while let Some((entry, responder)) = entries.try_next().await? { + tracing::debug!(%entry.log_id, "replicate to sm"); + + sm.last_applied = Some(entry.log_id); + + let response = match entry.payload { + EntryPayload::Blank => Response { value: None }, + EntryPayload::Normal(ref req) => match req { + Request::Set { key, value, .. } => { + sm.data.insert(key.clone(), value.clone()); + Response { + value: Some(value.clone()), + } + } + }, + EntryPayload::Membership(ref mem) => { + sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone()); + Response { value: None } + } + }; + + if let Some(responder) = responder { + responder.send(response); + } + } + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn begin_receiving_snapshot(&mut self) -> Result { + Ok(Default::default()) + } + + #[tracing::instrument(level = "trace", skip(self, snapshot))] + async fn install_snapshot(&mut self, meta: &SnapshotMeta, snapshot: SnapshotData) -> Result<(), io::Error> { + tracing::info!("install snapshot"); + + let new_snapshot = StoredSnapshot { + meta: meta.clone(), + data: snapshot, + }; + + // Update the state machine. + { + let updated_state_machine: StateMachineData = new_snapshot.data.clone(); + let mut state_machine = self.state_machine.lock().await; + *state_machine = updated_state_machine; + } + + // Update current snapshot. + let mut current_snapshot = self.current_snapshot.lock().unwrap(); + *current_snapshot = Some(new_snapshot); + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn get_current_snapshot(&mut self) -> Result, io::Error> { + match &*self.current_snapshot.lock().unwrap() { + Some(snapshot) => { + let data = snapshot.data.clone(); + Ok(Some(Snapshot { + meta: snapshot.meta.clone(), + snapshot: data, + })) + } + None => Ok(None), + } + } + + async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder { + self.clone() + } +} diff --git a/examples/multi-raft-kv/tests/cluster/main.rs b/examples/multi-raft-kv/tests/cluster/main.rs new file mode 100644 index 000000000..42a4d90cd --- /dev/null +++ b/examples/multi-raft-kv/tests/cluster/main.rs @@ -0,0 +1 @@ +mod test_cluster; diff --git a/examples/multi-raft-kv/tests/cluster/test_cluster.rs b/examples/multi-raft-kv/tests/cluster/test_cluster.rs new file mode 100644 index 000000000..c5294dfba --- /dev/null +++ b/examples/multi-raft-kv/tests/cluster/test_cluster.rs @@ -0,0 +1,250 @@ +//! Integration test for Multi-Raft KV store with 3 groups. +//! +//! This test demonstrates the TRUE Multi-Raft pattern: +//! - Each Node has ONE shared connection (not per-group connections) +//! - Multiple Raft groups share this connection +//! - Messages are routed to the correct group based on group_id + +use std::backtrace::Backtrace; +use std::collections::BTreeMap; +use std::panic::PanicHookInfo; +use std::time::Duration; + +use multi_raft_kv::create_node; +use multi_raft_kv::groups; +use multi_raft_kv::router::Router; +use multi_raft_kv::store::Request; +use multi_raft_kv::typ; +use multi_raft_kv::GroupId; +use openraft::BasicNode; +use tokio::task; +use tokio::task::LocalSet; +use tracing_subscriber::EnvFilter; + +pub fn log_panic(panic: &PanicHookInfo) { + let backtrace = format!("{:?}", Backtrace::force_capture()); + + eprintln!("{}", panic); + + if let Some(location) = panic.location() { + tracing::error!( + message = %panic, + backtrace = %backtrace, + panic.file = location.file(), + panic.line = location.line(), + panic.column = location.column(), + ); + eprintln!("{}:{}:{}", location.file(), location.line(), location.column()); + } else { + tracing::error!(message = %panic, backtrace = %backtrace); + } + + eprintln!("{}", backtrace); +} + +/// Test Multi-Raft cluster with 3 groups and 2 nodes. +#[tokio::test] +async fn test_multi_raft_cluster() { + std::panic::set_hook(Box::new(|panic| { + log_panic(panic); + })); + + tracing_subscriber::fmt() + .with_target(true) + .with_thread_ids(true) + .with_level(true) + .with_ansi(false) + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + // Shared router - this is where connection sharing happens + let router = Router::new(); + let group_ids = groups::all(); + + let local = LocalSet::new(); + + // Create nodes - each node has ONE connection, multiple groups + let node1 = create_node(1, &group_ids, router.clone()).await; + let node2 = create_node(2, &group_ids, router.clone()).await; + + // Get Raft handles before moving nodes into tasks + let node1_rafts: Vec<_> = group_ids.iter().map(|g| node1.get_raft(g).unwrap().clone()).collect(); + let node2_rafts: Vec<_> = group_ids.iter().map(|g| node2.get_raft(g).unwrap().clone()).collect(); + + local + .run_until(async move { + // Spawn node message handlers (one per node, not per group!) + task::spawn_local(node1.run()); + task::spawn_local(node2.run()); + + run_test(&node1_rafts, &node2_rafts, &group_ids).await; + }) + .await; +} + +async fn run_test(node1_rafts: &[typ::Raft], node2_rafts: &[typ::Raft], group_ids: &[GroupId]) { + // Wait for servers to start up + tokio::time::sleep(Duration::from_millis(200)).await; + + println!("\n╔════════════════════════════════════════════════════════════════════╗"); + println!("║ Multi-Raft Test: 3 groups, 2 nodes, CONNECTION SHARING ║"); + println!("╚════════════════════════════════════════════════════════════════════╝\n"); + + // ========================================================================= + // Initialize each group with node 1 as leader + // ========================================================================= + println!("=== Initializing 3 Raft groups (all on Node 1) ===\n"); + + for (i, raft) in node1_rafts.iter().enumerate() { + let mut nodes = BTreeMap::new(); + nodes.insert(1u64, BasicNode { addr: "".to_string() }); + raft.initialize(nodes).await.unwrap(); + println!(" ✓ Group '{}' initialized on Node 1", group_ids[i]); + } + + tokio::time::sleep(Duration::from_millis(500)).await; + + // ========================================================================= + // Add Node 2 as learner for each group + // ========================================================================= + println!("\n=== Adding Node 2 as learner to all groups ===\n"); + + for (i, raft) in node1_rafts.iter().enumerate() { + let node = BasicNode { addr: "".to_string() }; + raft.add_learner(2, node, true).await.unwrap(); + println!(" ✓ Group '{}': Node 2 added as learner", group_ids[i]); + } + + tokio::time::sleep(Duration::from_millis(500)).await; + + // ========================================================================= + // Write data to each group + // ========================================================================= + println!("\n=== Writing data to each group ===\n"); + + // users group + node1_rafts[0].client_write(Request::set("user:1", "Alice")).await.unwrap(); + node1_rafts[0].client_write(Request::set("user:2", "Bob")).await.unwrap(); + println!(" ✓ Group 'users': wrote user:1=Alice, user:2=Bob"); + + // orders group + node1_rafts[1].client_write(Request::set("order:1001", "pending")).await.unwrap(); + node1_rafts[1].client_write(Request::set("order:1002", "shipped")).await.unwrap(); + println!(" ✓ Group 'orders': wrote order:1001=pending, order:1002=shipped"); + + // products group + node1_rafts[2].client_write(Request::set("product:A", "Widget")).await.unwrap(); + node1_rafts[2].client_write(Request::set("product:B", "Gadget")).await.unwrap(); + println!(" ✓ Group 'products': wrote product:A=Widget, product:B=Gadget"); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // ========================================================================= + // Verify replication + // ========================================================================= + println!("\n=== Verifying replication to Node 2 ===\n"); + + for (i, raft) in node2_rafts.iter().enumerate() { + let metrics = raft.metrics().borrow().clone(); + println!( + " Group '{}' on Node 2: last_applied={:?}", + group_ids[i], metrics.last_applied + ); + assert!( + metrics.last_applied.is_some(), + "Group {} should have applied logs", + group_ids[i] + ); + } +} + +// ============================================================================ +// Test: Leader Distribution using transfer_leader +// ============================================================================ + +/// Test that demonstrates using transfer_leader to distribute leaders. +#[tokio::test] +async fn test_leader_distribution() { + let router = Router::new(); + let group_ids = groups::all(); + + let local = LocalSet::new(); + + // Create 3 nodes + let node1 = create_node(1, &group_ids, router.clone()).await; + let node2 = create_node(2, &group_ids, router.clone()).await; + let node3 = create_node(3, &group_ids, router.clone()).await; + + let node1_rafts: Vec<_> = group_ids.iter().map(|g| node1.get_raft(g).unwrap().clone()).collect(); + let node2_rafts: Vec<_> = group_ids.iter().map(|g| node2.get_raft(g).unwrap().clone()).collect(); + let node3_rafts: Vec<_> = group_ids.iter().map(|g| node3.get_raft(g).unwrap().clone()).collect(); + + local + .run_until(async move { + task::spawn_local(node1.run()); + task::spawn_local(node2.run()); + task::spawn_local(node3.run()); + + run_leader_distribution_test(&node1_rafts, &node2_rafts, &node3_rafts, &group_ids).await; + }) + .await; +} + +async fn run_leader_distribution_test( + node1_rafts: &[typ::Raft], + node2_rafts: &[typ::Raft], + node3_rafts: &[typ::Raft], + group_ids: &[GroupId], +) { + tokio::time::sleep(Duration::from_millis(200)).await; + + println!("\n╔════════════════════════════════════════════════════════════════════╗"); + println!("║ Leader Distribution Test using transfer_leader ║"); + println!("╚════════════════════════════════════════════════════════════════════╝\n"); + + // Initialize all groups on Node 1 with all 3 nodes as voters + println!("=== Initializing all groups with 3 voters ===\n"); + + let all_nodes = { + let mut nodes = BTreeMap::new(); + nodes.insert(1u64, BasicNode { addr: "".to_string() }); + nodes.insert(2u64, BasicNode { addr: "".to_string() }); + nodes.insert(3u64, BasicNode { addr: "".to_string() }); + nodes + }; + + for (i, raft) in node1_rafts.iter().enumerate() { + raft.initialize(all_nodes.clone()).await.unwrap(); + println!(" ✓ Group '{}' initialized (voters: 1, 2, 3)", group_ids[i]); + } + + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Transfer leaders to distribute load + println!("\n=== Using transfer_leader to distribute leaders ===\n"); + + // orders -> Node 2 + println!(" → Transferring 'orders' leader to Node 2..."); + node1_rafts[1].trigger().transfer_leader(2).await.unwrap(); + tokio::time::sleep(Duration::from_millis(1000)).await; + + // products -> Node 3 + println!(" → Transferring 'products' leader to Node 3..."); + node1_rafts[2].trigger().transfer_leader(3).await.unwrap(); + tokio::time::sleep(Duration::from_millis(1000)).await; + + // Verify distribution + println!("\n=== Verifying leader distribution ===\n"); + + let users_leader = node1_rafts[0].metrics().borrow().current_leader; + let orders_leader = node2_rafts[1].metrics().borrow().current_leader; + let products_leader = node3_rafts[2].metrics().borrow().current_leader; + + println!(" Group 'users': leader = {:?}", users_leader); + println!(" Group 'orders': leader = {:?}", orders_leader); + println!(" Group 'products': leader = {:?}", products_leader); + + assert_eq!(users_leader, Some(1), "users leader should be Node 1"); + assert_eq!(orders_leader, Some(2), "orders leader should be Node 2"); + assert_eq!(products_leader, Some(3), "products leader should be Node 3"); +} diff --git a/examples/multi-raft-sharding/Cargo.toml b/examples/multi-raft-sharding/Cargo.toml new file mode 100644 index 000000000..3c9254113 --- /dev/null +++ b/examples/multi-raft-sharding/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "multi-raft-sharding" +version = "0.1.0" +readme = "README.md" + +edition = "2021" +authors = [ + "AriesDevil ", +] +categories = ["algorithms", "asynchronous", "data-structures"] +description = "Multi-Raft sharding example with TiKV-style split support" +homepage = "https://github.com/databendlabs/openraft" +keywords = ["raft", "consensus", "multi-raft", "sharding", "split"] +license = "MIT OR Apache-2.0" +repository = "https://github.com/databendlabs/openraft" + +[dependencies] +mem-log = { path = "../mem-log", features = [] } +openraft = { path = "../../openraft", default-features = false, features = ["serde", "type-alias", "tokio-rt"] } +openraft-multi = { path = "../../multiraft" } + +futures = { version = "0.3" } +serde = { version = "1", features = ["derive"] } +serde_json = { version = "1" } +tokio = { version = "1", default-features = false, features = ["sync"] } +tracing = { version = "0.1" } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[features] + +[package.metadata.docs.rs] +all-features = true + diff --git a/examples/multi-raft-sharding/README.md b/examples/multi-raft-sharding/README.md new file mode 100644 index 000000000..84eb2eb5f --- /dev/null +++ b/examples/multi-raft-sharding/README.md @@ -0,0 +1,131 @@ +# Multi-Raft Sharding Example with TiKV-style Split + +This example demonstrates how to implement **dynamic range sharding** with TiKV-style split using OpenRaft's Multi-Raft support. + +## Overview + +The example shows how to: +0. Suppose we have a "very hot" shard +1. We want to split it into two shards (atomic, no service interruption needed) +2. Migrate the new shard to dedicated nodes +3. Remove the migrated shard from original nodes + +This is the same pattern used by TiKV for Region splits. + +## Architecture + +``` +Phase 1: Initial State (3 nodes, 1 shard) +┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│ Node 1 │ │ Node 2 │ │ Node 3 │ +│ shard_a: ★ │ │ shard_a: F │ │ shard_a: F │ +│ [1..200] │ │ [1..200] │ │ [1..200] │ +└─────────────┘ └─────────────┘ └─────────────┘ + +Phase 2: After Split (3 nodes, 2 shards) +┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│ Node 1 │ │ Node 2 │ │ Node 3 │ +│ shard_a: ★ │ │ shard_a: F │ │ shard_a: F │ +│ [1..100] │ │ [1..100] │ │ [1..100] │ +│ shard_b: ★ │ │ shard_b: F │ │ shard_b: F │ +│ [101..200] │ │ [101..200] │ │ [101..200] │ +└─────────────┘ └─────────────┘ └─────────────┘ + +Phase 3: Add Nodes 4,5 to shard_b +┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ +│ Node 1 │ │ Node 2 │ │ Node 3 │ │ Node 4 │ │ Node 5 │ +│ shard_a:★ │ │ shard_a:F │ │ shard_a:F │ │ - │ │ - │ +│ shard_b:★ │ │ shard_b:F │ │ shard_b:F │ │ shard_b:L │ │ shard_b:L │ +└───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘ + +Phase 4: Migrate shard_b to Nodes 4,5 +┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ +│ Node 1 │ │ Node 2 │ │ Node 3 │ │ Node 4 │ │ Node 5 │ +│ shard_a:★ │ │ shard_a:F │ │ shard_a:F │ │ - │ │ - │ +│ - │ │ - │ │ - │ │ shard_b:★ │ │ shard_b:F │ +└───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘ + +* is leader +F is follower +L is learner +``` + +## How Split Works + +### Step 1: Propose Split as Raft Log + +```rust +let split_request = Request::Split { + split_at: 100, // Split point + new_shard_id: "shard_b", // New shard ID +}; +raft.client_write(split_request).await?; +``` + +### Step 2: State Machine Executes Split + +When the split log is applied, each replica: + +```rust +// In state machine apply(): +Request::Split { split_at, new_shard_id } => { + // 1. Extract data for new shard + let split_data = self.keys_greater_than(*split_at); + + // 2. Remove extracted data from current shard + self.remove_keys_greater_than(*split_at); + + // 3. Return split data for bootstrapping new shard + Response::SplitComplete { + new_shard_id, + split_data, + } +} +``` + +### Step 3: Bootstrap New Shard + +The split response contains the data for the new shard: + +```rust +// Create new shard with split data +let state_machine = StateMachineStore::with_initial_data(split_data); +let raft_b = Raft::new(node_id, config, network, log_store, state_machine).await?; +``` + +## Running the Test + +```bash +# Run the split test +cargo test -p multi-raft-sharding test_shard_split -- --nocapture + +# With debug logging +RUST_LOG=debug cargo test -p multi-raft-sharding test_shard_split -- --nocapture +``` + +## Code Structure + +``` +multi-raft-sharding/ +├── src/ +│ ├── lib.rs # Type definitions, new_raft() +│ ├── store.rs # State machine with Split support +│ ├── shard_router.rs # Routes keys to shards by range +│ ├── router.rs # Network message routing +│ ├── network.rs # Raft network implementation +│ ├── app.rs # Application message handler +│ └── api.rs # API handlers +└── tests/ + └── cluster/ + └── test_split.rs # Four-phase split test +``` + +## Use Cases + +This pattern is useful for: + +1. **Horizontal Scaling**: Add nodes for hot shards +2. **Load Balancing**: Distribute shards across nodes +3. **Data Locality**: Move shards closer to users +4. **Resource Isolation**: Dedicate nodes to critical shards + diff --git a/examples/multi-raft-sharding/src/api.rs b/examples/multi-raft-sharding/src/api.rs new file mode 100644 index 000000000..17ee21804 --- /dev/null +++ b/examples/multi-raft-sharding/src/api.rs @@ -0,0 +1,148 @@ +use std::collections::BTreeMap; +use std::collections::BTreeSet; + +use openraft::BasicNode; + +use crate::app::App; +use crate::decode; +use crate::encode; +use crate::store::Request; +use crate::typ::*; +use crate::NodeId; + +// ============================================================================= +// Application API +// ============================================================================= + +/// Write a request to the shard. +/// +/// This handles all write operations including: +/// - Set: Store a key-value pair +/// - Delete: Remove a key +/// - Split: Split the shard (TiKV-style) +pub async fn write(app: &mut App, req: String) -> String { + let request: Request = decode(&req); + + tracing::debug!( + shard_id = %app.shard_id, + request = %request, + "processing write request" + ); + + let res = app.raft.client_write(request).await; + encode(res) +} + +/// Read a key from the shard using linearizable read. +/// +/// This ensures the read sees all committed writes by: +/// 1. Getting a read linearizer +/// 2. Waiting for it to be ready (confirms leadership) +/// 3. Reading from the state machine +pub async fn read(app: &mut App, req: String) -> String { + use openraft::ReadPolicy; + + let key: String = decode(&req); + + let ret = app.raft.ensure_linearizable(ReadPolicy::ReadIndex).await; + + let res: Result, RaftError> = match ret { + Ok(_) => { + let state_machine = app.state_machine.state_machine.lock().await; + let value = state_machine.data.get(&key).cloned(); + Ok(value) + } + Err(e) => Err(RaftError::Fatal(e.into_fatal().unwrap())), + }; + encode(res) +} + +// ============================================================================= +// Raft Protocol API +// ============================================================================= + +/// Handle Vote RPC from other nodes. +pub async fn vote(app: &mut App, req: String) -> String { + let res = app.raft.vote(decode(&req)).await; + encode(res) +} + +/// Handle AppendEntries RPC from the leader. +pub async fn append(app: &mut App, req: String) -> String { + let res = app.raft.append_entries(decode(&req)).await; + encode(res) +} + +/// Handle snapshot installation from the leader. +pub async fn snapshot(app: &mut App, req: String) -> String { + let (vote, snapshot_meta, snapshot_data): (Vote, SnapshotMeta, SnapshotData) = decode(&req); + let snapshot = Snapshot { + meta: snapshot_meta, + snapshot: snapshot_data, + }; + let res = app.raft.install_full_snapshot(vote, snapshot).await.map_err(RaftError::::Fatal); + encode(res) +} + +// ============================================================================= +// Management API +// ============================================================================= + +pub async fn add_learner(app: &mut App, req: String) -> String { + let node_id: NodeId = decode(&req); + let node = BasicNode { addr: "".to_string() }; + + tracing::info!( + shard_id = %app.shard_id, + learner_id = %node_id, + "adding learner to shard" + ); + + let res = app.raft.add_learner(node_id, node, true).await; + encode(res) +} + +/// Change the membership of this shard. +/// +/// This is used to: +/// - Promote learners to voters +/// - Remove nodes from the shard +/// +/// # Arguments +/// * `req` - JSON-encoded BTreeSet of new member IDs +pub async fn change_membership(app: &mut App, req: String) -> String { + let node_ids: BTreeSet = decode(&req); + + tracing::info!( + shard_id = %app.shard_id, + new_members = ?node_ids, + "changing shard membership" + ); + + let res = app.raft.change_membership(node_ids, false).await; + encode(res) +} + +/// Initialize a single-node cluster for this shard. +/// +/// This is called on the first node to bootstrap the shard. +pub async fn init(app: &mut App) -> String { + let mut nodes = BTreeMap::new(); + nodes.insert(app.node_id, BasicNode { addr: "".to_string() }); + + tracing::info!( + shard_id = %app.shard_id, + node_id = %app.node_id, + "initializing shard" + ); + + let res = app.raft.initialize(nodes).await; + encode(res) +} + +/// Get the current metrics for this shard. +pub async fn metrics(app: &mut App) -> String { + let metrics = app.raft.metrics().borrow().clone(); + let res: Result = Ok(metrics); + encode(res) +} diff --git a/examples/multi-raft-sharding/src/app.rs b/examples/multi-raft-sharding/src/app.rs new file mode 100644 index 000000000..028ae7029 --- /dev/null +++ b/examples/multi-raft-sharding/src/app.rs @@ -0,0 +1,103 @@ +use std::sync::Arc; + +use tokio::sync::mpsc; +use tokio::sync::oneshot; + +use crate::api; +use crate::router::Router; +use crate::typ; +use crate::NodeId; +use crate::ShardId; +use crate::StateMachineStore; + +/// Type alias for the request path (e.g., "/raft/vote"). +pub type Path = String; + +/// Type alias for the serialized request payload. +pub type Payload = String; + +/// Type alias for the response channel. +pub type ResponseTx = oneshot::Sender; + +/// Type alias for the request channel sender. +pub type RequestTx = mpsc::UnboundedSender<(Path, Payload, ResponseTx)>; + +/// Application handler for a single shard on a single node. +/// +/// Each App instance handles: +/// - Raft protocol RPCs (vote, append_entries, snapshot) +/// - Client application requests (read, write) +/// - Management operations (add_learner, change_membership) +pub struct App { + pub node_id: NodeId, + + pub shard_id: ShardId, + + pub raft: typ::Raft, + + pub rx: mpsc::UnboundedReceiver<(Path, Payload, ResponseTx)>, + + /// The shared message router. + pub router: Router, + + /// The state machine store for this shard. + pub state_machine: Arc, +} + +impl App { + /// Create a new App instance and register it with the router. + pub fn new( + node_id: NodeId, + shard_id: ShardId, + raft: typ::Raft, + router: Router, + state_machine: Arc, + ) -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + + // Register this app with the router so it can receive messages. + router.register(node_id, shard_id.clone(), tx); + + Self { + node_id, + shard_id, + raft, + rx, + router, + state_machine, + } + } + + pub async fn run(mut self) -> Option<()> { + tracing::info!( + node_id = %self.node_id, + shard_id = %self.shard_id, + "starting app event loop" + ); + + loop { + let (path, payload, response_tx) = self.rx.recv().await?; + + let res = match path.as_str() { + // Application API + "/app/write" => api::write(&mut self, payload).await, + "/app/read" => api::read(&mut self, payload).await, + + // Raft Protocol API + "/raft/append" => api::append(&mut self, payload).await, + "/raft/snapshot" => api::snapshot(&mut self, payload).await, + "/raft/vote" => api::vote(&mut self, payload).await, + + // Management API + "/mng/add-learner" => api::add_learner(&mut self, payload).await, + "/mng/change-membership" => api::change_membership(&mut self, payload).await, + "/mng/init" => api::init(&mut self).await, + "/mng/metrics" => api::metrics(&mut self).await, + + _ => panic!("unknown path: {}", path), + }; + + let _ = response_tx.send(res); + } + } +} diff --git a/examples/multi-raft-sharding/src/lib.rs b/examples/multi-raft-sharding/src/lib.rs new file mode 100644 index 000000000..9b7dfdfc1 --- /dev/null +++ b/examples/multi-raft-sharding/src/lib.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; + +use openraft::Config; + +use crate::app::App; +use crate::router::Router; +use crate::store::Request; +use crate::store::Response; +use crate::store::StateMachineData; + +pub mod router; +pub mod shard_router; + +pub mod api; +pub mod app; +pub mod network; +pub mod store; + +/// Node ID type - identifies a physical node in the cluster. +pub type NodeId = u64; + +/// Shard ID type - identifies a Raft group (shard). +/// +/// In a sharded system, each shard is responsible for a range of keys. +/// The shard ID is used to route requests to the correct Raft group. +pub type ShardId = String; + +openraft::declare_raft_types!( + /// Type configuration for the sharded KV store. + /// + /// This configuration uses: + /// - `Request`: Application requests including Set, Delete, and Split operations + /// - `Response`: Application responses including split completion data + /// - `StateMachineData`: The state machine snapshot format + pub TypeConfig: + D = Request, + R = Response, + SnapshotData = StateMachineData, +); + +pub type LogStore = store::LogStore; +pub type StateMachineStore = store::StateMachineStore; + +/// Define all Raft-related type aliases. +#[path = "../../utils/declare_types.rs"] +pub mod typ; + +/// Shard naming constants. +pub mod shards { + /// The initial shard that contains all data. + pub const SHARD_A: &str = "shard_a"; + + /// The shard created after split (contains user_id > split_point). + pub const SHARD_B: &str = "shard_b"; +} + +pub fn encode(t: T) -> String { + serde_json::to_string(&t).unwrap() +} + +pub fn decode(s: &str) -> T { + serde_json::from_str(s).unwrap() +} + +/// Create a new Raft instance for a specific shard on a node. +pub async fn new_raft(node_id: NodeId, shard_id: ShardId, router: Router) -> (typ::Raft, App) { + let config = Config { + heartbeat_interval: 500, + election_timeout_min: 1500, + election_timeout_max: 3000, + max_in_snapshot_log_to_keep: 0, + ..Default::default() + }; + + let config = Arc::new(config.validate().unwrap()); + + let log_store = LogStore::default(); + + let state_machine_store = Arc::new(StateMachineStore::default()); + + let network = network::ShardNetworkFactory::new(router.clone(), shard_id.clone()); + + let raft = openraft::Raft::new(node_id, config, network, log_store, state_machine_store.clone()).await.unwrap(); + + let app = App::new(node_id, shard_id, raft.clone(), router, state_machine_store); + + (raft, app) +} diff --git a/examples/multi-raft-sharding/src/network.rs b/examples/multi-raft-sharding/src/network.rs new file mode 100644 index 000000000..231b2882e --- /dev/null +++ b/examples/multi-raft-sharding/src/network.rs @@ -0,0 +1,81 @@ +use std::future::Future; + +use openraft::error::RPCError; +use openraft::error::ReplicationClosed; +use openraft::error::StreamingError; +use openraft::network::Backoff; +use openraft::network::RPCOption; +use openraft::network::RaftNetworkFactory; +use openraft::raft::AppendEntriesRequest; +use openraft::raft::AppendEntriesResponse; +use openraft::raft::SnapshotResponse; +use openraft::raft::VoteRequest; +use openraft::raft::VoteResponse; +use openraft::storage::Snapshot; +use openraft::OptionalSend; +use openraft_multi::GroupNetworkAdapter; +use openraft_multi::GroupNetworkFactory; +use openraft_multi::GroupRouter; + +use crate::router::Router; +use crate::typ; +use crate::NodeId; +use crate::ShardId; +use crate::TypeConfig; + +impl GroupRouter for Router { + async fn send_append_entries( + &self, + target: NodeId, + shard_id: ShardId, + rpc: AppendEntriesRequest, + _option: RPCOption, + ) -> Result, RPCError> { + self.send(target, &shard_id, "/raft/append", rpc).await.map_err(RPCError::Unreachable) + } + + async fn send_vote( + &self, + target: NodeId, + shard_id: ShardId, + rpc: VoteRequest, + _option: RPCOption, + ) -> Result, RPCError> { + self.send(target, &shard_id, "/raft/vote", rpc).await.map_err(RPCError::Unreachable) + } + + async fn send_snapshot( + &self, + target: NodeId, + shard_id: ShardId, + vote: typ::Vote, + snapshot: Snapshot, + _cancel: impl Future + OptionalSend + 'static, + _option: RPCOption, + ) -> Result, StreamingError> { + self.send( + target, + &shard_id, + "/raft/snapshot", + (vote, snapshot.meta, snapshot.snapshot), + ) + .await + .map_err(StreamingError::Unreachable) + } + + fn backoff(&self) -> Backoff { + Backoff::new(std::iter::repeat(std::time::Duration::from_millis(500))) + } +} + +/// Shard network factory that creates `GroupNetworkAdapter` instances. +pub type ShardNetworkFactory = GroupNetworkFactory; + +impl RaftNetworkFactory for ShardNetworkFactory { + /// The network type is `GroupNetworkAdapter` binding (Router, target, shard_id). + type Network = GroupNetworkAdapter; + + async fn new_client(&mut self, target: NodeId, _node: &openraft::BasicNode) -> Self::Network { + GroupNetworkAdapter::new(self.factory.clone(), target, self.group_id.clone()) + } +} diff --git a/examples/multi-raft-sharding/src/router.rs b/examples/multi-raft-sharding/src/router.rs new file mode 100644 index 000000000..691296b6f --- /dev/null +++ b/examples/multi-raft-sharding/src/router.rs @@ -0,0 +1,143 @@ +use std::collections::BTreeMap; +use std::fmt; +use std::sync::Arc; +use std::sync::Mutex; + +use openraft::error::Unreachable; +use tokio::sync::oneshot; + +use crate::app::RequestTx; +use crate::decode; +use crate::encode; +use crate::typ::RaftError; +use crate::NodeId; +use crate::ShardId; + +#[derive(Debug)] +pub struct RouterError(pub String); + +impl fmt::Display for RouterError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for RouterError {} + +/// Unique identifier for a Raft instance (node + shard). +/// +/// In Multi-Raft, a single physical node can host multiple Raft instances, +/// one for each shard. This key uniquely identifies each instance. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct RouterKey { + /// The physical node ID. + pub node_id: NodeId, + /// The shard (Raft group) ID. + pub shard_id: ShardId, +} + +impl RouterKey { + pub fn new(node_id: NodeId, shard_id: ShardId) -> Self { + Self { node_id, shard_id } + } +} + +/// Message router for Multi-Raft communication. +/// +/// This router manages the mapping from (node_id, shard_id) to message channels, +/// allowing Raft RPCs to be delivered to the correct shard on each node. +#[derive(Debug, Clone, Default)] +pub struct Router { + pub targets: Arc>>, +} + +impl Router { + /// Create a new router. + pub fn new() -> Self { + Self { + targets: Arc::new(Mutex::new(BTreeMap::new())), + } + } + + /// Register a message handler for a specific (node, shard) pair. + pub fn register(&self, node_id: NodeId, shard_id: ShardId, tx: RequestTx) { + let key = RouterKey::new(node_id, shard_id); + let mut targets = self.targets.lock().unwrap(); + targets.insert(key, tx); + } + + /// Unregister a message handler. + pub fn unregister(&self, node_id: NodeId, shard_id: &ShardId) -> Option { + let key = RouterKey::new(node_id, shard_id.clone()); + let mut targets = self.targets.lock().unwrap(); + targets.remove(&key) + } + + /// Send a request to a specific (node, shard) pair and wait for response. + pub async fn send( + &self, + to_node: NodeId, + to_shard: &ShardId, + path: &str, + req: Req, + ) -> Result + where + Req: serde::Serialize, + Result: serde::de::DeserializeOwned, + { + let (resp_tx, resp_rx) = oneshot::channel(); + + let encoded_req = encode(&req); + tracing::trace!( + to_node = %to_node, + to_shard = %to_shard, + path = %path, + "sending request" + ); + + { + let key = RouterKey::new(to_node, to_shard.clone()); + let targets = self.targets.lock().unwrap(); + let tx = targets.get(&key).ok_or_else(|| { + Unreachable::new(&RouterError(format!( + "target not found: node={}, shard={}", + to_node, to_shard + ))) + })?; + + tx.send((path.to_string(), encoded_req, resp_tx)) + .map_err(|e| Unreachable::new(&RouterError(e.to_string())))?; + } + + let resp_str = resp_rx.await.map_err(|e| Unreachable::new(&RouterError(e.to_string())))?; + + tracing::trace!( + from_node = %to_node, + from_shard = %to_shard, + path = %path, + "received response" + ); + + let res = decode::>(&resp_str); + res.map_err(|e| Unreachable::new(&RouterError(e.to_string()))) + } + + /// Check if a target is registered. + pub fn has_target(&self, node_id: NodeId, shard_id: &ShardId) -> bool { + let key = RouterKey::new(node_id, shard_id.clone()); + let targets = self.targets.lock().unwrap(); + targets.contains_key(&key) + } + + /// Get all registered targets. + pub fn all_targets(&self) -> Vec { + let targets = self.targets.lock().unwrap(); + targets.keys().cloned().collect() + } + + /// Get the number of registered targets. + pub fn target_count(&self) -> usize { + let targets = self.targets.lock().unwrap(); + targets.len() + } +} diff --git a/examples/multi-raft-sharding/src/shard_router.rs b/examples/multi-raft-sharding/src/shard_router.rs new file mode 100644 index 000000000..9a56ec11a --- /dev/null +++ b/examples/multi-raft-sharding/src/shard_router.rs @@ -0,0 +1,229 @@ +//! Shard Router - Routes requests to the correct shard based on key range. +//! +//! In a sharded system, each shard is responsible for a range of keys. The shard router +//! maintains the mapping from key ranges to shard IDs and routes requests accordingly. +//! +//! ## How It Works +//! +//! The router maintains a list of shards, each with: +//! - A shard ID +//! - A key range (start, end) that this shard is responsible for +//! +//! When a request comes in, the router: +//! 1. Extracts the user_id from the key +//! 2. Finds the shard whose range contains this user_id +//! 3. Returns the shard ID for routing +//! +//! ## Split Updates +//! +//! After a split operation, the router must be updated: +//! - The original shard's range is shrunk (end = split_point) +//! - A new shard entry is added (start = split_point + 1, end = old_end) +//! +//! ```text +//! Before Split: +//! ┌─────────────────────────────────────┐ +//! │ shard_a: [1, MAX] │ +//! └─────────────────────────────────────┘ +//! +//! After Split at 100: +//! ┌─────────────────────────────────────┐ +//! │ shard_a: [1, 100] │ +//! │ shard_b: [101, MAX] │ +//! └─────────────────────────────────────┘ +//! ``` + +use std::sync::RwLock; + +use crate::ShardId; + +/// A key range that a shard is responsible for. +/// +/// The range is inclusive on both ends: [start, end]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct KeyRange { + /// The start of the range (inclusive). + pub start: u64, + /// The end of the range (inclusive). + pub end: u64, +} + +impl KeyRange { + /// Create a new key range. + pub fn new(start: u64, end: u64) -> Self { + Self { start, end } + } + + /// Check if the range contains the given key. + pub fn contains(&self, key: u64) -> bool { + key >= self.start && key <= self.end + } +} + +/// A shard entry in the router. +#[derive(Debug, Clone)] +pub struct ShardEntry { + /// The shard ID. + pub shard_id: ShardId, + /// The key range this shard is responsible for. + pub range: KeyRange, +} + +/// Routes requests to the correct shard based on key range. +/// +/// This is the central routing component in a sharded system. It maintains +/// the mapping from key ranges to shards and must be kept in sync with +/// the actual shard topology. +/// +/// ## Thread Safety +/// +/// The router uses interior mutability (RwLock) to allow concurrent reads +/// and exclusive writes. This is important because: +/// - Reads (routing) are very frequent and should be fast +/// - Writes (split updates) are rare but must be atomic +#[derive(Debug, Default)] +pub struct ShardRouter { + /// The list of shards and their key ranges. + shards: RwLock>, +} + +impl ShardRouter { + /// Create a new empty shard router. + pub fn new() -> Self { + Self { + shards: RwLock::new(Vec::new()), + } + } + + /// Create a shard router with an initial shard covering all keys. + /// + /// This is typically used when starting a new cluster with a single shard. + pub fn with_initial_shard(shard_id: ShardId) -> Self { + let router = Self::new(); + router.add_shard(shard_id, KeyRange::new(1, u64::MAX)); + router + } + + /// Add a new shard with the given key range. + pub fn add_shard(&self, shard_id: ShardId, range: KeyRange) { + let mut shards = self.shards.write().unwrap(); + shards.push(ShardEntry { shard_id, range }); + } + + /// Route a key to the appropriate shard. + pub fn route(&self, user_id: u64) -> Option { + let shards = self.shards.read().unwrap(); + for entry in shards.iter() { + if entry.range.contains(user_id) { + return Some(entry.shard_id.clone()); + } + } + None + } + + /// Update the router after a split operation. + /// + /// This method: + /// 1. Shrinks the original shard's range to [start, split_at] + /// 2. Adds a new shard entry for [split_at + 1, old_end] + /// + /// # Arguments + /// * `original_shard` - The ID of the shard being split + /// * `split_at` - The split point (keys > split_at go to new shard) + /// * `new_shard_id` - The ID of the new shard + /// + /// # Returns + /// `true` if the split was applied, `false` if the original shard wasn't found. + pub fn apply_split(&self, original_shard: &str, split_at: u64, new_shard_id: ShardId) -> bool { + let mut shards = self.shards.write().unwrap(); + + // Find the original shard + let original_idx = shards.iter().position(|e| e.shard_id == original_shard); + + if let Some(idx) = original_idx { + let original_end = shards[idx].range.end; + + // Shrink the original shard's range + shards[idx].range.end = split_at; + + // Add the new shard + shards.push(ShardEntry { + shard_id: new_shard_id, + range: KeyRange::new(split_at + 1, original_end), + }); + + true + } else { + false + } + } + + /// Get all shard entries. + pub fn all_shards(&self) -> Vec { + let shards = self.shards.read().unwrap(); + shards.clone() + } + + /// Get the key range for a specific shard. + pub fn get_range(&self, shard_id: &str) -> Option { + let shards = self.shards.read().unwrap(); + shards.iter().find(|e| e.shard_id == shard_id).map(|e| e.range.clone()) + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_key_range_contains() { + let range = KeyRange::new(1, 100); + assert!(range.contains(1)); + assert!(range.contains(50)); + assert!(range.contains(100)); + assert!(!range.contains(0)); + assert!(!range.contains(101)); + } + + #[test] + fn test_initial_shard() { + let router = ShardRouter::with_initial_shard("shard_a".to_string()); + + assert_eq!(router.route(1), Some("shard_a".to_string())); + assert_eq!(router.route(1000), Some("shard_a".to_string())); + assert_eq!(router.route(u64::MAX), Some("shard_a".to_string())); + } + + #[test] + fn test_apply_split() { + let router = ShardRouter::with_initial_shard("shard_a".to_string()); + + // Split at 100 + assert!(router.apply_split("shard_a", 100, "shard_b".to_string())); + + // Verify routing + assert_eq!(router.route(1), Some("shard_a".to_string())); + assert_eq!(router.route(100), Some("shard_a".to_string())); + assert_eq!(router.route(101), Some("shard_b".to_string())); + assert_eq!(router.route(1000), Some("shard_b".to_string())); + + // Verify ranges + let range_a = router.get_range("shard_a").unwrap(); + assert_eq!(range_a.start, 1); + assert_eq!(range_a.end, 100); + + let range_b = router.get_range("shard_b").unwrap(); + assert_eq!(range_b.start, 101); + assert_eq!(range_b.end, u64::MAX); + } + + #[test] + fn test_split_nonexistent_shard() { + let router = ShardRouter::new(); + assert!(!router.apply_split("nonexistent", 100, "new_shard".to_string())); + } +} diff --git a/examples/multi-raft-sharding/src/store.rs b/examples/multi-raft-sharding/src/store.rs new file mode 100644 index 000000000..0004d70cf --- /dev/null +++ b/examples/multi-raft-sharding/src/store.rs @@ -0,0 +1,512 @@ +//! State machine storage with TiKV-style Split support. +//! +//! This module implements a state machine that supports: +//! - Normal key-value operations (Set, Delete) +//! - **Split operation**: Atomically splits a shard into two shards +//! +//! ## Split Mechanism +//! +//! The Split operation is implemented as a special Raft log entry. When the Split log +//! is applied to the state machine: +//! +//! 1. All keys > split_point are extracted from the current state machine +//! 2. These keys are packaged as the "initial state" for the new shard +//! 3. The extracted keys are removed from the current state machine +//! 4. The response contains the initial state for bootstrapping the new shard +//! +//! ```text +//! Before Split (shard_a): +//! ┌─────────────────────────────────────┐ +//! │ user:1=Alice user:50=Bob │ +//! │ user:101=Charlie user:150=Diana │ +//! └─────────────────────────────────────┘ +//! +//! Split at user_id=100 +//! +//! After Split: +//! ┌─────────────────────────────────────┐ +//! │ shard_a: user:1=Alice user:50=Bob │ (keys <= 100) +//! └─────────────────────────────────────┘ +//! ┌─────────────────────────────────────┐ +//! │ shard_b: user:101=Charlie │ (keys > 100, new shard) +//! │ user:150=Diana │ +//! └─────────────────────────────────────┘ +//! ``` +//! +//! ## Why This Design? +//! +//! By making Split a Raft log entry: +//! - **Atomicity**: All replicas execute split at the same log index +//! - **No Locks**: No distributed coordination needed during split +//! - **Consistency**: Split boundary is deterministic and identical on all replicas + +use std::collections::BTreeMap; +use std::fmt; +use std::fmt::Debug; +use std::io; +use std::sync::Arc; +use std::sync::Mutex; + +use futures::Stream; +use futures::TryStreamExt; +use openraft::storage::EntryResponder; +use openraft::storage::RaftStateMachine; +use openraft::EntryPayload; +use openraft::OptionalSend; +use openraft::RaftSnapshotBuilder; +use serde::Deserialize; +use serde::Serialize; + +use crate::typ::*; +use crate::ShardId; +use crate::TypeConfig; + +pub type LogStore = mem_log::LogStore; + +// ============================================================================= +// Request Types +// ============================================================================= + +/// These requests are proposed to Raft and applied to the state machine when committed. +/// The key design here is that `Split` is just another request type, making it atomic. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum Request { + Set { + key: String, + value: String, + }, + + Delete { + key: String, + }, + + /// Split the shard at the given user_id boundary. + /// + /// This is the TiKV-style split operation: + /// - All data with user_id > split_at will be moved to the new shard + /// - The split is atomic and consistent across all replicas + /// - After split, the new shard can be started on any node + /// + /// # Arguments + /// * `split_at` - The split boundary (exclusive). Keys with user_id > split_at go to new shard. + /// * `new_shard_id` - The ID for the new shard being created. + /// + /// # Example + /// ```ignore + /// // Split shard_a at user_id=100, creating shard_b for user_id > 100 + /// Request::Split { + /// split_at: 100, + /// new_shard_id: "shard_b".to_string(), + /// } + /// ``` + Split { + /// The split point. Keys with user_id > split_at go to the new shard. + split_at: u64, + /// The ID for the new shard. + new_shard_id: ShardId, + }, +} + +impl fmt::Display for Request { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Request::Set { key, value } => write!(f, "Set({} = {})", key, value), + Request::Delete { key } => write!(f, "Delete({})", key), + Request::Split { split_at, new_shard_id } => { + write!(f, "Split(at={}, new_shard={})", split_at, new_shard_id) + } + } + } +} + +impl Request { + pub fn set_user(user_id: u64, value: impl ToString) -> Self { + Self::Set { + key: format!("user:{}", user_id), + value: value.to_string(), + } + } + + /// Create a Split request. + /// + /// # Arguments + /// * `split_at` - User IDs greater than this value go to the new shard + /// * `new_shard_id` - The ID for the new shard + pub fn split(split_at: u64, new_shard_id: impl ToString) -> Self { + Self::Split { + split_at, + new_shard_id: new_shard_id.to_string(), + } + } +} + +// ============================================================================= +// Response Types +// ============================================================================= + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum Response { + Ok { + previous_value: Option, + }, + + /// Response for Split operation. + /// Contains all the data needed to bootstrap the new shard. + SplitComplete { + /// The ID of the new shard + new_shard_id: ShardId, + /// The data that was split off (to be used as initial state for new shard) + split_data: BTreeMap, + /// Number of keys that were split off + key_count: usize, + }, +} + +// ============================================================================= +// State Machine Data +// ============================================================================= + +#[derive(Serialize, Deserialize, Debug, Default, Clone)] +pub struct StateMachineData { + pub last_applied: Option, + + pub last_membership: StoredMembership, + + /// The key-value data store. + /// + /// Keys are in the format "user:{user_id}". + /// This is a BTreeMap for efficient range operations during split. + pub data: BTreeMap, +} + +impl StateMachineData { + /// Extract the user_id from a key string. + /// Returns None if the key doesn't match the expected format "user:{id}". + pub fn parse_user_id(key: &str) -> Option { + key.strip_prefix("user:").and_then(|s| s.parse().ok()) + } + + /// Get all keys with user_id greater than the given value. + /// This is used during split to extract data for the new shard. + pub fn keys_greater_than(&self, user_id: u64) -> BTreeMap { + self.data + .iter() + .filter(|(k, _)| Self::parse_user_id(k).map(|id| id > user_id).unwrap_or(false)) + .map(|(k, v)| (k.clone(), v.clone())) + .collect() + } + + /// Remove all keys with user_id greater than the given value. + /// This is used during split to clean up the original shard after data extraction. + pub fn remove_keys_greater_than(&mut self, user_id: u64) { + self.data.retain(|k, _| { + Self::parse_user_id(k).map(|id| id <= user_id).unwrap_or(true) // Keep non-user keys + }); + } +} + +// ============================================================================= +// Stored Snapshot +// ============================================================================= + +#[derive(Debug)] +pub struct StoredSnapshot { + pub meta: SnapshotMeta, + pub data: SnapshotData, +} + +// ============================================================================= +// State Machine Store +// ============================================================================= + +#[derive(Debug, Default)] +pub struct StateMachineStore { + pub state_machine: tokio::sync::Mutex, + + /// Counter for generating unique snapshot IDs. + snapshot_idx: Mutex, + + /// The current snapshot (if any). + current_snapshot: Mutex>, +} + +impl StateMachineStore { + pub fn with_initial_data(data: BTreeMap) -> Self { + Self { + state_machine: tokio::sync::Mutex::new(StateMachineData { + last_applied: None, + last_membership: StoredMembership::default(), + data, + }), + snapshot_idx: Mutex::new(0), + current_snapshot: Mutex::new(None), + } + } +} + +// ============================================================================= +// Snapshot Builder Implementation +// ============================================================================= + +impl RaftSnapshotBuilder for Arc { + #[tracing::instrument(level = "trace", skip(self))] + async fn build_snapshot(&mut self) -> Result { + let data; + let last_applied_log; + let last_membership; + + { + let state_machine = self.state_machine.lock().await.clone(); + last_applied_log = state_machine.last_applied; + last_membership = state_machine.last_membership.clone(); + data = state_machine; + } + + let snapshot_idx = { + let mut l = self.snapshot_idx.lock().unwrap(); + *l += 1; + *l + }; + + let snapshot_id = if let Some(last) = last_applied_log { + format!("{}-{}-{}", last.committed_leader_id(), last.index(), snapshot_idx) + } else { + format!("--{}", snapshot_idx) + }; + + let meta = SnapshotMeta { + last_log_id: last_applied_log, + last_membership, + snapshot_id, + }; + + let snapshot = StoredSnapshot { + meta: meta.clone(), + data: data.clone(), + }; + + { + let mut current_snapshot = self.current_snapshot.lock().unwrap(); + *current_snapshot = Some(snapshot); + } + + Ok(Snapshot { meta, snapshot: data }) + } +} + +// ============================================================================= +// State Machine Implementation +// ============================================================================= + +impl RaftStateMachine for Arc { + type SnapshotBuilder = Self; + + async fn applied_state(&mut self) -> Result<(Option, StoredMembership), io::Error> { + let state_machine = self.state_machine.lock().await; + Ok((state_machine.last_applied, state_machine.last_membership.clone())) + } + + /// Apply committed log entries to the state machine. + /// + /// This is where the magic happens for Split operations: + /// - When a Split log is applied, we atomically extract data for the new shard + /// - The response contains the extracted data for bootstrapping the new shard + #[tracing::instrument(level = "trace", skip(self, entries))] + async fn apply(&mut self, mut entries: Strm) -> Result<(), io::Error> + where Strm: Stream, io::Error>> + Unpin + OptionalSend { + let mut sm = self.state_machine.lock().await; + + while let Some((entry, responder)) = entries.try_next().await? { + tracing::debug!(%entry.log_id, "applying entry to state machine"); + + sm.last_applied = Some(entry.log_id); + + let response = match entry.payload { + EntryPayload::Blank => Response::Ok { previous_value: None }, + + EntryPayload::Normal(ref req) => { + match req { + Request::Set { key, value } => { + // Normal set operation + let previous = sm.data.insert(key.clone(), value.clone()); + tracing::debug!(key = %key, value = %value, "set key"); + Response::Ok { + previous_value: previous, + } + } + + Request::Delete { key } => { + // Normal delete operation + let previous = sm.data.remove(key); + tracing::debug!(key = %key, "delete key"); + Response::Ok { + previous_value: previous, + } + } + + Request::Split { split_at, new_shard_id } => { + // ============================================================ + // SPLIT OPERATION - The Key Feature of This Example + // ============================================================ + // + // This is where the TiKV-style split happens: + // 1. Extract all keys with user_id > split_at + // 2. Remove those keys from the current state machine + // 3. Return the extracted data in the response + // + // The response will be used to bootstrap the new shard. + + tracing::info!( + split_at = %split_at, + new_shard_id = %new_shard_id, + "executing split operation" + ); + + // Step 1: Extract data for the new shard + let split_data = sm.keys_greater_than(*split_at); + let key_count = split_data.len(); + + tracing::info!( + key_count = %key_count, + "extracted {} keys for new shard", + key_count + ); + + // Step 2: Remove extracted keys from current shard + sm.remove_keys_greater_than(*split_at); + + tracing::info!( + remaining_keys = %sm.data.len(), + "current shard now has {} keys", + sm.data.len() + ); + + // Step 3: Return the split data in the response + Response::SplitComplete { + new_shard_id: new_shard_id.clone(), + split_data, + key_count, + } + } + } + } + + EntryPayload::Membership(ref mem) => { + sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone()); + Response::Ok { previous_value: None } + } + }; + + if let Some(responder) = responder { + responder.send(response); + } + } + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn begin_receiving_snapshot(&mut self) -> Result { + Ok(Default::default()) + } + + #[tracing::instrument(level = "trace", skip(self, snapshot))] + async fn install_snapshot(&mut self, meta: &SnapshotMeta, snapshot: SnapshotData) -> Result<(), io::Error> { + tracing::info!( + snapshot_id = %meta.snapshot_id, + keys = %snapshot.data.len(), + "installing snapshot" + ); + + let new_snapshot = StoredSnapshot { + meta: meta.clone(), + data: snapshot, + }; + + // Update the state machine + { + let updated_state_machine: StateMachineData = new_snapshot.data.clone(); + let mut state_machine = self.state_machine.lock().await; + *state_machine = updated_state_machine; + } + + // Update current snapshot + let mut current_snapshot = self.current_snapshot.lock().unwrap(); + *current_snapshot = Some(new_snapshot); + Ok(()) + } + + #[tracing::instrument(level = "trace", skip(self))] + async fn get_current_snapshot(&mut self) -> Result, io::Error> { + match &*self.current_snapshot.lock().unwrap() { + Some(snapshot) => { + let data = snapshot.data.clone(); + Ok(Some(Snapshot { + meta: snapshot.meta.clone(), + snapshot: data, + })) + } + None => Ok(None), + } + } + + async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder { + self.clone() + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_user_id() { + assert_eq!(StateMachineData::parse_user_id("user:123"), Some(123)); + assert_eq!(StateMachineData::parse_user_id("user:0"), Some(0)); + assert_eq!(StateMachineData::parse_user_id("other:123"), None); + assert_eq!(StateMachineData::parse_user_id("user:abc"), None); + } + + #[test] + fn test_keys_greater_than() { + let mut data = BTreeMap::new(); + data.insert("user:50".to_string(), "Alice".to_string()); + data.insert("user:100".to_string(), "Bob".to_string()); + data.insert("user:101".to_string(), "Charlie".to_string()); + data.insert("user:200".to_string(), "Diana".to_string()); + + let sm = StateMachineData { + data, + ..Default::default() + }; + + let split_data = sm.keys_greater_than(100); + assert_eq!(split_data.len(), 2); + assert!(split_data.contains_key("user:101")); + assert!(split_data.contains_key("user:200")); + assert!(!split_data.contains_key("user:100")); + } + + #[test] + fn test_remove_keys_greater_than() { + let mut data = BTreeMap::new(); + data.insert("user:50".to_string(), "Alice".to_string()); + data.insert("user:100".to_string(), "Bob".to_string()); + data.insert("user:101".to_string(), "Charlie".to_string()); + data.insert("user:200".to_string(), "Diana".to_string()); + + let mut sm = StateMachineData { + data, + ..Default::default() + }; + sm.remove_keys_greater_than(100); + + assert_eq!(sm.data.len(), 2); + assert!(sm.data.contains_key("user:50")); + assert!(sm.data.contains_key("user:100")); + assert!(!sm.data.contains_key("user:101")); + } +} diff --git a/examples/multi-raft-sharding/tests/cluster/main.rs b/examples/multi-raft-sharding/tests/cluster/main.rs new file mode 100644 index 000000000..30e08a274 --- /dev/null +++ b/examples/multi-raft-sharding/tests/cluster/main.rs @@ -0,0 +1 @@ +mod test_split; diff --git a/examples/multi-raft-sharding/tests/cluster/test_split.rs b/examples/multi-raft-sharding/tests/cluster/test_split.rs new file mode 100644 index 000000000..1a867f58d --- /dev/null +++ b/examples/multi-raft-sharding/tests/cluster/test_split.rs @@ -0,0 +1,570 @@ +//! TiKV-style Shard Split Test +//! +//! This test demonstrates the complete lifecycle of a shard split operation: +//! +//! ## Phase 1: Initial State +//! - 3 nodes (Node 1, 2, 3) +//! - 1 shard (shard_a) containing all user data (user:1 to user:200) +//! +//! ## Phase 2: Split +//! - Split shard_a at user_id=100 +//! - shard_a now contains user:1 to user:100 +//! - shard_b is created with user:101 to user:200 +//! - Both shards run on Node 1, 2, 3 +//! +//! ## Phase 3: Add New Nodes +//! - Add Node 4, 5 to the cluster +//! - Add Node 4, 5 as learners to shard_b +//! - Wait for replication to complete +//! +//! ## Phase 4: Migrate shard_b +//! - Change shard_b membership to only Node 4, 5 +//! - Shut down shard_b on Node 1, 2, 3 +//! - Final state: shard_a on Node 1,2,3; shard_b on Node 4,5 +//! +//! ## Key Insight: Split as Raft Log +//! +//! The split operation is proposed as a normal Raft log entry. When applied: +//! 1. Each replica extracts data with user_id > 100 +//! 2. Each replica removes that data from its state machine +//! 3. The extracted data is used to bootstrap the new shard +//! +//! This ensures atomic, consistent split across all replicas without +//! distributed locks or two-phase commit. + +use std::backtrace::Backtrace; +use std::collections::BTreeMap; +use std::collections::BTreeSet; +use std::panic::PanicHookInfo; +use std::sync::Arc; +use std::time::Duration; + +use multi_raft_sharding::new_raft; +use multi_raft_sharding::router::Router; +use multi_raft_sharding::shard_router::ShardRouter; +use multi_raft_sharding::shards; +use multi_raft_sharding::store::Request; +use multi_raft_sharding::store::Response; +use multi_raft_sharding::store::StateMachineStore; +use multi_raft_sharding::typ; +use multi_raft_sharding::NodeId; +use multi_raft_sharding::ShardId; +use openraft::BasicNode; +use openraft::Config; +use openraft::Raft; +use tokio::task; +use tokio::task::LocalSet; +use tracing_subscriber::EnvFilter; + +/// Log panic with backtrace for debugging. +pub fn log_panic(panic: &PanicHookInfo) { + let backtrace = format!("{:?}", Backtrace::force_capture()); + eprintln!("{}", panic); + if let Some(location) = panic.location() { + tracing::error!( + message = %panic, + backtrace = %backtrace, + panic.file = location.file(), + panic.line = location.line(), + panic.column = location.column(), + ); + eprintln!("{}:{}:{}", location.file(), location.line(), location.column()); + } else { + tracing::error!(message = %panic, backtrace = %backtrace); + } + eprintln!("{}", backtrace); +} + +#[tokio::test] +async fn test_shard_split() { + std::panic::set_hook(Box::new(|panic| { + log_panic(panic); + })); + + tracing_subscriber::fmt() + .with_target(true) + .with_thread_ids(true) + .with_level(true) + .with_ansi(false) + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + // ========================================================================= + // Setup: Create shared infrastructure + // ========================================================================= + + // The message router simulates network communication between nodes. + let router = Router::new(); + + // The shard router tracks which shard handles which key range. + // Initially, shard_a handles all keys [1, MAX]. + let shard_router = Arc::new(ShardRouter::with_initial_shard(shards::SHARD_A.to_string())); + + let local = LocalSet::new(); + + local + .run_until(async move { + run_split_test(router, shard_router).await; + }) + .await; +} + +/// Run the complete split test through all four phases. +async fn run_split_test(router: Router, shard_router: Arc) { + // ========================================================================= + // Phase 1: Initial State - 3 nodes, 1 shard + // ========================================================================= + println!("┌──────────────────────────────────────────────────────────────────────────┐"); + println!("│ PHASE 1: Initial State │"); + println!("│ │"); + println!("│ Node 1 Node 2 Node 3 │"); + println!("│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │"); + println!("│ │shard_a:★│ │shard_a:F│ │shard_a:F│ │"); + println!("│ │[1..200] │ │[1..200] │ │[1..200] │ │"); + println!("│ └─────────┘ └─────────┘ └─────────┘ │"); + println!("└──────────────────────────────────────────────────────────────────────────┘"); + println!(); + + // Create shard_a on nodes 1, 2, 3 + let (raft_1a, app_1a) = new_raft(1, shards::SHARD_A.to_string(), router.clone()).await; + let (_raft_2a, app_2a) = new_raft(2, shards::SHARD_A.to_string(), router.clone()).await; + let (_raft_3a, app_3a) = new_raft(3, shards::SHARD_A.to_string(), router.clone()).await; + + // Keep reference to state_machine for verification at the end + let state_machine_1a = app_1a.state_machine.clone(); + + // Spawn app handlers + task::spawn_local(app_1a.run()); + task::spawn_local(app_2a.run()); + task::spawn_local(app_3a.run()); + + // Wait for apps to start + tokio::time::sleep(Duration::from_millis(100)).await; + + // Initialize shard_a with Node 1 as the initial leader + println!(" → Initializing shard_a on Node 1..."); + let mut nodes = BTreeMap::new(); + nodes.insert(1u64, BasicNode { addr: "".to_string() }); + raft_1a.initialize(nodes).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // Add Node 2 and 3 as voters + println!(" → Adding Node 2, 3 to shard_a..."); + raft_1a.add_learner(2, BasicNode { addr: "".to_string() }, true).await.unwrap(); + raft_1a.add_learner(3, BasicNode { addr: "".to_string() }, true).await.unwrap(); + + // Promote to voters + let members: BTreeSet = [1, 2, 3].into_iter().collect(); + raft_1a.change_membership(members, false).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // Write user data (user:1 to user:200) + println!(" → Writing 200 users (user:1 to user:200)..."); + for user_id in 1..=200 { + let request = Request::set_user(user_id, format!("User_{}", user_id)); + raft_1a.client_write(request).await.unwrap(); + } + + println!(" ✓ Phase 1 complete: shard_a has 200 users on 3 nodes"); + println!(); + + // Verify data + { + let metrics = raft_1a.metrics().borrow().clone(); + println!( + " Metrics: leader={:?}, last_applied={:?}", + metrics.current_leader, + metrics.last_applied.map(|x| x.index) + ); + } + + // ========================================================================= + // Phase 2: Split shard_a at user_id=100 + // ========================================================================= + println!(); + println!("┌──────────────────────────────────────────────────────────────────────────┐"); + println!("│ PHASE 2: Split at user_id=100 │"); + println!("│ │"); + println!("│ Node 1 Node 2 Node 3 │"); + println!("│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │"); + println!("│ │shard_a:★│ │shard_a:F│ │shard_a:F│ │"); + println!("│ │[1..100] │ │[1..100] │ │[1..100] │ │"); + println!("│ ├─────────┤ ├─────────┤ ├─────────┤ │"); + println!("│ │shard_b:★│ │shard_b:F│ │shard_b:F│ │"); + println!("│ │[101..200│ │[101..200│ │[101..200│ │"); + println!("│ └─────────┘ └─────────┘ └─────────┘ │"); + println!("└──────────────────────────────────────────────────────────────────────────┘"); + println!(); + + // Step 1: Propose the split as a Raft log entry + println!(" → Proposing split operation as Raft log entry..."); + println!(" (This is the TiKV-style atomic split)"); + + let split_request = Request::split(100, shards::SHARD_B); + let split_response = raft_1a.client_write(split_request).await.unwrap(); + + // Extract the split data from the response + let split_data = match split_response.response() { + Response::SplitComplete { + new_shard_id, + split_data, + key_count, + } => { + println!(" ✓ Split completed atomically!"); + println!(" - New shard: {}", new_shard_id); + println!(" - Keys migrated: {}", key_count); + split_data.clone() + } + _ => panic!("Expected SplitComplete response"), + }; + + // Update the shard router + shard_router.apply_split(shards::SHARD_A, 100, shards::SHARD_B.to_string()); + println!(" → Updated shard router:"); + println!(" - shard_a: [1..100]"); + println!(" - shard_b: [101..MAX]"); + + // Step 2: Create shard_b on the same nodes using the split data + println!(" → Creating shard_b instances on Node 1, 2, 3..."); + + // Create shard_b with the split data as initial state + let (raft_1b, app_1b) = + create_raft_with_data(1, shards::SHARD_B.to_string(), router.clone(), split_data.clone()).await; + let (_raft_2b, app_2b) = + create_raft_with_data(2, shards::SHARD_B.to_string(), router.clone(), split_data.clone()).await; + let (_raft_3b, app_3b) = + create_raft_with_data(3, shards::SHARD_B.to_string(), router.clone(), split_data.clone()).await; + + // Keep reference to shard_b state_machine (Node 1 is the source of split data) + let _state_machine_1b = app_1b.state_machine.clone(); + + task::spawn_local(app_1b.run()); + task::spawn_local(app_2b.run()); + task::spawn_local(app_3b.run()); + + tokio::time::sleep(Duration::from_millis(100)).await; + + // Initialize shard_b + let mut nodes = BTreeMap::new(); + nodes.insert(1u64, BasicNode { addr: "".to_string() }); + raft_1b.initialize(nodes).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // Add other nodes + raft_1b.add_learner(2, BasicNode { addr: "".to_string() }, true).await.unwrap(); + raft_1b.add_learner(3, BasicNode { addr: "".to_string() }, true).await.unwrap(); + + let members: BTreeSet = [1, 2, 3].into_iter().collect(); + raft_1b.change_membership(members, false).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // Verify both shards have correct data + println!(" → Verifying data split..."); + + // Trigger snapshot to verify data + raft_1a.trigger().snapshot().await.unwrap(); + tokio::time::sleep(Duration::from_millis(500)).await; + + // Check shard_a has users 1-100 + { + let snapshot = raft_1a.get_snapshot().await.unwrap().unwrap(); + let has_user_50 = snapshot.snapshot.data.contains_key("user:50"); + let has_user_100 = snapshot.snapshot.data.contains_key("user:100"); + let has_user_101 = snapshot.snapshot.data.contains_key("user:101"); + + assert!(has_user_50, "shard_a should have user:50"); + assert!(has_user_100, "shard_a should have user:100"); + assert!(!has_user_101, "shard_a should NOT have user:101 after split"); + + println!(" ✓ shard_a: {} keys (users 1-100)", snapshot.snapshot.data.len()); + } + + // Check shard_b has users 101-200 (check state machine directly as it's newly created) + { + let metrics = raft_1b.metrics().borrow().clone(); + println!( + " ✓ shard_b: initialized with split data, last_applied={:?}", + metrics.last_applied.map(|x| x.index) + ); + } + + println!(" ✓ Phase 2 complete: shard split successful!"); + println!(); + + // ========================================================================= + // Phase 3: Add Node 4, 5 to shard_b + // ========================================================================= + println!("┌──────────────────────────────────────────────────────────────────────────┐"); + println!("│ PHASE 3: Add Node 4, 5 to shard_b │"); + println!("│ │"); + println!("│ Node 1 Node 2 Node 3 Node 4 Node 5 │"); + println!("│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │"); + println!("│ │shrd_a:★│ │shrd_a:F│ │shrd_a:F│ │ ─ │ │ - │ │"); + println!("│ ├────────┤ ├────────┤ ├────────┤ ├────────┤ ├────────┤ │"); + println!("│ │shrd_b:★│ │shrd_b:F│ │shrd_b:F│ │shrd_b:L│ │shrd_b:L│ │"); + println!("│ └────────┘ └────────┘ └────────┘ └────────┘ └────────┘ │"); + println!("└──────────────────────────────────────────────────────────────────────────┘"); + println!(); + + // Create shard_b on Node 4, 5 + println!(" → Creating shard_b on Node 4, 5..."); + let (raft_4b, app_4b) = new_raft(4, shards::SHARD_B.to_string(), router.clone()).await; + let (raft_5b, app_5b) = new_raft(5, shards::SHARD_B.to_string(), router.clone()).await; + + // Keep reference to state_machine for verification at the end + let state_machine_4b = app_4b.state_machine.clone(); + + task::spawn_local(app_4b.run()); + task::spawn_local(app_5b.run()); + + tokio::time::sleep(Duration::from_millis(100)).await; + + // Before adding learners, trigger a snapshot on the leader. + // This is necessary because the initial data was injected directly into the state machine + // (via create_raft_with_data) without corresponding log entries. Without a snapshot, + // the leader wouldn't know it needs to send the data to new followers. + println!(" → Triggering snapshot on shard_b leader before adding learners..."); + raft_1b.trigger().snapshot().await.unwrap(); + tokio::time::sleep(Duration::from_millis(500)).await; + + // Add as learners + println!(" → Adding Node 4, 5 as learners to shard_b..."); + raft_1b.add_learner(4, BasicNode { addr: "".to_string() }, true).await.unwrap(); + raft_1b.add_learner(5, BasicNode { addr: "".to_string() }, true).await.unwrap(); + + // Wait for replication (snapshot will be sent to learners) + println!(" → Waiting for snapshot replication to Node 4, 5..."); + tokio::time::sleep(Duration::from_millis(2000)).await; + + // Verify replication + { + let metrics_4 = raft_4b.metrics().borrow().clone(); + let metrics_5 = raft_5b.metrics().borrow().clone(); + println!(" Node 4: last_applied={:?}", metrics_4.last_applied.map(|x| x.index)); + println!(" Node 5: last_applied={:?}", metrics_5.last_applied.map(|x| x.index)); + } + + println!(" ✓ Phase 3 complete: Node 4, 5 have shard_b data"); + println!(); + + // ========================================================================= + // Phase 4: Migrate shard_b to Node 4, 5 (remove from Node 1, 2, 3) + // ========================================================================= + println!("┌──────────────────────────────────────────────────────────────────────────┐"); + println!("│ PHASE 4: Migrate shard_b to Node 4, 5 │"); + println!("│ │"); + println!("│ Node 1 Node 2 Node 3 Node 4 Node 5 │"); + println!("│ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │"); + println!("│ │shrd_a:★│ │shrd_a:F│ │shrd_a:F│ │ - │ │ - │ │"); + println!("│ │ - │ │ - │ │ - │ │shrd_b:★│ │shrd_b:F│ │"); + println!("│ └────────┘ └────────┘ └────────┘ └────────┘ └────────┘ │"); + println!("└──────────────────────────────────────────────────────────────────────────┘"); + println!(); + + // First, add Node 4, 5 as voters while keeping 1, 2, 3 + println!(" → Promoting Node 4, 5 to voters..."); + let members: BTreeSet = [1, 2, 3, 4, 5].into_iter().collect(); + raft_1b.change_membership(members, false).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // Now remove Node 1, 2, 3 from shard_b + println!(" → Removing Node 1, 2, 3 from shard_b membership..."); + let new_members: BTreeSet = [4, 5].into_iter().collect(); + + // This will transfer leadership to Node 4 or 5 automatically + // because the current leader (Node 1) is being removed + raft_1b.change_membership(new_members, false).await.unwrap(); + + // Wait for leader election to complete on Node 4 or 5 + // This may take up to election_timeout_max (3000ms) + tokio::time::sleep(Duration::from_millis(4000)).await; + + // Verify final state + println!(" → Verifying final state..."); + + // Verify shard_b leadership moved to Node 4 or 5 + { + let metrics_4 = raft_4b.metrics().borrow().clone(); + let metrics_5 = raft_5b.metrics().borrow().clone(); + + let leader = metrics_4.current_leader.or(metrics_5.current_leader); + println!(" shard_b leader: {:?}", leader); + + // After membership change, either Node 4 or 5 should be the leader + // If leader is None, it means election is still in progress + if leader.is_none() { + println!(" (Leader election may still be in progress, checking membership...)"); + println!(" Node 4 membership: {:?}", metrics_4.membership_config); + println!(" Node 5 membership: {:?}", metrics_5.membership_config); + // In a real scenario, we'd wait and retry + // For the test, we just verify the membership change was successful + let voter_ids_4: Vec = metrics_4.membership_config.membership().voter_ids().collect(); + let voter_ids_5: Vec = metrics_5.membership_config.membership().voter_ids().collect(); + assert!( + voter_ids_4 == vec![4u64, 5] || voter_ids_5 == vec![4u64, 5], + "Membership should contain only Node 4 and 5, got: {:?} / {:?}", + voter_ids_4, + voter_ids_5 + ); + } else { + assert!( + leader == Some(4) || leader == Some(5), + "shard_b leader should be Node 4 or 5" + ); + } + } + + // Verify shard_a is still working on Node 1, 2, 3 + { + let metrics_1a = raft_1a.metrics().borrow().clone(); + println!(" shard_a leader: {:?}", metrics_1a.current_leader); + assert_eq!( + metrics_1a.current_leader, + Some(1), + "shard_a leader should still be Node 1" + ); + } + + // Shutdown shard_b on Node 1, 2, 3 + println!(" → Shutting down shard_b on Node 1, 2, 3..."); + // In a real system, we would call raft.shutdown() and unregister from router + // For this test, we just verify the membership change worked + + println!(" ✓ Phase 4 complete: shard_b migrated to Node 4, 5!"); + println!(); + + // ========================================================================= + // Final Verification: Data Integrity After Split and Migration + // ========================================================================= + println!("┌──────────────────────────────────────────────────────────────────────────┐"); + println!("│ FINAL VERIFICATION: Data Integrity │"); + println!("│ │"); + println!("│ shard_a (Node 1): Should have user:1..100, NOT user:101..200 │"); + println!("│ shard_b (Node 4): Should have user:101..200, NOT user:1..100 │"); + println!("└──────────────────────────────────────────────────────────────────────────┘"); + println!(); + + // Verify shard_a data: Should contain user:1..100, should NOT contain user:101..200 + { + let sm_a = state_machine_1a.state_machine.lock().await; + println!(" → Verifying shard_a data (Node 1)..."); + println!(" Total keys in shard_a: {}", sm_a.data.len()); + + // Check that shard_a contains users 1-100 + for user_id in 1..=100 { + let key = format!("user:{}", user_id); + assert!( + sm_a.data.contains_key(&key), + "shard_a should contain {} but it doesn't", + key + ); + } + println!(" ✓ shard_a contains all users 1-100"); + + // Check that shard_a does NOT contain users 101-200 + for user_id in 101..=200 { + let key = format!("user:{}", user_id); + assert!( + !sm_a.data.contains_key(&key), + "shard_a should NOT contain {} but it does", + key + ); + } + println!(" ✓ shard_a does NOT contain any users 101-200"); + + // Verify exact count + assert_eq!( + sm_a.data.len(), + 100, + "shard_a should have exactly 100 keys, but has {}", + sm_a.data.len() + ); + } + + // Verify shard_b data: Should contain user:101..200, should NOT contain user:1..100 + // We verify using Node 4's shard_b state machine, which received data via snapshot replication + // from the leader (Node 1). This proves the migration was successful. + { + let sm_b = state_machine_4b.state_machine.lock().await; + println!(" → Verifying shard_b data (Node 4, received via snapshot)..."); + println!(" Total keys in shard_b: {}", sm_b.data.len()); + + // Check that shard_b contains users 101-200 + for user_id in 101..=200 { + let key = format!("user:{}", user_id); + assert!( + sm_b.data.contains_key(&key), + "shard_b should contain {} but it doesn't", + key + ); + } + println!(" ✓ shard_b contains all users 101-200"); + + // Check that shard_b does NOT contain users 1-100 + for user_id in 1..=100 { + let key = format!("user:{}", user_id); + assert!( + !sm_b.data.contains_key(&key), + "shard_b should NOT contain {} but it does", + key + ); + } + println!(" ✓ shard_b does NOT contain any users 1-100"); + + // Verify exact count + assert_eq!( + sm_b.data.len(), + 100, + "shard_b should have exactly 100 keys, but has {}", + sm_b.data.len() + ); + } + + println!(); + println!("═══════════════════════════════════════════════════════════════════════════"); + println!(" ✓ ALL VERIFICATIONS PASSED!"); + println!(" ✓ Shard split and migration completed successfully with data integrity!"); + println!("═══════════════════════════════════════════════════════════════════════════"); +} + +/// Create a Raft instance with initial data (used for shard_b after split). +/// +/// This function creates a new Raft instance with the state machine pre-populated +/// with data from the split operation. This is how we bootstrap the new shard +/// after a split. +async fn create_raft_with_data( + node_id: NodeId, + shard_id: ShardId, + router: Router, + initial_data: BTreeMap, +) -> (typ::Raft, multi_raft_sharding::app::App) { + use multi_raft_sharding::network::ShardNetworkFactory; + + let config = Config { + heartbeat_interval: 500, + election_timeout_min: 1500, + election_timeout_max: 3000, + max_in_snapshot_log_to_keep: 0, + ..Default::default() + }; + + let config = Arc::new(config.validate().unwrap()); + + let log_store = multi_raft_sharding::LogStore::default(); + + // Create state machine with the split data + let state_machine_store = Arc::new(StateMachineStore::with_initial_data(initial_data)); + + let network = ShardNetworkFactory::new(router.clone(), shard_id.clone()); + + let raft = Raft::new(node_id, config, network, log_store, state_machine_store.clone()).await.unwrap(); + + let app = multi_raft_sharding::app::App::new(node_id, shard_id, raft.clone(), router, state_machine_store); + + (raft, app) +} diff --git a/multiraft/Cargo.toml b/multiraft/Cargo.toml new file mode 100644 index 000000000..f334f594f --- /dev/null +++ b/multiraft/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "openraft-multi" +description = "Multi-Raft adapters for connection sharing across Raft groups" +documentation = "https://docs.rs/openraft-multiraft" +readme = "README.md" +version = "0.10.0" +edition = "2021" +authors = ["Databend Authors "] +categories = ["network", "asynchronous", "data-structures"] +homepage = "https://github.com/databendlabs/openraft" +keywords = ["consensus", "raft", "multi-raft"] +license = "MIT OR Apache-2.0" +repository = "https://github.com/databendlabs/openraft" + +[dependencies] +openraft = { path = "../openraft", version = "0.10.0", default-features = false } +anyerror = { version = "0.1" } + diff --git a/multiraft/README.md b/multiraft/README.md new file mode 100644 index 000000000..bf7791456 --- /dev/null +++ b/multiraft/README.md @@ -0,0 +1,41 @@ +# openraft-multi + +Multi-Raft adapters for connection sharing across Raft groups. + +## Components + +- **`GroupRouter`** - Trait for sending RPCs with (target, group) routing +- **`GroupNetworkAdapter`** - Wraps `GroupRouter`, implements `RaftNetworkV2` +- **`GroupNetworkFactory`** - Simple factory + group_id wrapper + +## Usage + +1. Implement `GroupRouter` on your shared router/connection pool +2. Use `GroupNetworkAdapter` to get automatic `RaftNetworkV2` implementation + +```rust +use openraft_multiraft::{GroupRouter, GroupNetworkAdapter, GroupNetworkFactory}; + +// Your Router implements GroupRouter +impl GroupRouter for Router { + // ... +} + +// Create factory for a specific group +let factory = GroupNetworkFactory::new(router, group_id); + +// Factory creates adapters that implement RaftNetworkV2 +impl RaftNetworkFactory for NetworkFactory { + type Network = GroupNetworkAdapter; + + async fn new_client(&mut self, target: NodeId, _node: &Node) -> Self::Network { + GroupNetworkAdapter::new(self.factory.clone(), target, self.group_id.clone()) + } +} +``` + +## Examples + +- [multi-raft-kv](../examples/multi-raft-kv/) - Basic Multi-Raft with 3 groups +- [multi-raft-sharding](../examples/multi-raft-sharding/) - TiKV-style shard split + diff --git a/multiraft/src/lib.rs b/multiraft/src/lib.rs new file mode 100644 index 000000000..b9d198928 --- /dev/null +++ b/multiraft/src/lib.rs @@ -0,0 +1,5 @@ +mod network; + +pub use network::GroupNetworkAdapter; +pub use network::GroupNetworkFactory; +pub use network::GroupRouter; diff --git a/multiraft/src/network.rs b/multiraft/src/network.rs new file mode 100644 index 000000000..d39c31e35 --- /dev/null +++ b/multiraft/src/network.rs @@ -0,0 +1,205 @@ +//! Network adapters for Multi-Raft connection sharing. +//! +//! - [`GroupRouter`] - Trait for sending RPCs with target + group routing +//! - [`GroupNetworkAdapter`] - Wraps `GroupRouter` with (target, group_id), implements +//! `RaftNetworkV2` +//! - [`GroupNetworkFactory`] - Simple wrapper for factory + group_id +//! +//! See `examples/multi-raft-kv` and `examples/multi-raft-sharding` for usage. + +use std::future::Future; +use std::time::Duration; + +use openraft::error::RPCError; +use openraft::error::ReplicationClosed; +use openraft::error::StreamingError; +use openraft::error::Unreachable; +use openraft::network::v2::RaftNetworkV2; +use openraft::network::Backoff; +use openraft::network::RPCOption; +use openraft::raft::AppendEntriesRequest; +use openraft::raft::AppendEntriesResponse; +use openraft::raft::SnapshotResponse; +use openraft::raft::TransferLeaderRequest; +use openraft::raft::VoteRequest; +use openraft::raft::VoteResponse; +use openraft::storage::Snapshot; +use openraft::type_config::alias::VoteOf; +use openraft::OptionalSend; +use openraft::OptionalSync; +use openraft::RaftTypeConfig; + +/// Trait for sending Raft RPCs with target and group routing. +/// +/// Implement this on your shared router/connection pool to enable connection +/// sharing across all Raft groups. The adapter will bind (target, group_id). +pub trait GroupRouter: Clone + OptionalSend + OptionalSync + 'static +where C: RaftTypeConfig +{ + /// Send AppendEntries to target node for a specific group. + fn send_append_entries( + &self, + target: C::NodeId, + group_id: G, + rpc: AppendEntriesRequest, + option: RPCOption, + ) -> impl Future, RPCError>> + OptionalSend; + + /// Send Vote to target node for a specific group. + fn send_vote( + &self, + target: C::NodeId, + group_id: G, + rpc: VoteRequest, + option: RPCOption, + ) -> impl Future, RPCError>> + OptionalSend; + + /// Send snapshot to target node for a specific group. + fn send_snapshot( + &self, + target: C::NodeId, + group_id: G, + vote: VoteOf, + snapshot: Snapshot, + cancel: impl Future + OptionalSend + 'static, + option: RPCOption, + ) -> impl Future, StreamingError>> + OptionalSend; + + /// Send TransferLeader to target node for a specific group. + /// Default: returns "not implemented" error. + fn send_transfer_leader( + &self, + _target: C::NodeId, + _group_id: G, + _req: TransferLeaderRequest, + _option: RPCOption, + ) -> impl Future>> + OptionalSend { + async { + Err(RPCError::Unreachable(Unreachable::new(&anyerror::AnyError::error( + "transfer_leader not implemented", + )))) + } + } + + /// Backoff strategy for retries. Default: 500ms constant. + fn backoff(&self) -> Backoff { + Backoff::new(std::iter::repeat(Duration::from_millis(500))) + } +} + +/// Adapter that binds (target, group_id) to a shared router. +/// +/// This wraps a [`GroupRouter`] implementation (e.g., your Router) and +/// automatically implements `RaftNetworkV2` for a specific (target, group). +pub struct GroupNetworkAdapter +where + C: RaftTypeConfig, + N: GroupRouter, +{ + router: N, + target: C::NodeId, + group_id: G, +} + +impl Clone for GroupNetworkAdapter +where + C: RaftTypeConfig, + G: Clone, + N: GroupRouter, +{ + fn clone(&self) -> Self { + Self { + router: self.router.clone(), + target: self.target.clone(), + group_id: self.group_id.clone(), + } + } +} + +impl GroupNetworkAdapter +where + C: RaftTypeConfig, + N: GroupRouter, +{ + /// Create adapter binding router to specific (target, group). + pub fn new(router: N, target: C::NodeId, group_id: G) -> Self { + Self { + router, + target, + group_id, + } + } + + /// Returns the target node ID. + pub fn target(&self) -> &C::NodeId { + &self.target + } + + /// Returns the group ID. + pub fn group_id(&self) -> &G { + &self.group_id + } +} + +// Implement RaftNetworkV2 for GroupNetworkAdapter. +impl RaftNetworkV2 for GroupNetworkAdapter +where + C: RaftTypeConfig, + G: Clone + OptionalSend + OptionalSync + 'static, + N: GroupRouter, +{ + async fn append_entries( + &mut self, + rpc: AppendEntriesRequest, + option: RPCOption, + ) -> Result, RPCError> { + self.router.send_append_entries(self.target.clone(), self.group_id.clone(), rpc, option).await + } + + async fn vote(&mut self, rpc: VoteRequest, option: RPCOption) -> Result, RPCError> { + self.router.send_vote(self.target.clone(), self.group_id.clone(), rpc, option).await + } + + async fn full_snapshot( + &mut self, + vote: VoteOf, + snapshot: Snapshot, + cancel: impl Future + OptionalSend + 'static, + option: RPCOption, + ) -> Result, StreamingError> { + self.router + .send_snapshot( + self.target.clone(), + self.group_id.clone(), + vote, + snapshot, + cancel, + option, + ) + .await + } + + async fn transfer_leader(&mut self, req: TransferLeaderRequest, option: RPCOption) -> Result<(), RPCError> { + self.router.send_transfer_leader(self.target.clone(), self.group_id.clone(), req, option).await + } + + fn backoff(&self) -> Backoff { + self.router.backoff() + } +} + +/// Simple wrapper for network factory + group_id. +#[derive(Clone)] +pub struct GroupNetworkFactory { + /// The underlying factory (e.g., Router). + pub factory: F, + /// The group ID. + pub group_id: G, +} + +impl GroupNetworkFactory { + /// Create a new factory wrapper. + pub fn new(factory: F, group_id: G) -> Self { + Self { factory, group_id } + } +} diff --git a/openraft/src/error/mod.rs b/openraft/src/error/mod.rs index 93064ef6c..d112e87a1 100644 --- a/openraft/src/error/mod.rs +++ b/openraft/src/error/mod.rs @@ -6,6 +6,7 @@ pub(crate) mod higher_vote; pub mod into_ok; pub(crate) mod into_raft_result; mod invalid_sm; +mod leader_changed; mod linearizable_read_error; mod membership_error; mod node_not_found; @@ -16,8 +17,6 @@ pub(crate) mod storage_error; mod storage_io_result; mod streaming_error; -mod leader_changed; - use std::collections::BTreeSet; use std::error::Error; use std::fmt::Debug; diff --git a/rt-monoio/src/lib.rs b/rt-monoio/src/lib.rs index 3f1fc0101..61d5880a5 100644 --- a/rt-monoio/src/lib.rs +++ b/rt-monoio/src/lib.rs @@ -205,8 +205,6 @@ mod oneshot_mod { /// /// Tokio MPSC channel are runtime independent. mod mpsc_mod { - //! MPSC channel wrapper types and their trait impl. - use std::future::Future; use futures::TryFutureExt;