diff --git a/Cargo.lock b/Cargo.lock index 51fafa85..a4348567 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2241,6 +2241,7 @@ dependencies = [ "default-net", "dstack-guest-agent-rpc", "dstack-types", + "ed25519-dalek", "figment", "fs-err", "git-version", @@ -2251,6 +2252,7 @@ dependencies = [ "load_config", "ra-rpc", "ra-tls", + "rand 0.8.5", "rcgen", "reqwest", "ring", @@ -2265,6 +2267,7 @@ dependencies = [ "strip-ansi-escapes", "sysinfo", "tdx-attest", + "tempfile", "tokio", "tracing", "tracing-subscriber", @@ -2574,6 +2577,31 @@ dependencies = [ "spki", ] +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8", + "signature", +] + +[[package]] +name = "ed25519-dalek" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" +dependencies = [ + "curve25519-dalek", + "ed25519", + "rand_core 0.6.4", + "serde", + "sha2 0.10.9", + "subtle", + "zeroize", +] + [[package]] name = "either" version = "1.15.0" diff --git a/Cargo.toml b/Cargo.toml index c0283de3..918afc71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -171,6 +171,7 @@ blake2 = "0.10.6" tokio-rustls = { version = "0.26.2", features = ["ring"] } x25519-dalek = { version = "2.0.1", features = ["static_secrets"] } k256 = "0.13.4" +ed25519-dalek = { version = "2.2.0", features = ["rand_core"] } # Additional RustCrypto dependencies for sealed box xsalsa20poly1305 = "0.9.0" salsa20 = "0.10" diff --git a/guest-agent/Cargo.toml b/guest-agent/Cargo.toml index d17ec3b5..101c63d7 100644 --- a/guest-agent/Cargo.toml +++ b/guest-agent/Cargo.toml @@ -47,3 +47,6 @@ sha3.workspace = true strip-ansi-escapes.workspace = true cert-client.workspace = true ring.workspace = true +ed25519-dalek.workspace = true +tempfile.workspace = true +rand.workspace = true diff --git a/guest-agent/rpc/proto/agent_rpc.proto b/guest-agent/rpc/proto/agent_rpc.proto index 15606dbc..4d728dc3 100644 --- a/guest-agent/rpc/proto/agent_rpc.proto +++ b/guest-agent/rpc/proto/agent_rpc.proto @@ -37,7 +37,7 @@ service DstackGuest { // Returns the derived key along with its TLS certificate chain. rpc GetTlsKey(GetTlsKeyArgs) returns (GetTlsKeyResponse) {} - // Derives a new ECDSA key with k256 EC curve. + // Derives a new key. rpc GetKey(GetKeyArgs) returns (GetKeyResponse) {} // Generates a TDX quote with given report data. @@ -48,6 +48,12 @@ service DstackGuest { // Get app info rpc Info(google.protobuf.Empty) returns (AppInfo) {} + + // Sign a payload + rpc Sign(SignRequest) returns (SignResponse) {} + + // Verify a signature + rpc Verify(VerifyRequest) returns (VerifyResponse) {} } // The request to derive a key @@ -91,12 +97,14 @@ message GetTlsKeyResponse { repeated string certificate_chain = 2; } -// The request to derive a new ECDSA key with k256 EC curve +// The request to derive a new key message GetKeyArgs { // Path to the key to derive string path = 1; // Purpose of the key string purpose = 2; + // Algorithm of the key. Either `secp256k1` or `ed25519`. Defaults to `secp256k1` + string algorithm = 3; } // The response to a DeriveK256Key request @@ -109,9 +117,11 @@ message DeriveK256KeyResponse { // The response to a GetEthKey request message GetKeyResponse { - // Derived k256 key + // Derived key bytes key = 1; - // Derived k256 signature chain + // The signature chain consists of the following signatures: + // [0] - the k256 signature of the derived pK signed by the app root key + // [1] - the k256 signature of the app root pK signed by the KMS root key repeated bytes signature_chain = 2; } @@ -216,4 +226,38 @@ service Worker { rpc Info(google.protobuf.Empty) returns (AppInfo) {} // Get the guest agent version rpc Version(google.protobuf.Empty) returns (WorkerVersion) {} + // Get attestation + rpc GetAttestationForAppKey(GetAttestationForAppKeyRequest) returns (GetQuoteResponse) {} +} + +message SignRequest { + string algorithm = 1; + bytes data = 2; +} + +message SignResponse { + // the signature of the data + bytes signature = 1; + // The signature chain consists of the following signatures: + // [0] - the signature of the data + // [1] - the k256 signature of the message signing pubkey signed by the app root key + // [2] - the k256 signature of the app root pubkey signed by the KMS root key + repeated bytes signature_chain = 2; + // The public key signing the data + bytes public_key = 3; +} + +message VerifyRequest { + string algorithm = 1; + bytes data = 2; + bytes signature = 3; + bytes public_key = 4; +} + +message VerifyResponse { + bool valid = 1; +} + +message GetAttestationForAppKeyRequest { + string algorithm = 1; } diff --git a/guest-agent/src/rpc_service.rs b/guest-agent/src/rpc_service.rs index b4b53923..30b1f6d7 100644 --- a/guest-agent/src/rpc_service.rs +++ b/guest-agent/src/rpc_service.rs @@ -5,16 +5,22 @@ use std::sync::{Arc, RwLock}; use anyhow::{Context, Result}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; use cert_client::CertRequestClient; use dstack_guest_agent_rpc::{ dstack_guest_server::{DstackGuestRpc, DstackGuestServer}, tappd_server::{TappdRpc, TappdServer}, worker_server::{WorkerRpc, WorkerServer}, - AppInfo, DeriveK256KeyResponse, DeriveKeyArgs, EmitEventArgs, GetKeyArgs, GetKeyResponse, - GetQuoteResponse, GetTlsKeyArgs, GetTlsKeyResponse, RawQuoteArgs, TdxQuoteArgs, - TdxQuoteResponse, WorkerVersion, + AppInfo, DeriveK256KeyResponse, DeriveKeyArgs, EmitEventArgs, GetAttestationForAppKeyRequest, + GetKeyArgs, GetKeyResponse, GetQuoteResponse, GetTlsKeyArgs, GetTlsKeyResponse, RawQuoteArgs, + SignRequest, SignResponse, TdxQuoteArgs, TdxQuoteResponse, VerifyRequest, VerifyResponse, + WorkerVersion, }; use dstack_types::{AppKeys, SysConfig}; +use ed25519_dalek::ed25519::signature::hazmat::{PrehashSigner, PrehashVerifier}; +use ed25519_dalek::{ + Signer as Ed25519Signer, SigningKey as Ed25519SigningKey, Verifier as Ed25519Verifier, +}; use fs_err as fs; use k256::ecdsa::SigningKey; use ra_rpc::{Attestation, CallContext, RpcCall}; @@ -210,16 +216,33 @@ impl DstackGuestRpc for InternalRpcHandler { async fn get_key(self, request: GetKeyArgs) -> Result { let k256_app_key = &self.state.inner.keys.k256_key; - let derived_k256_key = derive_ecdsa_key(k256_app_key, &[request.path.as_bytes()], 32) - .context("Failed to derive k256 key")?; - let derived_k256_key = - SigningKey::from_slice(&derived_k256_key).context("Failed to parse k256 key")?; - let derived_k256_pubkey = derived_k256_key.verifying_key(); - let msg_to_sign = format!( - "{}:{}", - request.purpose, - hex::encode(derived_k256_pubkey.to_sec1_bytes()) - ); + + let (key, pubkey_hex) = match request.algorithm.as_str() { + "ed25519" => { + let derived_key = derive_ecdsa_key(k256_app_key, &[request.path.as_bytes()], 32) + .context("Failed to derive ed25519 key")?; + let signing_key = Ed25519SigningKey::from_bytes( + &derived_key + .as_slice() + .try_into() + .or(Err(anyhow::anyhow!("Invalid key length")))?, + ); + let pubkey_hex = hex::encode(signing_key.verifying_key().as_bytes()); + (derived_key, pubkey_hex) + } + "secp256k1" | "secp256k1_prehashed" | "" => { + let derived_key = derive_ecdsa_key(k256_app_key, &[request.path.as_bytes()], 32) + .context("Failed to derive k256 key")?; + + let signing_key = + SigningKey::from_slice(&derived_key).context("Failed to parse k256 key")?; + let pubkey_hex = hex::encode(signing_key.verifying_key().to_sec1_bytes()); + (derived_key, pubkey_hex) + } + _ => return Err(anyhow::anyhow!("Unsupported algorithm")), + }; + + let msg_to_sign = format!("{}:{}", request.purpose, pubkey_hex); let app_signing_key = SigningKey::from_slice(k256_app_key).context("Failed to parse app k256 key")?; let digest = Keccak256::new_with_prefix(msg_to_sign); @@ -228,7 +251,7 @@ impl DstackGuestRpc for InternalRpcHandler { signature.push(recid.to_byte()); Ok(GetKeyResponse { - key: derived_k256_key.to_bytes().to_vec(), + key, signature_chain: vec![signature, self.state.inner.keys.k256_signature.clone()], }) } @@ -274,6 +297,83 @@ impl DstackGuestRpc for InternalRpcHandler { async fn info(self) -> Result { get_info(&self.state, false).await } + + async fn sign(self, request: SignRequest) -> Result { + let key_response = self + .get_key(GetKeyArgs { + path: "vms".to_string(), + purpose: "signing".to_string(), + algorithm: request.algorithm.clone(), + }) + .await?; + let (signature, public_key) = match request.algorithm.as_str() { + "ed25519" => { + let key_bytes: [u8; 32] = key_response.key.try_into().expect("Key is incorrect"); + let signing_key = Ed25519SigningKey::from_bytes(&key_bytes); + let signature = signing_key.sign(&request.data); + let public_key = signing_key.verifying_key().to_bytes().to_vec(); + (signature.to_bytes().to_vec(), public_key) + } + "secp256k1" => { + let signing_key = SigningKey::from_slice(&key_response.key) + .context("Failed to parse secp256k1 key")?; + let signature: k256::ecdsa::Signature = signing_key.sign(&request.data); + let public_key = signing_key.verifying_key().to_sec1_bytes().to_vec(); + (signature.to_bytes().to_vec(), public_key) + } + "secp256k1_prehashed" => { + if request.data.len() != 32 { + return Err(anyhow::anyhow!( + "Pre-hashed signing requires a 32-byte digest, but received {} bytes", + request.data.len() + )); + } + let signing_key = SigningKey::from_slice(&key_response.key) + .context("Failed to parse secp256k1 key")?; + let signature: k256::ecdsa::Signature = signing_key.sign_prehash(&request.data)?; + let public_key = signing_key.verifying_key().to_sec1_bytes().to_vec(); + (signature.to_bytes().to_vec(), public_key) + } + _ => return Err(anyhow::anyhow!("Unsupported algorithm")), + }; + Ok(SignResponse { + signature: signature.clone(), + signature_chain: vec![ + signature, + key_response.signature_chain[0].clone(), + key_response.signature_chain[1].clone(), + ], + public_key, + }) + } + + async fn verify(self, request: VerifyRequest) -> Result { + let valid = match request.algorithm.as_str() { + "ed25519" => { + let verifying_key = ed25519_dalek::VerifyingKey::from_bytes( + &request.public_key.as_slice().try_into().unwrap(), + )?; + let signature = ed25519_dalek::Signature::from_slice(&request.signature)?; + verifying_key.verify(&request.data, &signature).is_ok() + } + "secp256k1" => { + let verifying_key = + k256::ecdsa::VerifyingKey::from_sec1_bytes(&request.public_key)?; + let signature = k256::ecdsa::Signature::from_slice(&request.signature)?; + verifying_key.verify(&request.data, &signature).is_ok() + } + "secp256k1_prehashed" => { + let verifying_key = + k256::ecdsa::VerifyingKey::from_sec1_bytes(&request.public_key)?; + let signature = k256::ecdsa::Signature::from_slice(&request.signature)?; + verifying_key + .verify_prehash(&request.data, &signature) + .is_ok() + } + _ => return Err(anyhow::anyhow!("Unsupported algorithm")), + }; + Ok(VerifyResponse { valid }) + } } fn simulate_quote( @@ -445,6 +545,90 @@ impl WorkerRpc for ExternalRpcHandler { rev: super::GIT_REV.to_string(), }) } + + async fn get_attestation_for_app_key( + self, + request: GetAttestationForAppKeyRequest, + ) -> Result { + let key_response = InternalRpcHandler { + state: self.state.clone(), + } + .get_key(GetKeyArgs { + path: "vms".to_string(), + purpose: "signing".to_string(), + algorithm: request.algorithm.clone(), + }) + .await?; + + match request.algorithm.as_str() { + "ed25519" => { + let key_bytes: [u8; 32] = key_response.key.try_into().expect("Key is incorrect"); + let ed25519_key = Ed25519SigningKey::from_bytes(&key_bytes); + let ed25519_pubkey = ed25519_key.verifying_key().to_bytes(); + + let mut ed25519_report_data = [0u8; 64]; + let ed25519_b64 = URL_SAFE_NO_PAD.encode(ed25519_pubkey); + let ed25519_report_string = format!("dip1::ed25519-pk:{}", ed25519_b64); + let ed_bytes = ed25519_report_string.as_bytes(); + ed25519_report_data[..ed_bytes.len()].copy_from_slice(ed_bytes); + + if self.state.config().simulator.enabled { + Ok(simulate_quote( + self.state.config(), + ed25519_report_data, + &self.state.inner.vm_config, + )?) + } else { + let ed25519_quote = tdx_attest::get_quote(&ed25519_report_data, None) + .context("Failed to get ed25519 quote")? + .1; + let event_log = serde_json::to_string( + &read_event_logs().context("Failed to read event log")?, + )?; + Ok(GetQuoteResponse { + quote: ed25519_quote, + event_log: event_log.clone(), + report_data: ed25519_report_data.to_vec(), + vm_config: self.state.inner.vm_config.clone(), + }) + } + } + "secp256k1" | "secp256k1_prehashed" => { + let secp256k1_key = SigningKey::from_slice(&key_response.key) + .context("Failed to parse secp256k1 key")?; + let secp256k1_pubkey = secp256k1_key.verifying_key().to_sec1_bytes(); + + let mut secp256k1_report_data = [0u8; 64]; + let secp256k1_b64 = URL_SAFE_NO_PAD.encode(secp256k1_pubkey); + let secp256k1_report_string = format!("dip1::secp256k1c-pk:{}", secp256k1_b64); + let secp_bytes = secp256k1_report_string.as_bytes(); + secp256k1_report_data[..secp_bytes.len()].copy_from_slice(secp_bytes); + + if self.state.config().simulator.enabled { + Ok(simulate_quote( + self.state.config(), + secp256k1_report_data, + &self.state.inner.vm_config, + )?) + } else { + let secp256k1_quote = tdx_attest::get_quote(&secp256k1_report_data, None) + .context("Failed to get secp256k1 quote")? + .1; + let event_log = serde_json::to_string( + &read_event_logs().context("Failed to read event log")?, + )?; + + Ok(GetQuoteResponse { + quote: secp256k1_quote, + event_log, + report_data: secp256k1_report_data.to_vec(), + vm_config: self.state.inner.vm_config.clone(), + }) + } + } + _ => Err(anyhow::anyhow!("Unsupported algorithm")), + } + } } impl RpcCall for ExternalRpcHandler { @@ -456,3 +640,409 @@ impl RpcCall for ExternalRpcHandler { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{AppComposeWrapper, Config, Simulator}; + use dstack_guest_agent_rpc::{GetAttestationForAppKeyRequest, SignRequest}; + use dstack_types::{AppCompose, AppKeys, KeyProvider}; + use ed25519_dalek::ed25519::signature::hazmat::PrehashVerifier; + use ed25519_dalek::{ + Signature as Ed25519Signature, Verifier, VerifyingKey as Ed25519VerifyingKey, + }; + use k256::ecdsa::{Signature as K256Signature, VerifyingKey}; + use sha2::Sha256; + use std::collections::HashSet; + use std::convert::TryFrom; + use std::io::Write; + use tempfile; + + fn extract_pubkey_from_report_data(report_data: &[u8], prefix: &str) -> Result> { + let end = report_data + .iter() + .position(|&b| b == 0) + .unwrap_or(report_data.len()); + let report_str = std::str::from_utf8(&report_data[..end])?; + + if let Some(base64_pk) = report_str.strip_prefix(prefix) { + URL_SAFE_NO_PAD + .decode(base64_pk) + .context("Failed to decode base64") + } else { + Err(anyhow::anyhow!("Prefix not found in report data")) + } + } + + async fn setup_test_state() -> (AppState, tempfile::NamedTempFile, tempfile::NamedTempFile) { + let mut dummy_quote_file = tempfile::NamedTempFile::new().unwrap(); + let dummy_event_log_file = tempfile::NamedTempFile::new().unwrap(); + + let dummy_quote = vec![b'0'; 10020]; + dummy_quote_file.write_all(&dummy_quote).unwrap(); + dummy_quote_file.flush().unwrap(); + + let dummy_simulator = Simulator { + enabled: true, + quote_file: dummy_quote_file.path().to_str().unwrap().to_string(), + event_log_file: dummy_event_log_file.path().to_str().unwrap().to_string(), + }; + + let dummy_appcompose = AppCompose { + manifest_version: 0, + name: String::new(), + features: Vec::new(), + runner: String::new(), + docker_compose_file: None, + public_logs: false, + public_sysinfo: false, + public_tcbinfo: false, + kms_enabled: false, + gateway_enabled: false, + local_key_provider_enabled: false, + key_provider: None, + key_provider_id: Vec::new(), + allowed_envs: Vec::new(), + no_instance_id: false, + secure_time: false, + storage_fs: None, + swap_size: 0, + }; + + let dummy_appcompose_wrapper = AppComposeWrapper { + app_compose: dummy_appcompose, + raw: String::new(), + }; + + let dummy_config = Config { + keys_file: String::new(), + app_compose: dummy_appcompose_wrapper, + sys_config_file: String::new().into(), + pccs_url: None, + simulator: dummy_simulator, + data_disks: HashSet::new(), + }; + + const DUMMY_PEM_KEY: &str = r#"-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCSeV81CKVqILf/ +bk+OarAkZeph4ggb1d9Qt4bzJjVNsowpc/iWbacO6dHvrjXrqNdK7WEHDuxYlQCS +xppINUCKyCoelAt2OJuUonLHtT3s41pGM0k69fcUb420fhKqNAHIaCCc38vOFDZ7 +aqLUGNDooc7bXgZxHUJHmq9QneeB74Ia+6TzA2KKXMu4ixvZWvrgRt64XKyL3+4J +sQ6QqSgopGeyTv0blxFxF6X8UTUO/nZPnqf7BN9GnkJtHglb0TLI1H7BYvFmnpjT +8yfjmdbRxvnczvRJuKCzTq9ePEvhRrwAzqQk3Ide0/KWdIiu2nrrfO/Imvia1DNp +GgJsV0L7AgMBAAECggEARUbTcV1kAwRzkgOF7CloouZzCxWhWSz4AJC06oadOmDi +qu53WgqFs2eCjBZ82TdTkFQiiniT7zeV/FWjfdh17M3MIgdKPoF6kDufBvahUcuc +FEzIa3MPB+LVBlOEl2yelT8ugZPVrGPh+tBOL/uGvyhckmNvr4szoHM4TOxKJSk/ +njFbJcoX3UmampyxSa6MMSGaxM2pdziTujoj5+sJ/a0x/wwIih/XEZSWgLzDjGZS +qaKmldjD0SRJQrZ1LTjjguKtkbOwKa2dtNOoHBkAtHyI+vWOLXNzZisXMazpmHNT +mE2X6oQFcAXI7HHuHzkLaLpEdqlHA16nwFPNF0LzAQKBgQDLaE1eZnutK+nxHpUq +cb3vMGN8dPxCrQJz/fvEb6lP93RCWBZbGen2gLGvFKyFwPcD/OR0HfBnFRjHIy25 +V4ta+iubQM3GFO2FOp9SwequCPY2H6YXah4LyXrCIw4Pv3x/I2bpbLOlltmMT5PS +qPV86dH546kxOsJS6VhMCcQXAQKBgQC4WJu9VTBPfKf8JL8f7b/K0+MBN3OBkhsN +V6nCR8JizAa1hxmxpMaeq7PqlGpJhQKinBblR314Cpqqrt7AL005gCxD0ddBM9Ib +/7HafmLrAuhEDxnYx/QAyprTOsqjLS8Vd+eaA0nGF68R1LLHLxfXfhiuAjMwScCs +afCrbdG1+wKBgAyZ3ZEnkCneOpPxbRRAD6AtwzwGk0oeJbTB20MEF90YW19wzZG/ +PTtEJb3O7hErLyJUHGMFJ8t7BxnvF/oPblaogOMRVK4cxconI4+g68T0USxxMXzp +2gqo5K36NfjLyA6oRsvXLBnqCngixembBfpDEfsFG4otNbSlOA8d28QBAoGBAKdG +YCtxPaEi8BtwDK2gQsR9eCMGeh08wqdcwIG2M8EKeZwGt13mswQPsfZOLhQASd/b +2zq5oDRpCueOPjoNsflXQNNZegWETEdzwaMNxByUSsZXHZED/3koX00EsBNZULwe +TV4HVc4Wd5mqc38iUHQNy78559ENW3QXvXcQ85Y5AoGBAIQlSbNRupo/5ATwJW0e +bggPyacIhS9GrsgP9qz9p8xxNSfcyAFRGiXnlGoiRbNchbUiZPRjoJ08lOHGxVQw +O17ivI85heZnG+i5Yz0ZolMd8fbc4h78oA9FnJQJV5AeTDqTxf528A2jyWCAmu11 +Sv2zO+vcYHN7bT2UTCEWkeAw +-----END PRIVATE KEY----- +"#; + + const DUMMY_PEM_CERT: &str = r#"-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUYRX7SNHsL6EGSy0ACQzjX4cfaw0wDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MTAwOTEyNDMyN1oXDTI2MTAw +OTEyNDMyN1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAknlfNQilaiC3/25PjmqwJGXqYeIIG9XfULeG8yY1TbKM +KXP4lm2nDunR764166jXSu1hBw7sWJUAksaaSDVAisgqHpQLdjiblKJyx7U97ONa +RjNJOvX3FG+NtH4SqjQByGggnN/LzhQ2e2qi1BjQ6KHO214GcR1CR5qvUJ3nge+C +Gvuk8wNiilzLuIsb2Vr64EbeuFysi9/uCbEOkKkoKKRnsk79G5cRcRel/FE1Dv52 +T56n+wTfRp5CbR4JW9EyyNR+wWLxZp6Y0/Mn45nW0cb53M70Sbigs06vXjxL4Ua8 +AM6kJNyHXtPylnSIrtp663zvyJr4mtQzaRoCbFdC+wIDAQABo1MwUTAdBgNVHQ4E +FgQUsnBjoCWFH3il0MvjO9p0o/vcACgwHwYDVR0jBBgwFoAUsnBjoCWFH3il0Mvj +O9p0o/vcACgwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAj9rI +cHDTj9LhD2Nca/Mj2dNwUa1Fq81I5EF3GWi6mosTT4hfQupUC1i/6UE6ubLHRUGr +J3JnHBG8hUCddx5VxLncDmYP/4LHVEue/XdCURgY+K2WxQnUPDzZV2mXJXUzp8si +6xzFyiPyf4qsQaoRQnpOmyUXvBwtdf3M28EA/pTBBDZ4pZJ1QaSTlT7fpDgK2e6L +arBh7HebdS9UBaWLtYBMsRWRK5qpOQnLiy8H6J93/W6i4X3DSxeZXeYiMSO/jsJ8 +5XxL9zqOVjsw9Bxr79zCe7JF6fp6r3miUndMHQch/WXOY07lxH00cEqYo+2/Vk5D +pNs85uhOZE8z2jr8Pg== +-----END CERTIFICATE----- +"#; + + const DUMMY_K256_KEY: [u8; 32] = [ + 0x1A, 0x2B, 0x3C, 0x4D, 0x5E, 0x6F, 0x7A, 0x8B, 0x9C, 0x0D, 0x1E, 0x2F, 0x3A, 0x4B, + 0x5C, 0x6D, 0x7E, 0x8F, 0x9A, 0x0B, 0x1C, 0x2D, 0x3E, 0x4F, 0x5A, 0x6B, 0x7C, 0x8D, + 0x9E, 0x0F, 0x1A, 0x2B, + ]; + + let dummy_keys = AppKeys { + disk_crypt_key: Vec::new(), + env_crypt_key: Vec::new(), + k256_key: DUMMY_K256_KEY.to_vec(), + k256_signature: Vec::new(), + gateway_app_id: String::new(), + ca_cert: DUMMY_PEM_CERT.to_string(), + key_provider: KeyProvider::None { + key: DUMMY_PEM_KEY.to_string(), + }, + }; + + let dummy_cert_client = CertRequestClient::create(&dummy_keys, None, String::new()) + .await + .expect("Failed to create CertRequestClient"); + + let inner = AppStateInner { + config: dummy_config, + keys: dummy_keys, + vm_config: String::new(), + cert_client: dummy_cert_client, + demo_cert: RwLock::new(String::new()), + }; + + ( + AppState { + inner: Arc::new(inner), + }, + dummy_quote_file, + dummy_event_log_file, + ) + } + + #[tokio::test] + async fn test_verify_ed25519_success() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = InternalRpcHandler { + state: state.clone(), + }; + let data_to_sign = b"test message for ed25519"; + let sign_request = SignRequest { + algorithm: "ed25519".to_string(), + data: data_to_sign.to_vec(), + }; + + let sign_response = handler.sign(sign_request).await.unwrap(); + + let verify_request = VerifyRequest { + algorithm: "ed25519".to_string(), + data: data_to_sign.to_vec(), + signature: sign_response.signature, + public_key: sign_response.public_key, + }; + let handler = InternalRpcHandler { + state: state.clone(), + }; + let verify_response = handler.verify(verify_request).await.unwrap(); + assert!(verify_response.valid); + } + + #[tokio::test] + async fn test_verify_secp256k1_success() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = InternalRpcHandler { + state: state.clone(), + }; + let data_to_sign = b"test message for secp256k1"; + let sign_request = SignRequest { + algorithm: "secp256k1".to_string(), + data: data_to_sign.to_vec(), + }; + + let sign_response = handler.sign(sign_request).await.unwrap(); + + let verify_request = VerifyRequest { + algorithm: "secp256k1".to_string(), + data: data_to_sign.to_vec(), + signature: sign_response.signature, + public_key: sign_response.public_key, + }; + let handler = InternalRpcHandler { + state: state.clone(), + }; + let verify_response = handler.verify(verify_request).await.unwrap(); + assert!(verify_response.valid); + } + + #[tokio::test] + async fn test_sign_ed25519_success() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = InternalRpcHandler { + state: state.clone(), + }; + let data_to_sign = b"test message for ed25519"; + let request = SignRequest { + algorithm: "ed25519".to_string(), + data: data_to_sign.to_vec(), + }; + + let response = handler.sign(request).await.unwrap(); + + let attestation_response = ExternalRpcHandler::new(state) + .get_attestation_for_app_key(GetAttestationForAppKeyRequest { + algorithm: "ed25519".to_string(), + }) + .await + .unwrap(); + + let pk_bytes = + extract_pubkey_from_report_data(&attestation_response.report_data, "dip1::ed25519-pk:") + .unwrap(); + + let public_key = Ed25519VerifyingKey::try_from(pk_bytes.as_slice()).unwrap(); + let signature = Ed25519Signature::try_from(response.signature.as_slice()).unwrap(); + assert!(public_key.verify(data_to_sign, &signature).is_ok()); + } + + #[tokio::test] + async fn test_sign_secp256k1_success() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = InternalRpcHandler { + state: state.clone(), + }; + let data_to_sign = b"test message for secp256k1"; + let request = SignRequest { + algorithm: "secp256k1".to_string(), + data: data_to_sign.to_vec(), + }; + + let response = handler.sign(request).await.unwrap(); + + let attestation_response = ExternalRpcHandler::new(state) + .get_attestation_for_app_key(GetAttestationForAppKeyRequest { + algorithm: "secp256k1".to_string(), + }) + .await + .unwrap(); + + let pk_bytes = extract_pubkey_from_report_data( + &attestation_response.report_data, + "dip1::secp256k1c-pk:", + ) + .unwrap(); + + let public_key = VerifyingKey::from_sec1_bytes(&pk_bytes).unwrap(); + let signature = K256Signature::try_from(response.signature.as_slice()).unwrap(); + assert!(public_key.verify(data_to_sign, &signature).is_ok()); + } + + #[tokio::test] + async fn test_sign_secp256k1_prehashed_success() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = InternalRpcHandler { + state: state.clone(), + }; + let data_to_sign = b"test message for secp256k1 prehashed"; + + let digest = Sha256::digest(data_to_sign); + + let request = SignRequest { + algorithm: "secp256k1_prehashed".to_string(), + data: digest.to_vec(), + }; + + let response = handler.sign(request).await.unwrap(); + + let attestation_response = ExternalRpcHandler::new(state) + .get_attestation_for_app_key(GetAttestationForAppKeyRequest { + algorithm: "secp256k1".to_string(), + }) + .await + .unwrap(); + + let pk_bytes = extract_pubkey_from_report_data( + &attestation_response.report_data, + "dip1::secp256k1c-pk:", + ) + .unwrap(); + + let public_key = VerifyingKey::from_sec1_bytes(&pk_bytes).unwrap(); + let signature = K256Signature::try_from(response.signature.as_slice()).unwrap(); + assert!(public_key + .verify_prehash(digest.as_slice(), &signature) + .is_ok()); + } + + #[tokio::test] + async fn test_sign_secp256k1_prehashed_invalid_length_fails() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = InternalRpcHandler { + state: state.clone(), + }; + + // digest with an invalid length + let invalid_digest = vec![0; 31]; + + let request = SignRequest { + algorithm: "secp256k1_prehashed".to_string(), + data: invalid_digest, + }; + + let response = handler.sign(request).await; + assert!(response.is_err()); + assert!(response + .unwrap_err() + .to_string() + .contains("requires a 32-byte digest")); + } + + #[tokio::test] + async fn test_sign_unsupported_algorithm_fails() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = InternalRpcHandler { state }; + let request = SignRequest { + algorithm: "rsa".to_string(), // Unsupported algorithm + data: b"test message".to_vec(), + }; + + let result = handler.sign(request).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "Unsupported algorithm"); + } + + #[tokio::test] + async fn test_get_attestation_for_app_key_ed25519_success() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = ExternalRpcHandler::new(state.clone()); + let request = GetAttestationForAppKeyRequest { + algorithm: "ed25519".to_string(), + }; + + let response = handler.get_attestation_for_app_key(request).await.unwrap(); + + const EXPECTED_REPORT_DATA: &str = + "dip1::ed25519-pk:5Pbre1Amf1hrp2V2bbfKlIfxpQb2pJAmrgmhxgVoG9s\0\0\0\0"; + assert_eq!(EXPECTED_REPORT_DATA.as_bytes(), response.report_data); + } + + #[tokio::test] + async fn test_get_attestation_for_app_key_secp256k1_success() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = ExternalRpcHandler::new(state.clone()); + let request = GetAttestationForAppKeyRequest { + algorithm: "secp256k1".to_string(), + }; + + let response = handler.get_attestation_for_app_key(request).await.unwrap(); + + const EXPECTED_REPORT_DATA: &str = + "dip1::secp256k1c-pk:A6t_JdVkVdMAocH3f1f20WGT6JzdntxcXimUtEax8zc9"; + assert_eq!(EXPECTED_REPORT_DATA.as_bytes(), response.report_data); + } + + #[tokio::test] + async fn test_get_attestation_for_app_key_unsupported_algorithm_fails() { + let (state, _quote_file, _log_file) = setup_test_state().await; + let handler = ExternalRpcHandler::new(state); + let request = GetAttestationForAppKeyRequest { + algorithm: "ecdsa".to_string(), // Unsupported algorithm + }; + + let result = handler.get_attestation_for_app_key(request).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "Unsupported algorithm"); + } +} diff --git a/sdk/curl/api.md b/sdk/curl/api.md index 99cd119b..2a18c393 100644 --- a/sdk/curl/api.md +++ b/sdk/curl/api.md @@ -71,6 +71,7 @@ Generates an ECDSA key using the k256 elliptic curve, derived from the applicati |-------|------|-------------|----------| | `path` | string | Path for the key | `"my/key/path"` | | `purpose` | string | Purpose for the key. Can be any string. This is used in the signature chain. | `"signing"` | `"encryption"` | +| `algorithm` | string | Either `secp256k1` or `ed25519`. Defaults to `secp256k1` | `ed25519` | **Example:** ```bash @@ -79,14 +80,15 @@ curl --unix-socket /var/run/dstack.sock -X POST \ -H 'Content-Type: application/json' \ -d '{ "path": "my/key/path", - "purpose": "signing" + "purpose": "signing", + "algorithm": "ed25519", }' ``` Or ```bash -curl --unix-socket /var/run/dstack.sock http://dstack/GetKey?path=my/key/path&purpose=signing +curl --unix-socket /var/run/dstack.sock http://dstack/GetKey?path=my/key/path&purpose=signing&algorithm=ed25519 ``` **Response:** @@ -191,6 +193,80 @@ curl --unix-socket /var/run/dstack.sock -X POST \ **Response:** Empty response with HTTP 200 status code on success. +### 6. Sign + +Signs a payload. + +**Endpoint:** `/Sign` + +**Request Parameters:** + +| Field | Type | Description | Example | +|-------|------|-------------|----------| +| `algorithm` | string | `ed25519`, `secp256k1_prehashed` or `secp256k1`| `ed25519` | +| `data` | string | Hex-encoded payload data | `deadbeef` | + +**Example:** +```bash +curl --unix-socket /var/run/dstack.sock -X POST \ + http://dstack/Sign \ + -H 'Content-Type: application/json' \ + -d '{ + "algorithm": "ed25519", + "data": "deadbeef" + }' +``` + +**Response:** +```json +{ + "signature": "", + "signature_chain": [ + "", + "", + "" + ] + "public_key": "" +} +``` + +### 7. Verify + +Verifies a signature. + +**Endpoint:** `/Verify` + +**Request Parameters:** + +| Field | Type | Description | Example | +|-------|------|-------------|----------| +| `algorithm` | string | `ed25519`, `secp256k1_prehashed` or `secp256k1`| `ed25519` | +| `data` | string | Hex-encoded payload data | `deadbeef` | +| `signature` | string | Hex-encoded signature | `deadbeef` | +| `public_key` | string | Hex-encoded public key | `deadbeef` | + +**Example:** +```bash +curl --unix-socket /var/run/dstack.sock -X POST \ + http://dstack/Verify \ + -H 'Content-Type: application/json' \ + -d '{ + "algorithm": "ed25519", + "data": "deadbeef", + "signature": "deadbeef", + "public_key": "deadbeef" + }' +``` + +**Response:** +```json +{ + "valid": "" +} +``` + +``` + ## Error Responses All endpoints may return the following HTTP status codes: diff --git a/sdk/go/README.md b/sdk/go/README.md index a8b6a97d..4379d2e6 100644 --- a/sdk/go/README.md +++ b/sdk/go/README.md @@ -91,9 +91,11 @@ NOTE: Leave endpoint empty in production. You only need to add `volumes` in your #### Methods - `Info(ctx context.Context) (*InfoResponse, error)`: Retrieves information about the CVM instance. -- `GetKey(ctx context.Context, path string, purpose string) (*GetKeyResponse, error)`: Derives a key for the given path and purpose. +- `GetKey(ctx context.Context, path string, purpose string, algorithm string) (*GetKeyResponse, error)`: Derives a key for the given path, purpose and algorithm. - `GetQuote(ctx context.Context, reportData []byte) (*GetQuoteResponse, error)`: Generates a TDX quote using SHA512 as the hash algorithm. - `GetTlsKey(ctx context.Context, path string, subject string, altNames []string, usageRaTls bool, usageServerAuth bool, usageClientAuth bool, randomSeed bool) (*GetTlsKeyResponse, error)`: Derives a key for the given path and purpose. +- `Sign(ctx context.Context, algorithm string, data []byte) (*SignResponse, error)`: Signs a payload +- `Verify(ctx context.Context, algorithm string, data []byte, signature []byte, public_key []byte) (*VerifyResponse, error)`: Verifies a payload ## Development diff --git a/sdk/go/dstack/client.go b/sdk/go/dstack/client.go index 6d571429..ebff5724 100644 --- a/sdk/go/dstack/client.go +++ b/sdk/go/dstack/client.go @@ -351,10 +351,11 @@ func (c *DstackClient) GetTlsKey( } // Gets a key from the dstack service. -func (c *DstackClient) GetKey(ctx context.Context, path string, purpose string) (*GetKeyResponse, error) { +func (c *DstackClient) GetKey(ctx context.Context, path string, purpose string, algorithm string) (*GetKeyResponse, error) { payload := map[string]interface{}{ - "path": path, - "purpose": purpose, + "path": path, + "purpose": purpose, + "algorithm": algorithm, } data, err := c.sendRPCRequest(ctx, "/GetKey", payload) @@ -425,6 +426,83 @@ func (c *DstackClient) Info(ctx context.Context) (*InfoResponse, error) { return &response, nil } +type SignResponse struct { + Signature []byte + SignatureChain [][]byte + PublicKey []byte +} + +// Signs a payload. +func (c *DstackClient) Sign(ctx context.Context, algorithm string, data []byte) (*SignResponse, error) { + payload := map[string]interface{}{ + "algorithm": algorithm, + "data": hex.EncodeToString(data), + } + + respData, err := c.sendRPCRequest(ctx, "/Sign", payload) + if err != nil { + return nil, err + } + + var response struct { + Signature string `json:"signature"` + SignatureChain []string `json:"signature_chain"` + PublicKey string `json:"public_key"` + } + if err := json.Unmarshal(respData, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal sign response: %w", err) + } + + sig, err := hex.DecodeString(response.Signature) + if err != nil { + return nil, fmt.Errorf("failed to decode signature: %w", err) + } + pubKey, err := hex.DecodeString(response.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to decode public key: %w", err) + } + + sigChain := make([][]byte, len(response.SignatureChain)) + for i, s := range response.SignatureChain { + sigChain[i], err = hex.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("failed to decode signature chain element %d: %w", i, err) + } + } + + return &SignResponse{ + Signature: sig, + SignatureChain: sigChain, + PublicKey: pubKey, + }, nil +} + +type VerifyResponse struct { + Valid bool `json:"valid"` +} + +// Verifies a payload. +func (c *DstackClient) Verify(ctx context.Context, algorithm string, data []byte, signature []byte, publicKey []byte) (*VerifyResponse, error) { + payload := map[string]interface{}{ + "algorithm": algorithm, + "data": hex.EncodeToString(data), + "signature": hex.EncodeToString(signature), + "public_key": hex.EncodeToString(publicKey), + } + + respData, err := c.sendRPCRequest(ctx, "/Verify", payload) + if err != nil { + return nil, err + } + + var response VerifyResponse + if err := json.Unmarshal(respData, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal verify response: %w", err) + } + + return &response, nil +} + // EmitEvent sends an event to be extended to RTMR3 on TDX platform. // The event will be extended to RTMR3 with the provided name and payload. // diff --git a/sdk/go/dstack/client_test.go b/sdk/go/dstack/client_test.go index 73c5c360..ee8df0ff 100644 --- a/sdk/go/dstack/client_test.go +++ b/sdk/go/dstack/client_test.go @@ -7,6 +7,7 @@ package dstack_test import ( "bytes" "context" + "crypto/sha256" "encoding/hex" "encoding/json" "fmt" @@ -25,7 +26,7 @@ import ( func TestGetKey(t *testing.T) { client := dstack.NewDstackClient() - resp, err := client.GetKey(context.Background(), "/", "test") + resp, err := client.GetKey(context.Background(), "/", "test", "ed25519") if err != nil { t.Fatal(err) } @@ -434,7 +435,7 @@ func TestGetKeySignatureVerification(t *testing.T) { client := dstack.NewDstackClient() path := "/test/path" purpose := "test-purpose" - resp, err := client.GetKey(context.Background(), path, purpose) + resp, err := client.GetKey(context.Background(), path, purpose, "secp256k1") if err != nil { t.Fatal(err) } @@ -608,3 +609,111 @@ func compressPublicKey(uncompressedKey []byte) ([]byte, error) { } return crypto.CompressPubkey(pubKey), nil } + +func TestSignAndVerifyEd25519(t *testing.T) { + client := dstack.NewDstackClient() + dataToSign := []byte("test message for ed25519") + algorithm := "ed25519" + + signResp, err := client.Sign(context.Background(), algorithm, dataToSign) + if err != nil { + t.Fatalf("Sign() error = %v", err) + } + + if len(signResp.Signature) == 0 { + t.Error("expected signature to not be empty") + } + if len(signResp.PublicKey) == 0 { + t.Error("expected public key to not be empty") + } + if len(signResp.SignatureChain) != 3 { + t.Errorf("expected signature chain to have 3 elements, got %d", len(signResp.SignatureChain)) + } + if !bytes.Equal(signResp.Signature, signResp.SignatureChain[0]) { + t.Error("expected Signature to be the same as SignatureChain[0]") + } + + verifyResp, err := client.Verify(context.Background(), algorithm, dataToSign, signResp.Signature, signResp.PublicKey) + if err != nil { + t.Fatalf("Verify() error = %v", err) + } + + if !verifyResp.Valid { + t.Error("expected verification to be valid") + } + + badData := []byte("wrong message") + verifyResp, err = client.Verify(context.Background(), algorithm, badData, signResp.Signature, signResp.PublicKey) + if err != nil { + t.Fatalf("Verify() with bad data error = %v", err) + } + + if verifyResp.Valid { + t.Error("expected verification with bad data to be invalid") + } +} + +func TestSignAndVerifySecp256k1(t *testing.T) { + client := dstack.NewDstackClient() + dataToSign := []byte("test message for secp256k1") + algorithm := "secp256k1" + + signResp, err := client.Sign(context.Background(), algorithm, dataToSign) + if err != nil { + t.Fatalf("Sign() error = %v", err) + } + + if len(signResp.Signature) == 0 { + t.Error("expected signature to not be empty") + } + if len(signResp.PublicKey) == 0 { + t.Error("expected public key to not be empty") + } + if len(signResp.SignatureChain) != 3 { + t.Errorf("expected signature chain to have 3 elements, got %d", len(signResp.SignatureChain)) + } + + verifyResp, err := client.Verify(context.Background(), algorithm, dataToSign, signResp.Signature, signResp.PublicKey) + if err != nil { + t.Fatalf("Verify() error = %v", err) + } + + if !verifyResp.Valid { + t.Error("expected verification to be valid") + } +} + +func TestSignAndVerifySecp256k1Prehashed(t *testing.T) { + client := dstack.NewDstackClient() + dataToSign := []byte("test message for secp256k1 prehashed") + digest := sha256.Sum256(dataToSign) + algorithm := "secp256k1_prehashed" + + signResp, err := client.Sign(context.Background(), algorithm, digest[:]) + if err != nil { + t.Fatalf("Sign() error = %v", err) + } + + if len(signResp.Signature) == 0 { + t.Error("expected signature to not be empty") + } + + verifyResp, err := client.Verify(context.Background(), algorithm, digest[:], signResp.Signature, signResp.PublicKey) + if err != nil { + t.Fatalf("Verify() error = %v", err) + } + + if !verifyResp.Valid { + t.Error("expected verification to be valid") + } + + // Test invalid digest length for signing + invalidDigest := []byte{1, 2, 3} + _, err = client.Sign(context.Background(), algorithm, invalidDigest) + if err == nil { + t.Fatal("expected error for invalid digest length, got nil") + } + if !strings.Contains(err.Error(), "32-byte digest") { + t.Errorf("expected error to mention '32-byte digest', got: %v", err) + } +} diff --git a/sdk/js/src/__tests__/index.test.ts b/sdk/js/src/__tests__/index.test.ts index 2236c138..dea1a1eb 100644 --- a/sdk/js/src/__tests__/index.test.ts +++ b/sdk/js/src/__tests__/index.test.ts @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import { expect, describe, it, vi } from 'vitest' +import crypto from 'crypto' // Added for prehashed test import { DstackClient, TappdClient } from '../index' describe('DstackClient', () => { @@ -25,6 +26,18 @@ describe('DstackClient', () => { expect(result).toHaveProperty('signature_chain') }) + it('should able to get key with different algorithms', async () => { + const client = new DstackClient() + const resultSecp = await client.getKey('/secp', 'test', 'secp256k1') + expect(resultSecp.key).toBeInstanceOf(Uint8Array) + expect(resultSecp.key.length).toBe(32) // secp256k1 private key size + + const resultEd = await client.getKey('/ed', 'test', 'ed25519') + expect(resultEd.key).toBeInstanceOf(Uint8Array) + expect(resultEd.key.length).toBe(32) // ed25519 private key size (seed) + }) + + it('should able to request tdx quote', async () => { const client = new DstackClient() // You can put computation result as report data to tdxQuote. NOTE: it should serializable by JSON.stringify @@ -155,6 +168,84 @@ describe('DstackClient', () => { expect(typeof isReachable).toBe('boolean') }) + describe('Sign and Verify Methods', () => { + const client = new DstackClient() + const testData = 'Test message for signing' + const badData = 'This is not the original message' + + it('should sign and verify with ed25519', async () => { + const algorithm = 'ed25519' + const signResp = await client.sign(algorithm, testData) + + expect(signResp).toHaveProperty('signature') + expect(signResp).toHaveProperty('signature_chain') + expect(signResp).toHaveProperty('public_key') + expect(signResp.signature).toBeInstanceOf(Uint8Array) + expect(signResp.public_key).toBeInstanceOf(Uint8Array) + expect(signResp.signature_chain.length).toBeGreaterThan(0) // Should have at least the signature itself + expect(signResp.signature_chain[0]).toBeInstanceOf(Uint8Array) + + // Verify success + const verifyResp = await client.verify(algorithm, testData, signResp.signature, signResp.public_key) + expect(verifyResp).toHaveProperty('valid', true) + + // Verify failure (bad data) + const verifyRespBadData = await client.verify(algorithm, badData, signResp.signature, signResp.public_key) + expect(verifyRespBadData).toHaveProperty('valid', false) + }) + + it('should sign and verify with secp256k1', async () => { + const algorithm = 'secp256k1' + const signResp = await client.sign(algorithm, testData) + + expect(signResp.signature).toBeInstanceOf(Uint8Array) + expect(signResp.public_key).toBeInstanceOf(Uint8Array) + expect(signResp.signature_chain.length).toBeGreaterThan(0) + + // Verify success + const verifyResp = await client.verify(algorithm, testData, signResp.signature, signResp.public_key) + expect(verifyResp).toHaveProperty('valid', true) + + // Verify failure (bad data) + const verifyRespBadData = await client.verify(algorithm, badData, signResp.signature, signResp.public_key) + expect(verifyRespBadData).toHaveProperty('valid', false) + }) + + it('should sign and verify with secp256k1_prehashed', async () => { + const algorithm = 'secp256k1_prehashed' + const digest = crypto.createHash('sha256').update(testData).digest() + expect(digest.length).toBe(32) // Ensure it's 32 bytes + + const signResp = await client.sign(algorithm, digest) + + expect(signResp.signature).toBeInstanceOf(Uint8Array) + expect(signResp.public_key).toBeInstanceOf(Uint8Array) + + // Verify success + const verifyResp = await client.verify(algorithm, digest, signResp.signature, signResp.public_key) + expect(verifyResp).toHaveProperty('valid', true) + + // Verify failure (bad digest) + const badDigest = crypto.createHash('sha256').update(badData).digest() + const verifyRespBadData = await client.verify(algorithm, badDigest, signResp.signature, signResp.public_key) + expect(verifyRespBadData).toHaveProperty('valid', false) + }) + + it('should throw error when signing secp256k1_prehashed with incorrect data length', async () => { + const algorithm = 'secp256k1_prehashed' + const invalidData = 'This is not 32 bytes' + await expect(() => client.sign(algorithm, invalidData)).rejects.toThrow('Pre-hashed signing requires a 32-byte digest') + + const invalidBuffer = Buffer.alloc(31) // Not 32 bytes + await expect(() => client.sign(algorithm, invalidBuffer)).rejects.toThrow('Pre-hashed signing requires a 32-byte digest') + }) + + it('should throw error for unsupported sign algorithm', async () => { + const algorithm = 'rsa' + await expect(() => client.sign(algorithm, testData)).rejects.toThrow() // Specific error depends on server impl. + }) + }) + describe('deprecated methods with TappdClient', () => { it('should support deprecated deriveKey method with warning', async () => { const client = new TappdClient() diff --git a/sdk/js/src/index.ts b/sdk/js/src/index.ts index 218305a7..6e8e3206 100644 --- a/sdk/js/src/index.ts +++ b/sdk/js/src/index.ts @@ -24,6 +24,21 @@ export interface GetKeyResponse { signature_chain: Uint8Array[] } +export interface SignResponse { + __name__: Readonly<'SignResponse'> + + signature: Uint8Array + signature_chain: Uint8Array[] + public_key: Uint8Array +} + +export interface VerifyResponse { + __name__: Readonly<'VerifyResponse'> + + valid: boolean +} + + export type Hex = `${string}` export type TdxQuoteHashAlgorithms = @@ -166,10 +181,11 @@ export class DstackClient { this.endpoint = endpoint } - async getKey(path: string, purpose: string = ''): Promise { + async getKey(path: string, purpose: string = '', algorithm: string = 'secp256k1'): Promise { const payload = JSON.stringify({ path: path, - purpose: purpose + purpose: purpose, + algorithm: algorithm }) const result = await send_rpc_request<{ key: string, signature_chain: string[] }>(this.endpoint, '/GetKey', payload) return Object.freeze({ @@ -268,6 +284,62 @@ export class DstackClient { ) } + /** + * Signs a payload using a derived key. + * @param algorithm The algorithm to use (e.g., "ed25519", "secp256k1", "secp256k1_prehashed") + * @param data The data to sign. If algorithm is "secp256k1_prehashed", this must be a 32-byte hash. + * @returns A SignResponse containing the signature, signature chain, and public key. + */ + async sign(algorithm: string, data: string | Buffer | Uint8Array): Promise { + const hexData = to_hex(data); + if (algorithm === 'secp256k1_prehashed' && hexData.length !== 64) { + throw new Error(`Pre-hashed signing requires a 32-byte digest, but received ${hexData.length / 2} bytes`); + } + + const payload = JSON.stringify({ + algorithm: algorithm, + data: hexData + }); + + const result = await send_rpc_request<{ signature: string, signature_chain: string[], public_key: string }>(this.endpoint, '/Sign', payload); + + return Object.freeze({ + signature: new Uint8Array(Buffer.from(result.signature, 'hex')), + signature_chain: result.signature_chain.map(sig => new Uint8Array(Buffer.from(sig, 'hex'))), + public_key: new Uint8Array(Buffer.from(result.public_key, 'hex')), + __name__: 'SignResponse', + }); + } + + /** + * Verifies a payload signature. + * @param algorithm The algorithm to use (e.g., "ed25519", "secp256k1", "secp256k1_prehashed") + * @param data The data that was signed. + * @param signature The signature to verify. + * @param publicKey The public key to use for verification. + * @returns A VerifyResponse indicating if the signature is valid. + */ + async verify( + algorithm: string, + data: string | Buffer | Uint8Array, + signature: string | Buffer | Uint8Array, + publicKey: string | Buffer | Uint8Array + ): Promise { + const payload = JSON.stringify({ + algorithm: algorithm, + data: to_hex(data), + signature: to_hex(signature), + public_key: to_hex(publicKey) + }); + + const result = await send_rpc_request<{ valid: boolean }>(this.endpoint, '/Verify', payload); + + return Object.freeze({ + ...result, + __name__: 'VerifyResponse', + }); + } + // // Legacy methods for backward compatibility with a warning to notify users about migrating to new methods. // These methods don't mean fully compatible as past, but we keep them here until next major version. @@ -366,4 +438,4 @@ export class TappdClient extends DstackClient { }) return Object.freeze(result) } -} \ No newline at end of file +} diff --git a/sdk/python/README.md b/sdk/python/README.md index 8b81ef76..8b6d521a 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -556,13 +556,14 @@ Retrieves comprehensive information about the TEE instance. - `app_cert`: Application certificate in PEM format - `key_provider_info`: Key management configuration -##### `get_key(path: str | None = None, purpose: str | None = None) -> GetKeyResponse` +##### `get_key(path: str | None = None, purpose: str | None = None, algorithm: str = "secp256k1") -> GetKeyResponse` Derives a deterministic secp256k1/K256 private key for blockchain and Web3 applications. This is the primary method for obtaining cryptographic keys for wallets, signing, and other deterministic key scenarios. **Parameters:** - `path`: Unique identifier for key derivation (e.g., `"wallet/ethereum"`, `"signing/solana"`) - `purpose` (optional): Additional context for key usage (default: `""`) +- `algorithm` (optional): Key algorithm (e.g., "secp256k1", "ed25519"). Defaults to "secp256k1". **Returns:** `GetKeyResponse` - `key`: 32-byte secp256k1 private key as hex string (suitable for Ethereum, Bitcoin, Solana, etc.) @@ -636,6 +637,32 @@ Generates a fresh, random TLS key pair with X.509 certificate for TLS/SSL connec - **RA-TLS Support**: Optional remote attestation extension in certificates - **TEE-Signed**: Certificates signed by TEE-resident Certificate Authority +##### `sign(algorithm: str, data: str | bytes) -> SignResponse` + +Signs data using a derived key. + +**Parameters**: +- `algorithm`: The algorithm to use (e.g., "ed25519", "secp256k1", "secp256k1_prehashed"). +- `data`: The data to sign. If using "secp256k1_prehashed", this must be a 32-byte hash (bytes). + +**Returns**: `SignResponse` +- `signature`: The resulting signature as hex string. +- `signature_chain`: List of hex strings proving key authenticity. +- `public_key`: The public key corresponding to the derived signing key as hex string. + +##### `verify(algorithm: str, data: str | bytes, signature: str | bytes, public_key: str | bytes) -> VerifyResponse` + +Verifies a payload signature. + +**Parameters**: +- `algorithm`: The algorithm used for signing. +- `data`: The original data that was signed. +- `signature`: The signature to verify (hex string or bytes). +- `public_key`: The public key to use for verification (hex string or bytes). + +**Returns**: `VerifyResponse` +- `valid`: A bool indicating if the signature is valid. + ##### `emit_event(event: str, payload: str | bytes) -> None` Extends RTMR3 with a custom event for audit logging. @@ -811,4 +838,4 @@ DSTACK_SIMULATOR_ENDPOINT=/path/to/dstack/sdk/simulator/dstack.sock pdm run pyte ## License -Apache License 2.0 \ No newline at end of file +Apache License 2.0 diff --git a/sdk/python/src/dstack_sdk/__init__.py b/sdk/python/src/dstack_sdk/__init__.py index 8be9664c..2831a17d 100644 --- a/sdk/python/src/dstack_sdk/__init__.py +++ b/sdk/python/src/dstack_sdk/__init__.py @@ -10,8 +10,10 @@ from .dstack_client import GetQuoteResponse from .dstack_client import GetTlsKeyResponse from .dstack_client import InfoResponse +from .dstack_client import SignResponse from .dstack_client import TappdClient from .dstack_client import TcbInfo +from .dstack_client import VerifyResponse from .encrypt_env_vars import EnvVar from .encrypt_env_vars import encrypt_env_vars from .encrypt_env_vars import encrypt_env_vars_sync diff --git a/sdk/python/src/dstack_sdk/dstack_client.py b/sdk/python/src/dstack_sdk/dstack_client.py index aba1d6ca..372c3f55 100644 --- a/sdk/python/src/dstack_sdk/dstack_client.py +++ b/sdk/python/src/dstack_sdk/dstack_client.py @@ -152,6 +152,25 @@ def replay_rtmrs(self) -> Dict[int, str]: return rtmrs +class SignResponse(BaseModel): + signature: str + signature_chain: List[str] + public_key: str + + def decode_signature(self) -> bytes: + return bytes.fromhex(self.signature) + + def decode_signature_chain(self) -> List[bytes]: + return [bytes.fromhex(chain) for chain in self.signature_chain] + + def decode_public_key(self) -> bytes: + return bytes.fromhex(self.public_key) + + +class VerifyResponse(BaseModel): + valid: bool + + class EventLog(BaseModel): imr: int event_type: int @@ -332,9 +351,14 @@ async def get_key( self, path: str | None = None, purpose: str | None = None, + algorithm: str = "secp256k1", ) -> GetKeyResponse: - """Derive a key from the given path and purpose.""" - data: Dict[str, Any] = {"path": path or "", "purpose": purpose or ""} + """Derive a key from the given path, purpose, and algorithm.""" + data: Dict[str, Any] = { + "path": path or "", + "purpose": purpose or "", + "algorithm": algorithm, + } result = await self._send_rpc_request("GetKey", data) return GetKeyResponse(**result) @@ -396,6 +420,40 @@ async def get_tls_key( result = await self._send_rpc_request("GetTlsKey", data) return GetTlsKeyResponse(**result) + async def sign(self, algorithm: str, data: str | bytes) -> SignResponse: + """Signs data using a derived key.""" + data_bytes = data.encode() if isinstance(data, str) else data + if algorithm == "secp256k1_prehashed" and len(data_bytes) != 32: + raise ValueError( + f"Pre-hashed signing requires a 32-byte digest, but received {len(data_bytes)} bytes" + ) + + hex_data = binascii.hexlify(data_bytes).decode() + payload = {"algorithm": algorithm, "data": hex_data} + result = await self._send_rpc_request("Sign", payload) + return SignResponse(**result) + + async def verify( + self, + algorithm: str, + data: str | bytes, + signature: str | bytes, + public_key: str | bytes, + ) -> VerifyResponse: + """Verify a signature.""" + data_bytes = data.encode() if isinstance(data, str) else data + sig_bytes = signature.encode() if isinstance(signature, str) else signature + pk_bytes = public_key.encode() if isinstance(public_key, str) else public_key + + payload = { + "algorithm": algorithm, + "data": binascii.hexlify(data_bytes).decode(), + "signature": binascii.hexlify(sig_bytes).decode(), + "public_key": binascii.hexlify(pk_bytes).decode(), + } + result = await self._send_rpc_request("Verify", payload) + return VerifyResponse(**result) + async def is_reachable(self) -> bool: """Return True if the service responds to a quick health call.""" try: @@ -423,8 +481,9 @@ def get_key( self, path: str | None = None, purpose: str | None = None, + algorithm: str = "secp256k1", ) -> GetKeyResponse: - """Derive a key from the given path and purpose.""" + """Derive a key from the given path, purpose, and algorithm.""" raise NotImplementedError @call_async @@ -461,6 +520,22 @@ def get_tls_key( """Request a TLS key from the service with optional parameters.""" raise NotImplementedError + @call_async + def sign(self, algorithm: str, data: str | bytes) -> SignResponse: + """Signs data using a derived key.""" + raise NotImplementedError + + @call_async + def verify( + self, + algorithm: str, + data: str | bytes, + signature: str | bytes, + public_key: str | bytes, + ) -> VerifyResponse: + """Verify a signature.""" + raise NotImplementedError + @call_async def is_reachable(self) -> bool: """Return True if the service responds to a quick health call.""" diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py index 9f948df7..437acf95 100644 --- a/sdk/python/tests/test_client.py +++ b/sdk/python/tests/test_client.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import hashlib import warnings from evidence_api.tdx.quote import TdxQuote @@ -13,18 +14,28 @@ from dstack_sdk import GetKeyResponse from dstack_sdk import GetQuoteResponse from dstack_sdk import GetTlsKeyResponse +from dstack_sdk import SignResponse from dstack_sdk import TappdClient +from dstack_sdk import VerifyResponse from dstack_sdk.dstack_client import InfoResponse from dstack_sdk.dstack_client import TcbInfo def test_sync_client_get_key(): client = DstackClient() - result = client.get_key() + result = client.get_key() # Test default algorithm (secp256k1) assert isinstance(result, GetKeyResponse) assert isinstance(result.decode_key(), bytes) assert len(result.decode_key()) == 32 + # Test specifying algorithm + result_ed = client.get_key(algorithm="ed25519") + assert isinstance(result_ed, GetKeyResponse) + assert len(result_ed.decode_key()) == 32 + + with pytest.raises(Exception): # Assuming unsupported algo raises error + client.get_key(algorithm="rsa") + def test_sync_client_get_quote(): client = DstackClient() @@ -66,8 +77,18 @@ def check_info_response(result: InfoResponse): @pytest.mark.asyncio async def test_async_client_get_key(): client = AsyncDstackClient() - result = await client.get_key() + result = await client.get_key() # Test default algorithm (secp256k1) assert isinstance(result, GetKeyResponse) + assert isinstance(result.decode_key(), bytes) + assert len(result.decode_key()) == 32 + + # Test specifying algorithm + result_ed = await client.get_key(algorithm="ed25519") + assert isinstance(result_ed, GetKeyResponse) + assert len(result_ed.decode_key()) == 32 + + with pytest.raises(Exception): # Assuming unsupported algo raises error + await client.get_key(algorithm="rsa") @pytest.mark.asyncio @@ -257,6 +278,157 @@ def test_emit_event_validation(): assert "event name cannot be empty" in str(exc_info.value) +SIGN_TEST_DATA = b"Test message for signing" +SIGN_BAD_DATA = b"This is not the original message" + + +def test_sync_sign_verify_ed25519(): + client = DstackClient() + algo = "ed25519" + sign_resp = client.sign(algo, SIGN_TEST_DATA) + assert isinstance(sign_resp, SignResponse) + assert len(sign_resp.decode_signature()) > 0 + assert len(sign_resp.decode_public_key()) > 0 + assert len(sign_resp.signature_chain) > 0 + + verify_resp = client.verify( + algo, + SIGN_TEST_DATA, + sign_resp.decode_signature(), + sign_resp.decode_public_key(), + ) + assert isinstance(verify_resp, VerifyResponse) + assert verify_resp.valid is True + + verify_bad = client.verify( + algo, SIGN_BAD_DATA, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_bad.valid is False + + +def test_sync_sign_verify_secp256k1(): + client = DstackClient() + algo = "secp256k1" + sign_resp = client.sign(algo, SIGN_TEST_DATA) + assert isinstance(sign_resp, SignResponse) + + verify_resp = client.verify( + algo, + SIGN_TEST_DATA, + sign_resp.decode_signature(), + sign_resp.decode_public_key(), + ) + assert verify_resp.valid is True + + verify_bad = client.verify( + algo, SIGN_BAD_DATA, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_bad.valid is False + + +def test_sync_sign_verify_secp256k1_prehashed(): + client = DstackClient() + algo = "secp256k1_prehashed" + digest = hashlib.sha256(SIGN_TEST_DATA).digest() + assert len(digest) == 32 + + sign_resp = client.sign(algo, digest) + assert isinstance(sign_resp, SignResponse) + + verify_resp = client.verify( + algo, digest, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_resp.valid is True + + bad_digest = hashlib.sha256(SIGN_BAD_DATA).digest() + verify_bad = client.verify( + algo, bad_digest, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_bad.valid is False + + +def test_sync_sign_prehashed_length_error(): + client = DstackClient() + algo = "secp256k1_prehashed" + with pytest.raises(ValueError) as excinfo: + client.sign(algo, b"too short") + assert "32-byte digest" in str(excinfo.value) + + +@pytest.mark.asyncio +async def test_async_sign_verify_ed25519(): + client = AsyncDstackClient() + algo = "ed25519" + sign_resp = await client.sign(algo, SIGN_TEST_DATA) + assert isinstance(sign_resp, SignResponse) + assert len(sign_resp.decode_signature()) > 0 + assert len(sign_resp.decode_public_key()) > 0 + + verify_resp = await client.verify( + algo, + SIGN_TEST_DATA, + sign_resp.decode_signature(), + sign_resp.decode_public_key(), + ) + assert verify_resp.valid is True + + verify_bad = await client.verify( + algo, SIGN_BAD_DATA, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_bad.valid is False + + +@pytest.mark.asyncio +async def test_async_sign_verify_secp256k1(): + client = AsyncDstackClient() + algo = "secp256k1" + sign_resp = await client.sign(algo, SIGN_TEST_DATA) + assert isinstance(sign_resp, SignResponse) + + verify_resp = await client.verify( + algo, + SIGN_TEST_DATA, + sign_resp.decode_signature(), + sign_resp.decode_public_key(), + ) + assert verify_resp.valid is True + + verify_bad = await client.verify( + algo, SIGN_BAD_DATA, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_bad.valid is False + + +@pytest.mark.asyncio +async def test_async_sign_verify_secp256k1_prehashed(): + client = AsyncDstackClient() + algo = "secp256k1_prehashed" + digest = hashlib.sha256(SIGN_TEST_DATA).digest() + + sign_resp = await client.sign(algo, digest) + assert isinstance(sign_resp, SignResponse) + + verify_resp = await client.verify( + algo, digest, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_resp.valid is True + + bad_digest = hashlib.sha256(SIGN_BAD_DATA).digest() + verify_bad = await client.verify( + algo, bad_digest, sign_resp.decode_signature(), sign_resp.decode_public_key() + ) + assert verify_bad.valid is False + + +@pytest.mark.asyncio +async def test_async_sign_prehashed_length_error(): + client = AsyncDstackClient() + algo = "secp256k1_prehashed" + with pytest.raises(ValueError) as excinfo: + await client.sign(algo, b"too short") + assert "32-byte digest" in str(excinfo.value) + + # Test deprecated TappdClient def test_tappd_client_deprecated(): """Test that TappdClient shows deprecation warning.""" diff --git a/sdk/rust/README.md b/sdk/rust/README.md index 339363d1..2ae47ad1 100644 --- a/sdk/rust/README.md +++ b/sdk/rust/README.md @@ -120,6 +120,12 @@ Sends an event log with associated binary payload to the runtime. #### `get_tls_key(...) -> GetTlsKeyResponse` Requests a key and X.509 certificate chain for RA-TLS or server/client authentication. +#### sign(algorithm: &str, data: Vec) -> SignResponse +Signs a payload using a derived key. + +#### verify(algorithm: &str, data: Vec, signature: Vec, public_key: Vec) -> VerifyResponse +Verifies a payload signature. + ### TappdClient Methods (Legacy API) #### `info(): TappdInfoResponse` @@ -147,6 +153,10 @@ Generates a TDX quote with exactly 64 bytes of raw report data. - `InfoResponse`: CVM instance metadata, including image and runtime measurements +- `SignResponse`: Holds a signature, signature chain, and public key + +- `VerifyResponse`: Holds a boolean valid result + ## API Reference ### Running the Simulator diff --git a/sdk/rust/examples/dstack_client_usage.rs b/sdk/rust/examples/dstack_client_usage.rs index ffe7e71a..722dbf2d 100644 --- a/sdk/rust/examples/dstack_client_usage.rs +++ b/sdk/rust/examples/dstack_client_usage.rs @@ -115,5 +115,24 @@ async fn main() -> anyhow::Result<()> { ); } + let data_to_sign = b"my secret message".to_vec(); + let algorithm = "secp256k1"; + println!("Signing data with algorithm '{}'...", algorithm); + let sign_resp = client.sign(algorithm, data_to_sign.clone()).await?; + println!(" Signature: {}", sign_resp.signature); + println!(" Public Key: {}", sign_resp.public_key); + + let sig_bytes = sign_resp.decode_signature()?; + let pub_key_bytes = sign_resp.decode_public_key()?; + + let verify_resp = client + .verify( + algorithm, + data_to_sign.clone(), + sig_bytes.clone(), + pub_key_bytes.clone(), + ) + .await?; + println!(" Verification successful: {}", verify_resp.valid); Ok(()) } diff --git a/sdk/rust/src/dstack_client.rs b/sdk/rust/src/dstack_client.rs index e2eaeba4..40a5242a 100644 --- a/sdk/rust/src/dstack_client.rs +++ b/sdk/rust/src/dstack_client.rs @@ -14,6 +14,21 @@ use std::env; pub use dstack_sdk_types::dstack::*; +// Internal request structs for hex encoding +#[derive(Debug, Serialize)] +struct SignRequest<'a> { + algorithm: &'a str, + data: String, +} + +#[derive(Debug, Serialize)] +struct VerifyRequest<'a> { + algorithm: &'a str, + data: String, + signature: String, + public_key: String, +} + fn get_endpoint(endpoint: Option<&str>) -> String { if let Some(e) = endpoint { return e.to_string(); @@ -106,6 +121,7 @@ impl DstackClient { let data = json!({ "path": path.unwrap_or_default(), "purpose": purpose.unwrap_or_default(), + "algorithm": "secp256k1", // Default or specify as needed }); let response = self.send_rpc_request("/GetKey", &data).await?; let response = serde_json::from_value::(response)?; @@ -146,4 +162,34 @@ impl DstackClient { Ok(response) } + + /// Signs a payload using a derived key. + pub async fn sign(&self, algorithm: &str, data: Vec) -> Result { + let payload = SignRequest { + algorithm, + data: hex_encode(data), + }; + let response = self.send_rpc_request("/Sign", &payload).await?; + let response = serde_json::from_value::(response)?; + Ok(response) + } + + /// Verifies a payload signature. + pub async fn verify( + &self, + algorithm: &str, + data: Vec, + signature: Vec, + public_key: Vec, + ) -> Result { + let payload = VerifyRequest { + algorithm, + data: hex_encode(data), + signature: hex_encode(signature), + public_key: hex_encode(public_key), + }; + let response = self.send_rpc_request("/Verify", &payload).await?; + let response = serde_json::from_value::(response)?; + Ok(response) + } } diff --git a/sdk/rust/tests/test_client.rs b/sdk/rust/tests/test_client.rs index 84b1407d..e7be67e0 100644 --- a/sdk/rust/tests/test_client.rs +++ b/sdk/rust/tests/test_client.rs @@ -7,6 +7,7 @@ use dcap_qvl::quote::Quote; use dstack_sdk::dstack_client::DstackClient as AsyncDstackClient; +use sha2::{Digest, Sha256}; #[tokio::test] async fn test_async_client_get_key() { @@ -95,3 +96,71 @@ async fn test_info() { assert!(!info.key_provider_info.is_empty()); assert!(!info.compose_hash.is_empty()); } + +#[tokio::test] +async fn test_async_client_sign_and_verify_ed25519() { + let client = AsyncDstackClient::new(None); + let data_to_sign = b"test message for ed25519".to_vec(); + let algorithm = "ed25519"; + + let sign_resp = client.sign(algorithm, data_to_sign.clone()).await.unwrap(); + assert!(!sign_resp.signature.is_empty()); + assert!(!sign_resp.public_key.is_empty()); + assert_eq!(sign_resp.signature_chain.len(), 3); + + let sig = sign_resp.decode_signature().unwrap(); + let pub_key = sign_resp.decode_public_key().unwrap(); + + let verify_resp = client + .verify( + algorithm, + data_to_sign.clone(), + sig.clone(), + pub_key.clone(), + ) + .await + .unwrap(); + assert!(verify_resp.valid); + + let bad_data = b"wrong message".to_vec(); + let verify_resp_bad = client + .verify(algorithm, bad_data, sig, pub_key) + .await + .unwrap(); + assert!(!verify_resp_bad.valid); +} + +#[tokio::test] +async fn test_async_client_sign_and_verify_secp256k1() { + let client = AsyncDstackClient::new(None); + let data_to_sign = b"test message for secp256k1".to_vec(); + let algorithm = "secp256k1"; + + let sign_resp = client.sign(algorithm, data_to_sign.clone()).await.unwrap(); + let sig = sign_resp.decode_signature().unwrap(); + let pub_key = sign_resp.decode_public_key().unwrap(); + + let verify_resp = client + .verify(algorithm, data_to_sign, sig, pub_key) + .await + .unwrap(); + assert!(verify_resp.valid); +} + +#[tokio::test] +async fn test_async_client_sign_and_verify_secp256k1_prehashed() { + let client = AsyncDstackClient::new(None); + let data_to_sign = b"test message for secp256k1 prehashed"; + let digest = Sha256::digest(data_to_sign).to_vec(); + let algorithm = "secp256k1_prehashed"; + + let sign_resp = client.sign(algorithm, digest.clone()).await.unwrap(); + let sig = sign_resp.decode_signature().unwrap(); + let pub_key = sign_resp.decode_public_key().unwrap(); + + let verify_resp = client + .verify(algorithm, digest.clone(), sig, pub_key) + .await + .unwrap(); + assert!(verify_resp.valid); +} diff --git a/sdk/rust/types/src/dstack.rs b/sdk/rust/types/src/dstack.rs index f27dc575..ba151fbc 100644 --- a/sdk/rust/types/src/dstack.rs +++ b/sdk/rust/types/src/dstack.rs @@ -224,3 +224,42 @@ pub struct GetTlsKeyResponse { /// The chain of certificates pub certificate_chain: Vec, } + +/// Response from a Sign request +#[derive(Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))] +#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))] +pub struct SignResponse { + /// The signature in hexadecimal format + pub signature: String, + /// The chain of signatures in hexadecimal format + pub signature_chain: Vec, + /// The public key in hexadecimal format + pub public_key: String, +} + +impl SignResponse { + /// Decodes the signature from hex to bytes + pub fn decode_signature(&self) -> Result, FromHexError> { + hex::decode(&self.signature) + } + + /// Decodes the public key from hex to bytes + pub fn decode_public_key(&self) -> Result, FromHexError> { + hex::decode(&self.public_key) + } + + /// Decodes the signature chain from hex to bytes + pub fn decode_signature_chain(&self) -> Result>, FromHexError> { + self.signature_chain.iter().map(hex::decode).collect() + } +} + +/// Response from a Verify request +#[derive(Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "borsh", derive(BorshSerialize, BorshDeserialize))] +#[cfg_attr(feature = "borsh_schema", derive(BorshSchema))] +pub struct VerifyResponse { + /// Whether the signature is valid + pub valid: bool, +}