From adf4f37536636662274648cccd9eca0c58b1936f Mon Sep 17 00:00:00 2001 From: Garen Tyler Date: Sun, 8 Dec 2024 21:37:10 -0700 Subject: [PATCH] Create App trait to manage application startup/shutdown --- src/lib.rs | 47 ++++++++++++++- src/main.rs | 13 +++- src/net/connection.rs | 136 ++++++++++++++++++++++++++++++++++++++++-- src/net/listener.rs | 101 ------------------------------- src/net/mod.rs | 1 - src/proxy/mod.rs | 99 +++++++++++++----------------- src/server/error.rs | 3 - src/server/mod.rs | 103 +++++++++++--------------------- 8 files changed, 264 insertions(+), 239 deletions(-) delete mode 100644 src/net/listener.rs diff --git a/src/lib.rs b/src/lib.rs index ce7661b..b624e43 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,8 @@ pub(crate) mod world; use config::Subcommand; use once_cell::sync::OnceCell; use std::time::Instant; +use tokio_util::sync::CancellationToken; +use tracing::info; pub const PROTOCOL_VERSION: i32 = 762; pub const GAME_VERSION: &str = "1.19.4"; @@ -26,12 +28,51 @@ pub const GAME_VERSION: &str = "1.19.4"; /// This should be set immediately on startup. pub static START_TIME: OnceCell = OnceCell::new(); -pub async fn run(command: Subcommand) { +pub async fn run(command: Subcommand, running: CancellationToken) { match command { #[cfg(feature = "server")] - Subcommand::Server => server::Server::run().await, + Subcommand::Server => server::Server::run(running).await, #[cfg(feature = "proxy")] - Subcommand::Proxy => proxy::Proxy::run().await, + Subcommand::Proxy => proxy::Proxy::run(running).await, Subcommand::None => unreachable!(), } } + +#[async_trait::async_trait] +pub(crate) trait App: Sized { + type Error: std::fmt::Debug; + + fn startup_message() -> String; + async fn new(running: CancellationToken) -> Result; + async fn update(&mut self) -> Result<(), Self::Error>; + async fn shutdown(self) -> Result<(), Self::Error>; + + async fn run(running: CancellationToken) { + info!("{}", Self::startup_message()); + let mut app = Self::new(running.clone()).await.expect("app to start"); + info!( + "Done! Start took {:?}", + crate::START_TIME.get().unwrap().elapsed() + ); + + // The main loop. + loop { + tokio::select! { + _ = running.cancelled() => { + break; + } + r = app.update() => { + if r.is_err() { + break; + } + } + } + } + + // Run shutdown tasks. + match tokio::time::timeout(std::time::Duration::from_secs(10), app.shutdown()).await { + Ok(_) => std::process::exit(0), + Err(_) => std::process::exit(1), + } + } +} diff --git a/src/main.rs b/src/main.rs index 070b0a6..954fa34 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use tokio_util::sync::CancellationToken; use tracing::{info, warn}; use tracing_subscriber::prelude::*; @@ -65,7 +66,17 @@ pub fn main() { } .unwrap() .block_on(async move { + let running = CancellationToken::new(); + + // Spawn the ctrl-c task. + let r = running.clone(); + tokio::spawn(async move { + tokio::signal::ctrl_c().await.unwrap(); + info!("Ctrl-C received, shutting down"); + r.cancel(); + }); + let args = composition::config::Args::instance(); - composition::run(args.subcommand).await; + composition::run(args.subcommand, running).await; }); } diff --git a/src/net/connection.rs b/src/net/connection.rs index 5b978a5..78f9e71 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -5,15 +5,140 @@ use crate::protocol::{ ClientState, }; use futures::{stream::StreamExt, SinkExt}; -use std::time::{Duration, Instant}; -use tokio::{io::BufStream, net::TcpStream}; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; +use tokio::{io::BufStream, net::TcpStream, sync::mpsc}; +use tokio::{ + net::{TcpListener, ToSocketAddrs}, + task::JoinHandle, +}; use tokio_util::codec::{Decoder, Framed}; -use tracing::trace; +use tokio_util::sync::CancellationToken; +use tracing::{error, trace}; + +#[derive(Debug)] +pub struct ConnectionManager { + clients: HashMap, + channel: ( + mpsc::UnboundedSender, + mpsc::UnboundedReceiver, + ), +} +impl ConnectionManager { + pub fn new() -> ConnectionManager { + ConnectionManager { + clients: HashMap::new(), + channel: mpsc::unbounded_channel(), + } + } + pub fn client(&self, id: u128) -> Option<&Connection> { + self.clients.get(&id) + } + pub fn client_mut(&mut self, id: u128) -> Option<&mut Connection> { + self.clients.get_mut(&id) + } + pub async fn spawn_listener( + &self, + bind_address: A, + running: CancellationToken, + ) -> Result, std::io::Error> + where + A: 'static + ToSocketAddrs + Send + std::fmt::Debug, + { + trace!("Starting listener task"); + let fmt_addr = format!("{:?}", bind_address); + let listener = TcpListener::bind(bind_address) + .await + .inspect_err(|_| error!("Could not bind to {}.", fmt_addr))?; + + let sender = self.channel.0.clone(); + + let join_handle = tokio::spawn(async move { + let mut client_id = 0u128; + + loop { + tokio::select! { + _ = running.cancelled() => { + break; + } + result = listener.accept() => { + if let Ok((stream, _)) = result { + trace!("Listener task got connection (id {})", client_id); + let client = Connection::new_client(client_id, stream); + if sender.send(client).is_err() { + trace!("Client receiver disconnected"); + break; + } + client_id += 1; + } + } + } + } + trace!("Listener task shutting down"); + }); + + Ok(join_handle) + } + pub fn update(&mut self) -> Result<(), std::io::Error> { + use std::io::{Error, ErrorKind}; + + // Receive new clients from the sender. + loop { + match self.channel.1.try_recv() { + Ok(connection) => { + let id = connection.id(); + self.clients.insert(id, connection); + } + Err(mpsc::error::TryRecvError::Disconnected) => { + return Err(Error::new( + ErrorKind::BrokenPipe, + "all senders disconnected", + )) + } + Err(mpsc::error::TryRecvError::Empty) => break, + }; + } + + // Remove disconnected clients. + self.clients + .retain(|_id, c| c.client_state() != ClientState::Disconnected); + Ok(()) + } + pub async fn disconnect( + &mut self, + id: u128, + reason: Option, + ) -> Option> { + let client = self.clients.remove(&id)?; + Some(client.disconnect(reason).await) + } + pub async fn shutdown(mut self, reason: Option) -> Result<(), std::io::Error> { + let reason = reason.unwrap_or(serde_json::json!({ + "text": "You have been disconnected!" + })); + + let disconnections = self + .clients + .drain() + .map(|(_, c)| c) + .map(|c| c.disconnect(Some(reason.clone()))) + .collect::>(); + + // We don't actually care if the disconnections succeed, + // the connection is going to be dropped anyway. + let _disconnections: Vec> = + futures::future::join_all(disconnections).await; + + Ok(()) + } +} #[derive(Debug)] pub struct Connection { /// The `Connection`'s unique id. - pub id: u128, + id: u128, stream: Framed, PacketCodec>, last_received_data_time: Instant, last_sent_data_time: Instant, @@ -35,6 +160,9 @@ impl Connection { pub fn new_server(id: u128, stream: TcpStream) -> Self { Self::new(id, PacketDirection::Clientbound, stream) } + pub fn id(&self) -> u128 { + self.id + } pub fn client_state(&self) -> ClientState { self.stream.codec().client_state } diff --git a/src/net/listener.rs b/src/net/listener.rs deleted file mode 100644 index 04ff000..0000000 --- a/src/net/listener.rs +++ /dev/null @@ -1,101 +0,0 @@ -use super::connection::Connection; -use crate::protocol::types::Chat; -use std::{ - collections::HashMap, - sync::{Arc, Weak}, -}; -use tokio::{ - net::{TcpListener, ToSocketAddrs}, - sync::RwLock, -}; -use tokio_util::sync::CancellationToken; -use tracing::{error, trace}; - -pub type Callback = dyn Fn(u128, Arc>) + Send; - -#[derive(Clone, Debug)] -pub struct NetworkListener { - running: CancellationToken, - clients: Arc>>>>, -} -impl NetworkListener { - pub async fn new( - bind_address: A, - running: CancellationToken, - callback: Option>, - ) -> Result { - let listener = TcpListener::bind(bind_address) - .await - .inspect_err(|_| error!("Could not bind to given address."))?; - let clients = Arc::new(RwLock::new(HashMap::new())); - - let r = running.clone(); - let c = clients.clone(); - tokio::spawn(async move { - trace!("Starting listener task"); - let mut client_id = 0u128; - - loop { - tokio::select! { - _ = r.cancelled() => { - break; - } - result = listener.accept() => { - if let Ok((stream, _)) = result { - trace!("Listener task got connection (id {})", client_id); - let client = Arc::new(RwLock::new(Connection::new_client(client_id, stream))); - c.write().await.insert(client_id, client.clone()); - if let Some(ref callback) = callback { - callback(client_id, client); - } - client_id += 1; - } - } - } - } - }); - - Ok(NetworkListener { running, clients }) - } - pub async fn get_client(&self, id: u128) -> Option>> { - self.clients.read().await.get(&id).map(Arc::downgrade) - } - pub async fn disconnect_client( - &self, - id: u128, - reason: Option, - ) -> Result, ()> { - // Remove the client from the hashmap. - let client = self.clients.write().await.remove(&id).ok_or(())?; - let client: Connection = Arc::into_inner(client) - .expect("only one reference") - .into_inner(); - // let mut client = client.write().await; - // Send a disconnect packet. - Ok(client.disconnect(reason).await) - } - pub async fn shutdown(self, reason: Option) -> Result<(), std::io::Error> { - self.running.cancel(); - - let reason = reason.unwrap_or(serde_json::json!({ - "text": "You have been disconnected!" - })); - - let disconnections = self - .clients - .write() - .await - .drain() - .map(|(_, c)| c) - .map(|c| Arc::into_inner(c).expect("only one reference").into_inner()) - .map(|c| c.disconnect(Some(reason.clone()))) - .collect::>(); - - // We don't actually care if the disconnections succeed, - // the connection is going to be dropped anyway. - let _disconnections: Vec> = - futures::future::join_all(disconnections).await; - - Ok(()) - } -} diff --git a/src/net/mod.rs b/src/net/mod.rs index ff2a2a7..91ff74e 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -2,4 +2,3 @@ pub mod codec; pub mod connection; -pub mod listener; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 6a51a53..a477723 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,83 +1,64 @@ pub mod config; -use crate::{config::Config, net::listener::NetworkListener}; +use crate::App; +use crate::{config::Config, net::connection::ConnectionManager}; use config::ProxyConfig; -use tokio::net::ToSocketAddrs; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; -use tracing::{info, trace}; +use tracing::info; #[derive(Debug)] pub struct Proxy { - _network_listener: NetworkListener, + running: CancellationToken, + connections: ConnectionManager, + listener: JoinHandle<()>, } -impl Proxy { - /// Start the proxy. - #[tracing::instrument] - pub async fn run() { +#[async_trait::async_trait] +impl App for Proxy { + type Error = (); + + fn startup_message() -> String { let config = Config::instance(); - info!( + format!( "Starting {} on port {}", ProxyConfig::default().version, config.proxy.port - ); - let (mut proxy, running) = Self::new(format!("0.0.0.0:{}", config.proxy.port)).await; - info!( - "Done! Start took {:?}", - crate::START_TIME.get().unwrap().elapsed() - ); + ) + } + #[tracing::instrument] + async fn new(running: CancellationToken) -> Result { + let config = Config::instance(); + let bind_address = format!("0.0.0.0:{}", config.proxy.port); + + let connections = ConnectionManager::new(); + let listener = connections + .spawn_listener(bind_address, running.child_token()) + .await + .map_err(|_| ())?; + info!( "Upstream server: {}:{}", config.proxy.upstream_host, config.proxy.upstream_port ); - // Spawn the ctrl-c task. - let r = running.clone(); - tokio::spawn(async move { - tokio::signal::ctrl_c().await.unwrap(); - info!("Ctrl-C received, shutting down"); - r.cancel(); - }); - - // The main loop. - loop { - tokio::select! { - _ = running.cancelled() => { - break; - } - _ = proxy.update() => {} - } - } - - match tokio::time::timeout(std::time::Duration::from_secs(10), proxy.shutdown()).await { - Ok(_) => std::process::exit(0), - Err(_) => std::process::exit(1), - } + Ok(Proxy { + running, + connections, + listener, + }) } #[tracing::instrument] - async fn new( - bind_address: A, - ) -> (Proxy, CancellationToken) { - trace!("Proxy::new()"); - let running = CancellationToken::new(); - - let network_listener = NetworkListener::new(bind_address, running.child_token(), None) - .await - .expect("listener to bind properly"); - - let proxy = Proxy { - _network_listener: network_listener, - }; - - (proxy, running) + async fn update(&mut self) -> Result<(), Self::Error> { + todo!() } #[tracing::instrument] - async fn update(&mut self) -> Result<(), ()> { - // TODO + async fn shutdown(self) -> Result<(), Self::Error> { + // Ensure any child tasks have been shut down. + self.running.cancel(); + + let _ = self.listener.await.map_err(|_| ())?; + let _ = self.connections.shutdown(None).await.map_err(|_| ())?; + Ok(()) } - #[tracing::instrument] - async fn shutdown(self) { - trace!("Proxy.shutdown()"); - // TODO - } } diff --git a/src/server/error.rs b/src/server/error.rs index 109bb97..c2fd3da 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -5,6 +5,3 @@ pub enum Error { #[error("the server is not running")] NotRunning, } - -/// Alias for a Result with the error type `composition_core::server::Error`. -pub type Result = std::result::Result; diff --git a/src/server/mod.rs b/src/server/mod.rs index 2598d1b..be03275 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -7,14 +7,14 @@ pub mod net; use crate::config::Config; use crate::protocol::ClientState; +use crate::App; use config::ServerConfig; -use error::Result; use net::{NetworkClient, NetworkClientState}; use std::sync::Arc; use tokio::net::{TcpListener, ToSocketAddrs}; use tokio::{sync::RwLock, task::JoinHandle}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, trace}; +use tracing::{error, trace}; /// The main state and logic of the program. #[derive(Debug)] @@ -23,65 +23,6 @@ pub struct Server { net_tasks_handle: JoinHandle<()>, } impl Server { - /// Start the server. - #[tracing::instrument] - pub async fn run() { - let config = Config::instance(); - info!( - "Starting {} on port {}", - ServerConfig::default().version, - config.server.port - ); - let (mut server, running) = Self::new(format!("0.0.0.0:{}", config.server.port)).await; - info!( - "Done! Start took {:?}", - crate::START_TIME.get().unwrap().elapsed() - ); - - // Spawn the ctrl-c task. - let r = running.clone(); - tokio::spawn(async move { - tokio::signal::ctrl_c().await.unwrap(); - info!("Ctrl-C received, shutting down"); - r.cancel(); - }); - - // The main server loop. - loop { - tokio::select! { - _ = running.cancelled() => { - break; - } - _ = server.update() => {} - } - } - - match tokio::time::timeout(std::time::Duration::from_secs(10), server.shutdown()).await { - Ok(_) => std::process::exit(0), - Err(_) => std::process::exit(1), - } - } - #[tracing::instrument] - async fn new( - bind_address: A, - ) -> (Server, CancellationToken) { - trace!("Server::new()"); - - let running = CancellationToken::new(); - let clients = Arc::new(RwLock::new(vec![])); - let net_tasks_handle = tokio::spawn(Self::create_network_tasks( - bind_address, - clients.clone(), - running.clone(), - )); - - let server = Server { - clients, - net_tasks_handle, - }; - - (server, running) - } #[tracing::instrument] async fn create_network_tasks( bind_address: A, @@ -187,10 +128,38 @@ impl Server { .await .expect("Disconnection task crashed"); } - #[tracing::instrument] - pub async fn update(&mut self) -> Result<()> { - trace!("Server.update()"); +} +#[async_trait::async_trait] +impl App for Server { + type Error = error::Error; + fn startup_message() -> String { + let config = Config::instance(); + format!( + "Starting {} on port {}", + ServerConfig::default().version, + config.server.port + ) + } + #[tracing::instrument] + async fn new(running: CancellationToken) -> Result { + let config = Config::instance(); + let bind_address = format!("0.0.0.0:{}", config.server.port); + + let clients = Arc::new(RwLock::new(vec![])); + let net_tasks_handle = tokio::spawn(Self::create_network_tasks( + bind_address, + clients.clone(), + running.child_token(), + )); + + Ok(Server { + clients, + net_tasks_handle, + }) + } + #[tracing::instrument] + async fn update(&mut self) -> Result<(), Self::Error> { let mut clients = self.clients.write().await; // Handle packets from the clients. @@ -321,9 +290,7 @@ impl Server { Ok(()) } #[tracing::instrument] - pub async fn shutdown(self) { - trace!("Server.shutdown()"); - + async fn shutdown(self) -> Result<(), Self::Error> { // Close the concurrent tasks. let _ = self.net_tasks_handle.await; @@ -335,5 +302,7 @@ impl Server { )) .await; } + + Ok(()) } }