aboutsummaryrefslogtreecommitdiffstats
path: root/components/net
diff options
context:
space:
mode:
authorFabrice Desré <fabrice@desre.org>2023-02-25 13:22:47 -0800
committerThe Capyloon Team <fabrice@capyloon.org>2023-05-20 21:55:00 +0000
commit0d0540fc9545ad5233f897a609a4ba48a228bb01 (patch)
treefc6096d70624533a9bc84aa942d97f58d21c1126 /components/net
parentbc8cea2495e928f071875c26b7fba262744a26ea (diff)
downloadservo-0d0540fc9545ad5233f897a609a4ba48a228bb01.tar.gz
servo-0d0540fc9545ad5233f897a609a4ba48a228bb01.zip
Update tungstenite
Diffstat (limited to 'components/net')
-rw-r--r--components/net/Cargo.toml7
-rw-r--r--components/net/websocket_loader.rs102
2 files changed, 46 insertions, 63 deletions
diff --git a/components/net/Cargo.toml b/components/net/Cargo.toml
index 5e9481fa599..773546efa7f 100644
--- a/components/net/Cargo.toml
+++ b/components/net/Cargo.toml
@@ -16,7 +16,7 @@ doctest = false
[dependencies]
async-recursion = "0.3.2"
-async-tungstenite = { version = "0.9", features = ["tokio-openssl"] }
+async-tungstenite = { version = "0.22", features = ["tokio-openssl"] }
base64 = { workspace = true }
brotli = "3"
bytes = "1"
@@ -60,10 +60,9 @@ servo_config = { path = "../config" }
servo_url = { path = "../url" }
sha2 = "0.10"
time = { workspace = true }
-tokio = { version = "1", package = "tokio", features = ["sync", "macros", "rt-multi-thread"] }
-tokio2 = { version = "0.2", package = "tokio", features = ["sync", "macros", "rt-threaded", "tcp"] }
+tokio = { workspace = true, features = ["sync", "macros", "rt-multi-thread"] }
tokio-stream = "0.1"
-tungstenite = "0.11"
+tungstenite = "0.19"
url = { workspace = true }
uuid = { workspace = true }
webrender_api = { git = "https://github.com/servo/webrender" }
diff --git a/components/net/websocket_loader.rs b/components/net/websocket_loader.rs
index a73b48e1f41..cc9016cfa7b 100644
--- a/components/net/websocket_loader.rs
+++ b/components/net/websocket_loader.rs
@@ -22,7 +22,7 @@ use embedder_traits::resources::{self, Resource};
use futures::future::TryFutureExt;
use futures::sink::SinkExt;
use futures::stream::StreamExt;
-use http::header::{HeaderMap, HeaderName, HeaderValue};
+use http::header::{self, HeaderName, HeaderValue};
use ipc_channel::ipc::{IpcReceiver, IpcSender};
use ipc_channel::router::ROUTER;
use net_traits::request::{RequestBuilder, RequestMode};
@@ -33,14 +33,13 @@ use servo_url::ServoUrl;
use std::fs;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
-use tokio2::net::TcpStream;
-use tokio2::runtime::Runtime;
-use tokio2::select;
-use tokio2::sync::mpsc::{unbounded_channel, UnboundedReceiver};
-use tungstenite::error::Error;
+use tokio::net::TcpStream;
+use tokio::runtime::Runtime;
+use tokio::select;
+use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
use tungstenite::error::Result as WebSocketResult;
+use tungstenite::error::{Error, ProtocolError, UrlError};
use tungstenite::handshake::client::{Request, Response};
-use tungstenite::http::header::{self as WSHeader, HeaderValue as WSHeaderValue};
use tungstenite::protocol::CloseFrame;
use tungstenite::Message;
use url::Url;
@@ -65,20 +64,32 @@ fn create_request(
) -> WebSocketResult<Request> {
let mut builder = Request::get(resource_url.as_str());
let headers = builder.headers_mut().unwrap();
- headers.insert("Origin", WSHeaderValue::from_str(origin)?);
+ headers.insert("Origin", HeaderValue::from_str(origin)?);
+
+ let origin = resource_url.origin();
+ let host = format!(
+ "{}",
+ origin
+ .host()
+ .ok_or_else(|| Error::Url(UrlError::NoHostName))?
+ );
+ headers.insert("Host", HeaderValue::from_str(&host)?);
+ headers.insert("Connection", HeaderValue::from_static("upgrade"));
+ headers.insert("Upgrade", HeaderValue::from_static("websocket"));
+ headers.insert("Sec-Websocket-Version", HeaderValue::from_static("13"));
+
+ let key = HeaderValue::from_str(&tungstenite::handshake::client::generate_key()).unwrap();
+ headers.insert("Sec-WebSocket-Key", key);
if !protocols.is_empty() {
let protocols = protocols.join(",");
- headers.insert(
- "Sec-WebSocket-Protocol",
- WSHeaderValue::from_str(&protocols)?,
- );
+ headers.insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocols)?);
}
let mut cookie_jar = http_state.cookie_jar.write().unwrap();
cookie_jar.remove_expired_cookies_for_url(resource_url);
if let Some(cookie_list) = cookie_jar.cookies_for_url(resource_url, CookieSource::HTTP) {
- headers.insert("Cookie", WSHeaderValue::from_str(&cookie_list)?);
+ headers.insert("Cookie", HeaderValue::from_str(&cookie_list)?);
}
if resource_url.password().is_some() || resource_url.username() != "" {
@@ -89,7 +100,7 @@ fn create_request(
));
headers.insert(
"Authorization",
- WSHeaderValue::from_str(&format!("Basic {}", basic))?,
+ HeaderValue::from_str(&format!("Basic {}", basic))?,
);
}
@@ -110,18 +121,18 @@ fn process_ws_response(
trace!("processing websocket http response for {}", resource_url);
let mut protocol_in_use = None;
if let Some(protocol_name) = response.headers().get("Sec-WebSocket-Protocol") {
- let protocol_name = protocol_name.to_str().unwrap();
+ let protocol_name = protocol_name.to_str().unwrap_or("");
if !protocols.is_empty() && !protocols.iter().any(|p| protocol_name == (*p)) {
- return Err(Error::Protocol(
- "Protocol in use not in client-supplied protocol list".into(),
- ));
+ return Err(Error::Protocol(ProtocolError::InvalidHeader(
+ HeaderName::from_static("sec-websocket-protocol"),
+ )));
}
protocol_in_use = Some(protocol_name.to_string());
}
let mut jar = http_state.cookie_jar.write().unwrap();
// TODO(eijebong): Replace thise once typed headers settled on a cookie impl
- for cookie in response.headers().get_all(WSHeader::SET_COOKIE) {
+ for cookie in response.headers().get_all(header::SET_COOKIE) {
if let Ok(s) = std::str::from_utf8(cookie.as_bytes()) {
if let Some(cookie) =
Cookie::from_cookie_string(s.into(), resource_url, CookieSource::HTTP)
@@ -131,23 +142,11 @@ fn process_ws_response(
}
}
- // We need to make a new header map here because tungstenite depends on
- // a more recent version of http than the rest of the network stack, so the
- // HeaderMap types are incompatible.
- let mut headers = HeaderMap::new();
- for (key, value) in response.headers().iter() {
- if let (Ok(key), Ok(value)) = (
- HeaderName::from_bytes(key.as_ref()),
- HeaderValue::from_bytes(value.as_ref()),
- ) {
- headers.insert(key, value);
- }
- }
http_state
.hsts_list
.write()
.unwrap()
- .update_hsts_list_from_response(resource_url, &headers);
+ .update_hsts_list_from_response(resource_url, &response.headers());
Ok(protocol_in_use)
}
@@ -283,6 +282,10 @@ async fn run_ws_loop(
));
break;
}
+
+ Message::Frame(_) => {
+ warn!("Unexpected websocket frame message");
+ }
}
}
}
@@ -309,20 +312,20 @@ async fn start_websocket(
let host_str = client
.uri()
.host()
- .ok_or_else(|| Error::Url("No host string".into()))?;
+ .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
let host = replace_host(host_str);
- let mut net_url =
- Url::parse(&client.uri().to_string()).map_err(|e| Error::Url(e.to_string().into()))?;
+ let mut net_url = Url::parse(&client.uri().to_string())
+ .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
net_url
.set_host(Some(&host))
- .map_err(|e| Error::Url(e.to_string().into()))?;
+ .map_err(|e| Error::Url(UrlError::UnableToConnect(e.to_string())))?;
let domain = net_url
.host()
- .ok_or_else(|| Error::Url("No host string".into()))?;
+ .ok_or_else(|| Error::Url(UrlError::NoHostName))?;
let port = net_url
.port_or_known_default()
- .ok_or_else(|| Error::Url("Unknown port".into()))?;
+ .ok_or_else(|| Error::Url(UrlError::UnableToConnect("Unknown port".into())))?;
let try_socket = TcpStream::connect((&*domain.to_string(), port)).await;
let socket = try_socket.map_err(Error::Io)?;
@@ -366,32 +369,13 @@ fn connect(
};
// https://fetch.spec.whatwg.org/#websocket-opening-handshake
- // By standard, we should work with an http(s):// URL (req_url),
- // but as ws-rs expects to be called with a ws(s):// URL (net_url)
- // we upgrade ws to wss, so we don't have to convert http(s) back to ws(s).
http_state
.hsts_list
.read()
.unwrap()
.apply_hsts_rules(&mut req_builder.url);
- let scheme = req_builder.url.scheme();
- let mut req_url = req_builder.url.clone();
- match scheme {
- "ws" => {
- req_url
- .as_mut_url()
- .set_scheme("http")
- .map_err(|()| "couldn't replace scheme".to_string())?;
- },
- "wss" => {
- req_url
- .as_mut_url()
- .set_scheme("https")
- .map_err(|()| "couldn't replace scheme".to_string())?;
- },
- _ => {},
- }
+ let req_url = req_builder.url.clone();
if should_be_blocked_due_to_bad_port(&req_url) {
return Err("Port blocked".to_string());
@@ -403,7 +387,7 @@ fn connect(
};
let client = match create_request(
- &req_builder.url,
+ &req_url,
&req_builder.origin.ascii_serialization(),
&protocols,
&*http_state,