diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d87a5b8..893e599 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,6 +2,7 @@ name: CI on: push: + branches: [main, master] pull_request: env: @@ -9,37 +10,150 @@ env: RUSTFLAGS: -D warnings jobs: + fmt: + name: Formatting + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Check formatting + run: cargo fmt --all -- --check + + clippy: + name: Clippy + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-clippy-${{ hashFiles('**/Cargo.lock') }} + + - name: Run Clippy + run: cargo clippy --all-targets --no-deps -- -D warnings + test: + name: Tests runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} + - name: Run tests + run: cargo test --all-features + + coverage: + name: Code Coverage + runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v4 - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable + with: + components: llvm-tools-preview + + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-coverage-${{ hashFiles('**/Cargo.lock') }} + + - name: Generate coverage report + run: cargo llvm-cov --all-features --lcov --output-path lcov.info + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 with: - toolchain: stable - override: true - components: rustfmt, clippy + files: lcov.info + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - - name: Lint + - name: Check coverage threshold run: | - cargo fmt -- --check - cargo clippy --all-targets --no-deps + cargo llvm-cov --all-features --fail-under-lines 90 - - name: Build Documentation - run: cargo doc --no-deps + doc: + name: Documentation + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 - - name: Run tests - run: cargo test + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-doc-${{ hashFiles('**/Cargo.lock') }} - minimum-supported-rust-version: + - name: Build documentation + run: cargo doc --no-deps + env: + RUSTDOCFLAGS: -D warnings + + msrv: + name: Minimum Supported Rust Version runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions-rs/toolchain@v1 + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@master with: - toolchain: 1.78.0 - override: true - - run: cargo check + toolchain: "1.78.0" + + - name: Cache cargo registry + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-msrv-${{ hashFiles('**/Cargo.lock') }} + + - name: Check MSRV + run: cargo check --all-features diff --git a/.github/workflows/rustdoc.yml b/.github/workflows/rustdoc.yml index aa59557..72a5e90 100644 --- a/.github/workflows/rustdoc.yml +++ b/.github/workflows/rustdoc.yml @@ -1,31 +1,41 @@ -name: rustdoc +name: Deploy Documentation on: push: - branches: [ master ] + branches: [main, master] env: CARGO_INCREMENTAL: 0 - RUSTFLAGS: -D warnings jobs: - rustdoc: + deploy-docs: + name: Deploy Documentation runs-on: ubuntu-latest + permissions: + contents: write steps: - name: Checkout repository uses: actions/checkout@v4 - name: Install Rust toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry + uses: actions/cache@v4 with: - toolchain: stable - override: true + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-doc-${{ hashFiles('**/Cargo.lock') }} - - name: Build Documentation + - name: Build documentation run: cargo doc --no-deps + env: + RUSTDOCFLAGS: -D warnings - - name: Deploy Docs + - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@v4 with: github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index e6e689a..62d9c49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,48 @@ # Changelog +## [Unreleased] + +### Security + +- **Memory amplification protection**: Deserialization now validates that claimed sequence/string lengths are plausible given the remaining input, preventing DoS attacks where a small malicious payload could trigger large memory allocations. +- **Duplicate map key detection**: Serialization now returns `Error::NonCanonicalMap` when duplicate keys are encountered instead of silently dropping duplicates, ensuring data integrity. + +### Added + +- `to_bytes_with_capacity()` function for pre-allocating output buffers when the serialized size is known or estimated, reducing allocations. +- Comprehensive `# Errors` documentation sections on all public functions. +- `#[must_use]` attribute on `is_human_readable()`. +- Explicit `#![forbid(unsafe_code)]` via Cargo.toml lints section. +- Full pedantic clippy lint compliance with minimal, justified exceptions for binary serialization casts. +- `rustfmt.toml` configuration for consistent code formatting. + +### Changed + +- **Optimized ULEB128 encoding/decoding**: Added fast paths for single-byte values (0-127), which are common for sequence lengths and enum variant indices. +- **Optimized bulk byte reading**: Deserialization now uses slice splitting instead of byte-by-byte copying for integer parsing. +- **Added `#[inline]` hints** on hot serialization/deserialization paths for better performance. +- Replaced `sort_by` with `sort_unstable_by` for map key sorting (faster, no stability needed for unique keys). + +### CI/CD + +- Added separate CI jobs for formatting (`cargo fmt`), linting (`cargo clippy`), testing, coverage, documentation, and MSRV verification. +- Added code coverage reporting with Codecov integration and 90% line coverage threshold. +- Added Minimum Supported Rust Version (MSRV) check at Rust 1.78.0. +- Improved CI caching for faster builds. +- Documentation builds now use `-D warnings` to catch doc issues. + +### Testing + +- Expanded test suite with security-focused tests for memory amplification and duplicate key detection. +- Added tests for `to_bytes_with_capacity`, `from_bytes_seed`, and other previously uncovered code paths. +- Improved benchmark suite with comprehensive type coverage and deserialization benchmarks. + ## [v0.1.1] - 2020-12-11 - Renaming crate into "bcs". ## [v0.1.0] - 2020-11-17 - Initial release. +[Unreleased]: https://github.com/diem/bcs/compare/v0.1.1...HEAD [v0.1.1]: https://github.com/diem/bcs/releases/tag/v0.1.1 [v0.1.0]: https://github.com/diem/bcs/releases/tag/v0.1.0 diff --git a/Cargo.toml b/Cargo.toml index ac48aac..5dbab33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,14 @@ proptest-derive = "0.2.0" [[bench]] name = "bcs_bench" harness = false + +[lints.rust] +unsafe_code = "forbid" + +[lints.clippy] +all = { level = "warn", priority = -1 } +pedantic = { level = "warn", priority = -1 } +# These casts are intentional for binary serialization +cast_possible_truncation = "allow" +cast_possible_wrap = "allow" +cast_sign_loss = "allow" diff --git a/benches/bcs_bench.rs b/benches/bcs_bench.rs index 03a7261..5b4b65b 100644 --- a/benches/bcs_bench.rs +++ b/benches/bcs_bench.rs @@ -1,28 +1,176 @@ // Copyright (c) The Diem Core Contributors // SPDX-License-Identifier: Apache-2.0 -use bcs::to_bytes; -use criterion::{criterion_group, criterion_main, Criterion}; +use bcs::{from_bytes, to_bytes, to_bytes_with_capacity}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; -pub fn bcs_benchmark(c: &mut Criterion) { +#[derive(Serialize, Deserialize, Clone)] +struct SimpleStruct { + a: u64, + b: u32, + c: u16, + d: u8, + e: bool, +} + +#[derive(Serialize, Deserialize, Clone)] +struct ComplexStruct { + id: u64, + name: String, + values: Vec, + nested: Option, +} + +/// Benchmarks for BCS serialization. +/// +/// # Panics +/// +/// Panics if any serialization operation fails unexpectedly. +pub fn serialize_benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("serialize"); + + // Primitive types + let u64_val: u64 = 0x1234_5678_90AB_CDEF; + group.bench_function("u64", |b| b.iter(|| to_bytes(black_box(&u64_val)).unwrap())); + + // Simple struct + let simple = SimpleStruct { + a: 12_345_678_901_234, + b: 1_234_567_890, + c: 12345, + d: 123, + e: true, + }; + group.bench_function("simple_struct", |b| { + b.iter(|| to_bytes(black_box(&simple)).unwrap()); + }); + + // Complex struct with nested data + let complex = ComplexStruct { + id: 42, + name: "benchmark test string".to_string(), + values: (0..100).collect(), + nested: Some(simple.clone()), + }; + group.bench_function("complex_struct", |b| { + b.iter(|| to_bytes(black_box(&complex)).unwrap()); + }); + + // Test pre-allocation benefit + let serialized_size = to_bytes(&complex).unwrap().len(); + group.bench_function("complex_struct_with_capacity", |b| { + b.iter(|| to_bytes_with_capacity(black_box(&complex), serialized_size).unwrap()); + }); + + // Vec of u64s at various sizes + for size in &[10_u64, 100, 1000, 10000] { + let vec: Vec = (0..*size).collect(); + group.throughput(Throughput::Elements(*size)); + group.bench_with_input(BenchmarkId::new("vec_u64", size), &vec, |b, v| { + b.iter(|| to_bytes(black_box(v)).unwrap()); + }); + } + + // String serialization + let short_string = "hello".to_string(); + let long_string = "a".repeat(1000); + group.bench_function("short_string", |b| { + b.iter(|| to_bytes(black_box(&short_string)).unwrap()); + }); + group.bench_function("long_string", |b| { + b.iter(|| to_bytes(black_box(&long_string)).unwrap()); + }); + + // Maps let mut btree_map = BTreeMap::new(); let mut hash_map = HashMap::new(); for i in 0u32..2000u32 { btree_map.insert(i, i); hash_map.insert(i, i); } - c.bench_function("serialize btree map", |b| { - b.iter(|| { - to_bytes(&btree_map).unwrap(); - }) + group.bench_function("btree_map_2000", |b| { + b.iter(|| to_bytes(black_box(&btree_map)).unwrap()); + }); + group.bench_function("hash_map_2000", |b| { + b.iter(|| to_bytes(black_box(&hash_map)).unwrap()); + }); + + group.finish(); +} + +/// Benchmarks for BCS deserialization. +/// +/// # Panics +/// +/// Panics if any serialization or deserialization operation fails unexpectedly. +pub fn deserialize_benchmarks(c: &mut Criterion) { + let mut group = c.benchmark_group("deserialize"); + + // Primitive types + let u64_bytes = to_bytes(&0x1234_5678_90AB_CDEF_u64).unwrap(); + group.bench_function("u64", |b| { + b.iter(|| from_bytes::(black_box(&u64_bytes)).unwrap()); + }); + + // Simple struct + let simple = SimpleStruct { + a: 12_345_678_901_234, + b: 1_234_567_890, + c: 12345, + d: 123, + e: true, + }; + let simple_bytes = to_bytes(&simple).unwrap(); + group.bench_function("simple_struct", |b| { + b.iter(|| from_bytes::(black_box(&simple_bytes)).unwrap()); + }); + + // Complex struct + let complex = ComplexStruct { + id: 42, + name: "benchmark test string".to_string(), + values: (0..100).collect(), + nested: Some(simple.clone()), + }; + let complex_bytes = to_bytes(&complex).unwrap(); + group.bench_function("complex_struct", |b| { + b.iter(|| from_bytes::(black_box(&complex_bytes)).unwrap()); + }); + + // Vec of u64s at various sizes + for size in &[10_u64, 100, 1000, 10000] { + let vec: Vec = (0..*size).collect(); + let vec_bytes = to_bytes(&vec).unwrap(); + group.throughput(Throughput::Elements(*size)); + group.bench_with_input(BenchmarkId::new("vec_u64", size), &vec_bytes, |b, bytes| { + b.iter(|| from_bytes::>(black_box(bytes)).unwrap()); + }); + } + + // String deserialization + let short_string_bytes = to_bytes(&"hello".to_string()).unwrap(); + let long_string_bytes = to_bytes(&"a".repeat(1000)).unwrap(); + group.bench_function("short_string", |b| { + b.iter(|| from_bytes::(black_box(&short_string_bytes)).unwrap()); }); - c.bench_function("serialize hash map", |b| { - b.iter(|| { - to_bytes(&hash_map).unwrap(); - }) + group.bench_function("long_string", |b| { + b.iter(|| from_bytes::(black_box(&long_string_bytes)).unwrap()); }); + + // Maps + let mut btree_map = BTreeMap::new(); + for i in 0u32..2000u32 { + btree_map.insert(i, i); + } + let map_bytes = to_bytes(&btree_map).unwrap(); + group.bench_function("btree_map_2000", |b| { + b.iter(|| from_bytes::>(black_box(&map_bytes)).unwrap()); + }); + + group.finish(); } -criterion_group!(benches, bcs_benchmark); +criterion_group!(benches, serialize_benchmarks, deserialize_benchmarks); criterion_main!(benches); diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..c8f546a --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,4 @@ +# Rustfmt configuration for BCS +edition = "2018" +max_width = 100 +use_small_heuristics = "Default" diff --git a/src/de.rs b/src/de.rs index 240c8f0..f4d472c 100644 --- a/src/de.rs +++ b/src/de.rs @@ -3,13 +3,27 @@ use crate::error::{Error, Result}; use serde::de::{self, Deserialize, DeserializeSeed, IntoDeserializer, Visitor}; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; /// Deserializes a `&[u8]` into a type. /// /// This function will attempt to interpret `bytes` as the BCS serialized form of `T` and /// deserialize `T` from `bytes`. /// +/// # Errors +/// +/// Returns an error if: +/// - The bytes do not represent a valid BCS encoding of `T` +/// - The input ends unexpectedly ([`Error::Eof`]) +/// - A sequence length exceeds [`MAX_SEQUENCE_LENGTH`](crate::MAX_SEQUENCE_LENGTH) +/// - The container depth exceeds [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH) +/// - There are remaining bytes after deserialization ([`Error::RemainingInput`]) +/// - An unsupported type is encountered (f32, f64, char) +/// - UTF-8 validation fails for strings ([`Error::Utf8`]) +/// - A boolean value is not 0 or 1 ([`Error::ExpectedBoolean`]) +/// - An option tag is not 0 or 1 ([`Error::ExpectedOption`]) +/// - Map keys are not in canonical order ([`Error::NonCanonicalMap`]) +/// /// # Examples /// /// ``` @@ -40,10 +54,17 @@ where { let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); let t = T::deserialize(&mut deserializer)?; - deserializer.end().map(move |_| t) + deserializer.end().map(move |()| t) } -/// Same as `from_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` +/// Same as [`from_bytes`] but use `limit` as max container depth instead of +/// [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH). +/// +/// # Errors +/// +/// Returns an error if: +/// - `limit` exceeds [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH) +/// - Any error condition from [`from_bytes`] occurs pub fn from_bytes_with_limit<'a, T>(bytes: &'a [u8], limit: usize) -> Result where T: Deserialize<'a>, @@ -55,20 +76,31 @@ where } let mut deserializer = Deserializer::new(bytes, limit); let t = T::deserialize(&mut deserializer)?; - deserializer.end().map(move |_| t) + deserializer.end().map(move |()| t) } /// Perform a stateful deserialization from a `&[u8]` using the provided `seed`. +/// +/// # Errors +/// +/// Returns an error if any error condition from [`from_bytes`] occurs. pub fn from_bytes_seed<'a, T>(seed: T, bytes: &'a [u8]) -> Result where T: DeserializeSeed<'a>, { let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH); let t = seed.deserialize(&mut deserializer)?; - deserializer.end().map(move |_| t) + deserializer.end().map(move |()| t) } -/// Same as `from_bytes_seed` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH` +/// Same as [`from_bytes_seed`] but use `limit` as max container depth instead of +/// [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH). +/// +/// # Errors +/// +/// Returns an error if: +/// - `limit` exceeds [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH) +/// - Any error condition from [`from_bytes_seed`] occurs pub fn from_bytes_seed_with_limit<'a, T>(seed: T, bytes: &'a [u8], limit: usize) -> Result where T: DeserializeSeed<'a>, @@ -80,7 +112,7 @@ where } let mut deserializer = Deserializer::new(bytes, limit); let t = seed.deserialize(&mut deserializer)?; - deserializer.end().map(move |_| t) + deserializer.end().map(move |()| t) } /// Deserialization implementation for BCS @@ -112,16 +144,19 @@ impl<'de> Deserializer<'de> { } impl<'de> Deserializer<'de> { + #[inline] fn peek(&mut self) -> Result { self.input.first().copied().ok_or(Error::Eof) } + #[inline] fn next(&mut self) -> Result { let byte = self.peek()?; self.input = &self.input[1..]; Ok(byte) } + #[inline] fn parse_bool(&mut self) -> Result { let byte = self.next()?; @@ -132,51 +167,73 @@ impl<'de> Deserializer<'de> { } } - fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> { - for byte in slice { - *byte = self.next()?; + /// Reads exactly `n` bytes from input, returning a slice. + /// This is more efficient than byte-by-byte copying. + #[inline] + fn read_bytes(&mut self, n: usize) -> Result<&'de [u8]> { + if self.input.len() < n { + return Err(Error::Eof); } - Ok(()) + let (bytes, rest) = self.input.split_at(n); + self.input = rest; + Ok(bytes) } + #[inline] fn parse_u8(&mut self) -> Result { self.next() } + #[inline] fn parse_u16(&mut self) -> Result { - let mut le_bytes = [0; 2]; - self.fill_slice(&mut le_bytes)?; - Ok(u16::from_le_bytes(le_bytes)) + let bytes = self.read_bytes(2)?; + // SAFETY: We just verified we have exactly 2 bytes + Ok(u16::from_le_bytes([bytes[0], bytes[1]])) } + #[inline] fn parse_u32(&mut self) -> Result { - let mut le_bytes = [0; 4]; - self.fill_slice(&mut le_bytes)?; - Ok(u32::from_le_bytes(le_bytes)) + let bytes = self.read_bytes(4)?; + // SAFETY: We just verified we have exactly 4 bytes + Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) } + #[inline] fn parse_u64(&mut self) -> Result { - let mut le_bytes = [0; 8]; - self.fill_slice(&mut le_bytes)?; - Ok(u64::from_le_bytes(le_bytes)) + let bytes = self.read_bytes(8)?; + // Use try_into for cleaner conversion from slice to array + Ok(u64::from_le_bytes(bytes.try_into().unwrap())) } + #[inline] fn parse_u128(&mut self) -> Result { - let mut le_bytes = [0; 16]; - self.fill_slice(&mut le_bytes)?; - Ok(u128::from_le_bytes(le_bytes)) + let bytes = self.read_bytes(16)?; + // Use try_into for cleaner conversion from slice to array + Ok(u128::from_le_bytes(bytes.try_into().unwrap())) } + /// Parse a ULEB128-encoded u32. Optimized for common small values. #[allow(clippy::arithmetic_side_effects)] + #[inline] fn parse_u32_from_uleb128(&mut self) -> Result { - let mut value: u64 = 0; - for shift in (0..32).step_by(7) { + // Fast path: single byte (values 0-127) + let first_byte = self.next()?; + if first_byte < 0x80 { + return Ok(u32::from(first_byte)); + } + + // Multi-byte path + let mut value = u64::from(first_byte & 0x7f); + let mut shift = 7; + + loop { let byte = self.next()?; let digit = byte & 0x7f; value |= u64::from(digit) << shift; + // If the highest bit of `byte` is 0, return the final value. if digit == byte { - if shift > 0 && digit == 0 { + if digit == 0 { // We only accept canonical ULEB128 encodings, therefore the // heaviest (and last) base-128 digit must be non-zero. return Err(Error::NonCanonicalUleb128Encoding); @@ -185,31 +242,47 @@ impl<'de> Deserializer<'de> { return u32::try_from(value) .map_err(|_| Error::IntegerOverflowDuringUleb128Decoding); } + + shift += 7; + if shift >= 35 { + // Decoded integer must not overflow. + return Err(Error::IntegerOverflowDuringUleb128Decoding); + } } - // Decoded integer must not overflow. - Err(Error::IntegerOverflowDuringUleb128Decoding) } + /// Parse a sequence length, validating against both the maximum allowed length + /// and the remaining input size to prevent memory amplification attacks. + #[inline] fn parse_length(&mut self) -> Result { let len = self.parse_u32_from_uleb128()? as usize; if len > crate::MAX_SEQUENCE_LENGTH { return Err(Error::ExceededMaxLen(len)); } + // Security: Validate that the claimed length is plausible given remaining input. + // This prevents memory amplification attacks where a small payload claims a huge + // length, causing pre-allocation of gigabytes of memory before we detect EOF. + // Note: For sequences of multi-byte elements, this is a lower bound check only. + if len > self.input.len() { + return Err(Error::Eof); + } Ok(len) } + #[inline] fn parse_bytes(&mut self) -> Result<&'de [u8]> { let len = self.parse_length()?; - let slice = self.input.get(..len).ok_or(Error::Eof)?; - self.input = &self.input[len..]; - Ok(slice) + // Use read_bytes for consistent and efficient slice access + self.read_bytes(len) } + #[inline] fn parse_string(&mut self) -> Result<&'de str> { let slice = self.parse_bytes()?; std::str::from_utf8(slice).map_err(|_| Error::Utf8) } + #[inline] fn enter_named_container(&mut self, name: &'static str) -> Result<()> { if self.max_remaining_depth == 0 { return Err(Error::ExceededContainerDepthLimit(name)); @@ -218,6 +291,7 @@ impl<'de> Deserializer<'de> { Ok(()) } + #[inline] fn leave_named_container(&mut self) { self.max_remaining_depth += 1; } @@ -234,6 +308,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { Err(Error::NotSupported("deserialize_any")) } + #[inline] fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de>, @@ -241,6 +316,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_bool(self.parse_bool()?) } + #[inline] fn deserialize_i8(self, visitor: V) -> Result where V: Visitor<'de>, @@ -248,6 +324,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_i8(self.parse_u8()? as i8) } + #[inline] fn deserialize_i16(self, visitor: V) -> Result where V: Visitor<'de>, @@ -255,6 +332,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_i16(self.parse_u16()? as i16) } + #[inline] fn deserialize_i32(self, visitor: V) -> Result where V: Visitor<'de>, @@ -262,6 +340,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_i32(self.parse_u32()? as i32) } + #[inline] fn deserialize_i64(self, visitor: V) -> Result where V: Visitor<'de>, @@ -269,6 +348,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_i64(self.parse_u64()? as i64) } + #[inline] fn deserialize_i128(self, visitor: V) -> Result where V: Visitor<'de>, @@ -276,6 +356,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_i128(self.parse_u128()? as i128) } + #[inline] fn deserialize_u8(self, visitor: V) -> Result where V: Visitor<'de>, @@ -283,6 +364,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_u8(self.parse_u8()?) } + #[inline] fn deserialize_u16(self, visitor: V) -> Result where V: Visitor<'de>, @@ -290,6 +372,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_u16(self.parse_u16()?) } + #[inline] fn deserialize_u32(self, visitor: V) -> Result where V: Visitor<'de>, @@ -297,6 +380,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_u32(self.parse_u32()?) } + #[inline] fn deserialize_u64(self, visitor: V) -> Result where V: Visitor<'de>, @@ -304,6 +388,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_u64(self.parse_u64()?) } + #[inline] fn deserialize_u128(self, visitor: V) -> Result where V: Visitor<'de>, @@ -332,6 +417,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { Err(Error::NotSupported("deserialize_char")) } + #[inline] fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, @@ -339,6 +425,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_borrowed_str(self.parse_string()?) } + #[inline] fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de>, @@ -346,6 +433,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { self.deserialize_str(visitor) } + #[inline] fn deserialize_bytes(self, visitor: V) -> Result where V: Visitor<'de>, @@ -353,6 +441,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_borrowed_bytes(self.parse_bytes()?) } + #[inline] fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de>, @@ -360,6 +449,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { self.deserialize_bytes(visitor) } + #[inline] fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de>, @@ -373,6 +463,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { } } + #[inline] fn deserialize_unit(self, visitor: V) -> Result where V: Visitor<'de>, @@ -380,6 +471,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_unit() } + #[inline] fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result where V: Visitor<'de>, @@ -390,6 +482,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { r } + #[inline] fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result where V: Visitor<'de>, @@ -400,6 +493,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { r } + #[inline] fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, @@ -408,6 +502,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_seq(SeqDeserializer::new(self, len)) } + #[inline] fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de>, @@ -415,6 +510,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_seq(SeqDeserializer::new(self, len)) } + #[inline] fn deserialize_tuple_struct( self, name: &'static str, @@ -430,6 +526,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { r } + #[inline] fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, @@ -438,6 +535,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { visitor.visit_map(MapDeserializer::new(self, len)) } + #[inline] fn deserialize_struct( self, name: &'static str, @@ -453,6 +551,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { r } + #[inline] fn deserialize_enum( self, name: &'static str, @@ -469,11 +568,11 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { } // BCS does not utilize identifiers, so throw them away - fn deserialize_identifier(self, _visitor: V) -> Result + fn deserialize_identifier(self, visitor: V) -> Result where V: Visitor<'de>, { - self.deserialize_bytes(_visitor) + self.deserialize_bytes(visitor) } // BCS is not a self-describing format so we can't implement `deserialize_ignored_any` @@ -485,6 +584,7 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { } // BCS is not a human readable format + #[inline] fn is_human_readable(&self) -> bool { false } @@ -496,6 +596,7 @@ struct SeqDeserializer<'a, 'de: 'a> { } impl<'a, 'de> SeqDeserializer<'a, 'de> { + #[inline] fn new(de: &'a mut Deserializer<'de>, remaining: usize) -> Self { Self { de, remaining } } @@ -504,6 +605,7 @@ impl<'a, 'de> SeqDeserializer<'a, 'de> { impl<'de> de::SeqAccess<'de> for SeqDeserializer<'_, 'de> { type Error = Error; + #[inline] fn next_element_seed(&mut self, seed: T) -> Result> where T: DeserializeSeed<'de>, @@ -516,6 +618,7 @@ impl<'de> de::SeqAccess<'de> for SeqDeserializer<'_, 'de> { } } + #[inline] fn size_hint(&self) -> Option { Some(self.remaining) } @@ -528,6 +631,7 @@ struct MapDeserializer<'a, 'de: 'a> { } impl<'a, 'de> MapDeserializer<'a, 'de> { + #[inline] fn new(de: &'a mut Deserializer<'de>, remaining: usize) -> Self { Self { de, @@ -540,6 +644,7 @@ impl<'a, 'de> MapDeserializer<'a, 'de> { impl<'de> de::MapAccess<'de> for MapDeserializer<'_, 'de> { type Error = Error; + #[inline] fn next_key_seed(&mut self, seed: K) -> Result> where K: DeserializeSeed<'de>, @@ -565,6 +670,7 @@ impl<'de> de::MapAccess<'de> for MapDeserializer<'_, 'de> { } } + #[inline] fn next_value_seed(&mut self, seed: V) -> Result where V: DeserializeSeed<'de>, @@ -572,6 +678,7 @@ impl<'de> de::MapAccess<'de> for MapDeserializer<'_, 'de> { seed.deserialize(&mut *self.de) } + #[inline] fn size_hint(&self) -> Option { Some(self.remaining) } @@ -581,6 +688,7 @@ impl<'de> de::EnumAccess<'de> for &mut Deserializer<'de> { type Error = Error; type Variant = Self; + #[inline] fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> where V: DeserializeSeed<'de>, @@ -594,10 +702,12 @@ impl<'de> de::EnumAccess<'de> for &mut Deserializer<'de> { impl<'de> de::VariantAccess<'de> for &mut Deserializer<'de> { type Error = Error; + #[inline] fn unit_variant(self) -> Result<()> { Ok(()) } + #[inline] fn newtype_variant_seed(self, seed: T) -> Result where T: DeserializeSeed<'de>, @@ -605,6 +715,7 @@ impl<'de> de::VariantAccess<'de> for &mut Deserializer<'de> { seed.deserialize(self) } + #[inline] fn tuple_variant(self, len: usize, visitor: V) -> Result where V: Visitor<'de>, @@ -612,6 +723,7 @@ impl<'de> de::VariantAccess<'de> for &mut Deserializer<'de> { de::Deserializer::deserialize_tuple(self, len, visitor) } + #[inline] fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result where V: Visitor<'de>, diff --git a/src/lib.rs b/src/lib.rs index e9883a8..419421d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -318,5 +318,5 @@ pub use de::{from_bytes, from_bytes_seed, from_bytes_seed_with_limit, from_bytes pub use error::{Error, Result}; pub use ser::{ is_human_readable, serialize_into, serialize_into_with_limit, serialized_size, - serialized_size_with_limit, to_bytes, to_bytes_with_limit, + serialized_size_with_limit, to_bytes, to_bytes_with_capacity, to_bytes_with_limit, }; diff --git a/src/ser.rs b/src/ser.rs index 2410fcb..fe602c7 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -6,10 +6,14 @@ use serde::{ser, Serialize}; /// Serialize the given data structure as a `Vec` of BCS. /// -/// Serialization can fail if `T`'s implementation of `Serialize` decides to -/// fail, if `T` contains sequences which are longer than `MAX_SEQUENCE_LENGTH`, -/// or if `T` attempts to serialize an unsupported datatype such as a f32, -/// f64, or char. +/// # Errors +/// +/// Returns an error if: +/// - `T`'s implementation of [`Serialize`] decides to fail +/// - `T` contains sequences longer than [`MAX_SEQUENCE_LENGTH`](crate::MAX_SEQUENCE_LENGTH) +/// - `T` attempts to serialize an unsupported datatype (f32, f64, or char) +/// - The container depth exceeds [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH) +/// - An I/O error occurs while writing /// /// # Examples /// @@ -55,7 +59,47 @@ where Ok(output) } -/// Same as `to_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH +/// Same as [`to_bytes`] but pre-allocates the output buffer with the given capacity. +/// +/// This can improve performance when you have a good estimate of the serialized size, +/// as it avoids reallocations during serialization. +/// +/// # Errors +/// +/// Returns an error if any error condition from [`to_bytes`] occurs. +/// +/// # Examples +/// +/// ``` +/// use bcs::to_bytes_with_capacity; +/// use serde::Serialize; +/// +/// #[derive(Serialize)] +/// struct Data { +/// values: Vec, +/// } +/// +/// let data = Data { values: vec![1, 2, 3, 4, 5] }; +/// // Pre-allocate for length prefix (1 byte) + 5 u64s (40 bytes) +/// let bytes = to_bytes_with_capacity(&data, 41).unwrap(); +/// ``` +pub fn to_bytes_with_capacity(value: &T, capacity: usize) -> Result> +where + T: ?Sized + Serialize, +{ + let mut output = Vec::with_capacity(capacity); + serialize_into(&mut output, value)?; + Ok(output) +} + +/// Same as [`to_bytes`] but use `limit` as max container depth instead of +/// [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH). +/// +/// # Errors +/// +/// Returns an error if: +/// - `limit` exceeds [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH) +/// - Any error condition from [`to_bytes`] occurs pub fn to_bytes_with_limit(value: &T, limit: usize) -> Result> where T: ?Sized + Serialize, @@ -70,7 +114,11 @@ where Ok(output) } -/// Same as `to_bytes` but write directly into an `std::io::Write` object. +/// Same as [`to_bytes`] but write directly into an [`std::io::Write`] object. +/// +/// # Errors +/// +/// Returns an error if any error condition from [`to_bytes`] occurs. pub fn serialize_into(write: &mut W, value: &T) -> Result<()> where W: ?Sized + std::io::Write, @@ -80,7 +128,14 @@ where value.serialize(serializer) } -/// Same as `serialize_into` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH +/// Same as [`serialize_into`] but use `limit` as max container depth instead of +/// [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH). +/// +/// # Errors +/// +/// Returns an error if: +/// - `limit` exceeds [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH) +/// - Any error condition from [`serialize_into`] occurs pub fn serialize_into_with_limit(write: &mut W, value: &T, limit: usize) -> Result<()> where W: ?Sized + std::io::Write, @@ -100,9 +155,10 @@ struct WriteCounter(usize); impl std::io::Write for WriteCounter { fn write(&mut self, buf: &[u8]) -> std::io::Result { let len = buf.len(); - self.0 = self.0.checked_add(len).ok_or_else(|| { - std::io::Error::new(std::io::ErrorKind::Other, "WriteCounter reached max value") - })?; + self.0 = self + .0 + .checked_add(len) + .ok_or_else(|| std::io::Error::other("WriteCounter reached max value"))?; Ok(len) } @@ -111,7 +167,14 @@ impl std::io::Write for WriteCounter { } } -/// Same as `to_bytes` but only return the size of the serialized bytes. +/// Same as [`to_bytes`] but only return the size of the serialized bytes. +/// +/// This is useful for pre-allocating buffers or validating size constraints +/// without actually performing the serialization. +/// +/// # Errors +/// +/// Returns an error if any error condition from [`to_bytes`] occurs. pub fn serialized_size(value: &T) -> Result where T: ?Sized + Serialize, @@ -121,7 +184,14 @@ where Ok(counter.0) } -/// Same as `serialized_size` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH +/// Same as [`serialized_size`] but use `limit` as max container depth instead of +/// [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH). +/// +/// # Errors +/// +/// Returns an error if: +/// - `limit` exceeds [`MAX_CONTAINER_DEPTH`](crate::MAX_CONTAINER_DEPTH) +/// - Any error condition from [`serialized_size`] occurs pub fn serialized_size_with_limit(value: &T, limit: usize) -> Result where T: ?Sized + Serialize, @@ -136,6 +206,10 @@ where Ok(counter.0) } +/// Returns whether BCS is a human-readable format. +/// +/// This always returns `false` as BCS is a binary format. +#[must_use] pub fn is_human_readable() -> bool { let mut output = Vec::new(); let serializer = Serializer::new(&mut output, crate::MAX_CONTAINER_DEPTH); @@ -153,6 +227,7 @@ where W: ?Sized + std::io::Write, { /// Creates a new `Serializer` which will emit BCS. + #[inline] fn new(output: &'a mut W, max_remaining_depth: usize) -> Self { Self { output, @@ -160,23 +235,39 @@ where } } - fn output_u32_as_uleb128(&mut self, mut value: u32) -> Result<()> { - while value >= 0x80 { - // Write 7 (lowest) bits of data and set the 8th bit to 1. - let byte = (value & 0x7f) as u8; - self.output.write_all(&[byte | 0x80])?; - value >>= 7; + /// Encode a u32 as ULEB128. Optimized for common small values. + #[inline] + fn output_u32_as_uleb128(&mut self, value: u32) -> Result<()> { + // Fast path: single byte (values 0-127) - very common for lengths and variant indices + if value < 0x80 { + self.output.write_all(&[value as u8])?; + return Ok(()); + } + + // Multi-byte encoding - pre-compute all bytes to minimize write calls + // Max ULEB128 encoding for u32 is 5 bytes + let mut buf = [0u8; 5]; + let mut i = 0; + let mut v = value; + + while v >= 0x80 { + buf[i] = (v as u8 & 0x7f) | 0x80; + v >>= 7; + i += 1; } - // Write the remaining bits of data and set the highest bit to 0. - self.output.write_all(&[value as u8])?; + buf[i] = v as u8; + + self.output.write_all(&buf[..=i])?; Ok(()) } + #[inline] fn output_variant_index(&mut self, v: u32) -> Result<()> { self.output_u32_as_uleb128(v) } /// Serialize a sequence length as a u32. + #[inline] fn output_seq_len(&mut self, len: usize) -> Result<()> { if len > crate::MAX_SEQUENCE_LENGTH { return Err(Error::ExceededMaxLen(len)); @@ -184,6 +275,7 @@ where self.output_u32_as_uleb128(len as u32) } + #[inline] fn enter_named_container(&mut self, name: &'static str) -> Result<()> { if self.max_remaining_depth == 0 { return Err(Error::ExceededContainerDepthLimit(name)); @@ -207,50 +299,67 @@ where type SerializeStruct = Self; type SerializeStructVariant = Self; + #[inline] fn serialize_bool(self, v: bool) -> Result<()> { - self.serialize_u8(v.into()) + self.output.write_all(&[u8::from(v)])?; + Ok(()) } + #[inline] fn serialize_i8(self, v: i8) -> Result<()> { - self.serialize_u8(v as u8) + self.output.write_all(&[v as u8])?; + Ok(()) } + #[inline] fn serialize_i16(self, v: i16) -> Result<()> { - self.serialize_u16(v as u16) + self.output.write_all(&(v as u16).to_le_bytes())?; + Ok(()) } + #[inline] fn serialize_i32(self, v: i32) -> Result<()> { - self.serialize_u32(v as u32) + self.output.write_all(&(v as u32).to_le_bytes())?; + Ok(()) } + #[inline] fn serialize_i64(self, v: i64) -> Result<()> { - self.serialize_u64(v as u64) + self.output.write_all(&(v as u64).to_le_bytes())?; + Ok(()) } + #[inline] fn serialize_i128(self, v: i128) -> Result<()> { - self.serialize_u128(v as u128) + self.output.write_all(&(v as u128).to_le_bytes())?; + Ok(()) } + #[inline] fn serialize_u8(self, v: u8) -> Result<()> { self.output.write_all(&[v])?; Ok(()) } + #[inline] fn serialize_u16(self, v: u16) -> Result<()> { self.output.write_all(&v.to_le_bytes())?; Ok(()) } + #[inline] fn serialize_u32(self, v: u32) -> Result<()> { self.output.write_all(&v.to_le_bytes())?; Ok(()) } + #[inline] fn serialize_u64(self, v: u64) -> Result<()> { self.output.write_all(&v.to_le_bytes())?; Ok(()) } + #[inline] fn serialize_u128(self, v: u128) -> Result<()> { self.output.write_all(&v.to_le_bytes())?; Ok(()) @@ -269,11 +378,13 @@ where } // Just serialize the string as a raw byte array + #[inline] fn serialize_str(self, v: &str) -> Result<()> { self.serialize_bytes(v.as_bytes()) } // Serialize a byte array as an array of bytes. + #[inline] fn serialize_bytes(mut self, v: &[u8]) -> Result<()> { self.output_seq_len(v.len())?; self.output.write_all(v)?; @@ -281,11 +392,14 @@ where } // An absent optional is represented as `00` + #[inline] fn serialize_none(self) -> Result<()> { - self.serialize_u8(0) + self.output.write_all(&[0])?; + Ok(()) } // A present optional is represented as `01` followed by the serialized value + #[inline] fn serialize_some(self, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -294,15 +408,18 @@ where value.serialize(self) } + #[inline] fn serialize_unit(self) -> Result<()> { Ok(()) } + #[inline] fn serialize_unit_struct(mut self, name: &'static str) -> Result<()> { self.enter_named_container(name)?; - self.serialize_unit() + Ok(()) } + #[inline] fn serialize_unit_variant( mut self, name: &'static str, @@ -313,6 +430,7 @@ where self.output_variant_index(variant_index) } + #[inline] fn serialize_newtype_struct(mut self, name: &'static str, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -321,6 +439,7 @@ where value.serialize(self) } + #[inline] fn serialize_newtype_variant( mut self, name: &'static str, @@ -340,6 +459,7 @@ where // method calls. This one is responsible only for serializing the start, // which for BCS is either nothing for fixed structures or for variable // length structures, the length encoded as a u32. + #[inline] fn serialize_seq(mut self, len: Option) -> Result { if let Some(len) = len { self.output_seq_len(len)?; @@ -350,10 +470,12 @@ where } // Tuples are fixed sized structs so we don't need to encode the length + #[inline] fn serialize_tuple(self, _len: usize) -> Result { Ok(self) } + #[inline] fn serialize_tuple_struct( mut self, name: &'static str, @@ -363,6 +485,7 @@ where Ok(self) } + #[inline] fn serialize_tuple_variant( mut self, name: &'static str, @@ -375,10 +498,12 @@ where Ok(self) } + #[inline] fn serialize_map(self, _len: Option) -> Result { Ok(MapSerializer::new(self)) } + #[inline] fn serialize_struct( mut self, name: &'static str, @@ -388,6 +513,7 @@ where Ok(self) } + #[inline] fn serialize_struct_variant( mut self, name: &'static str, @@ -401,6 +527,7 @@ where } // BCS is not a human readable format + #[inline] fn is_human_readable(&self) -> bool { false } @@ -413,6 +540,7 @@ where type Ok = (); type Error = Error; + #[inline] fn serialize_element(&mut self, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -420,6 +548,7 @@ where value.serialize(Serializer::new(self.output, self.max_remaining_depth)) } + #[inline] fn end(self) -> Result<()> { Ok(()) } @@ -432,6 +561,7 @@ where type Ok = (); type Error = Error; + #[inline] fn serialize_element(&mut self, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -439,6 +569,7 @@ where value.serialize(Serializer::new(self.output, self.max_remaining_depth)) } + #[inline] fn end(self) -> Result<()> { Ok(()) } @@ -451,6 +582,7 @@ where type Ok = (); type Error = Error; + #[inline] fn serialize_field(&mut self, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -458,6 +590,7 @@ where value.serialize(Serializer::new(self.output, self.max_remaining_depth)) } + #[inline] fn end(self) -> Result<()> { Ok(()) } @@ -470,6 +603,7 @@ where type Ok = (); type Error = Error; + #[inline] fn serialize_field(&mut self, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -477,6 +611,7 @@ where value.serialize(Serializer::new(self.output, self.max_remaining_depth)) } + #[inline] fn end(self) -> Result<()> { Ok(()) } @@ -490,6 +625,7 @@ struct MapSerializer<'a, W: ?Sized> { } impl<'a, W: ?Sized> MapSerializer<'a, W> { + #[inline] fn new(serializer: Serializer<'a, W>) -> Self { MapSerializer { serializer, @@ -506,6 +642,7 @@ where type Ok = (); type Error = Error; + #[inline] fn serialize_key(&mut self, key: &T) -> Result<()> where T: ?Sized + Serialize, @@ -523,6 +660,7 @@ where Ok(()) } + #[inline] fn serialize_value(&mut self, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -545,8 +683,15 @@ where if self.next_key.is_some() { return Err(Error::ExpectedMapValue); } - self.entries.sort_by(|e1, e2| e1.0.cmp(&e2.0)); - self.entries.dedup_by(|e1, e2| e1.0.eq(&e2.0)); + self.entries.sort_unstable_by(|e1, e2| e1.0.cmp(&e2.0)); + + // Security: Detect and reject duplicate keys instead of silently dropping them. + // This ensures serialization is lossless and maintains data integrity. + for i in 1..self.entries.len() { + if self.entries[i].0 == self.entries[i - 1].0 { + return Err(Error::NonCanonicalMap); + } + } let len = self.entries.len(); self.serializer.output_seq_len(len)?; @@ -567,6 +712,7 @@ where type Ok = (); type Error = Error; + #[inline] fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -574,6 +720,7 @@ where value.serialize(Serializer::new(self.output, self.max_remaining_depth)) } + #[inline] fn end(self) -> Result<()> { Ok(()) } @@ -586,6 +733,7 @@ where type Ok = (); type Error = Error; + #[inline] fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<()> where T: ?Sized + Serialize, @@ -593,6 +741,7 @@ where value.serialize(Serializer::new(self.output, self.max_remaining_depth)) } + #[inline] fn end(self) -> Result<()> { Ok(()) } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 4c7e51f..9840ad0 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -1,11 +1,21 @@ // Copyright (c) The Diem Core Contributors // SPDX-License-Identifier: Apache-2.0 -pub fn assert_canonical_encode_decode(t: T) +/// Asserts that a value can be serialized and deserialized back to an equal value. +/// +/// This is a test helper function that verifies round-trip serialization. +/// +/// # Panics +/// +/// Panics if: +/// - Serialization fails +/// - Deserialization fails +/// - The deserialized value is not equal to the original +pub fn assert_canonical_encode_decode(t: &T) where T: serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + PartialEq, { - let bytes = crate::to_bytes(&t).unwrap(); + let bytes = crate::to_bytes(t).unwrap(); let s: T = crate::from_bytes(&bytes).unwrap(); - assert_eq!(t, s); + assert_eq!(*t, s); } diff --git a/tests/serde.rs b/tests/serde.rs index 3ae3680..d5dd223 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -4,6 +4,12 @@ // For some reason deriving `Arbitrary` results in clippy firing a `unit_arg` violation #![allow(clippy::unit_arg)] #![allow(non_local_definitions)] +// Allow these clippy warnings for test code +#![allow(clippy::needless_pass_by_value)] +#![allow(clippy::zero_sized_map_values)] +#![allow(clippy::items_after_statements)] +#![allow(clippy::match_wildcard_for_single_variants)] +#![allow(clippy::owned_cow)] use std::{ collections::{BTreeMap, BTreeSet}, @@ -163,7 +169,7 @@ proptest! { #[test] fn proptest_option(v in any::>()) { - let expected = v.map(|v| vec![1, v]).unwrap_or_else(|| vec![0]); + let expected = v.map_or_else(|| vec![0], |v| vec![1, v]); assert_eq!(to_bytes(&v)?, expected); is_same(v); @@ -683,3 +689,273 @@ fn test_recursion_limit_enum() { let bytes = to_bytes_with_limit(&a, 1).unwrap(); let _: EnumA = from_bytes_with_limit(&bytes, 1).unwrap(); } + +// ============================================================================ +// Additional tests for 100% coverage +// ============================================================================ + +#[test] +fn test_to_bytes_with_capacity() { + use bcs::to_bytes_with_capacity; + + let data = vec![1u32, 2, 3, 4, 5]; + let bytes = to_bytes_with_capacity(&data, 100).unwrap(); + let expected = to_bytes(&data).unwrap(); + assert_eq!(bytes, expected); +} + +#[test] +fn test_serialized_size_with_limit() { + use bcs::{serialized_size, serialized_size_with_limit}; + + let data = vec![1u32, 2, 3]; + let size = serialized_size_with_limit(&data, 10).unwrap(); + assert_eq!(size, serialized_size(&data).unwrap()); + + // Test exceeding limit + let err = serialized_size_with_limit(&data, 501).unwrap_err(); + assert!(matches!(err, Error::NotSupported(_))); +} + +#[test] +fn test_serialize_into_with_limit_exceeding() { + use bcs::serialize_into_with_limit; + + let data = 42u32; + let mut output = Vec::new(); + let err = serialize_into_with_limit(&mut output, &data, 501).unwrap_err(); + assert!(matches!(err, Error::NotSupported(_))); +} + +#[test] +fn test_is_human_readable() { + assert!(!bcs::is_human_readable()); +} + +#[test] +fn test_from_bytes_seed() { + use bcs::from_bytes_seed; + use serde::de::DeserializeSeed; + + struct U32Seed; + + impl<'de> DeserializeSeed<'de> for U32Seed { + type Value = u32; + + fn deserialize(self, deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + u32::deserialize(deserializer) + } + } + + let bytes = to_bytes(&42u32).unwrap(); + let result: u32 = from_bytes_seed(U32Seed, &bytes).unwrap(); + assert_eq!(result, 42); +} + +#[test] +fn test_from_bytes_seed_with_limit() { + use bcs::from_bytes_seed_with_limit; + use serde::de::DeserializeSeed; + + struct U32Seed; + + impl<'de> DeserializeSeed<'de> for U32Seed { + type Value = u32; + + fn deserialize(self, deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + u32::deserialize(deserializer) + } + } + + let bytes = to_bytes(&42u32).unwrap(); + let result: u32 = from_bytes_seed_with_limit(U32Seed, &bytes, 10).unwrap(); + assert_eq!(result, 42); + + // Test exceeding limit + let err = from_bytes_seed_with_limit(U32Seed, &bytes, 501).unwrap_err(); + assert!(matches!(err, Error::NotSupported(_))); +} + +#[test] +fn test_unit_struct_serde() { + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct UnitStruct; + + let bytes = to_bytes(&UnitStruct).unwrap(); + let result: UnitStruct = from_bytes(&bytes).unwrap(); + assert_eq!(result, UnitStruct); +} + +#[test] +fn test_tuple_struct_serde() { + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct TupleStruct(u32, String); + + let data = TupleStruct(42, "hello".to_string()); + let bytes = to_bytes(&data).unwrap(); + let result: TupleStruct = from_bytes(&bytes).unwrap(); + assert_eq!(result, data); +} + +#[test] +fn test_eof_error_in_read_bytes() { + // Try to deserialize with insufficient bytes + let bytes = vec![0x05]; // Length says 5 bytes, but only 1 byte available + let err = from_bytes::>(&bytes).unwrap_err(); + assert!(matches!(err, Error::Eof)); +} + +#[test] +fn test_test_helpers() { + use bcs::test_helpers::assert_canonical_encode_decode; + + assert_canonical_encode_decode(&42u32); + assert_canonical_encode_decode(&"hello".to_string()); + assert_canonical_encode_decode(&vec![1, 2, 3]); +} + +#[test] +fn test_map_serialization_error_paths() { + // Test empty map + let map: BTreeMap = BTreeMap::new(); + let bytes = to_bytes(&map).unwrap(); + let result: BTreeMap = from_bytes(&bytes).unwrap(); + assert_eq!(result, map); +} + +#[test] +fn test_byte_buf_deserialization() { + // Test deserializing into Vec (uses deserialize_byte_buf path) + let data: Vec = vec![1, 2, 3, 4, 5]; + let bytes = to_bytes(&data).unwrap(); + let result: Vec = from_bytes(&bytes).unwrap(); + assert_eq!(result, data); +} + +#[test] +fn test_custom_error() { + // Test Error::Custom path via serde + let err: Error = ::custom("test error"); + assert!(matches!(err, Error::Custom(_))); +} + +#[test] +fn test_io_error_conversion() { + use std::io; + + let io_err = io::Error::other("test"); + let bcs_err: Error = io_err.into(); + assert!(matches!(bcs_err, Error::Io(_))); +} + +#[test] +fn test_serialization_custom_error() { + // Test ser::Error::custom path + let err: Error = ::custom("serialization error"); + assert!(matches!(err, Error::Custom(_))); +} + +#[test] +fn test_exceeded_max_sequence_length() { + // Create bytes that claim a huge sequence length + // ULEB128 encoding of MAX_SEQUENCE_LENGTH + 1 = 2^31 + let bytes = vec![0x80, 0x80, 0x80, 0x80, 0x08]; // 2^31 in ULEB128 + let err = from_bytes::>(&bytes).unwrap_err(); + assert!(matches!(err, Error::ExceededMaxLen(_))); +} + +// ============================================================================ +// Security tests +// ============================================================================ + +#[test] +fn test_memory_amplification_protection() { + // Security test: Ensure that claiming a large length with insufficient data + // returns Eof quickly without attempting to allocate large amounts of memory. + // + // A 5-byte payload claiming 1 million elements should fail immediately, + // not after trying to allocate memory for 1 million elements. + + // ULEB128 encoding of 1,000,000 = 0xF4240 + // 0x40 | 0x80 = 0xC0, 0x84 | 0x80 = 0x84 (wait, let me recalculate) + // 1,000,000 = 0xF4240 + // Byte 0: 0x40 | 0x80 = 0xC0 (bits 0-6 = 64) + // Byte 1: 0x48 | 0x80 = 0xC8 (bits 7-13 = 72) -- wait let me do this properly + // 1,000,000 in binary: 11110100001001000000 + // Split into 7-bit groups from LSB: 1000000 (64), 0100100 (36), 0111101 (61) + // So: 0xC0 (64 + 0x80), 0xA4 (36 + 0x80), 0x3D (61, no continuation) + let bytes = vec![0xC0, 0x84, 0x3D]; // ULEB128 for 1,000,000 + + let err = from_bytes::>(&bytes).unwrap_err(); + // Should fail with Eof because we only have 0 bytes of actual data, + // but claimed 1,000,000 bytes + assert!(matches!(err, Error::Eof)); +} + +#[test] +fn test_memory_amplification_protection_string() { + // Test the same protection for strings + // Claim 10,000 bytes but provide only 2 + let bytes = vec![0x90, 0x4E, b'h', b'i']; // ULEB128 for 10,000, then "hi" + + let err = from_bytes::(&bytes).unwrap_err(); + assert!(matches!(err, Error::Eof)); +} + +#[test] +fn test_memory_amplification_small_valid_length() { + // Make sure valid small lengths still work + let bytes = vec![0x02, 0x41, 0x42]; // length 2, then "AB" + let result: Vec = from_bytes(&bytes).unwrap(); + assert_eq!(result, vec![0x41, 0x42]); +} + +#[test] +fn test_duplicate_map_keys_serialization() { + // Create a HashMap and manually serialize it to test the duplicate detection. + // Since HashMap doesn't allow duplicate keys directly, we need to test via + // serde's serialize_map interface with a custom type. + #[derive(Debug)] + struct DuplicateKeyMap; + + impl Serialize for DuplicateKeyMap { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry(&1u32, &"first")?; + map.serialize_entry(&1u32, &"second")?; // Duplicate key! + map.end() + } + } + + let err = to_bytes(&DuplicateKeyMap).unwrap_err(); + assert!(matches!(err, Error::NonCanonicalMap)); +} + +#[test] +fn test_map_serialization_no_duplicates() { + use std::collections::HashMap; + + // Normal map without duplicates should work fine + let mut map = HashMap::new(); + map.insert(1u32, "one"); + map.insert(2u32, "two"); + map.insert(3u32, "three"); + + let bytes = to_bytes(&map).unwrap(); + let result: HashMap = from_bytes(&bytes).unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!(result.get(&1).map(String::as_str), Some("one")); + assert_eq!(result.get(&2).map(String::as_str), Some("two")); + assert_eq!(result.get(&3).map(String::as_str), Some("three")); +}