Skip to content

Commit ae560f0

Browse files
authored
Merge pull request #70 from flashbots/peg/get-tls-cert-given-trust-root
Fix get-tls-cert to accept custom CA and gracefully close connection
2 parents 06aafe4 + 91ddb3a commit ae560f0

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

src/lib.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ impl ProxyServer {
189189
attestation_generator: AttestationGenerator,
190190
attestation_verifier: AttestationVerifier,
191191
) -> Result<(), ProxyError> {
192+
tracing::debug!("proxy-server accepted connection");
193+
192194
// Do TLS handshake
193195
let mut tls_stream = acceptor.accept(inbound).await?;
194196
let (_io, connection) = tls_stream.get_ref();
@@ -524,6 +526,8 @@ impl ProxyClient {
524526
inbound: TcpStream,
525527
requests_tx: mpsc::Sender<RequestWithResponseSender>,
526528
) -> Result<(), ProxyError> {
529+
tracing::debug!("proxy-client accepted connection");
530+
527531
// Setup http server and handler
528532
let http = hyper::server::conn::http1::Builder::new();
529533
let service = service_fn(move |req| {
@@ -688,11 +692,28 @@ impl ProxyClient {
688692
pub async fn get_tls_cert(
689693
server_name: String,
690694
attestation_verifier: AttestationVerifier,
695+
remote_certificate: Option<CertificateDer<'_>>,
691696
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
692-
let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
693-
let client_config = ClientConfig::builder()
697+
tracing::debug!("Getting remote TLS cert");
698+
// If a remote CA cert was given, use it as the root store, otherwise use webpki_roots
699+
let root_store = match remote_certificate {
700+
Some(remote_certificate) => {
701+
let mut root_store = RootCertStore::empty();
702+
root_store.add(remote_certificate)?;
703+
root_store
704+
}
705+
None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()),
706+
};
707+
708+
let mut client_config = ClientConfig::builder()
694709
.with_root_certificates(root_store)
695710
.with_no_client_auth();
711+
712+
client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS
713+
.into_iter()
714+
.map(|p| p.to_vec())
715+
.collect();
716+
696717
get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await
697718
}
698719

@@ -737,6 +758,8 @@ async fn get_tls_cert_with_config(
737758
.verify_attestation(remote_attestation_message, remote_input_data)
738759
.await?;
739760

761+
tls_stream.shutdown().await?;
762+
740763
Ok(remote_cert_chain)
741764
}
742765

src/main.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ enum CliCommand {
7878
/// The path to a PEM encoded private key
7979
#[arg(long, env = "TLS_PRIVATE_KEY_PATH")]
8080
tls_private_key_path: PathBuf,
81-
/// The path to a PEM encoded certificate chain
81+
/// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs.
8282
#[arg(long, env = "TLS_CERTIFICATE_PATH")]
8383
tls_certificate_path: PathBuf,
8484
/// Whether to use client authentication. If the client is running in a CVM this must be
@@ -99,6 +99,9 @@ enum CliCommand {
9999
GetTlsCert {
100100
/// The hostname:port or ip:port of the proxy server (port defaults to 443)
101101
server: String,
102+
/// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs.
103+
#[arg(long)]
104+
tls_ca_certificate: Option<PathBuf>,
102105
},
103106
/// Serve a filesystem path over an attested channel
104107
AttestedFileServer {
@@ -114,7 +117,7 @@ enum CliCommand {
114117
/// The path to a PEM encoded private key
115118
#[arg(long, env = "TLS_PRIVATE_KEY_PATH")]
116119
tls_private_key_path: PathBuf,
117-
/// The path to a PEM encoded certificate chain
120+
/// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs.
118121
#[arg(long, env = "TLS_CERTIFICATE_PATH")]
119122
tls_certificate_path: PathBuf,
120123
/// URL of the remote dummy attestation service. Only use with --server-attestation-type
@@ -145,7 +148,7 @@ async fn main() -> anyhow::Result<()> {
145148
"Exactly one of --measurements-file or --allowed-remote-attestation-type must be provided"
146149
);
147150

148-
let crate_name = env!("CARGO_PKG_NAME");
151+
let crate_name = env!("CARGO_CRATE_NAME");
149152

150153
let env_filter = tracing_subscriber::EnvFilter::builder()
151154
.with_default_directive(LevelFilter::WARN.into()) // global default
@@ -281,8 +284,20 @@ async fn main() -> anyhow::Result<()> {
281284
}
282285
}
283286
}
284-
CliCommand::GetTlsCert { server } => {
285-
let cert_chain = get_tls_cert(server, attestation_verifier).await?;
287+
CliCommand::GetTlsCert {
288+
server,
289+
tls_ca_certificate,
290+
} => {
291+
let remote_tls_cert = match tls_ca_certificate {
292+
Some(remote_cert_filename) => Some(
293+
load_certs_pem(remote_cert_filename)?
294+
.first()
295+
.ok_or(anyhow!("Filename given but no ceritificates found"))?
296+
.clone(),
297+
),
298+
None => None,
299+
};
300+
let cert_chain = get_tls_cert(server, attestation_verifier, remote_tls_cert).await?;
286301
println!("{}", certs_to_pem_string(&cert_chain)?);
287302
}
288303
CliCommand::AttestedFileServer {

0 commit comments

Comments
 (0)