Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions backend/src/lib/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use futures_util::{stream::SplitStream, SinkExt, StreamExt};
use monad_exec_events::ExecEvent;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
use tokio_tungstenite::{accept_hdr_async, tungstenite::Message, WebSocketStream};
use tokio_tungstenite::tungstenite::handshake::server::{Request as WsRequest, Response as WsResponse, ErrorResponse};
use tracing::{error, info, warn};
use serde::{Deserialize, Serialize};

Expand All @@ -26,6 +27,9 @@ use super::event_filter::EventFilter;
use super::event_listener::EventData;
use super::serializable_event::SerializableEventData;

/// Allowed origin for WebSocket connections (production frontend)
const ALLOWED_ORIGIN: &str = "https://node.monad.xyz";

/// Stores the Unix timestamp (in seconds) of the last event received from the ring
type LastEventTime = Arc<AtomicU64>;

Expand Down Expand Up @@ -320,7 +324,38 @@ async fn handle_connection(
) {
info!("New WebSocket connection from: {}", addr);

let ws_stream = match accept_async(stream).await {
// Callback to validate the Origin header during WebSocket handshake
let origin_callback = |request: &WsRequest, response: WsResponse| -> Result<WsResponse, ErrorResponse> {
let origin = request
.headers()
.get("Origin")
.and_then(|v| v.to_str().ok());

match origin {
Some(o) if o == ALLOWED_ORIGIN => {
info!("Accepted connection from allowed origin: {}", o);
Ok(response)
}
Some(o) => {
warn!("Rejected connection from disallowed origin: {}", o);
let error_response = WsResponse::builder()
.status(hyper::StatusCode::FORBIDDEN)
.body(None)
.unwrap();
Err(error_response)
}
None => {
warn!("Rejected connection with no Origin header from {}", addr);
let error_response = WsResponse::builder()
.status(hyper::StatusCode::FORBIDDEN)
.body(None)
.unwrap();
Err(error_response)
}
}
Comment on lines +334 to +355
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
match origin {
Some(o) if o == ALLOWED_ORIGIN => {
info!("Accepted connection from allowed origin: {}", o);
Ok(response)
}
Some(o) => {
warn!("Rejected connection from disallowed origin: {}", o);
let error_response = WsResponse::builder()
.status(hyper::StatusCode::FORBIDDEN)
.body(None)
.unwrap();
Err(error_response)
}
None => {
warn!("Rejected connection with no Origin header from {}", addr);
let error_response = WsResponse::builder()
.status(hyper::StatusCode::FORBIDDEN)
.body(None)
.unwrap();
Err(error_response)
}
}
match origin {
Some(o) => {
if o == ALLOWED_ORIGIN {
info!("Accepted connection from allowed origin: {}", o);
Ok(response)
} else {
warn!("Rejected connection from disallowed origin: {}", o);
let error_response = WsResponse::builder()
.status(hyper::StatusCode::FORBIDDEN)
.body(None)
.unwrap();
Err(error_response)
}
}
None => {
warn!("Rejected connection with no Origin header from {}", addr);
let error_response = WsResponse::builder()
.status(hyper::StatusCode::FORBIDDEN)
.body(None)
.unwrap();
Err(error_response)
}
}

};

let ws_stream = match accept_hdr_async(stream, origin_callback).await {
Ok(ws) => ws,
Err(e) => {
error!("Error during WebSocket handshake: {}", e);
Expand Down