Skip to content

Commit

Permalink
Merge pull request #60 from polyphony-chat/gateway
Browse files Browse the repository at this point in the history
Working heartbeats and bugfixes
  • Loading branch information
bitfl0wer authored Oct 21, 2024
2 parents 20a8711 + a919632 commit 092e1e9
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 112 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ DATABASE_PORT=[Postgres port, usually 5432]
DATABASE_USERNAME=[Your Postgres username]
DATABASE_PASSWORD=[Your Postgres password]
DATABASE_NAME=[Your Postgres database name]
API_BIND=[ip:port to bind the HTTP API server to. Defaults to 0.0.0.0:3001 if not set]
GATEWAY_BIND=[ip:port to bind the Gateway server to. Defaults to 0.0.0.0:3003 if not set]
```

4. Install the sqlx CLI with `cargo install sqlx-cli`
Expand Down
7 changes: 6 additions & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*/

static DEFAULT_API_BIND: &str = "0.0.0.0:3001";

use poem::{
listener::TcpListener,
middleware::{NormalizePath, TrailingSlash},
Expand Down Expand Up @@ -86,7 +88,10 @@ pub async fn start_api(
.with(NormalizePath::new(TrailingSlash::Trim))
.catch_all_error(custom_error);

let bind = std::env::var("API_BIND").unwrap_or_else(|_| String::from("localhost:3001"));
let bind = &std::env::var("API_BIND").unwrap_or_else(|_| {
log::warn!(target: "symfonia::db", "You did not specify API_BIND environment variable. Defaulting to '{DEFAULT_API_BIND}'.");
DEFAULT_API_BIND.to_string()
});
let bind_clone = bind.clone();

log::info!(target: "symfonia::api", "Starting HTTP Server");
Expand Down
10 changes: 5 additions & 5 deletions src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,25 @@ static DEFAULT_CONNECTION_PORT: u16 = 5432;

pub async fn establish_connection() -> Result<sqlx::PgPool, Error> {
let db_url = std::env::var("DATABASE_HOST").unwrap_or_else(|_| {
log::warn!(target: "symfonia::db", "You did not specify `DATABASE_HOST` environment variable, defaulting to '{DEFAULT_CONNECTION_HOST}'.");
log::warn!(target: "symfonia::db", "You did not specify DATABASE_HOST environment variable, defaulting to '{DEFAULT_CONNECTION_HOST}'.");
DEFAULT_CONNECTION_HOST.to_string()
});
let connect_options = PgConnectOptions::new()
.host(&db_url)
.port(std::env::var("DATABASE_PORT").unwrap_or_else(|_| {
log::warn!(target: "symfonia::db", "You did not specify `DATABASE_PORT` environment variable. Defaulting to '{DEFAULT_CONNECTION_PORT}'.");
log::warn!(target: "symfonia::db", "You did not specify DATABASE_PORT environment variable. Defaulting to '{DEFAULT_CONNECTION_PORT}'.");
DEFAULT_CONNECTION_PORT.to_string()
}).parse::<u16>().expect("DATABASE_PORT must be a valid 16 bit unsigned integer."))
.username(&std::env::var("DATABASE_USERNAME").unwrap_or_else(|_| {
log::warn!(target: "symfonia::db", "You did not specify `DATABASE_USERNAME` environment variable. Defaulting to '{DEFAULT_CONNECTION_USERNAME}'.");
log::warn!(target: "symfonia::db", "You did not specify DATABASE_USERNAME environment variable. Defaulting to '{DEFAULT_CONNECTION_USERNAME}'.");
DEFAULT_CONNECTION_USERNAME.to_string()
}))
.password(&std::env::var("DATABASE_PASSWORD").unwrap_or_else(|_| {
log::warn!(target: "symfonia::db", "You did not specify `DATABASE_PASSWORD` environment variable. Defaulting to '{DEFAULT_CONNECTION_PASSWORD}'.");
log::warn!(target: "symfonia::db", "You did not specify DATABASE_PASSWORD environment variable. Defaulting to '{DEFAULT_CONNECTION_PASSWORD}'.");
DEFAULT_CONNECTION_PASSWORD.to_string()
}))
.database(&std::env::var("DATABASE_NAME").unwrap_or_else(|_| {
log::warn!(target: "symfonia::db", "You did not specify `DATABASE_NAME` environment variable. Defaulting to '{DEFAULT_CONNECTION_NAME}'.");
log::warn!(target: "symfonia::db", "You did not specify DATABASE_NAME environment variable. Defaulting to '{DEFAULT_CONNECTION_NAME}'.");
DEFAULT_CONNECTION_NAME.to_string()
}));
let pool = PgPool::connect_with(connect_options).await?;
Expand Down
20 changes: 8 additions & 12 deletions src/gateway/establish_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ struct State {
config: Config,
connected_users: ConnectedUsers,
sequence_number: Arc<Mutex<u64>>,
kill_send: Sender<()>,
kill_receive: tokio::sync::broadcast::Receiver<()>,
/// Receiver for heartbeat messages. The `HeartbeatHandler` will receive messages from this channel.
heartbeat_receive: tokio::sync::broadcast::Receiver<GatewayHeartbeat>,
/// Sender for heartbeat messages. The main gateway task will send messages to this channel for the `HeartbeatHandler` to receive and handle.
Expand Down Expand Up @@ -100,8 +98,6 @@ pub(super) async fn establish_connection(
config: config.clone(),
connected_users: connected_users.clone(),
sequence_number: sequence_number.clone(),
kill_send: kill_send.clone(),
kill_receive: kill_receive.resubscribe(),
heartbeat_receive: message_receive.resubscribe(),
heartbeat_send: message_send.clone(),
session_id_send: session_id_send.clone(),
Expand Down Expand Up @@ -150,7 +146,11 @@ async fn finish_connecting(
Ok(next) => next,
Err(_) => {
log::debug!(target: "symfonia::gateway::finish_connecting", "Encountered error when trying to receive message. Sending kill signal...");
state.kill_send.send(()).expect("Failed to send kill_send");
state
.connection
.kill_send
.send(())
.expect("Failed to send kill_send");
return Err(GatewayError::Timeout.into());
}
};
Expand All @@ -172,8 +172,6 @@ async fn finish_connecting(
heartbeat_handler_handle = Some(tokio::spawn({
let mut heartbeat_handler = HeartbeatHandler::new(
state.connection.clone(),
state.kill_receive.resubscribe(),
state.kill_send.clone(),
state.heartbeat_receive.resubscribe(),
state.sequence_number.clone(),
state.session_id_receive.resubscribe(),
Expand Down Expand Up @@ -205,6 +203,7 @@ async fn finish_connecting(
Err(_) => {
log::trace!(target: "symfonia::gateway::establish_connection::finish_connecting", "Failed to verify token");
state
.connection
.kill_send
.send(())
.expect("Failed to send kill signal");
Expand All @@ -217,8 +216,6 @@ async fn finish_connecting(
let main_task_handle = tokio::spawn(gateway_task::gateway_task(
state.connection.clone(),
gateway_user.lock().await.inbox.resubscribe(),
state.kill_receive.resubscribe(),
state.kill_send.clone(),
state.heartbeat_send.clone(),
state.sequence_number.clone(),
));
Expand All @@ -235,8 +232,6 @@ async fn finish_connecting(
log::trace!(target: "symfonia::gateway::establish_connection::finish_connecting", "No heartbeat_handler yet. Creating one...");
let mut heartbeat_handler = HeartbeatHandler::new(
state.connection.clone(),
state.kill_receive.resubscribe(),
state.kill_send.clone(),
state.heartbeat_receive.resubscribe(),
state.sequence_number.clone(),
state.session_id_receive.resubscribe(),
Expand All @@ -246,7 +241,6 @@ async fn finish_connecting(
}
}),
},
state.kill_send.clone(),
&identify.event_data.token,
state.sequence_number.clone(),
)
Expand All @@ -256,6 +250,7 @@ async fn finish_connecting(
Err(_) => {
log::error!(target: "symfonia::gateway::establish_connection::finish_connecting", "Failed to send session_id to heartbeat handler");
state
.connection
.kill_send
.send(())
.expect("Failed to send kill signal");
Expand Down Expand Up @@ -289,6 +284,7 @@ async fn finish_connecting(
.into(),
})))?;
state
.connection
.kill_send
.send(())
.expect("Failed to send kill signal");
Expand Down
56 changes: 38 additions & 18 deletions src/gateway/gateway_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,11 @@ use super::{Event, GatewayClient, GatewayPayload};
pub(super) async fn gateway_task(
mut connection: super::WebSocketConnection,
mut inbox: tokio::sync::broadcast::Receiver<Event>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
mut kill_send: tokio::sync::broadcast::Sender<()>,
mut heartbeat_send: tokio::sync::broadcast::Sender<GatewayHeartbeat>,
last_sequence_number: Arc<Mutex<u64>>,
) {
log::trace!(target: "symfonia::gateway::gateway_task", "Started a new gateway task!");
let inbox_processor = tokio::spawn(process_inbox(
connection.clone(),
inbox.resubscribe(),
kill_receive.resubscribe(),
));
let inbox_processor = tokio::spawn(process_inbox(connection.clone(), inbox.resubscribe()));

/*
Before we can respond to any gateway event we receive, we need to figure out what kind of event
Expand All @@ -39,18 +33,45 @@ pub(super) async fn gateway_task(

loop {
tokio::select! {
_ = kill_receive.recv() => {
_ = connection.kill_receive.recv() => {
return;
},
message_result = connection.receiver.recv() => {
match message_result {
Ok(message_of_unknown_type) => {
let event = unwrap_event(Event::try_from(message_of_unknown_type), connection.clone(), kill_send.clone());
// TODO: Handle event
log::trace!(target: "symfonia::gateway::gateway_task", "Received raw message {:?}", message_of_unknown_type);
let event = unwrap_event(Event::try_from(message_of_unknown_type), connection.clone(), connection.kill_send.clone());
log::trace!(target: "symfonia::gateway::gateway_task", "Event type of received message: {:?}", event);
match event {
Event::Dispatch(_) => {
// Receiving a dispatch event from a client is never correct
log::debug!(target: "symfonia::gateway::gateway_task", "Received an unexpected message: {:?}", event);
connection.sender.send(Message::Close(Some(CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library(4002), reason: "DECODE_ERROR".into() })));
connection.kill_send.send(()).expect("Failed to send kill_send");
panic!("Killing gateway task: Received an unexpected message");
},
Event::Heartbeat(hearbeat_event) => {
match heartbeat_send.send(hearbeat_event) {
Err(e) => {
log::debug!(target: "symfonia::gateway::gateway_task", "Received Heartbeat but HeartbeatHandler seems to be dead?");
connection.sender.send(Message::Close(Some(CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library(4002), reason: "DECODE_ERROR".into() })));
connection.kill_send.send(()).expect("Failed to send kill_send");
panic!("Killing gateway task: Received an unexpected message");
},
Ok(_) => {
log::trace!(target: "symfonia::gateway::gateway_task", "Forwarded heartbeat message to HeartbeatHandler!");
}
}
}
_ => {
log::error!(target: "symfonia::gateway::gateway_task", "Received an event type for which no code is yet implemented in the gateway_task. Please open a issue or PR at the symfonia repository. {:?}", event);
}
}

},
Err(error) => {
connection.sender.send(Message::Close(Some(CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library(4000), reason: "INTERNAL_SERVER_ERROR".into() })));
kill_send.send(()).expect("Failed to send kill_send");
connection.kill_send.send(()).expect("Failed to send kill_send");
return;
},
}
Expand Down Expand Up @@ -81,26 +102,26 @@ fn unwrap_event(
match e {
Error::Gateway(g) => match g {
GatewayError::UnexpectedOpcode(o) => {
log::debug!(target: "symfonia::gateway::gateway_task", "Received an unexpected opcode: {:?}", o);
log::debug!(target: "symfonia::gateway::gateway_task::unwrap_event", "Received an unexpected opcode: {:?}", o);
connection.sender.send(Message::Close(Some(CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library(4001), reason: "UNKNOWN_OPCODE".into() })));
kill_send.send(()).expect("Failed to send kill_send");
panic!("Killing gateway task: Received an unexpected opcode");
}
GatewayError::UnexpectedMessage(m) => {
log::debug!(target: "symfonia::gateway::gateway_task", "Received an unexpected message: {:?}", m);
log::debug!(target: "symfonia::gateway::gateway_task::unwrap_event", "Received an unexpected message: {:?}", m);
connection.sender.send(Message::Close(Some(CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library(4002), reason: "DECODE_ERROR".into() })));
kill_send.send(()).expect("Failed to send kill_send");
panic!("Killing gateway task: Received an unexpected message");
}
_ => {
log::debug!(target: "symfonia::gateway::gateway_task", "Received an unexpected error: {:?}", g);
log::debug!(target: "symfonia::gateway::gateway_task::unwrap_event", "Received an unexpected error: {:?}", g);
connection.sender.send(Message::Close(Some(CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library(4000), reason: "INTERNAL_SERVER_ERROR".into() })));
kill_send.send(()).expect("Failed to send kill_send");
panic!("Killing gateway task: Received an unexpected error");
}
},
_ => {
log::debug!(target: "symfonia::gateway::gateway_task", "Received an unexpected error: {:?}", e);
log::debug!(target: "symfonia::gateway::gateway_task::unwrap_event", "Received an unexpected error: {:?}", e);
connection.sender.send(Message::Close(Some(CloseFrame { code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Library(4000), reason: "INTERNAL_SERVER_ERROR".into() })));
kill_send.send(()).expect("Failed to send kill_send");
panic!("Killing gateway task: Received an unexpected error");
Expand All @@ -112,13 +133,12 @@ fn unwrap_event(
}

async fn process_inbox(
connection: super::WebSocketConnection,
mut connection: super::WebSocketConnection,
mut inbox: tokio::sync::broadcast::Receiver<Event>,
mut kill_receive: tokio::sync::broadcast::Receiver<()>,
) {
loop {
tokio::select! {
_ = kill_receive.recv() => {
_ = connection.kill_receive.recv() => {
return;
}
event = inbox.recv() => {
Expand Down
Loading

0 comments on commit 092e1e9

Please sign in to comment.