From 91ddb3a29d2b30eb8a0adc7f00ba48f9c07d3472 Mon Sep 17 00:00:00 2001 From: peg Date: Thu, 18 Dec 2025 09:46:51 +0100 Subject: [PATCH] Fix get-tls-cert to accept custom CA and gracefully close connection --- src/lib.rs | 27 +++++++++++++++++++++++++-- src/main.rs | 25 ++++++++++++++++++++----- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d0a703d..9884376 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -189,6 +189,8 @@ impl ProxyServer { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result<(), ProxyError> { + tracing::debug!("proxy-server accepted connection"); + // Do TLS handshake let mut tls_stream = acceptor.accept(inbound).await?; let (_io, connection) = tls_stream.get_ref(); @@ -524,6 +526,8 @@ impl ProxyClient { inbound: TcpStream, requests_tx: mpsc::Sender, ) -> Result<(), ProxyError> { + tracing::debug!("proxy-client accepted connection"); + // Setup http server and handler let http = hyper::server::conn::http1::Builder::new(); let service = service_fn(move |req| { @@ -688,11 +692,28 @@ impl ProxyClient { pub async fn get_tls_cert( server_name: String, attestation_verifier: AttestationVerifier, + remote_certificate: Option>, ) -> Result>, ProxyError> { - let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let client_config = ClientConfig::builder() + tracing::debug!("Getting remote TLS cert"); + // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots + let root_store = match remote_certificate { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate)?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), + }; + + let mut client_config = ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(); + + client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await } @@ -737,6 +758,8 @@ async fn get_tls_cert_with_config( .verify_attestation(remote_attestation_message, remote_input_data) .await?; + tls_stream.shutdown().await?; + Ok(remote_cert_chain) } diff --git a/src/main.rs b/src/main.rs index d20facc..51792f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,7 +78,7 @@ enum CliCommand { /// The path to a PEM encoded private key #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: PathBuf, - /// The path to a PEM encoded certificate chain + /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: PathBuf, /// Whether to use client authentication. If the client is running in a CVM this must be @@ -99,6 +99,9 @@ enum CliCommand { GetTlsCert { /// The hostname:port or ip:port of the proxy server (port defaults to 443) server: String, + /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. + #[arg(long)] + tls_ca_certificate: Option, }, /// Serve a filesystem path over an attested channel AttestedFileServer { @@ -114,7 +117,7 @@ enum CliCommand { /// The path to a PEM encoded private key #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: PathBuf, - /// The path to a PEM encoded certificate chain + /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: PathBuf, /// URL of the remote dummy attestation service. Only use with --server-attestation-type @@ -145,7 +148,7 @@ async fn main() -> anyhow::Result<()> { "Exactly one of --measurements-file or --allowed-remote-attestation-type must be provided" ); - let crate_name = env!("CARGO_PKG_NAME"); + let crate_name = env!("CARGO_CRATE_NAME"); let env_filter = tracing_subscriber::EnvFilter::builder() .with_default_directive(LevelFilter::WARN.into()) // global default @@ -281,8 +284,20 @@ async fn main() -> anyhow::Result<()> { } } } - CliCommand::GetTlsCert { server } => { - let cert_chain = get_tls_cert(server, attestation_verifier).await?; + CliCommand::GetTlsCert { + server, + tls_ca_certificate, + } => { + let remote_tls_cert = match tls_ca_certificate { + Some(remote_cert_filename) => Some( + load_certs_pem(remote_cert_filename)? + .first() + .ok_or(anyhow!("Filename given but no ceritificates found"))? + .clone(), + ), + None => None, + }; + let cert_chain = get_tls_cert(server, attestation_verifier, remote_tls_cert).await?; println!("{}", certs_to_pem_string(&cert_chain)?); } CliCommand::AttestedFileServer {