From f73fdae2c7f8d016cd00552c7272bdf54d00eaf4 Mon Sep 17 00:00:00 2001 From: Garen Tyler Date: Fri, 19 Mar 2021 19:23:10 -0600 Subject: [PATCH] Disconnect all clients on server shutdown, shutdown on ctrl-c --- Cargo.lock | 29 +++++++++++++++++++++++++++++ Cargo.toml | 1 + src/lib.rs | 10 +++++++++- src/main.rs | 11 ++++++++++- src/mctypes.rs | 31 ++++++++++++++++++------------- src/server/mod.rs | 23 ++++++++++++++++++----- src/server/packets/mod.rs | 7 ++++++- 7 files changed, 91 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2a8dfd1..a69b62d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,6 +46,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +[[package]] +name = "cc" +version = "1.0.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c69b077ad434294d3ce9f1f6143a2a4b89a8a2d54ef813d85003a4fd1137fd" + [[package]] name = "cfg-if" version = "0.1.10" @@ -88,6 +94,7 @@ version = "0.1.0" dependencies = [ "async-trait", "chrono", + "ctrlc", "fern", "lazy_static", "log", @@ -98,6 +105,16 @@ dependencies = [ "toml", ] +[[package]] +name = "ctrlc" +version = "3.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c15b8ec3b5755a188c141c1f6a98e76de31b936209bf066b647979e2a84764a9" +dependencies = [ + "nix", + "winapi", +] + [[package]] name = "fern" version = "0.6.0" @@ -191,6 +208,18 @@ dependencies = [ "winapi", ] +[[package]] +name = "nix" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa9b4819da1bc61c0ea48b63b7bc8604064dd43013e7cc325df098d49cd7c18a" +dependencies = [ + "bitflags", + "cc", + "cfg-if 1.0.0", + "libc", +] + [[package]] name = "ntapi" version = "0.3.6" diff --git a/Cargo.toml b/Cargo.toml index 2c289e1..679776d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ radix64 = "0.3.0" tokio = { version = "1", features = ["full"] } async-trait = "0.1.48" lazy_static = "1.4.0" +ctrlc = "3.1.8" # colorful = "0.2.1" # ozelot = "0.9.0" # Ozelot 0.9.0 supports protocol version 578 (1.15.2) # toml = "0.5.6" diff --git a/src/lib.rs b/src/lib.rs index 82b45e5..049c68f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub mod world; use log::warn; pub use mctypes::*; use serde::{Deserialize, Serialize}; +use std::sync::mpsc::{self, Receiver}; #[derive(Serialize, Deserialize)] pub struct Config { @@ -60,7 +61,7 @@ lazy_static! { } /// Set up logging, read the config file, etc. -pub fn init() { +pub fn init() -> Receiver<()> { // Set up fern logging. fern::Dispatch::new() .format(move |out, message, record| { @@ -81,6 +82,13 @@ pub fn init() { .chain(fern::log_file("output.log").unwrap()) .apply() .unwrap(); + // Set up the ctrl-c handler. + let (ctrlc_tx, ctrlc_rx) = mpsc::channel(); + ctrlc::set_handler(move || { + ctrlc_tx.send(()).expect("Ctrl-C receiver disconnected"); + }) + .expect("Error setting Ctrl-C handler"); + ctrlc_rx } /// Start the server. diff --git a/src/main.rs b/src/main.rs index ec7ef87..bbb31f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,25 @@ use log::info; +use std::sync::mpsc::TryRecvError; use std::time::{Duration, Instant}; #[tokio::main] pub async fn main() { let start_time = Instant::now(); - composition::init(); + let ctrlc_rx = composition::init(); info!("Starting server..."); let mut server = composition::start_server().await; info!("Done! Start took {:?}", start_time.elapsed()); // The main server loop. loop { + match ctrlc_rx.try_recv() { + Ok(_) => { + server.shutdown().await; + break; // Exit the loop. + } + Err(TryRecvError::Empty) => {} // Doesn't matter if there's nothing for us + Err(TryRecvError::Disconnected) => panic!("Ctrl-C sender disconnected"), + } server.update().await.unwrap(); std::thread::sleep(Duration::from_millis(2)); } diff --git a/src/mctypes.rs b/src/mctypes.rs index 954e8c8..83d7a7b 100644 --- a/src/mctypes.rs +++ b/src/mctypes.rs @@ -108,14 +108,12 @@ pub mod other { impl TryFrom> for MCBoolean { type Error = &'static str; fn try_from(bytes: Vec) -> Result { - if bytes.len() < 1 { + if bytes.is_empty() { Err("Not enough bytes") + } else if bytes[0] == 1u8 { + Ok(MCBoolean::True) } else { - if bytes[0] == 1u8 { - Ok(MCBoolean::True) - } else { - Ok(MCBoolean::False) - } + Ok(MCBoolean::False) } } } @@ -148,7 +146,7 @@ pub mod other { } impl From for MCString { fn from(s: String) -> MCString { - MCString { value: s.clone() } + MCString { value: s } } } impl Into for MCString { @@ -227,9 +225,7 @@ pub mod other { } impl From for MCChat { fn from(s: String) -> MCChat { - MCChat { - text: s.clone().into(), - } + MCChat { text: s.into() } } } impl Into for MCChat { @@ -345,6 +341,15 @@ pub mod other { Err(io_error("Cannot read MCPosition from stream")) } } + impl Default for MCPosition { + fn default() -> Self { + MCPosition { + x: 0.into(), + y: 0.into(), + z: 0.into(), + } + } + } } /// All the numbers, from `i8` and `u8` to `i64` and `u64`, plus `VarInt`s. @@ -394,7 +399,7 @@ pub mod numbers { impl TryFrom> for MCByte { type Error = &'static str; fn try_from(bytes: Vec) -> Result { - if bytes.len() < 1 { + if bytes.is_empty() { Err("Not enough bytes") } else { let mut a = [0u8; 1]; @@ -452,7 +457,7 @@ pub mod numbers { impl TryFrom> for MCUnsignedByte { type Error = &'static str; fn try_from(bytes: Vec) -> Result { - if bytes.len() < 1 { + if bytes.is_empty() { Err("Not enough bytes") } else { let mut a = [0u8; 1]; @@ -1065,7 +1070,7 @@ pub mod numbers { } out.push(temp); } - return out; + out } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index f60c0fc..e3efa96 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -42,6 +42,17 @@ impl Server { } } + /// Shut down the server. + /// + /// Disconnects all clients. + pub async fn shutdown(&mut self) { + info!("Server shutting down."); + for client in self.network_clients.iter_mut() { + let _ = client.disconnect(Some("The server is shutting down")).await; + // We don't care if it doesn't succeed in sending the packet. + } + } + /// Update the network server. /// /// Update each client in `self.network_clients`. @@ -67,7 +78,9 @@ impl Server { } }); for client in self.network_clients.iter_mut() { - client.update(num_players).await?; + if client.update(num_players).await.is_err() { + client.force_disconnect(); + } } // Remove disconnected clients. self.network_clients @@ -87,7 +100,7 @@ impl Server { /// The network client can only be in a few states, /// this enum keeps track of that. -#[derive(PartialEq)] +#[derive(PartialEq, Debug)] pub enum NetworkClientState { Handshake, Status, @@ -98,6 +111,7 @@ pub enum NetworkClientState { /// A wrapper to contain everything related /// to networking for the client. +#[derive(Debug)] pub struct NetworkClient { pub id: u128, pub connected: bool, @@ -126,6 +140,7 @@ impl NetworkClient { /// Updating could mean connecting new clients, reading packets, /// writing packets, or disconnecting clients. pub async fn update(&mut self, num_players: usize) -> tokio::io::Result<()> { + // println!("{:?}", self); match self.state { NetworkClientState::Handshake => { let handshake = self.get_packet::().await?; @@ -284,14 +299,12 @@ impl NetworkClient { let mut disconnect = Disconnect::new(); disconnect.reason.text = reason.unwrap_or("Disconnected").into(); self.send_packet(disconnect).await?; - // Give the client 10 seconds to disconnect before forcing it. - tokio::time::sleep(Duration::from_secs(10)).await; self.force_disconnect(); Ok(()) } /// Force disconnect the client by marking it for cleanup as disconnected. - async fn force_disconnect(&mut self) { + fn force_disconnect(&mut self) { self.connected = false; self.state = NetworkClientState::Disconnected; } diff --git a/src/server/packets/mod.rs b/src/server/packets/mod.rs index f104f78..1f00ec0 100644 --- a/src/server/packets/mod.rs +++ b/src/server/packets/mod.rs @@ -37,9 +37,14 @@ macro_rules! register_packets { } } } + impl Default for Packet { + fn default() -> Self { + Packet::Null + } + } $( impl $name { - pub fn into_packet(&self) -> Packet { + pub fn as_packet(&self) -> Packet { Packet::$name(self.clone()) } }