aboutsummaryrefslogtreecommitdiffstats
path: root/components/net/connector.rs
diff options
context:
space:
mode:
Diffstat (limited to 'components/net/connector.rs')
-rw-r--r--components/net/connector.rs114
1 files changed, 110 insertions, 4 deletions
diff --git a/components/net/connector.rs b/components/net/connector.rs
index 058a27c47d6..dc44002a85a 100644
--- a/components/net/connector.rs
+++ b/components/net/connector.rs
@@ -8,8 +8,13 @@ use hyper::client::HttpConnector as HyperHttpConnector;
use hyper::rt::Future;
use hyper::{Body, Client};
use hyper_openssl::HttpsConnector;
-use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
-use openssl::x509;
+use openssl::ex_data::Index;
+use openssl::ssl::{
+ Ssl, SslConnector, SslConnectorBuilder, SslContext, SslMethod, SslOptions, SslVerifyMode,
+};
+use openssl::x509::{self, X509StoreContext};
+use std::collections::hash_map::{Entry, HashMap};
+use std::sync::{Arc, Mutex};
use tokio::prelude::future::Executor;
pub const BUF_SIZE: usize = 32768;
@@ -30,6 +35,38 @@ const SIGNATURE_ALGORITHMS: &'static str = concat!(
"RSA+SHA512:RSA+SHA384:RSA+SHA256"
);
+#[derive(Clone)]
+pub struct ConnectionCerts {
+ certs: Arc<Mutex<HashMap<String, (Vec<u8>, u32)>>>,
+}
+
+impl ConnectionCerts {
+ pub fn new() -> Self {
+ Self {
+ certs: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ fn store(&self, host: String, cert_bytes: Vec<u8>) {
+ let mut certs = self.certs.lock().unwrap();
+ let entry = certs.entry(host).or_insert((cert_bytes, 0));
+ entry.1 += 1;
+ }
+
+ pub(crate) fn remove(&self, host: String) -> Option<Vec<u8>> {
+ match self.certs.lock().unwrap().entry(host) {
+ Entry::Vacant(_) => return None,
+ Entry::Occupied(mut e) => {
+ e.get_mut().1 -= 1;
+ if e.get().1 == 0 {
+ return Some((e.remove_entry().1).0);
+ }
+ Some(e.get().0.clone())
+ },
+ }
+ }
+}
+
pub struct HttpConnector {
inner: HyperHttpConnector,
}
@@ -60,7 +97,34 @@ impl Connect for HttpConnector {
pub type Connector = HttpsConnector<HttpConnector>;
pub type TlsConfig = SslConnectorBuilder;
-pub fn create_tls_config(certs: &str, alpn: &[u8]) -> TlsConfig {
+#[derive(Clone)]
+pub struct ExtraCerts(Arc<Mutex<Vec<Vec<u8>>>>);
+
+impl ExtraCerts {
+ pub fn new() -> Self {
+ Self(Arc::new(Mutex::new(vec![])))
+ }
+
+ pub fn add(&self, bytes: Vec<u8>) {
+ self.0.lock().unwrap().push(bytes);
+ }
+}
+
+struct Host(String);
+
+lazy_static! {
+ static ref EXTRA_INDEX: Index<SslContext, ExtraCerts> = SslContext::new_ex_index().unwrap();
+ static ref CONNECTION_INDEX: Index<SslContext, ConnectionCerts> =
+ SslContext::new_ex_index().unwrap();
+ static ref HOST_INDEX: Index<Ssl, Host> = Ssl::new_ex_index().unwrap();
+}
+
+pub fn create_tls_config(
+ certs: &str,
+ alpn: &[u8],
+ extra_certs: ExtraCerts,
+ connection_certs: ConnectionCerts,
+) -> TlsConfig {
// certs include multiple certificates. We could add all of them at once,
// but if any of them were already added, openssl would fail to insert all
// of them.
@@ -104,6 +168,44 @@ pub fn create_tls_config(certs: &str, alpn: &[u8]) -> TlsConfig {
SslOptions::NO_COMPRESSION,
);
+ cfg.set_ex_data(*EXTRA_INDEX, extra_certs);
+ cfg.set_ex_data(*CONNECTION_INDEX, connection_certs);
+ cfg.set_verify_callback(SslVerifyMode::PEER, |verified, x509_store_context| {
+ if verified {
+ return true;
+ }
+
+ let ssl_idx = X509StoreContext::ssl_idx().unwrap();
+ let ssl = x509_store_context.ex_data(ssl_idx).unwrap();
+
+ // Obtain the cert bytes for this connection.
+ let cert = match x509_store_context.current_cert() {
+ Some(cert) => cert,
+ None => return false,
+ };
+ let pem = match cert.to_pem() {
+ Ok(pem) => pem,
+ Err(_) => return false,
+ };
+
+ let ssl_context = ssl.ssl_context();
+
+ // Ensure there's an entry stored in the set of known connection certs for this connection.
+ if let Some(host) = ssl.ex_data(*HOST_INDEX) {
+ let connection_certs = ssl_context.ex_data(*CONNECTION_INDEX).unwrap();
+ connection_certs.store((*host).0.clone(), pem.clone());
+ }
+
+ // Fall back to the dynamic set of allowed certs.
+ let extra_certs = ssl_context.ex_data(*EXTRA_INDEX).unwrap();
+ for cert in &*extra_certs.0.lock().unwrap() {
+ if pem == *cert {
+ return true;
+ }
+ }
+ false
+ });
+
cfg
}
@@ -111,7 +213,11 @@ pub fn create_http_client<E>(tls_config: TlsConfig, executor: E) -> Client<Conne
where
E: Executor<Box<dyn Future<Error = (), Item = ()> + Send + 'static>> + Sync + Send + 'static,
{
- let connector = HttpsConnector::with_connector(HttpConnector::new(), tls_config).unwrap();
+ let mut connector = HttpsConnector::with_connector(HttpConnector::new(), tls_config).unwrap();
+ connector.set_callback(|configuration, destination| {
+ configuration.set_ex_data(*HOST_INDEX, Host(destination.host().to_owned()));
+ Ok(())
+ });
Client::builder()
.http1_title_case_headers(true)