Create App trait to manage application startup/shutdown

This commit is contained in:
Garen Tyler 2024-12-08 21:37:10 -07:00
parent b87c71737d
commit adf4f37536
Signed by: garentyler
SSH Key Fingerprint: SHA256:G4ke7blZMdpWPbkescyZ7IQYE4JAtwpI85YoJdq+S7U
8 changed files with 264 additions and 239 deletions

View File

@ -17,6 +17,8 @@ pub(crate) mod world;
use config::Subcommand; use config::Subcommand;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use std::time::Instant; use std::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::info;
pub const PROTOCOL_VERSION: i32 = 762; pub const PROTOCOL_VERSION: i32 = 762;
pub const GAME_VERSION: &str = "1.19.4"; pub const GAME_VERSION: &str = "1.19.4";
@ -26,12 +28,51 @@ pub const GAME_VERSION: &str = "1.19.4";
/// This should be set immediately on startup. /// This should be set immediately on startup.
pub static START_TIME: OnceCell<Instant> = OnceCell::new(); pub static START_TIME: OnceCell<Instant> = OnceCell::new();
pub async fn run(command: Subcommand) { pub async fn run(command: Subcommand, running: CancellationToken) {
match command { match command {
#[cfg(feature = "server")] #[cfg(feature = "server")]
Subcommand::Server => server::Server::run().await, Subcommand::Server => server::Server::run(running).await,
#[cfg(feature = "proxy")] #[cfg(feature = "proxy")]
Subcommand::Proxy => proxy::Proxy::run().await, Subcommand::Proxy => proxy::Proxy::run(running).await,
Subcommand::None => unreachable!(), Subcommand::None => unreachable!(),
} }
} }
#[async_trait::async_trait]
pub(crate) trait App: Sized {
type Error: std::fmt::Debug;
fn startup_message() -> String;
async fn new(running: CancellationToken) -> Result<Self, Self::Error>;
async fn update(&mut self) -> Result<(), Self::Error>;
async fn shutdown(self) -> Result<(), Self::Error>;
async fn run(running: CancellationToken) {
info!("{}", Self::startup_message());
let mut app = Self::new(running.clone()).await.expect("app to start");
info!(
"Done! Start took {:?}",
crate::START_TIME.get().unwrap().elapsed()
);
// The main loop.
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
r = app.update() => {
if r.is_err() {
break;
}
}
}
}
// Run shutdown tasks.
match tokio::time::timeout(std::time::Duration::from_secs(10), app.shutdown()).await {
Ok(_) => std::process::exit(0),
Err(_) => std::process::exit(1),
}
}
}

View File

@ -1,3 +1,4 @@
use tokio_util::sync::CancellationToken;
use tracing::{info, warn}; use tracing::{info, warn};
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -65,7 +66,17 @@ pub fn main() {
} }
.unwrap() .unwrap()
.block_on(async move { .block_on(async move {
let running = CancellationToken::new();
// Spawn the ctrl-c task.
let r = running.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
info!("Ctrl-C received, shutting down");
r.cancel();
});
let args = composition::config::Args::instance(); let args = composition::config::Args::instance();
composition::run(args.subcommand).await; composition::run(args.subcommand, running).await;
}); });
} }

View File

@ -5,15 +5,140 @@ use crate::protocol::{
ClientState, ClientState,
}; };
use futures::{stream::StreamExt, SinkExt}; use futures::{stream::StreamExt, SinkExt};
use std::time::{Duration, Instant}; use std::{
use tokio::{io::BufStream, net::TcpStream}; collections::HashMap,
time::{Duration, Instant},
};
use tokio::{io::BufStream, net::TcpStream, sync::mpsc};
use tokio::{
net::{TcpListener, ToSocketAddrs},
task::JoinHandle,
};
use tokio_util::codec::{Decoder, Framed}; use tokio_util::codec::{Decoder, Framed};
use tracing::trace; use tokio_util::sync::CancellationToken;
use tracing::{error, trace};
#[derive(Debug)]
pub struct ConnectionManager {
clients: HashMap<u128, Connection>,
channel: (
mpsc::UnboundedSender<Connection>,
mpsc::UnboundedReceiver<Connection>,
),
}
impl ConnectionManager {
pub fn new() -> ConnectionManager {
ConnectionManager {
clients: HashMap::new(),
channel: mpsc::unbounded_channel(),
}
}
pub fn client(&self, id: u128) -> Option<&Connection> {
self.clients.get(&id)
}
pub fn client_mut(&mut self, id: u128) -> Option<&mut Connection> {
self.clients.get_mut(&id)
}
pub async fn spawn_listener<A>(
&self,
bind_address: A,
running: CancellationToken,
) -> Result<JoinHandle<()>, std::io::Error>
where
A: 'static + ToSocketAddrs + Send + std::fmt::Debug,
{
trace!("Starting listener task");
let fmt_addr = format!("{:?}", bind_address);
let listener = TcpListener::bind(bind_address)
.await
.inspect_err(|_| error!("Could not bind to {}.", fmt_addr))?;
let sender = self.channel.0.clone();
let join_handle = tokio::spawn(async move {
let mut client_id = 0u128;
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
result = listener.accept() => {
if let Ok((stream, _)) = result {
trace!("Listener task got connection (id {})", client_id);
let client = Connection::new_client(client_id, stream);
if sender.send(client).is_err() {
trace!("Client receiver disconnected");
break;
}
client_id += 1;
}
}
}
}
trace!("Listener task shutting down");
});
Ok(join_handle)
}
pub fn update(&mut self) -> Result<(), std::io::Error> {
use std::io::{Error, ErrorKind};
// Receive new clients from the sender.
loop {
match self.channel.1.try_recv() {
Ok(connection) => {
let id = connection.id();
self.clients.insert(id, connection);
}
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(Error::new(
ErrorKind::BrokenPipe,
"all senders disconnected",
))
}
Err(mpsc::error::TryRecvError::Empty) => break,
};
}
// Remove disconnected clients.
self.clients
.retain(|_id, c| c.client_state() != ClientState::Disconnected);
Ok(())
}
pub async fn disconnect(
&mut self,
id: u128,
reason: Option<Chat>,
) -> Option<Result<(), std::io::Error>> {
let client = self.clients.remove(&id)?;
Some(client.disconnect(reason).await)
}
pub async fn shutdown(mut self, reason: Option<Chat>) -> Result<(), std::io::Error> {
let reason = reason.unwrap_or(serde_json::json!({
"text": "You have been disconnected!"
}));
let disconnections = self
.clients
.drain()
.map(|(_, c)| c)
.map(|c| c.disconnect(Some(reason.clone())))
.collect::<Vec<_>>();
// We don't actually care if the disconnections succeed,
// the connection is going to be dropped anyway.
let _disconnections: Vec<Result<(), std::io::Error>> =
futures::future::join_all(disconnections).await;
Ok(())
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct Connection { pub struct Connection {
/// The `Connection`'s unique id. /// The `Connection`'s unique id.
pub id: u128, id: u128,
stream: Framed<BufStream<TcpStream>, PacketCodec>, stream: Framed<BufStream<TcpStream>, PacketCodec>,
last_received_data_time: Instant, last_received_data_time: Instant,
last_sent_data_time: Instant, last_sent_data_time: Instant,
@ -35,6 +160,9 @@ impl Connection {
pub fn new_server(id: u128, stream: TcpStream) -> Self { pub fn new_server(id: u128, stream: TcpStream) -> Self {
Self::new(id, PacketDirection::Clientbound, stream) Self::new(id, PacketDirection::Clientbound, stream)
} }
pub fn id(&self) -> u128 {
self.id
}
pub fn client_state(&self) -> ClientState { pub fn client_state(&self) -> ClientState {
self.stream.codec().client_state self.stream.codec().client_state
} }

View File

@ -1,101 +0,0 @@
use super::connection::Connection;
use crate::protocol::types::Chat;
use std::{
collections::HashMap,
sync::{Arc, Weak},
};
use tokio::{
net::{TcpListener, ToSocketAddrs},
sync::RwLock,
};
use tokio_util::sync::CancellationToken;
use tracing::{error, trace};
pub type Callback = dyn Fn(u128, Arc<RwLock<Connection>>) + Send;
#[derive(Clone, Debug)]
pub struct NetworkListener {
running: CancellationToken,
clients: Arc<RwLock<HashMap<u128, Arc<RwLock<Connection>>>>>,
}
impl NetworkListener {
pub async fn new<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A,
running: CancellationToken,
callback: Option<Box<Callback>>,
) -> Result<NetworkListener, std::io::Error> {
let listener = TcpListener::bind(bind_address)
.await
.inspect_err(|_| error!("Could not bind to given address."))?;
let clients = Arc::new(RwLock::new(HashMap::new()));
let r = running.clone();
let c = clients.clone();
tokio::spawn(async move {
trace!("Starting listener task");
let mut client_id = 0u128;
loop {
tokio::select! {
_ = r.cancelled() => {
break;
}
result = listener.accept() => {
if let Ok((stream, _)) = result {
trace!("Listener task got connection (id {})", client_id);
let client = Arc::new(RwLock::new(Connection::new_client(client_id, stream)));
c.write().await.insert(client_id, client.clone());
if let Some(ref callback) = callback {
callback(client_id, client);
}
client_id += 1;
}
}
}
}
});
Ok(NetworkListener { running, clients })
}
pub async fn get_client(&self, id: u128) -> Option<Weak<RwLock<Connection>>> {
self.clients.read().await.get(&id).map(Arc::downgrade)
}
pub async fn disconnect_client(
&self,
id: u128,
reason: Option<Chat>,
) -> Result<Result<(), std::io::Error>, ()> {
// Remove the client from the hashmap.
let client = self.clients.write().await.remove(&id).ok_or(())?;
let client: Connection = Arc::into_inner(client)
.expect("only one reference")
.into_inner();
// let mut client = client.write().await;
// Send a disconnect packet.
Ok(client.disconnect(reason).await)
}
pub async fn shutdown(self, reason: Option<Chat>) -> Result<(), std::io::Error> {
self.running.cancel();
let reason = reason.unwrap_or(serde_json::json!({
"text": "You have been disconnected!"
}));
let disconnections = self
.clients
.write()
.await
.drain()
.map(|(_, c)| c)
.map(|c| Arc::into_inner(c).expect("only one reference").into_inner())
.map(|c| c.disconnect(Some(reason.clone())))
.collect::<Vec<_>>();
// We don't actually care if the disconnections succeed,
// the connection is going to be dropped anyway.
let _disconnections: Vec<Result<(), std::io::Error>> =
futures::future::join_all(disconnections).await;
Ok(())
}
}

View File

@ -2,4 +2,3 @@
pub mod codec; pub mod codec;
pub mod connection; pub mod connection;
pub mod listener;

View File

@ -1,83 +1,64 @@
pub mod config; pub mod config;
use crate::{config::Config, net::listener::NetworkListener}; use crate::App;
use crate::{config::Config, net::connection::ConnectionManager};
use config::ProxyConfig; use config::ProxyConfig;
use tokio::net::ToSocketAddrs; use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{info, trace}; use tracing::info;
#[derive(Debug)] #[derive(Debug)]
pub struct Proxy { pub struct Proxy {
_network_listener: NetworkListener, running: CancellationToken,
connections: ConnectionManager,
listener: JoinHandle<()>,
} }
impl Proxy { #[async_trait::async_trait]
/// Start the proxy. impl App for Proxy {
#[tracing::instrument] type Error = ();
pub async fn run() {
fn startup_message() -> String {
let config = Config::instance(); let config = Config::instance();
info!( format!(
"Starting {} on port {}", "Starting {} on port {}",
ProxyConfig::default().version, ProxyConfig::default().version,
config.proxy.port config.proxy.port
); )
let (mut proxy, running) = Self::new(format!("0.0.0.0:{}", config.proxy.port)).await; }
info!( #[tracing::instrument]
"Done! Start took {:?}", async fn new(running: CancellationToken) -> Result<Self, Self::Error> {
crate::START_TIME.get().unwrap().elapsed() let config = Config::instance();
); let bind_address = format!("0.0.0.0:{}", config.proxy.port);
let connections = ConnectionManager::new();
let listener = connections
.spawn_listener(bind_address, running.child_token())
.await
.map_err(|_| ())?;
info!( info!(
"Upstream server: {}:{}", "Upstream server: {}:{}",
config.proxy.upstream_host, config.proxy.upstream_port config.proxy.upstream_host, config.proxy.upstream_port
); );
// Spawn the ctrl-c task. Ok(Proxy {
let r = running.clone(); running,
tokio::spawn(async move { connections,
tokio::signal::ctrl_c().await.unwrap(); listener,
info!("Ctrl-C received, shutting down"); })
r.cancel();
});
// The main loop.
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
_ = proxy.update() => {}
}
}
match tokio::time::timeout(std::time::Duration::from_secs(10), proxy.shutdown()).await {
Ok(_) => std::process::exit(0),
Err(_) => std::process::exit(1),
}
} }
#[tracing::instrument] #[tracing::instrument]
async fn new<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>( async fn update(&mut self) -> Result<(), Self::Error> {
bind_address: A, todo!()
) -> (Proxy, CancellationToken) {
trace!("Proxy::new()");
let running = CancellationToken::new();
let network_listener = NetworkListener::new(bind_address, running.child_token(), None)
.await
.expect("listener to bind properly");
let proxy = Proxy {
_network_listener: network_listener,
};
(proxy, running)
} }
#[tracing::instrument] #[tracing::instrument]
async fn update(&mut self) -> Result<(), ()> { async fn shutdown(self) -> Result<(), Self::Error> {
// TODO // Ensure any child tasks have been shut down.
self.running.cancel();
let _ = self.listener.await.map_err(|_| ())?;
let _ = self.connections.shutdown(None).await.map_err(|_| ())?;
Ok(()) Ok(())
} }
#[tracing::instrument]
async fn shutdown(self) {
trace!("Proxy.shutdown()");
// TODO
}
} }

View File

@ -5,6 +5,3 @@ pub enum Error {
#[error("the server is not running")] #[error("the server is not running")]
NotRunning, NotRunning,
} }
/// Alias for a Result with the error type `composition_core::server::Error`.
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -7,14 +7,14 @@ pub mod net;
use crate::config::Config; use crate::config::Config;
use crate::protocol::ClientState; use crate::protocol::ClientState;
use crate::App;
use config::ServerConfig; use config::ServerConfig;
use error::Result;
use net::{NetworkClient, NetworkClientState}; use net::{NetworkClient, NetworkClientState};
use std::sync::Arc; use std::sync::Arc;
use tokio::net::{TcpListener, ToSocketAddrs}; use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::{sync::RwLock, task::JoinHandle}; use tokio::{sync::RwLock, task::JoinHandle};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace}; use tracing::{error, trace};
/// The main state and logic of the program. /// The main state and logic of the program.
#[derive(Debug)] #[derive(Debug)]
@ -23,65 +23,6 @@ pub struct Server {
net_tasks_handle: JoinHandle<()>, net_tasks_handle: JoinHandle<()>,
} }
impl Server { impl Server {
/// Start the server.
#[tracing::instrument]
pub async fn run() {
let config = Config::instance();
info!(
"Starting {} on port {}",
ServerConfig::default().version,
config.server.port
);
let (mut server, running) = Self::new(format!("0.0.0.0:{}", config.server.port)).await;
info!(
"Done! Start took {:?}",
crate::START_TIME.get().unwrap().elapsed()
);
// Spawn the ctrl-c task.
let r = running.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
info!("Ctrl-C received, shutting down");
r.cancel();
});
// The main server loop.
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
_ = server.update() => {}
}
}
match tokio::time::timeout(std::time::Duration::from_secs(10), server.shutdown()).await {
Ok(_) => std::process::exit(0),
Err(_) => std::process::exit(1),
}
}
#[tracing::instrument]
async fn new<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A,
) -> (Server, CancellationToken) {
trace!("Server::new()");
let running = CancellationToken::new();
let clients = Arc::new(RwLock::new(vec![]));
let net_tasks_handle = tokio::spawn(Self::create_network_tasks(
bind_address,
clients.clone(),
running.clone(),
));
let server = Server {
clients,
net_tasks_handle,
};
(server, running)
}
#[tracing::instrument] #[tracing::instrument]
async fn create_network_tasks<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>( async fn create_network_tasks<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A, bind_address: A,
@ -187,10 +128,38 @@ impl Server {
.await .await
.expect("Disconnection task crashed"); .expect("Disconnection task crashed");
} }
#[tracing::instrument] }
pub async fn update(&mut self) -> Result<()> { #[async_trait::async_trait]
trace!("Server.update()"); impl App for Server {
type Error = error::Error;
fn startup_message() -> String {
let config = Config::instance();
format!(
"Starting {} on port {}",
ServerConfig::default().version,
config.server.port
)
}
#[tracing::instrument]
async fn new(running: CancellationToken) -> Result<Self, Self::Error> {
let config = Config::instance();
let bind_address = format!("0.0.0.0:{}", config.server.port);
let clients = Arc::new(RwLock::new(vec![]));
let net_tasks_handle = tokio::spawn(Self::create_network_tasks(
bind_address,
clients.clone(),
running.child_token(),
));
Ok(Server {
clients,
net_tasks_handle,
})
}
#[tracing::instrument]
async fn update(&mut self) -> Result<(), Self::Error> {
let mut clients = self.clients.write().await; let mut clients = self.clients.write().await;
// Handle packets from the clients. // Handle packets from the clients.
@ -321,9 +290,7 @@ impl Server {
Ok(()) Ok(())
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn shutdown(self) { async fn shutdown(self) -> Result<(), Self::Error> {
trace!("Server.shutdown()");
// Close the concurrent tasks. // Close the concurrent tasks.
let _ = self.net_tasks_handle.await; let _ = self.net_tasks_handle.await;
@ -335,5 +302,7 @@ impl Server {
)) ))
.await; .await;
} }
Ok(())
} }
} }