diff --git a/backend/src/lib/server.rs b/backend/src/lib/server.rs index 8aa9935..c9f500c 100644 --- a/backend/src/lib/server.rs +++ b/backend/src/lib/server.rs @@ -14,7 +14,8 @@ use futures_util::{stream::SplitStream, SinkExt, StreamExt}; use monad_exec_events::ExecEvent; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::broadcast; -use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream}; +use tokio_tungstenite::{accept_hdr_async, tungstenite::Message, WebSocketStream}; +use tokio_tungstenite::tungstenite::handshake::server::{Request as WsRequest, Response as WsResponse, ErrorResponse}; use tracing::{error, info, warn}; use serde::{Deserialize, Serialize}; @@ -26,6 +27,9 @@ use super::event_filter::EventFilter; use super::event_listener::EventData; use super::serializable_event::SerializableEventData; +/// Allowed origin for WebSocket connections (production frontend) +const ALLOWED_ORIGIN: &str = "https://node.monad.xyz"; + /// Stores the Unix timestamp (in seconds) of the last event received from the ring type LastEventTime = Arc; @@ -320,7 +324,38 @@ async fn handle_connection( ) { info!("New WebSocket connection from: {}", addr); - let ws_stream = match accept_async(stream).await { + // Callback to validate the Origin header during WebSocket handshake + let origin_callback = |request: &WsRequest, response: WsResponse| -> Result { + let origin = request + .headers() + .get("Origin") + .and_then(|v| v.to_str().ok()); + + match origin { + Some(o) if o == ALLOWED_ORIGIN => { + info!("Accepted connection from allowed origin: {}", o); + Ok(response) + } + Some(o) => { + warn!("Rejected connection from disallowed origin: {}", o); + let error_response = WsResponse::builder() + .status(hyper::StatusCode::FORBIDDEN) + .body(None) + .unwrap(); + Err(error_response) + } + None => { + warn!("Rejected connection with no Origin header from {}", addr); + let error_response = WsResponse::builder() + .status(hyper::StatusCode::FORBIDDEN) + .body(None) + .unwrap(); + Err(error_response) + } + } + }; + + let ws_stream = match accept_hdr_async(stream, origin_callback).await { Ok(ws) => ws, Err(e) => { error!("Error during WebSocket handshake: {}", e);