Create App trait to manage application startup/shutdown

This commit is contained in:
Garen Tyler 2024-12-08 21:37:10 -07:00
parent b87c71737d
commit adf4f37536
Signed by: garentyler
SSH Key Fingerprint: SHA256:G4ke7blZMdpWPbkescyZ7IQYE4JAtwpI85YoJdq+S7U
8 changed files with 264 additions and 239 deletions

View File

@ -17,6 +17,8 @@ pub(crate) mod world;
use config::Subcommand;
use once_cell::sync::OnceCell;
use std::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::info;
pub const PROTOCOL_VERSION: i32 = 762;
pub const GAME_VERSION: &str = "1.19.4";
@ -26,12 +28,51 @@ pub const GAME_VERSION: &str = "1.19.4";
/// This should be set immediately on startup.
pub static START_TIME: OnceCell<Instant> = OnceCell::new();
pub async fn run(command: Subcommand) {
pub async fn run(command: Subcommand, running: CancellationToken) {
match command {
#[cfg(feature = "server")]
Subcommand::Server => server::Server::run().await,
Subcommand::Server => server::Server::run(running).await,
#[cfg(feature = "proxy")]
Subcommand::Proxy => proxy::Proxy::run().await,
Subcommand::Proxy => proxy::Proxy::run(running).await,
Subcommand::None => unreachable!(),
}
}
#[async_trait::async_trait]
pub(crate) trait App: Sized {
type Error: std::fmt::Debug;
fn startup_message() -> String;
async fn new(running: CancellationToken) -> Result<Self, Self::Error>;
async fn update(&mut self) -> Result<(), Self::Error>;
async fn shutdown(self) -> Result<(), Self::Error>;
async fn run(running: CancellationToken) {
info!("{}", Self::startup_message());
let mut app = Self::new(running.clone()).await.expect("app to start");
info!(
"Done! Start took {:?}",
crate::START_TIME.get().unwrap().elapsed()
);
// The main loop.
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
r = app.update() => {
if r.is_err() {
break;
}
}
}
}
// Run shutdown tasks.
match tokio::time::timeout(std::time::Duration::from_secs(10), app.shutdown()).await {
Ok(_) => std::process::exit(0),
Err(_) => std::process::exit(1),
}
}
}

View File

@ -1,3 +1,4 @@
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use tracing_subscriber::prelude::*;
@ -65,7 +66,17 @@ pub fn main() {
}
.unwrap()
.block_on(async move {
let running = CancellationToken::new();
// Spawn the ctrl-c task.
let r = running.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
info!("Ctrl-C received, shutting down");
r.cancel();
});
let args = composition::config::Args::instance();
composition::run(args.subcommand).await;
composition::run(args.subcommand, running).await;
});
}

View File

@ -5,15 +5,140 @@ use crate::protocol::{
ClientState,
};
use futures::{stream::StreamExt, SinkExt};
use std::time::{Duration, Instant};
use tokio::{io::BufStream, net::TcpStream};
use std::{
collections::HashMap,
time::{Duration, Instant},
};
use tokio::{io::BufStream, net::TcpStream, sync::mpsc};
use tokio::{
net::{TcpListener, ToSocketAddrs},
task::JoinHandle,
};
use tokio_util::codec::{Decoder, Framed};
use tracing::trace;
use tokio_util::sync::CancellationToken;
use tracing::{error, trace};
#[derive(Debug)]
pub struct ConnectionManager {
clients: HashMap<u128, Connection>,
channel: (
mpsc::UnboundedSender<Connection>,
mpsc::UnboundedReceiver<Connection>,
),
}
impl ConnectionManager {
pub fn new() -> ConnectionManager {
ConnectionManager {
clients: HashMap::new(),
channel: mpsc::unbounded_channel(),
}
}
pub fn client(&self, id: u128) -> Option<&Connection> {
self.clients.get(&id)
}
pub fn client_mut(&mut self, id: u128) -> Option<&mut Connection> {
self.clients.get_mut(&id)
}
pub async fn spawn_listener<A>(
&self,
bind_address: A,
running: CancellationToken,
) -> Result<JoinHandle<()>, std::io::Error>
where
A: 'static + ToSocketAddrs + Send + std::fmt::Debug,
{
trace!("Starting listener task");
let fmt_addr = format!("{:?}", bind_address);
let listener = TcpListener::bind(bind_address)
.await
.inspect_err(|_| error!("Could not bind to {}.", fmt_addr))?;
let sender = self.channel.0.clone();
let join_handle = tokio::spawn(async move {
let mut client_id = 0u128;
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
result = listener.accept() => {
if let Ok((stream, _)) = result {
trace!("Listener task got connection (id {})", client_id);
let client = Connection::new_client(client_id, stream);
if sender.send(client).is_err() {
trace!("Client receiver disconnected");
break;
}
client_id += 1;
}
}
}
}
trace!("Listener task shutting down");
});
Ok(join_handle)
}
pub fn update(&mut self) -> Result<(), std::io::Error> {
use std::io::{Error, ErrorKind};
// Receive new clients from the sender.
loop {
match self.channel.1.try_recv() {
Ok(connection) => {
let id = connection.id();
self.clients.insert(id, connection);
}
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(Error::new(
ErrorKind::BrokenPipe,
"all senders disconnected",
))
}
Err(mpsc::error::TryRecvError::Empty) => break,
};
}
// Remove disconnected clients.
self.clients
.retain(|_id, c| c.client_state() != ClientState::Disconnected);
Ok(())
}
pub async fn disconnect(
&mut self,
id: u128,
reason: Option<Chat>,
) -> Option<Result<(), std::io::Error>> {
let client = self.clients.remove(&id)?;
Some(client.disconnect(reason).await)
}
pub async fn shutdown(mut self, reason: Option<Chat>) -> Result<(), std::io::Error> {
let reason = reason.unwrap_or(serde_json::json!({
"text": "You have been disconnected!"
}));
let disconnections = self
.clients
.drain()
.map(|(_, c)| c)
.map(|c| c.disconnect(Some(reason.clone())))
.collect::<Vec<_>>();
// We don't actually care if the disconnections succeed,
// the connection is going to be dropped anyway.
let _disconnections: Vec<Result<(), std::io::Error>> =
futures::future::join_all(disconnections).await;
Ok(())
}
}
#[derive(Debug)]
pub struct Connection {
/// The `Connection`'s unique id.
pub id: u128,
id: u128,
stream: Framed<BufStream<TcpStream>, PacketCodec>,
last_received_data_time: Instant,
last_sent_data_time: Instant,
@ -35,6 +160,9 @@ impl Connection {
pub fn new_server(id: u128, stream: TcpStream) -> Self {
Self::new(id, PacketDirection::Clientbound, stream)
}
pub fn id(&self) -> u128 {
self.id
}
pub fn client_state(&self) -> ClientState {
self.stream.codec().client_state
}

View File

@ -1,101 +0,0 @@
use super::connection::Connection;
use crate::protocol::types::Chat;
use std::{
collections::HashMap,
sync::{Arc, Weak},
};
use tokio::{
net::{TcpListener, ToSocketAddrs},
sync::RwLock,
};
use tokio_util::sync::CancellationToken;
use tracing::{error, trace};
pub type Callback = dyn Fn(u128, Arc<RwLock<Connection>>) + Send;
#[derive(Clone, Debug)]
pub struct NetworkListener {
running: CancellationToken,
clients: Arc<RwLock<HashMap<u128, Arc<RwLock<Connection>>>>>,
}
impl NetworkListener {
pub async fn new<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A,
running: CancellationToken,
callback: Option<Box<Callback>>,
) -> Result<NetworkListener, std::io::Error> {
let listener = TcpListener::bind(bind_address)
.await
.inspect_err(|_| error!("Could not bind to given address."))?;
let clients = Arc::new(RwLock::new(HashMap::new()));
let r = running.clone();
let c = clients.clone();
tokio::spawn(async move {
trace!("Starting listener task");
let mut client_id = 0u128;
loop {
tokio::select! {
_ = r.cancelled() => {
break;
}
result = listener.accept() => {
if let Ok((stream, _)) = result {
trace!("Listener task got connection (id {})", client_id);
let client = Arc::new(RwLock::new(Connection::new_client(client_id, stream)));
c.write().await.insert(client_id, client.clone());
if let Some(ref callback) = callback {
callback(client_id, client);
}
client_id += 1;
}
}
}
}
});
Ok(NetworkListener { running, clients })
}
pub async fn get_client(&self, id: u128) -> Option<Weak<RwLock<Connection>>> {
self.clients.read().await.get(&id).map(Arc::downgrade)
}
pub async fn disconnect_client(
&self,
id: u128,
reason: Option<Chat>,
) -> Result<Result<(), std::io::Error>, ()> {
// Remove the client from the hashmap.
let client = self.clients.write().await.remove(&id).ok_or(())?;
let client: Connection = Arc::into_inner(client)
.expect("only one reference")
.into_inner();
// let mut client = client.write().await;
// Send a disconnect packet.
Ok(client.disconnect(reason).await)
}
pub async fn shutdown(self, reason: Option<Chat>) -> Result<(), std::io::Error> {
self.running.cancel();
let reason = reason.unwrap_or(serde_json::json!({
"text": "You have been disconnected!"
}));
let disconnections = self
.clients
.write()
.await
.drain()
.map(|(_, c)| c)
.map(|c| Arc::into_inner(c).expect("only one reference").into_inner())
.map(|c| c.disconnect(Some(reason.clone())))
.collect::<Vec<_>>();
// We don't actually care if the disconnections succeed,
// the connection is going to be dropped anyway.
let _disconnections: Vec<Result<(), std::io::Error>> =
futures::future::join_all(disconnections).await;
Ok(())
}
}

View File

@ -2,4 +2,3 @@
pub mod codec;
pub mod connection;
pub mod listener;

View File

@ -1,83 +1,64 @@
pub mod config;
use crate::{config::Config, net::listener::NetworkListener};
use crate::App;
use crate::{config::Config, net::connection::ConnectionManager};
use config::ProxyConfig;
use tokio::net::ToSocketAddrs;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{info, trace};
use tracing::info;
#[derive(Debug)]
pub struct Proxy {
_network_listener: NetworkListener,
running: CancellationToken,
connections: ConnectionManager,
listener: JoinHandle<()>,
}
impl Proxy {
/// Start the proxy.
#[tracing::instrument]
pub async fn run() {
#[async_trait::async_trait]
impl App for Proxy {
type Error = ();
fn startup_message() -> String {
let config = Config::instance();
info!(
format!(
"Starting {} on port {}",
ProxyConfig::default().version,
config.proxy.port
);
let (mut proxy, running) = Self::new(format!("0.0.0.0:{}", config.proxy.port)).await;
info!(
"Done! Start took {:?}",
crate::START_TIME.get().unwrap().elapsed()
);
)
}
#[tracing::instrument]
async fn new(running: CancellationToken) -> Result<Self, Self::Error> {
let config = Config::instance();
let bind_address = format!("0.0.0.0:{}", config.proxy.port);
let connections = ConnectionManager::new();
let listener = connections
.spawn_listener(bind_address, running.child_token())
.await
.map_err(|_| ())?;
info!(
"Upstream server: {}:{}",
config.proxy.upstream_host, config.proxy.upstream_port
);
// Spawn the ctrl-c task.
let r = running.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
info!("Ctrl-C received, shutting down");
r.cancel();
});
// The main loop.
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
_ = proxy.update() => {}
}
}
match tokio::time::timeout(std::time::Duration::from_secs(10), proxy.shutdown()).await {
Ok(_) => std::process::exit(0),
Err(_) => std::process::exit(1),
}
Ok(Proxy {
running,
connections,
listener,
})
}
#[tracing::instrument]
async fn new<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A,
) -> (Proxy, CancellationToken) {
trace!("Proxy::new()");
let running = CancellationToken::new();
let network_listener = NetworkListener::new(bind_address, running.child_token(), None)
.await
.expect("listener to bind properly");
let proxy = Proxy {
_network_listener: network_listener,
};
(proxy, running)
async fn update(&mut self) -> Result<(), Self::Error> {
todo!()
}
#[tracing::instrument]
async fn update(&mut self) -> Result<(), ()> {
// TODO
async fn shutdown(self) -> Result<(), Self::Error> {
// Ensure any child tasks have been shut down.
self.running.cancel();
let _ = self.listener.await.map_err(|_| ())?;
let _ = self.connections.shutdown(None).await.map_err(|_| ())?;
Ok(())
}
#[tracing::instrument]
async fn shutdown(self) {
trace!("Proxy.shutdown()");
// TODO
}
}

View File

@ -5,6 +5,3 @@ pub enum Error {
#[error("the server is not running")]
NotRunning,
}
/// Alias for a Result with the error type `composition_core::server::Error`.
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -7,14 +7,14 @@ pub mod net;
use crate::config::Config;
use crate::protocol::ClientState;
use crate::App;
use config::ServerConfig;
use error::Result;
use net::{NetworkClient, NetworkClientState};
use std::sync::Arc;
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::{sync::RwLock, task::JoinHandle};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace};
use tracing::{error, trace};
/// The main state and logic of the program.
#[derive(Debug)]
@ -23,65 +23,6 @@ pub struct Server {
net_tasks_handle: JoinHandle<()>,
}
impl Server {
/// Start the server.
#[tracing::instrument]
pub async fn run() {
let config = Config::instance();
info!(
"Starting {} on port {}",
ServerConfig::default().version,
config.server.port
);
let (mut server, running) = Self::new(format!("0.0.0.0:{}", config.server.port)).await;
info!(
"Done! Start took {:?}",
crate::START_TIME.get().unwrap().elapsed()
);
// Spawn the ctrl-c task.
let r = running.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
info!("Ctrl-C received, shutting down");
r.cancel();
});
// The main server loop.
loop {
tokio::select! {
_ = running.cancelled() => {
break;
}
_ = server.update() => {}
}
}
match tokio::time::timeout(std::time::Duration::from_secs(10), server.shutdown()).await {
Ok(_) => std::process::exit(0),
Err(_) => std::process::exit(1),
}
}
#[tracing::instrument]
async fn new<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A,
) -> (Server, CancellationToken) {
trace!("Server::new()");
let running = CancellationToken::new();
let clients = Arc::new(RwLock::new(vec![]));
let net_tasks_handle = tokio::spawn(Self::create_network_tasks(
bind_address,
clients.clone(),
running.clone(),
));
let server = Server {
clients,
net_tasks_handle,
};
(server, running)
}
#[tracing::instrument]
async fn create_network_tasks<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A,
@ -187,10 +128,38 @@ impl Server {
.await
.expect("Disconnection task crashed");
}
#[tracing::instrument]
pub async fn update(&mut self) -> Result<()> {
trace!("Server.update()");
}
#[async_trait::async_trait]
impl App for Server {
type Error = error::Error;
fn startup_message() -> String {
let config = Config::instance();
format!(
"Starting {} on port {}",
ServerConfig::default().version,
config.server.port
)
}
#[tracing::instrument]
async fn new(running: CancellationToken) -> Result<Self, Self::Error> {
let config = Config::instance();
let bind_address = format!("0.0.0.0:{}", config.server.port);
let clients = Arc::new(RwLock::new(vec![]));
let net_tasks_handle = tokio::spawn(Self::create_network_tasks(
bind_address,
clients.clone(),
running.child_token(),
));
Ok(Server {
clients,
net_tasks_handle,
})
}
#[tracing::instrument]
async fn update(&mut self) -> Result<(), Self::Error> {
let mut clients = self.clients.write().await;
// Handle packets from the clients.
@ -321,9 +290,7 @@ impl Server {
Ok(())
}
#[tracing::instrument]
pub async fn shutdown(self) {
trace!("Server.shutdown()");
async fn shutdown(self) -> Result<(), Self::Error> {
// Close the concurrent tasks.
let _ = self.net_tasks_handle.await;
@ -335,5 +302,7 @@ impl Server {
))
.await;
}
Ok(())
}
}