diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..0cda1ed --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,52 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - name: Install protoc + run: sudo apt-get update && sudo apt-get install -y protobuf-compiler + - uses: actions/checkout@v4 + - name: Build + run: cargo build --verbose + - name: Start Chat Server + run: | + cargo run --bin server -- --listen 0.0.0.0:4433 > server.log & + SERVER_PID=$! + sleep 5 + if ! ps -p $SERVER_PID > /dev/null; then + echo "Server failed to start." + exit 1 + fi + env: + RUST_LOG: debug + - name: wait server to start + run: sleep 5 + - name: Run tests + run: cargo test --verbose + - name: Test Chat Client with Interactive Input + run: | + cargo run --bin chat-client -- --server-addr 127.0.0.1:4433 --server-name localhost --name test-client + if [ $? -ne 0 ]; then + echo "Client exited with a failure." + exit 1 + fi + env: + RUST_LOG: debug + - name: Check Logs for Specific String + run: | + if ! grep "User test-client joined" server.log; then + echo "String not found in logs. No evidence to client messages to the server. Halting workflow." + exit 1 + fi diff --git a/.hooks/pre-commit b/.hooks/pre-commit new file mode 100755 index 0000000..a5f35c2 --- /dev/null +++ b/.hooks/pre-commit @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +echo "Running cargo fmt..." +cargo fmt -- --check + +echo "Running cargo check..." +cargo check + +echo "Running cargo clippy..." +cargo clippy -- -D warnings diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..7442bd3 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,133 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'server'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=server" + ], + "filter": { + "name": "server", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'chat_contract'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=chat-contract" + ], + "filter": { + "name": "chat_contract", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug executable 'server'", + "cargo": { + "args": [ + "build", + "--bin=server", + "--package=chat-server", + ], + "filter": { + "name": "server", + "kind": "bin" + } + }, + "args": [ + "--listen", + "127.0.0.1:4433" + ], + "cwd": "${workspaceFolder}", + "env": { + "RUST_LOG": "debug" // Set the RUST_LOG environment variable + } + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug executable 'client'", + "cargo": { + "args": [ + "build", + "--bin=chat-client", + "--package=chat-client" + ], + "filter": { + "name": "chat-client", + "kind": "bin" + } + }, + "args": [ + "--server-addr", + "127.0.0.1:4433", + "--server-name", + "localhost", + "--name", + "test-client-debugger", + ], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in executable 'client-cli'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bin=client-cli", + "--package=client-cli" + ], + "filter": { + "name": "client-cli", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'common'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=common" + ], + "filter": { + "name": "common", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c189089 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,7 @@ +[workspace] +resolver = "2" +members = [ + "server", + "client", + "chat-contract" +] diff --git a/README.md b/README.md index 8c4d4e1..abb18e4 100644 --- a/README.md +++ b/README.md @@ -69,3 +69,48 @@ without error, and is free of clippy errors. send a message to the server from the client. Make sure that niether the server or client exit with a failure. This action should be run anytime new code is pushed to a branch or landed on the main branch. + +### Demo +![Demo](assets/example-server-clients.gif) + +## Getting Started + +### 1. Install Git Pre-Commit Hook + +To ensure all code is formatted, compiles correctly, and passes Clippy checks before committing, run the following script once after cloning the repository: + +```bash +bash scripts/install-hooks.sh +``` +### 2. Install Protocol Buffers (protoc) +This project uses Protocol Buffers (Protobuf) for serializing structured data. In order to compile .proto files, you need to have the protoc compiler installed. + +Follow the instructions for your operating system here: +https://protobuf.dev/installation/ + +Make sure protoc is available in your PATH: +```bash +protoc --version +``` +Should output something like: libprotoc 3.21.12 + +Note: Some components of this project might auto-generate code from .proto files during the build process. Ensure protoc is installed and accessible before building or running the project. + +### 3. Run Server +```bash +cargo run --bin server +``` + +### 4. Run Client +```bash +cargo run --bin chat-client -- --server-addr 127.0.0.1:4433 --server-name localhost --name test-client +``` + +### 4. Run Tests +```bash +cargo run --bin server +``` +In other terminal +```bash +cargo test +``` \ No newline at end of file diff --git a/assets/example-server-clients.gif b/assets/example-server-clients.gif new file mode 100644 index 0000000..465c3e8 Binary files /dev/null and b/assets/example-server-clients.gif differ diff --git a/chat-contract/Cargo.toml b/chat-contract/Cargo.toml new file mode 100644 index 0000000..1f314a7 --- /dev/null +++ b/chat-contract/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "chat-contract" +version = "0.1.0" +edition = "2021" +authors = ["Ron Zigelman "] + +[dependencies] +prost = "0.12" +prost-types = "0.12" +uuid = { version = "1.16.0", features = ["v4"] } + +[build-dependencies] +prost-build = "0.12" diff --git a/chat-contract/build.rs b/chat-contract/build.rs new file mode 100644 index 0000000..88b6166 --- /dev/null +++ b/chat-contract/build.rs @@ -0,0 +1,5 @@ +use std::io::Result; +fn main() -> Result<()> { + prost_build::compile_protos(&["../proto/chat.proto"], &["../proto/"])?; + Ok(()) +} diff --git a/chat-contract/src/builders.rs b/chat-contract/src/builders.rs new file mode 100644 index 0000000..76dd91c --- /dev/null +++ b/chat-contract/src/builders.rs @@ -0,0 +1,843 @@ +use prost_types::Timestamp; +use uuid::Uuid; + +use crate::{ + chat::{self, ChatMessage, Error, ErrorCode, Header, Join, Leave, MessageType}, + current_timestamp, +}; + +const DEFAULT_ROOM: &str = "main"; + +/// A builder for creating [`Header`] messages with optional defaults. +/// +/// This builder helps construct a `Header` by requiring a `username` and `room`, +/// and optionally setting `message_id` and `timestamp`. +/// +/// You can use `build()` for strict validation (everything must be set), +/// or `build_with_defaults()` to automatically generate a message ID and timestamp. +/// +/// # Example (using default room and generated values) +/// ``` +/// use chat_contract::builders::HeaderBuilder; +/// +/// let header = HeaderBuilder::new() +/// .username("r-zig") +/// .with_default_room() +/// .build_with_defaults() +/// .unwrap(); +/// ``` +pub struct HeaderBuilder { + username: Option, + room: Option, + timestamp: Option, + message_id: Option, +} + +impl HeaderBuilder { + pub fn new() -> Self { + Self { + username: None, + room: None, + timestamp: None, + message_id: None, + } + } + + pub fn username(mut self, username: impl Into) -> Self { + self.username = Some(username.into()); + self + } + + pub fn room(mut self, room: impl Into) -> Self { + self.room = Some(room.into()); + self + } + + pub fn with_default_room(mut self) -> Self { + self.room = Some(DEFAULT_ROOM.to_string()); + self + } + + pub fn timestamp(mut self, ts: Timestamp) -> Self { + self.timestamp = Some(ts); + self + } + + pub fn message_id(mut self, id: impl Into) -> Self { + self.message_id = Some(id.into()); + self + } + + /// Strict build: all fields must be provided. + pub fn build(self) -> Result { + if self.username.as_ref().map_or(true, |s| s.trim().is_empty()) { + return Err(ErrorCode::UsernameRequired); + } + if self.room.as_ref().map_or(true, |s| s.trim().is_empty()) { + return Err(ErrorCode::RoomRequired); + } + if self.message_id.is_none() { + return Err(ErrorCode::MessageIdRequired); + } + if self.timestamp.is_none() { + return Err(ErrorCode::TimestampRequired); + } + + Ok(Header { + username: self.username.unwrap(), + room: self.room.unwrap(), + message_id: self.message_id.unwrap(), + timestamp: self.timestamp, + }) + } + + /// Lenient build: fills in missing message_id and timestamp. + pub fn build_with_defaults(self) -> Result { + if self.username.as_ref().map_or(true, |s| s.trim().is_empty()) { + return Err("username is required"); + } + if self.room.as_ref().map_or(true, |s| s.trim().is_empty()) { + return Err("room is required"); + } + + Ok(Header { + username: self.username.unwrap(), + room: self.room.unwrap(), + message_id: self + .message_id + .unwrap_or_else(|| Uuid::new_v4().to_string()), + timestamp: Some(self.timestamp.unwrap_or_else(current_timestamp)), + }) + } +} + +impl Default for HeaderBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Builder for [`Join`] message. +/// +/// This builder allows you to create a `Join` message by specifying a username, +/// and optionally a room. If no room is specified, it defaults to `"custom_room"`. +/// +/// The `build()` method will automatically fill in the `message_id` and `timestamp`. +/// +/// # Example (with default room) +/// ``` +/// use chat_contract::builders::JoinBuilder; +/// +/// let join = JoinBuilder::new() +/// .username("r-zig") +/// .with_default_room() +/// .build() +/// .unwrap(); +/// +/// let header = join.header.unwrap(); +/// assert_eq!(header.username, "r-zig"); +/// assert_eq!(header.room, "main"); +/// assert!(!header.message_id.is_empty()); +/// assert!(header.timestamp.is_some()); +/// ``` +pub struct JoinBuilder { + header_builder: HeaderBuilder, +} + +impl JoinBuilder { + pub fn new() -> Self { + Self { + header_builder: HeaderBuilder::new(), + } + } + + pub fn username(mut self, username: impl Into) -> Self { + self.header_builder = self.header_builder.username(username.into()); + self + } + + pub fn room(mut self, room: impl Into) -> Self { + self.header_builder = self.header_builder.room(room.into()); + self + } + + pub fn with_default_room(mut self) -> Self { + self.header_builder = self.header_builder.with_default_room(); + self + } + + pub fn build(self) -> Result { + let header = self + .header_builder + .message_id(Uuid::new_v4().to_string()) + .timestamp(current_timestamp()) + .build()?; + Ok(Join { + header: Some(header), + }) + } +} + +impl Default for JoinBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Builder for [`Leave`] message. +/// +/// This builder allows you to create a `Leave` message by specifying a username, +/// and optionally a room. If no room is specified, it defaults to "custom_room". +/// +/// The `build()` method will automatically fill in the `message_id` and `timestamp`. +/// +/// # Example (with default room) +/// ``` +/// use chat_contract::builders::LeaveBuilder; +/// +/// let leave = LeaveBuilder::new() +/// .username("r-zig") +/// .with_default_room() +/// .build() +/// .unwrap(); +/// +/// let header = leave.header.unwrap(); +/// assert_eq!(header.username, "r-zig"); +/// assert_eq!(header.room, "main"); +/// assert!(!header.message_id.is_empty()); +/// assert!(header.timestamp.is_some()); +/// ``` +pub struct LeaveBuilder { + header_builder: HeaderBuilder, +} + +impl LeaveBuilder { + pub fn new() -> Self { + Self { + header_builder: HeaderBuilder::new(), + } + } + + pub fn username(mut self, username: impl Into) -> Self { + self.header_builder = self.header_builder.username(username.into()); + self + } + + pub fn room(mut self, room: impl Into) -> Self { + self.header_builder = self.header_builder.username(room.into()); + self + } + + pub fn with_default_room(mut self) -> Self { + self.header_builder = self.header_builder.with_default_room(); + self + } + + pub fn build(self) -> Result { + let header = self + .header_builder + .message_id(Uuid::new_v4().to_string()) + .timestamp(current_timestamp()) + .build()?; + Ok(Leave { + header: Some(header), + }) + } +} + +impl Default for LeaveBuilder { + fn default() -> Self { + Self::new() + } +} +/// Builder for [`ChatMessage`] message. +/// +/// This builder allows you to create a `ChatMessage` by specifying a username and content, +/// and optionally a room. If no room is specified, it defaults to "custom_room". +/// +/// The `build()` method will automatically fill in the `message_id` and `timestamp`. +/// +/// # Example (with default room) +/// ``` +/// use chat_contract::builders::ChatMessageBuilder; +/// +/// let chat = ChatMessageBuilder::new() +/// .username("r-zig") +/// .with_default_room() +/// .content("Hello") +/// .build() +/// .unwrap(); +/// +/// let header = chat.header.unwrap(); +/// assert_eq!(header.username, "r-zig"); +/// assert_eq!(header.room, "main"); +/// assert!(header.timestamp.is_some()); +/// assert!(!header.message_id.is_empty()); +/// assert_eq!(chat.content, "Hello"); +/// ``` +pub struct ChatMessageBuilder { + header_builder: HeaderBuilder, + content: Option, +} + +impl ChatMessageBuilder { + pub fn new() -> Self { + Self { + header_builder: HeaderBuilder::new(), + content: None, + } + } + + pub fn username(mut self, username: impl Into) -> Self { + self.header_builder = self.header_builder.username(username.into()); + self + } + + pub fn room(mut self, room: impl Into) -> Self { + self.header_builder = self.header_builder.room(room.into()); + self + } + + pub fn with_default_room(mut self) -> Self { + self.header_builder = self.header_builder.with_default_room(); + self + } + + pub fn content(mut self, content: impl Into) -> Self { + self.content = Some(content.into()); + self + } + + pub fn build(self) -> Result { + let header = self + .header_builder + .message_id(Uuid::new_v4().to_string()) + .timestamp(current_timestamp()) + .build()?; + Ok(ChatMessage { + header: Some(header), + content: self.content.ok_or(ErrorCode::ContentRequired)?, + }) + } +} + +impl Default for ChatMessageBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Builder for [`ServerMessage`] message. +/// +/// This builder allows you to create a `ServerMessage` by specifying either an error +/// or a chat message, but not both. The `build()` method will enforce this constraint. +/// +/// # Example (with chat message) +/// ``` +/// use chat_contract::builders::{ServerMessageBuilder, ChatMessageBuilder}; +/// +/// let chat_message = ChatMessageBuilder::new() +/// .username("r-zig") +/// .content("Hello") +/// .with_default_room() +/// .build() +/// .unwrap(); +/// +/// let server_message = ServerMessageBuilder::new() +/// .chat_message(chat_message) +/// .build() +/// .unwrap(); +/// assert!(server_message.chat.is_some()); +/// assert!(server_message.error.is_none()); +/// ``` +/// +/// # Example (with error) +/// ``` +/// use chat_contract::builders::{ServerMessageBuilder}; +/// use chat_contract::chat::{Error, ErrorCode}; +/// +/// let error = Error { +/// message_id: "123".to_string(), +/// related_message_id: "456".to_string(), +/// r#type: 1, +/// code: ErrorCode::UsernameRequired as i32, +/// }; +/// +/// let server_message = ServerMessageBuilder::new() +/// .error(error) +/// .build() +/// .unwrap(); +/// assert!(server_message.error.is_some()); +/// assert!(server_message.chat.is_none()); +/// ``` +pub struct ServerMessageBuilder { + error: Option, + chat: Option, +} + +impl ServerMessageBuilder { + /// Creates a new `ServerMessageBuilder`. + pub fn new() -> Self { + Self { + error: None, + chat: None, + } + } + + /// Builds an ServerMessage `Error` message from the original `Header`. + /// + /// # Arguments + /// - `header`: The `Header` from the original message that failed to proceed. + /// - `error_type`: The type of message that caused the error (e.g., `MessageType::Join`). + /// - `error_code`: The specific error code (e.g., `ErrorCode::UsernameAlreadyTaken`). + /// + /// # Returns + /// An `Error` message populated with the relevant fields. + pub fn error_from_header( + self, + header: &Header, + error_type: MessageType, + error_code: ErrorCode, + ) -> Self { + self.error(Error { + message_id: uuid::Uuid::new_v4().to_string(), // Generate a unique ID for the error message + related_message_id: header.message_id.clone(), // Use the original message ID + r#type: error_type as i32, // Convert `MessageType` to its integer representation + code: error_code as i32, // Convert `ErrorCode` to its integer representation + }) + } + + /// Sets the error for the `ServerMessage`. + pub fn error(mut self, error: chat::Error) -> Self { + self.error = Some(error); + self + } + + /// Sets the chat message for the `ServerMessage`. + pub fn chat_message(mut self, chat_message: chat::ChatMessage) -> Self { + self.chat = Some(chat_message); + self + } + + /// Builds the `ServerMessage`. + /// + /// Returns an error if neither an error nor a chat message is set. + pub fn build(self) -> Result { + if self.error.is_some() && self.chat.is_some() { + return Err("ServerMessage cannot have both error and chat_message set"); + } + + if self.error.is_none() && self.chat.is_none() { + return Err("ServerMessage must have either error or chat_message set"); + } + + Ok(chat::ServerMessage { + error: self.error, + chat: self.chat, + }) + } +} + +impl Default for ServerMessageBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Builder for [`ClientMessage`] message. +/// +/// This builder allows you to create a `ClientMessage` by specifying either a join, +/// leave, or chat message payload. The `build()` method will enforce that at least +/// one payload is set. +/// +/// # Example (with join payload) +/// ``` +/// use chat_contract::builders::ClientMessageBuilder; +/// +/// let client_message = ClientMessageBuilder::new() +/// .join("r-zig", Some("main")) +/// .build() +/// .unwrap(); +/// assert!(client_message.payload.is_some()); +/// ``` +/// +/// # Example (with chat payload) +/// ``` +/// use chat_contract::builders::ClientMessageBuilder; +/// +/// let client_message = ClientMessageBuilder::new() +/// .chat("r-zig", "Hello", None) +/// .build() +/// .unwrap(); +/// assert!(client_message.payload.is_some()); +/// ``` +pub struct ClientMessageBuilder { + join_builder: Option, + leave_builder: Option, + chat_message_builder: Option, +} + +impl ClientMessageBuilder { + /// Creates a new `ClientMessageBuilder`. + pub fn new() -> Self { + Self { + join_builder: None, + leave_builder: None, + chat_message_builder: None, + } + } + + /// Sets the `Join` payload using a `JoinBuilder`. + pub fn join(mut self, username: impl Into, room: Option<&str>) -> Self { + let mut builder = JoinBuilder::new().username(username); + if let Some(room) = room { + builder = builder.room(room); + } else { + builder = builder.with_default_room(); + } + self.join_builder = Some(builder); + self + } + + /// Sets the `Leave` payload using a `LeaveBuilder`. + pub fn leave(mut self, username: impl Into, room: Option) -> Self { + let mut builder = LeaveBuilder::new().username(username); + if let Some(room) = room { + builder = builder.room(room); + } else { + builder = builder.with_default_room(); + } + self.leave_builder = Some(builder); + self + } + + /// Sets the `ChatMessage` payload using a `ChatMessageBuilder`. + pub fn chat( + mut self, + username: impl Into, + content: impl Into, + room: Option<&str>, + ) -> Self { + let mut builder = ChatMessageBuilder::new() + .username(username) + .content(content); + if let Some(room) = room { + builder = builder.room(room); + } else { + builder = builder.with_default_room(); + } + self.chat_message_builder = Some(builder); + self + } + + /// Builds the `ClientMessage`. + /// + /// Returns an error if no payload is set. + pub fn build(self) -> Result { + let payload = self + .join_builder + .map(|builder| chat::client_message::Payload::Join(builder.build().unwrap())) + .or_else(|| { + self.leave_builder + .map(|builder| chat::client_message::Payload::Leave(builder.build().unwrap())) + }) + .or_else(|| { + self.chat_message_builder + .map(|builder| chat::client_message::Payload::Chat(builder.build().unwrap())) + }); + + if let Some(payload) = payload { + Ok(chat::ClientMessage { + payload: Some(payload), + }) + } else { + Err("ClientMessage must have at least one payload set") + } + } +} + +impl Default for ClientMessageBuilder { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use crate::current_timestamp; + + use super::*; + #[test] + fn build_header_strict_success() { + let header = HeaderBuilder::new() + .username("r-zig") + .room("custom_room") + .message_id("custom-id-123") + .timestamp(current_timestamp()) + .build() + .expect("should build successfully"); + + assert_eq!(header.username, "r-zig"); + assert_eq!(header.room, "custom_room"); + assert_eq!(header.message_id, "custom-id-123"); + assert!(header.timestamp.is_some()); + } + + #[test] + fn build_header_strict_missing_username() { + let result = HeaderBuilder::new() + .room("custom_room") + .message_id("id") + .timestamp(current_timestamp()) + .build(); + + assert_eq!(result, Err(ErrorCode::UsernameRequired)); + } + + #[test] + fn build_header_strict_missing_room() { + let result = HeaderBuilder::new() + .username("r-zig") + .message_id("id") + .timestamp(current_timestamp()) + .build(); + + assert_eq!(result, Err(ErrorCode::RoomRequired)); + } + + #[test] + fn build_header_strict_missing_message_id() { + let result = HeaderBuilder::new() + .username("r-zig") + .room("custom_room") + .timestamp(current_timestamp()) + .build(); + + assert_eq!(result, Err(ErrorCode::MessageIdRequired)); + } + + #[test] + fn build_header_strict_missing_timestamp() { + let result = HeaderBuilder::new() + .username("r-zig") + .room("custom_room") + .message_id("id") + .build(); + + assert_eq!(result, Err(ErrorCode::TimestampRequired)); + } + + #[test] + fn build_header_with_defaults_success() { + let header = HeaderBuilder::new() + .username("r-zig") + .with_default_room() + .build_with_defaults() + .expect("should build successfully"); + + assert_eq!(header.username, "r-zig"); + assert_eq!(header.room, DEFAULT_ROOM); + assert!(header.message_id.len() > 10); + assert!(header.timestamp.is_some()); + } + + #[test] + fn build_header_with_defaults_missing_username() { + let result = HeaderBuilder::new() + .with_default_room() + .build_with_defaults(); + + assert_eq!(result, Err("username is required")); + } + + #[test] + fn build_header_with_defaults_missing_room() { + let result = HeaderBuilder::new().username("r-zig").build_with_defaults(); + + assert_eq!(result, Err("room is required")); + } + + #[test] + fn join_builder_sets_username_and_room() { + let join = JoinBuilder::new() + .username("r-zig") + .room("custom_room") + .build() + .unwrap(); + assert_eq!(join.header.as_ref().unwrap().username, "r-zig"); + assert_eq!(join.header.as_ref().unwrap().room, "custom_room"); + } + + #[test] + fn leave_builder_default_room() { + let leave = LeaveBuilder::new() + .username("r-zig") + .with_default_room() + .build() + .unwrap(); + assert_eq!(leave.header.as_ref().unwrap().room, DEFAULT_ROOM); + } + + #[test] + fn chat_message_builder_success() { + let chat = ChatMessageBuilder::new() + .username("r-zig") + .content("Hello") + .with_default_room() + .build() + .unwrap(); + assert_eq!(chat.content, "Hello"); + assert_eq!(chat.header.as_ref().unwrap().room, DEFAULT_ROOM); + } + + #[test] + fn chat_message_missing_content() { + let result = ChatMessageBuilder::new() + .username("r-zig") + .with_default_room() + .build(); + assert_eq!(result, Err(ErrorCode::ContentRequired)); + } + + #[test] + fn server_message_builder_with_chat_message() { + let chat_message = ChatMessageBuilder::new() + .username("r-zig") + .content("Hello") + .with_default_room() + .build() + .unwrap(); + + let server_message = ServerMessageBuilder::new() + .chat_message(chat_message) + .build() + .unwrap(); + + assert!(server_message.chat.is_some()); + assert!(server_message.error.is_none()); + } + + #[test] + fn server_message_builder_with_error() { + let error = chat::Error { + message_id: "123".to_string(), + related_message_id: "456".to_string(), + r#type: 1, + code: ErrorCode::UsernameRequired as i32, + }; + + let server_message = ServerMessageBuilder::new().error(error).build().unwrap(); + + assert!(server_message.error.is_some()); + assert!(server_message.chat.is_none()); + } + + #[test] + fn server_message_builder_error_when_both_set() { + let chat_message = ChatMessageBuilder::new() + .username("r-zig") + .content("Hello") + .with_default_room() + .build() + .unwrap(); + + let error = chat::Error { + message_id: "123".to_string(), + related_message_id: "456".to_string(), + r#type: 1, + code: ErrorCode::UsernameRequired as i32, + }; + + let result = ServerMessageBuilder::new() + .chat_message(chat_message) + .error(error) + .build(); + + assert_eq!( + result, + Err("ServerMessage cannot have both error and chat_message set") + ); + } + + #[test] + fn server_message_builder_error_when_none_set() { + let result = ServerMessageBuilder::new().build(); + + assert_eq!( + result, + Err("ServerMessage must have either error or chat_message set") + ); + } + + #[test] + fn server_message_builder_error_from_header() { + // Create a mock header + let header = Header { + username: "test_user".to_string(), + room: "test_room".to_string(), + message_id: "original-message-id".to_string(), + timestamp: Some(current_timestamp()), + }; + + // Build an error message using `error_from_header` + let server_message = ServerMessageBuilder::new() + .error_from_header(&header, MessageType::Join, ErrorCode::UsernameAlreadyTaken) + .build() + .unwrap(); + + // Assert that the error is correctly populated + let error = server_message.error.unwrap(); + assert_eq!(error.related_message_id, "original-message-id"); + assert_eq!(error.r#type, MessageType::Join as i32); + assert_eq!(error.code, ErrorCode::UsernameAlreadyTaken as i32); + assert!(!error.message_id.is_empty()); // Ensure a unique message ID is generated + } + + #[test] + fn build_client_message_with_join() { + let join_message = ClientMessageBuilder::new() + .join("test_user", Some("test_room")) + .build() + .unwrap(); + + assert!(matches!( + join_message.payload, + Some(chat::client_message::Payload::Join(_)) + )); + } + + #[test] + fn build_client_message_with_leave() { + let leave_message = ClientMessageBuilder::new() + .leave("test_user", None) // Defaults to the default room + .build() + .unwrap(); + + assert!(matches!( + leave_message.payload, + Some(chat::client_message::Payload::Leave(_)) + )); + } + + #[test] + fn build_client_message_with_chat() { + let chat_message = ClientMessageBuilder::new() + .chat("test_user", "Hello, world!", Some("test_room")) + .build() + .unwrap(); + + assert!(matches!( + chat_message.payload, + Some(chat::client_message::Payload::Chat(_)) + )); + } + + #[test] + fn build_client_message_without_payload() { + let result = ClientMessageBuilder::new().build(); + assert_eq!( + result, + Err("ClientMessage must have at least one payload set") + ); + } +} diff --git a/chat-contract/src/lib.rs b/chat-contract/src/lib.rs new file mode 100644 index 0000000..a0b823b --- /dev/null +++ b/chat-contract/src/lib.rs @@ -0,0 +1,115 @@ +// Include the generated Rust code from the .proto definitions +pub mod chat { + include!(concat!(env!("OUT_DIR"), "/chat.rs")); +} + +use chat::{ClientMessage, Header, ServerMessage}; +use prost::Message; +pub use prost_types::Timestamp; +use std::time::{SystemTime, UNIX_EPOCH}; +use uuid::Uuid; +pub mod builders; + +/// Create a new Header with current timestamp and a random message ID +pub fn generate_header(username: &str, room: &str) -> Header { + Header { + message_id: Uuid::new_v4().to_string(), + username: username.to_owned(), + room: room.to_owned(), + timestamp: Some(current_timestamp()), + } +} + +/// Return the current time as a prost_types::Timestamp +pub fn current_timestamp() -> Timestamp { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + Timestamp { + seconds: now.as_secs() as i64, + nanos: now.subsec_nanos() as i32, + } +} + +/// Validate a Header according to current rules: +/// - username must not be empty +/// - room must not be empty +/// - timestamp must be present +pub fn validate_header(header: &Header) -> Result<(), &'static str> { + if header.username.trim().is_empty() { + return Err("Username is empty"); + } + if header.room.trim().is_empty() { + return Err("Room name is empty"); + } + if header.timestamp.is_none() { + return Err("Missing timestamp"); + } + Ok(()) +} + +impl TryFrom for Vec { + type Error = prost::EncodeError; + fn try_from(value: ServerMessage) -> Result { + let mut buf = Vec::new(); + value.encode_length_delimited(&mut buf)?; + Ok(buf) + } +} + +impl TryFrom for Vec { + type Error = prost::EncodeError; + fn try_from(value: ClientMessage) -> Result { + let mut buf = Vec::new(); + value.encode_length_delimited(&mut buf)?; + Ok(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_header_has_values() { + let username = "r-zig"; + let room = "general"; + let header = generate_header(username, room); + + assert_eq!(header.username, username); + assert_eq!(header.room, room); + assert!(header.message_id.len() > 10, "message_id should be a UUID"); + assert!(header.timestamp.is_some(), "timestamp should be set"); + } + + #[test] + fn test_validate_header_success() { + let header = generate_header("r-zig", "general"); + let result = validate_header(&header); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_header_empty_username() { + let mut header = generate_header("", "general"); + header.username.clear(); // explicitly empty + let result = validate_header(&header); + assert_eq!(result, Err("Username is empty")); + } + + #[test] + fn test_validate_header_empty_room() { + let mut header = generate_header("r-zig", ""); + header.room.clear(); // explicitly empty + let result = validate_header(&header); + assert_eq!(result, Err("Room name is empty")); + } + + #[test] + fn test_validate_header_missing_timestamp() { + let mut header = generate_header("r-zig", "general"); + header.timestamp = None; + let result = validate_header(&header); + assert_eq!(result, Err("Missing timestamp")); + } +} diff --git a/client/Cargo.toml b/client/Cargo.toml new file mode 100644 index 0000000..87d62cf --- /dev/null +++ b/client/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "chat-client" +version = "0.1.0" +edition = "2021" +authors = ["Ron Zigelman "] + +[dependencies] +anyhow = "1.0.98" +chat-contract = { version = "0.1.0", path = "../chat-contract" } +clap = { version = "4.5.37", features = ["derive"] } +futures = "0.3.31" +protobuf-stream = { git = "https://github.com/r-zig/protobuf-stream.git", version = "0.1.3" } +quinn = "0.11.8" +quinn-proto = "0.11.11" +rustls = { version = "0.23.5", default-features = false, features = ["std"] } +rustls-pki-types = "1.7" +tokio = { version = "1.45.0", features = ["rt-multi-thread", "macros", "time" ,"signal", "io-std"] } +tracing = "0.1.41" +tracing-subscriber = { version = "0.3.19", features = ["registry", "env-filter"] } diff --git a/client/src/main.rs b/client/src/main.rs new file mode 100644 index 0000000..a50f24f --- /dev/null +++ b/client/src/main.rs @@ -0,0 +1,367 @@ +use anyhow::Result; +use clap::Parser; +use futures::StreamExt; +use protobuf_stream::protobuf_stream::{ProtobufStream, ProtobufStreamError}; +use quinn::{ + crypto::rustls::QuicClientConfig, rustls::pki_types::CertificateDer, ClientConfig, Endpoint, + RecvStream, SendStream, +}; +use tokio::{ + io::{self, AsyncBufReadExt, BufReader}, + sync::mpsc::Receiver, +}; +use tracing_subscriber::{fmt, layer::SubscriberExt, EnvFilter, Registry}; + +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, +}; +use tokio::sync::Mutex; +use tracing::{debug, error, info}; + +use chat_contract::{ + builders::ClientMessageBuilder, + chat::{client_message::Payload, ClientMessage, ServerMessage}, +}; + +#[derive(Parser, Debug)] +#[clap(name = "client")] +pub struct Opt { + /// Server address to connect to + #[clap(long = "server-addr")] + server_addr: SocketAddr, + + /// Server name for identification + #[clap(long = "server-name")] + server_name: String, + + /// Client name for identification + #[clap(long = "name")] + name: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + let app_name = concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION")).to_string(); + let subscriber = Registry::default() + .with(EnvFilter::from_default_env()) + .with(fmt::layer()); + tracing::subscriber::set_global_default(subscriber).unwrap(); + println!("Starting {}", app_name); + println!("Created by: {}", env!("CARGO_PKG_AUTHORS")); + + let options = Opt::parse(); + + // Create a QUIC client endpoint + let endpoint = create_client_endpoint()?; + info!("Connecting to server at {}", options.server_addr); + let connection = endpoint + .connect(options.server_addr, &options.server_name) + .unwrap() + .await?; + println!("Connected to server at {}", options.server_addr); + + let (send_stream, mut recv_stream) = connection.open_bi().await.unwrap(); + let send_stream = Arc::new(tokio::sync::Mutex::new(send_stream)); + + let send_stream_clone = send_stream.clone(); + // Send a join message + send_join(&options.name, send_stream.clone()).await?; + + // Create a channel for sending messages from the input handler to the sender + let (message_tx, message_rx) = tokio::sync::mpsc::channel::(100); + + // Spawn a task to handle user input + let user_name = options.name.clone(); + + // Wait for all tasks to complete or Ctrl+C + tokio::select! { + _ = handle_user_input(&user_name, message_tx) => { + debug!("Input task completed"); + } + _ = send_message(send_stream,message_rx) => { + debug!("Send task completed"); + } + _ = read_logic(&mut recv_stream) => { + debug!("Read task completed"); + } + _ = tokio::signal::ctrl_c() => { + debug!("Ctrl+C received, shutting down..."); + send_leave(&options.name, send_stream_clone).await?; + // let the server receive the leave message + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + } + + connection.close(0u32.into(), b"done"); + endpoint.wait_idle().await; + Ok(()) +} + +fn create_client_endpoint() -> Result { + let mut endpoint = Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; + + endpoint.set_default_client_config(ClientConfig::new(Arc::new(QuicClientConfig::try_from( + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(), + )?))); + Ok(endpoint) +} + +async fn send_message( + send_stream: Arc>, + mut message_rx: Receiver, +) -> Result<()> { + let mut should_shutdown = false; + while let Some(message) = message_rx.recv().await { + if let Some(payload) = message.payload.clone() { + match payload { + Payload::Leave(_) => { + debug!("Leave message sent, will start shutdown"); + should_shutdown = true; + } + _ => { + debug!("Other message sent"); + } + } + } + + match Vec::::try_from(message) { + Ok(buf) => { + if let Err(e) = send_stream.lock().await.write_all(&buf).await { + error!("Failed to send message: {:?}", e); + } else { + debug!("Message sent successfully"); + } + } + Err(e) => { + error!("Failed to encode message: {:?}", e); + } + } + if should_shutdown { + // let the server receive the leave message + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + debug!("Leave message sent, shutting down..."); + break; + } + } + Ok(()) +} + +async fn read_logic(recv_stream: &mut RecvStream) -> Result<()> { + println!("Waiting for messages..."); + // Read the server's response + let reader = BufReader::new(recv_stream); + let mut stream = ProtobufStream::<_, ServerMessage>::new(reader); + + // Loop over the stream to handle multiple messages + loop { + let response = stream.next().await; + match response { + Some(Ok(msg)) => { + if let Some(e) = msg.error { + println!("Server error occurred: {:?}", e); + } else { + match msg.chat { + Some(message) => { + println!( + "Received message from: {}", + message + .header + .map(|h| h.username) + .unwrap_or("Unknown".to_string()) + ); + println!("{}", message.content); + } + None => { + println!("Received empty message"); + } + } + } + } + Some(Err(ProtobufStreamError::Recoverable { code, source })) => { + // Handle the specific "Pending" error + debug!( + "Stream operation is pending: {:?}, error: {:?}, continue to next item", + code, source + ); + continue; + } + Some(Err(ProtobufStreamError::NonRecoverable { code, source })) => { + error!( + "Stream operation failed. code {:?}, error: {:?}", + code, source + ); + println!("Non recoverable error occurred {:?}", code); + break; + } + Some(Err(ProtobufStreamError::Other { + message, + code, + source, + })) => { + error!( + "Stream operation failed. code: {:?}, message: {}, error: {:?}", + code, message, source + ); + println!("Other Error occurred {:?}", code); + break; + } + None => { + // Stream ended + debug!("Stream ended"); + break; + } + } + } + debug!("Read logic task completed"); + Ok(()) +} + +async fn handle_user_input(user: &str, tx: tokio::sync::mpsc::Sender) { + let mut reader = BufReader::new(io::stdin()).lines(); + + println!("Enter commands (e.g., 'send ' or 'leave'):"); + + loop { + // Read user input + let line = reader.next_line().await; + match line { + Ok(Some(line)) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + + // Split the input into command and the rest of the string + let (command, rest) = match trimmed.split_once(' ') { + Some((cmd, rest)) => (cmd, rest), + None => (trimmed, ""), // If there's no space, treat the whole input as the command + }; + + let message = match command { + "send" => { + let content = rest.to_string(); + ClientMessageBuilder::new() + .chat(user, content, None) + .build() + } + "leave" => ClientMessageBuilder::new().leave(user, None).build(), + _ => { + println!("Unknown command: {}", command); + continue; + } + }; + + if let Ok(msg) = message { + if tx.send(msg).await.is_err() { + error!("Failed to send message to sender task"); + break; + } + } + + if command == "leave" { + println!("Leaving the connection..."); + } else { + println!("Enter commands (e.g., 'send ' or 'leave'):"); + } + } + Ok(None) => { + // EOF reached + break; + } + Err(e) => { + error!("Error reading input: {:?}", e); + break; + } + } + } + debug!("User input task completed"); +} + +async fn send_join(name: &str, send_stream: Arc>) -> Result<()> { + let client_message = ClientMessageBuilder::new() + .join(name, None) + .build() + .unwrap(); + + // Send the `Join` message - encode it to bytes with length prefix + let buf: Vec = client_message.try_into().unwrap(); + + send_stream.lock().await.write_all(&buf).await?; + + println!("Sent join message as {}", name); + Ok(()) +} + +async fn send_leave(name: &str, send_stream: Arc>) -> Result<()> { + let client_message = ClientMessageBuilder::new() + .leave(name, None) + .build() + .unwrap(); + + // Send the `Join` message - encode it to bytes with length prefix + let buf: Vec = client_message.try_into().unwrap(); + + send_stream.lock().await.write_all(&buf).await?; + println!("User {} send leave message", name); + Ok(()) +} + +/// Dummy certificate verifier that treats any certificate as valid. +/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing. +#[derive(Debug)] +struct SkipServerVerification(Arc); + +impl SkipServerVerification { + fn new() -> Arc { + Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider()))) + } +} + +impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + _ocsp: &[u8], + _now: rustls_pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() + } +} diff --git a/proto/chat.proto b/proto/chat.proto new file mode 100644 index 0000000..ac07b30 --- /dev/null +++ b/proto/chat.proto @@ -0,0 +1,134 @@ +syntax = "proto3"; + +package chat; + +import "google/protobuf/timestamp.proto"; + +// +// Represents the type of client message that caused an error. +// +enum MessageType { + // Default value if the message type is unknown. + UNKNOWN = 0; + + // Client attempted to join a chat room. + JOIN = 1; + + // Client attempted to leave a chat room. + LEAVE = 2; + + // Client attempted to send a chat message. + CHAT = 3; +} + +// +// Shared metadata included in all messages sent from the client. +// +message Header { + // Unique ID for matching messages and errors. + string message_id = 1; + + // The unique username of the client. + string username = 2; + + // The time the message was sent. + google.protobuf.Timestamp timestamp = 3; + + // The name of the chat room (used for routing). + string room = 4; +} + +// +// Message sent by the client to join a chat room. +// +message Join { + // Includes the username and target room. + Header header = 1; +} + +// +// Message sent by the client to leave a chat room. +// +message Leave { + // Includes the username and target room. + Header header = 1; +} + +// +// Message sent by the client to broadcast a message to a chat room. +// +message ChatMessage { + // Includes the username and target room. + Header header = 1; + + // The actual content of the chat message. + string content = 2; +} + +enum ErrorCode { + UNKNOWN_ERROR = 0; + + // Client sent join with empty username + USERNAME_REQUIRED = 1; + + // Client tried to join with a name that's in use + USERNAME_ALREADY_TAKEN = 2; + + // Client sent chat/leave but hasn't joined yet + USER_NOT_REGISTERED = 3; + + // Trying to send message without room + ROOM_REQUIRED = 10; + + // Trying to send message with non existing room + ROOM_NOT_FOUND = 11; + + // Trying to send message without content + CONTENT_REQUIRED = 20; + + // Trying to send message without message id + MESSAGE_ID_REQUIRED = 30; + + // Message must contain timestamp + TIMESTAMP_REQUIRED = 40; +} + +// +// Message sent from the server to indicate an error with a client request. +// +message Error { + // The unique ID of this error message. + string message_id = 1; + + // The message ID of the original client message that this error relates to. + string related_message_id = 2; + + // The type of message that triggered the error (e.g., JOIN, CHAT). + MessageType type = 3; + + // Error codes + ErrorCode code = 4; +} + +// +// Top-level message for all client→server communications. +// +message ClientMessage { + // The type of message the client is sending. + oneof payload { + Join join = 1; + Leave leave = 2; + ChatMessage chat = 3; + } +} + +// +// Top-level message for all server→client communications. +// +message ServerMessage { + // A message from another client (or self) within the same room. + ChatMessage chat = 1; + + // An error message sent only to the client that caused it. + Error error = 2; +} \ No newline at end of file diff --git a/scripts/install-hooks.sh b/scripts/install-hooks.sh new file mode 100755 index 0000000..7f58174 --- /dev/null +++ b/scripts/install-hooks.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +HOOK_SRC=".hooks/pre-commit" +HOOK_DEST=".git/hooks/pre-commit" + +if [ ! -f "$HOOK_SRC" ]; then + echo "❌ Hook file $HOOK_SRC does not exist. Aborting." + exit 1 +fi + +echo "🔗 Linking $HOOK_DEST to $HOOK_SRC..." +ln -s "$(pwd)/.hooks/pre-commit" .git/hooks/pre-commit +chmod +x $HOOK_SRC +echo "✅ Git pre-commit hook installed." \ No newline at end of file diff --git a/server/Cargo.toml b/server/Cargo.toml new file mode 100644 index 0000000..7e43bfd --- /dev/null +++ b/server/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "chat-server" +version = "0.1.0" +edition = "2021" +authors = ["Ron Zigelman "] + +[[bin]] +name = "server" +path = "src/main.rs" + +[dependencies] +actix = "0.13.5" +chat-contract = { version = "0.1.0", path = "../chat-contract" } +quinn = "0.11.7" +protobuf-stream = { git = "https://github.com/r-zig/protobuf-stream.git", version = "0.1.3" } +async-trait = "0.1.88" +prost = "0.12" +tokio = "1.44.2" +tracing = "0.1.41" +futures = "0.3.31" +clap = { version = "4.5.37", features = ["derive"] } +anyhow = "1.0.98" +rustls = { version = "0.23.5", default-features = false, features = ["std"] } +rustls-pemfile = "2" +rustls-platform-verifier = "0.5" +rustls-pki-types = "1.7" +quinn-proto = "0.11.11" +rcgen = "=0.13.2" +tracing-futures = "0.2.5" +tracing-subscriber = { version = "0.3.19", features = ["registry", "env-filter"] } diff --git a/server/src/actors/chat_actor.rs b/server/src/actors/chat_actor.rs new file mode 100644 index 0000000..2d12340 --- /dev/null +++ b/server/src/actors/chat_actor.rs @@ -0,0 +1,151 @@ +use crate::messages::{ChatMessage, Join, Leave}; +use crate::transport::TransportSender; +use actix::prelude::*; +use chat_contract::builders::{ChatMessageBuilder, ServerMessageBuilder}; +use chat_contract::chat::{ErrorCode, MessageType, ServerMessage}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::{debug, error, warn}; + +pub struct ChatActor { + users: HashMap>>, // Username -> TransportSender +} + +impl Default for ChatActor { + fn default() -> Self { + Self::new() + } +} +impl ChatActor { + pub fn new() -> Self { + Self { + users: HashMap::new(), + } + } + + /// Broadcast message to all users in the room except the excluded user. + fn broadcast(&mut self, msg: ServerMessage, exclude_user: Option<&str>) { + self.users + .iter() + .filter(|(user, _)| exclude_user.map_or(true, |exclude| user != &exclude)) // Exclude the user if specified + .for_each(|(user, user_sender)| { + let user_sender = user_sender.clone(); // Clone the sender for async move + let msg = msg.clone(); // Clone the message for async move + let user = user.clone(); // Clone the username for logging + + actix::spawn(async move { + if let Err(e) = user_sender.lock().await.send_message(msg).await { + error!("Failed to send message to {}: {}", user, e); + } + }); + }); + } +} + +// Implement the Actor trait for ChatActor +impl Actor for ChatActor { + type Context = Context; +} + +impl Handler> for ChatActor { + type Result = (); + + fn handle(&mut self, msg: Join, _: &mut Context) { + let username = msg.header.as_ref().unwrap().username.clone(); + + match self.users.contains_key(&username) { + true => { + warn!("User {} already exists", username); + // Build the error message + let server_message = ServerMessageBuilder::new() + .error_from_header( + msg.header.as_ref().unwrap(), + MessageType::Join, + ErrorCode::UsernameAlreadyTaken, + ) + .build() + .unwrap(); + + // Send the error message back to the client + let sender = msg.sender.clone(); + actix::spawn(async move { + if let Err(e) = sender.lock().await.send_message(server_message).await { + error!("Failed to send error message to {}: {}", username, e); + } + }); + } + false => { + // Add the user to the users map + let sender = msg.sender.clone(); + // Associate the user with their TransportSender + self.users.insert(username.clone(), sender); + + debug!("User {} joined", username); + let server_message = ServerMessageBuilder::new() + .chat_message( + ChatMessageBuilder::new() + .username(username.clone()) + .with_default_room() + .content(format!("User {} has joined the room", username)) + .build() + .unwrap(), + ) + .build() + .unwrap(); + self.broadcast(server_message, Some(&username)); + } + } + } +} + +impl Handler for ChatActor { + type Result = (); + + fn handle(&mut self, msg: Leave, _: &mut Context) { + let username = msg.header.as_ref().unwrap().username.clone(); + + // Remove the user from the room + if self.users.remove(&username).is_some() { + let server_message = ServerMessageBuilder::new() + .chat_message( + ChatMessageBuilder::new() + .username(username.clone()) + .with_default_room() + .content(format!("User {} has left the room", username)) + .build() + .unwrap(), + ) + .build() + .unwrap(); + self.broadcast(server_message, Some(&username)); + + debug!( + "User {} left. room contains {} users", + username, + self.users.len() + ); + } + } +} + +impl Handler for ChatActor { + type Result = (); + + fn handle(&mut self, msg: ChatMessage, _: &mut Context) { + let username = msg.header.as_ref().unwrap().username.clone(); + let room = msg.header.as_ref().unwrap().room.clone(); + let content = msg.content.clone(); + + debug!( + "User {} sent message in room {}: {}", + username, room, content + ); + + let server_message = ServerMessageBuilder::new() + .chat_message(msg.inner().clone()) + .build() + .unwrap(); + self.broadcast(server_message, Some(&username)); + } +} diff --git a/server/src/actors/mod.rs b/server/src/actors/mod.rs new file mode 100644 index 0000000..e2498d4 --- /dev/null +++ b/server/src/actors/mod.rs @@ -0,0 +1 @@ +pub(crate) mod chat_actor; diff --git a/server/src/main.rs b/server/src/main.rs new file mode 100644 index 0000000..b15e826 --- /dev/null +++ b/server/src/main.rs @@ -0,0 +1,216 @@ +use std::time::Duration; +use std::{error::Error, net::SocketAddr, sync::Arc}; +mod actors; +mod messages; +mod transport; +use crate::actors::chat_actor::ChatActor; +use crate::messages::{ChatMessage, Join, Leave}; +use crate::transport::TransportReceiver; +use crate::transport::{QuinnTransportReceiver, QuinnTransportSender}; +use actix::{Actor, Addr}; +use chat_contract::chat; +use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; + +use quinn::{Endpoint, IdleTimeout, ServerConfig, VarInt}; +use tokio::sync::Mutex; +use tracing::{debug, error, info, info_span}; +use tracing_futures::Instrument; +use tracing_subscriber::fmt; +use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry}; + +use anyhow::Result; +use clap::Parser; + +#[actix::main] +async fn main() -> Result<()> { + let app_name = concat!(env!("CARGO_PKG_NAME"), "-", env!("CARGO_PKG_VERSION")).to_string(); + let subscriber = Registry::default() + .with(EnvFilter::from_default_env()) + .with(fmt::layer()); + tracing::subscriber::set_global_default(subscriber).unwrap(); + println!("Starting {}", app_name); + println!("Created by: {}", env!("CARGO_PKG_AUTHORS")); + + let options = Opt::parse(); + run(options) + .await + .inspect_err(|e| { + error!("Error: {reason}", reason = e.to_string()); + }) + .unwrap_or_else(|_| { + println!("Server stopped"); + }); + Ok(()) +} + +#[derive(Parser, Debug)] +#[clap(name = "server")] +pub struct Opt { + /// Enable stateless retries + #[clap(long = "stateless-retry")] + stateless_retry: bool, + /// Address to listen on + #[clap(long = "listen", default_value = "127.0.0.1:4433")] + listen: SocketAddr, + /// Client address to block + #[clap(long = "block")] + block: Option, + /// Maximum number of concurrent connections to allow + #[clap(long = "connection-limit")] + connection_limit: Option, + + #[clap(long = "max-uni_streams")] + max_concurrent_uni_streams: Option, + + #[clap(long = "max-bidi-streams")] + max_concurrent_bidi_streams: Option, + + #[clap(long = "max_idle_timeout", default_value = "30")] + max_idle_timeout: Option, + + #[clap(long = "keep_alive_interval", default_value = "10")] + keep_alive_interval: Option, +} + +pub async fn run(options: Opt) -> Result<()> { + let (endpoint, _server_cert) = make_server_endpoint(&options).unwrap(); + info!("listening on {}", endpoint.local_addr()?); + let chat_actor = ChatActor::::new().start(); + + while let Some(conn) = endpoint.accept().await { + if options + .connection_limit + .is_some_and(|n| endpoint.open_connections() >= n) + { + info!("refusing due to open connection limit"); + conn.refuse(); + } else if Some(conn.remote_address()) == options.block { + info!("refusing blocked client IP address"); + conn.refuse(); + } else if options.stateless_retry && !conn.remote_address_validated() { + info!("requiring connection to validate its address"); + conn.retry().unwrap(); + } else { + info!("accepting connection"); + let fut = handle_connection(conn, chat_actor.clone()); + tokio::spawn(async move { + if let Err(e) = fut.await { + error!("connection failed: {reason}", reason = e.to_string()) + } + }); + } + } + + Ok(()) +} + +async fn handle_connection( + conn: quinn::Incoming, + chat_actor: Addr>, +) -> Result<()> { + let connection = conn.await?; + let span = info_span!( + "connection", + remote = %connection.remote_address(), + protocol = %connection + .handshake_data() + .unwrap() + .downcast::().unwrap() + .protocol + .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()) + ); + async { + info!("established"); + + // Each stream initiated by the client constitutes a new request. + loop { + let stream = connection.accept_bi().await; + let stream = match stream { + Err(quinn::ConnectionError::ApplicationClosed { .. }) => { + info!("connection closed"); + return Ok(()); + } + Err(e) => { + error!("error accepting stream: {reason}", reason = e.to_string()); + return Err(e); + } + Ok(s) => { + debug!("Accepted a bidirectional stream"); + s + } + }; + + // Create the TransportSender and TransportReceiver + let sender = Arc::new(Mutex::new(QuinnTransportSender::new(stream.0))); + let mut receiver = QuinnTransportReceiver::new( + stream.1, + create_message_handler(chat_actor.clone(), sender.clone()), // Use the message handler + ); + receiver.start().await.unwrap(); + } + } + .instrument(span) + .await?; + Ok(()) +} + +pub fn create_message_handler( + chat_actor: Addr>, + sender: Arc>, +) -> Box { + Box::new(move |client_message: chat::ClientMessage| { + if let Some(payload) = client_message.payload { + match payload { + chat::client_message::Payload::Join(join) => { + chat_actor.do_send(Join::new(join, sender.clone())); + } + chat::client_message::Payload::Leave(leave) => { + chat_actor.do_send(Leave::new(leave)); + } + chat::client_message::Payload::Chat(chat_message) => { + chat_actor.do_send(ChatMessage::new(chat_message)); + } + } + } else { + error!("ClientMessage payload is empty"); + } + }) +} + +/// Returns default server configuration along with its certificate. +fn configure_server( + option: &Opt, +) -> Result<(ServerConfig, CertificateDer<'static>), Box> { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_der = CertificateDer::from(cert.cert); + let priv_key = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der()); + + let mut server_config = + ServerConfig::with_single_cert(vec![cert_der.clone()], priv_key.into())?; + let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); + if let Some(max_concurrent_uni_streams) = option.max_concurrent_uni_streams { + transport_config.max_concurrent_uni_streams(VarInt::from_u64(max_concurrent_uni_streams)?); + } + if let Some(max_concurrent_bidi_streams) = option.max_concurrent_bidi_streams { + transport_config + .max_concurrent_bidi_streams(VarInt::from_u64(max_concurrent_bidi_streams)?); + } + if let Some(max_idle_timeout) = option.max_idle_timeout { + transport_config.max_idle_timeout(Some(IdleTimeout::try_from(Duration::from_secs( + max_idle_timeout, + ))?)); + } + if let Some(keep_alive_interval) = option.keep_alive_interval { + transport_config.keep_alive_interval(Some(Duration::from_secs(keep_alive_interval))); + } + + Ok((server_config, cert_der)) +} + +pub fn make_server_endpoint( + options: &Opt, +) -> Result<(Endpoint, CertificateDer<'static>), Box> { + let (server_config, server_cert) = configure_server(options)?; + let endpoint = Endpoint::server(server_config, options.listen)?; + Ok((endpoint, server_cert)) +} diff --git a/server/src/messages.rs b/server/src/messages.rs new file mode 100644 index 0000000..52c4847 --- /dev/null +++ b/server/src/messages.rs @@ -0,0 +1,105 @@ +use actix::Message; +use chat_contract::chat; +use std::{ + ops::{Deref, DerefMut}, + sync::Arc, +}; +use tokio::sync::Mutex; // Protobuf types + +use crate::transport::TransportSender; + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +pub struct Join +where + T: TransportSender + Clone + Unpin, +{ + inner: chat::Join, + pub sender: Arc>, +} + +impl Join +where + T: TransportSender + Clone + Unpin, +{ + pub fn new(inner: chat::Join, sender: Arc>) -> Self { + Self { inner, sender } + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +pub struct Leave { + inner: chat::Leave, +} +impl Leave { + pub(crate) fn new(leave: chat::Leave) -> Self { + Self { inner: leave } + } +} + +#[derive(Message, Debug, Clone)] +#[rtype(result = "()")] +pub struct ChatMessage { + inner: chat::ChatMessage, +} + +impl ChatMessage { + pub(crate) fn new(chat_message: chat::ChatMessage) -> Self { + Self { + inner: chat_message, + } + } + + pub fn inner(&self) -> &chat::ChatMessage { + &self.inner + } +} +// Implement Deref to expose fields of the inner Protobuf type +impl Deref for Join +where + T: TransportSender + Clone + Unpin, +{ + type Target = chat::Join; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for Join +where + T: TransportSender + Clone + Unpin, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl Deref for Leave { + type Target = chat::Leave; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for Leave { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl Deref for ChatMessage { + type Target = chat::ChatMessage; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for ChatMessage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} diff --git a/server/src/transport/mod.rs b/server/src/transport/mod.rs new file mode 100644 index 0000000..3bcb6a1 --- /dev/null +++ b/server/src/transport/mod.rs @@ -0,0 +1,22 @@ +use std::fmt::Debug; + +use async_trait::async_trait; +use chat_contract::chat; +mod quic; + +pub use quic::{QuinnTransportReceiver, QuinnTransportSender}; + +#[async_trait] +pub trait TransportSender: Send + Sync + 'static + Debug { + /// Sends a Protobuf `ServerMessage` to the client. + async fn send_message( + &mut self, + message: chat::ServerMessage, + ) -> Result<(), Box>; +} + +#[async_trait] +pub trait TransportReceiver: Send + Sync + 'static { + /// Starts listening for incoming `ClientMessage` and forwards them to the provided handler. + async fn start(&mut self) -> Result<(), Box>; +} diff --git a/server/src/transport/quic.rs b/server/src/transport/quic.rs new file mode 100644 index 0000000..1b5b0e9 --- /dev/null +++ b/server/src/transport/quic.rs @@ -0,0 +1,97 @@ +use async_trait::async_trait; +use chat_contract::chat; +use futures::StreamExt; +use protobuf_stream::{ + self, + protobuf_stream::{ProtobufStream, ProtobufStreamError}, +}; +use quinn::{RecvStream, SendStream}; +use std::sync::Arc; +use tokio::{io::BufReader, sync::Mutex}; +use tracing::{debug, error, warn}; + +use super::{TransportReceiver, TransportSender}; + +#[derive(Debug, Clone)] +pub struct QuinnTransportSender { + send_stream: Arc>, +} + +impl QuinnTransportSender { + pub fn new(send_stream: SendStream) -> Self { + Self { + send_stream: Arc::new(Mutex::new(send_stream)), + } + } +} + +#[async_trait] +impl TransportSender for QuinnTransportSender { + async fn send_message( + &mut self, + message: chat::ServerMessage, + ) -> Result<(), Box> { + let buf: Vec = message.try_into()?; + self.send_stream.lock().await.write_all(&buf).await?; + Ok(()) + } +} + +pub struct QuinnTransportReceiver { + recv_stream: Option, + message_handler: Option>, +} + +impl QuinnTransportReceiver { + pub fn new( + recv_stream: RecvStream, + message_handler: Box, + ) -> Self { + Self { + recv_stream: Some(recv_stream), + message_handler: Some(message_handler), + } + } +} + +#[async_trait] +impl TransportReceiver for QuinnTransportReceiver { + async fn start(&mut self) -> Result<(), Box> { + let recv_stream = self + .recv_stream + .take() + .ok_or("Receiver stream already taken")?; + let message_handler = self + .message_handler + .take() + .ok_or("Message handler already taken")?; + + tokio::spawn(async move { + let mut client_stream = ProtobufStream::new(BufReader::new(recv_stream)); + while let Some(result) = client_stream.next().await { + match result { + Ok(msg) => message_handler(msg), + Err(ProtobufStreamError::Recoverable { code, source }) => { + warn!("Stream operation failed with recoverable error with error code: {:?}, error: {:?}, will continue to fetch messages", code, source); + continue; + } + Err(ProtobufStreamError::NonRecoverable { code, source }) => { + error!("Stream operation failed with non recoverable error and will stop. error code: {:?}, error: {:?}", code, source); + break; + } + Err(ProtobufStreamError::Other { + message, + code, + source, + }) => { + error!("Stream operation failed with other error and will stop. error message: {:?} error code: {:?}, error: {:?}",message, code, source); + break; + } + } + } + debug!("Stream closed"); + }); + + Ok(()) + } +} diff --git a/server/tests/server_test.rs b/server/tests/server_test.rs new file mode 100644 index 0000000..b0f3377 --- /dev/null +++ b/server/tests/server_test.rs @@ -0,0 +1,274 @@ +use chat_contract::{ + builders::ClientMessageBuilder, + chat::{ErrorCode, MessageType, ServerMessage}, +}; +use futures::StreamExt; +use protobuf_stream::protobuf_stream::{ProtobufStream, ProtobufStreamError}; +use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, Endpoint}; +use rustls_pki_types::CertificateDer; +use std::{ + error::Error, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::UNIX_EPOCH, +}; +use std::{sync::Arc, time::SystemTime}; +use tokio::io::BufReader; + +const DEFAULT_SERVER_ADDRESS: &str = "127.0.0.1:4433"; + +async fn connect_to_server() -> Result<(quinn::Connection, Endpoint), Box> { + let mut endpoint = Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; + + endpoint.set_default_client_config(ClientConfig::new(Arc::new(QuicClientConfig::try_from( + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(SkipServerVerification::new()) + .with_no_client_auth(), + )?))); + + // Connect to the server + let connection = endpoint + .connect(DEFAULT_SERVER_ADDRESS.parse().unwrap(), "localhost") + .unwrap() + .await?; + Ok((connection, endpoint)) +} + +/// Dummy certificate verifier that treats any certificate as valid. +/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing. +#[derive(Debug)] +struct SkipServerVerification(Arc); + +impl SkipServerVerification { + fn new() -> Arc { + Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider()))) + } +} + +impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + _ocsp: &[u8], + _now: rustls_pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.0.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.signature_verification_algorithms.supported_schemes() + } +} + +#[tokio::test] +async fn test_server_accepts_connection() { + let (connection, endpoint) = connect_to_server() + .await + .expect("Failed to connect to server"); + let _ = connection.open_bi().await.unwrap(); + connection.close(0u32.into(), b"done"); + endpoint.wait_idle().await; +} + +#[tokio::test] +async fn test_server_accepts_join() { + let (connection, endpoint) = connect_to_server() + .await + .expect("Failed to connect to server"); + let (mut send_stream, recv_stream) = connection.open_bi().await.unwrap(); + + let client_message = ClientMessageBuilder::new() + .join("r-zig", None) + .build() + .unwrap(); + + // Send the `Join` message - encode it to bytes with length prefix + let buf: Vec = client_message.try_into().unwrap(); + for _ in 0..2 { + match send_stream.write_all(&buf).await { + Ok(_) => println!("Message sent successfully"), + Err(e) => panic!("Unexpected error: {:?}", e), + } + } + send_stream.finish().unwrap(); + + // Read the server's response + let reader = BufReader::new(recv_stream); + let mut stream = ProtobufStream::<_, ServerMessage>::new(reader); + + // loop over the stream to handle multiple messages + while let Some(response) = stream.next().await { + match response { + Ok(msg) => { + let error = msg.error.unwrap(); + assert_eq!(error.r#type, MessageType::Join as i32); + assert_eq!(error.code, ErrorCode::UsernameAlreadyTaken as i32); + break; + } + Err(ProtobufStreamError::Recoverable { code, source }) => { + // Handle the specific "Pending" error + println!( + "Stream operation is pending: {:?}, error: {:?}, continue to next item", + code, source + ); + continue; + } + Err(ProtobufStreamError::NonRecoverable { code, source }) => { + panic!("Stream operation failed: {:?}, error: {:?}", code, source); + } + Err(ProtobufStreamError::Other { + code, + message, + source, + }) => { + panic!( + "Stream operation failed with message: {:?}, code: {:?}, error: {:?}", + message, code, source + ); + } + } + } + connection.close(0u32.into(), b"done"); + endpoint.wait_idle().await; +} + +#[tokio::test] +async fn test_join_broadcast() { + // Connect the first client + let (connection1, endpoint1) = connect_to_server() + .await + .expect("Failed to connect to server"); + let (mut send_stream1, recv_stream1) = connection1.open_bi().await.unwrap(); + + // Connect the second client + let (connection2, endpoint2) = connect_to_server() + .await + .expect("Failed to connect to server"); + let (mut send_stream2, recv_stream2) = connection2.open_bi().await.unwrap(); + + // Prepare receiving streams for both clients + let reader1 = BufReader::new(recv_stream1); + let mut stream1 = ProtobufStream::<_, ServerMessage>::new(reader1); + + let reader2 = BufReader::new(recv_stream2); + let mut stream2 = ProtobufStream::<_, ServerMessage>::new(reader2); + + // Generate unique usernames using the current timestamp + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); + let username1 = format!("client1_{}", timestamp); + let username2 = format!("client2_{}", timestamp); + + // First client sends a "Join" message + let client_message1 = ClientMessageBuilder::new() + .join(username1.clone(), None) + .build() + .unwrap(); + let buf1: Vec = client_message1.try_into().unwrap(); + send_stream1.write_all(&buf1).await.unwrap(); + send_stream1.finish().unwrap(); + + // Second client sends a "Join" message + let client_message2 = ClientMessageBuilder::new() + .join(username2, None) + .build() + .unwrap(); + let buf2: Vec = client_message2.try_into().unwrap(); + send_stream2.write_all(&buf2).await.unwrap(); + send_stream2.finish().unwrap(); + + // First client should not receive the broadcast + tokio::select! { + msg = stream1.next() => { + match msg { + Some(Ok(response)) => { + if let Some(chat) = response.chat { + assert_ne!(chat.content, format!("User {} has joined the room", username1)); + } + } + Some(Err(e)) => { + println!("First client received an error: {:?}", e); + panic!("Unexpected error for first client"); + } + None => { + println!("First client stream ended unexpectedly"); + panic!("First client stream ended unexpectedly"); + } + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + // Timeout, as expected + println!("First client did not receive any broadcast (as expected)"); + } + } + + // Second client should receive the broadcast + tokio::select! { + msg = stream2.next() => { + match msg { + Some(Ok(response)) => { + if let Some(chat) = response.chat { + assert_eq!( + chat.content, + format!("User {} has joined the room", username1), + "Second client received an unexpected broadcast message" + ); + println!("Second client received the expected broadcast: {:?}", chat.content); + } else { + panic!("Second client received a message without chat content"); + } + } + Some(Err(e)) => { + println!("Second client received an error: {:?}", e); + panic!("Unexpected error for second client"); + } + None => { + println!("Second client stream ended unexpectedly"); + panic!("Second client stream ended unexpectedly"); + } + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { + panic!("Second client did not receive the broadcast within the timeout"); + } + } + + // Close connections + connection1.close(0u32.into(), b"done"); + endpoint1.wait_idle().await; + + connection2.close(0u32.into(), b"done"); + endpoint2.wait_idle().await; +}