diff options
Diffstat (limited to 'components/net/connector.rs')
-rw-r--r-- | components/net/connector.rs | 114 |
1 files changed, 110 insertions, 4 deletions
diff --git a/components/net/connector.rs b/components/net/connector.rs index 058a27c47d6..dc44002a85a 100644 --- a/components/net/connector.rs +++ b/components/net/connector.rs @@ -8,8 +8,13 @@ use hyper::client::HttpConnector as HyperHttpConnector; use hyper::rt::Future; use hyper::{Body, Client}; use hyper_openssl::HttpsConnector; -use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions}; -use openssl::x509; +use openssl::ex_data::Index; +use openssl::ssl::{ + Ssl, SslConnector, SslConnectorBuilder, SslContext, SslMethod, SslOptions, SslVerifyMode, +}; +use openssl::x509::{self, X509StoreContext}; +use std::collections::hash_map::{Entry, HashMap}; +use std::sync::{Arc, Mutex}; use tokio::prelude::future::Executor; pub const BUF_SIZE: usize = 32768; @@ -30,6 +35,38 @@ const SIGNATURE_ALGORITHMS: &'static str = concat!( "RSA+SHA512:RSA+SHA384:RSA+SHA256" ); +#[derive(Clone)] +pub struct ConnectionCerts { + certs: Arc<Mutex<HashMap<String, (Vec<u8>, u32)>>>, +} + +impl ConnectionCerts { + pub fn new() -> Self { + Self { + certs: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn store(&self, host: String, cert_bytes: Vec<u8>) { + let mut certs = self.certs.lock().unwrap(); + let entry = certs.entry(host).or_insert((cert_bytes, 0)); + entry.1 += 1; + } + + pub(crate) fn remove(&self, host: String) -> Option<Vec<u8>> { + match self.certs.lock().unwrap().entry(host) { + Entry::Vacant(_) => return None, + Entry::Occupied(mut e) => { + e.get_mut().1 -= 1; + if e.get().1 == 0 { + return Some((e.remove_entry().1).0); + } + Some(e.get().0.clone()) + }, + } + } +} + pub struct HttpConnector { inner: HyperHttpConnector, } @@ -60,7 +97,34 @@ impl Connect for HttpConnector { pub type Connector = HttpsConnector<HttpConnector>; pub type TlsConfig = SslConnectorBuilder; -pub fn create_tls_config(certs: &str, alpn: &[u8]) -> TlsConfig { +#[derive(Clone)] +pub struct ExtraCerts(Arc<Mutex<Vec<Vec<u8>>>>); + +impl ExtraCerts { + pub fn new() -> Self { + Self(Arc::new(Mutex::new(vec![]))) + } + + pub fn add(&self, bytes: Vec<u8>) { + self.0.lock().unwrap().push(bytes); + } +} + +struct Host(String); + +lazy_static! { + static ref EXTRA_INDEX: Index<SslContext, ExtraCerts> = SslContext::new_ex_index().unwrap(); + static ref CONNECTION_INDEX: Index<SslContext, ConnectionCerts> = + SslContext::new_ex_index().unwrap(); + static ref HOST_INDEX: Index<Ssl, Host> = Ssl::new_ex_index().unwrap(); +} + +pub fn create_tls_config( + certs: &str, + alpn: &[u8], + extra_certs: ExtraCerts, + connection_certs: ConnectionCerts, +) -> TlsConfig { // certs include multiple certificates. We could add all of them at once, // but if any of them were already added, openssl would fail to insert all // of them. @@ -104,6 +168,44 @@ pub fn create_tls_config(certs: &str, alpn: &[u8]) -> TlsConfig { SslOptions::NO_COMPRESSION, ); + cfg.set_ex_data(*EXTRA_INDEX, extra_certs); + cfg.set_ex_data(*CONNECTION_INDEX, connection_certs); + cfg.set_verify_callback(SslVerifyMode::PEER, |verified, x509_store_context| { + if verified { + return true; + } + + let ssl_idx = X509StoreContext::ssl_idx().unwrap(); + let ssl = x509_store_context.ex_data(ssl_idx).unwrap(); + + // Obtain the cert bytes for this connection. + let cert = match x509_store_context.current_cert() { + Some(cert) => cert, + None => return false, + }; + let pem = match cert.to_pem() { + Ok(pem) => pem, + Err(_) => return false, + }; + + let ssl_context = ssl.ssl_context(); + + // Ensure there's an entry stored in the set of known connection certs for this connection. + if let Some(host) = ssl.ex_data(*HOST_INDEX) { + let connection_certs = ssl_context.ex_data(*CONNECTION_INDEX).unwrap(); + connection_certs.store((*host).0.clone(), pem.clone()); + } + + // Fall back to the dynamic set of allowed certs. + let extra_certs = ssl_context.ex_data(*EXTRA_INDEX).unwrap(); + for cert in &*extra_certs.0.lock().unwrap() { + if pem == *cert { + return true; + } + } + false + }); + cfg } @@ -111,7 +213,11 @@ pub fn create_http_client<E>(tls_config: TlsConfig, executor: E) -> Client<Conne where E: Executor<Box<dyn Future<Error = (), Item = ()> + Send + 'static>> + Sync + Send + 'static, { - let connector = HttpsConnector::with_connector(HttpConnector::new(), tls_config).unwrap(); + let mut connector = HttpsConnector::with_connector(HttpConnector::new(), tls_config).unwrap(); + connector.set_callback(|configuration, destination| { + configuration.set_ex_data(*HOST_INDEX, Host(destination.host().to_owned())); + Ok(()) + }); Client::builder() .http1_title_case_headers(true) |