diff --git a/Cargo.lock b/Cargo.lock index 0468a05..b2c23af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -124,9 +124,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.22" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69371e34337c4c984bbe322360c2547210bf632eb2814bbe78a6e87a2935bd2b" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" dependencies = [ "clap_builder", "clap_derive", @@ -134,9 +134,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.22" +version = "4.5.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e24c1b4099818523236a8ca881d2b45db98dadfb4625cf6608c12069fcbbde1" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" dependencies = [ "anstream", "anstyle", @@ -158,9 +158,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "colorchoice" @@ -176,11 +176,12 @@ dependencies = [ "base64", "clap", "const_format", + "futures", "nom", "once_cell", "serde", "serde_json", - "thiserror 2.0.4", + "thiserror 2.0.5", "tokio", "tokio-util", "toml", @@ -191,18 +192,18 @@ dependencies = [ [[package]] name = "const_format" -version = "0.2.33" +version = "0.2.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c655d81ff1114fb0dcdea9225ea9f0cc712a6f8d189378e82bdf62a473a64b" +checksum = "126f97965c8ad46d6d9163268ff28432e8f6a1196a55578867832e3049df63dd" dependencies = [ "const_format_proc_macros", ] [[package]] name = "const_format_proc_macros" -version = "0.2.33" +version = "0.2.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eff1a44b93f47b1bac19a27932f5c591e43d1ba357ee4f61526c8a25603f0eb1" +checksum = "1d57c2eccfb16dbac1f4e61e206105db5820c9d26c3c472bc17c774259ef7744" dependencies = [ "proc-macro2", "quote", @@ -239,18 +240,95 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "futures-core" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.31.1" @@ -427,6 +505,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "powerfmt" version = "0.2.0" @@ -537,6 +621,15 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.13.2" @@ -581,11 +674,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f49a1853cf82743e3b7950f77e0f4d622ca36cf4317cba00c767838bac8d490" +checksum = "643caef17e3128658ff44d85923ef2d28af81bb71e0d67bbfe1d76f19a73e053" dependencies = [ - "thiserror-impl 2.0.4", + "thiserror-impl 2.0.5", ] [[package]] @@ -601,9 +694,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8381894bb3efe0c4acac3ded651301ceee58a15d47c2e34885ed1908ad667061" +checksum = "995d0bbc9995d1f19d28b7215a9352b0fc3cd3a2d2ec95c2cadc485cdedbcdde" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index c6cdab8..42d7893 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,16 +23,17 @@ update_1_20 = [] [dependencies] async-trait = { version = "0.1.68", optional = true } base64 = { version = "0.22.1", optional = true } -clap = { version = "4.5.22", features = ["derive"] } +clap = { version = "4.5.23", features = ["derive"] } +const_format = "0.2.34" +futures = "0.3.31" +nom = "7.1.3" once_cell = "1.17.1" serde = { version = "1.0.160", features = ["serde_derive"] } serde_json = "1.0.96" -thiserror = "2.0.4" +thiserror = "2.0.5" tokio = { version = "1.42.0", features = ["full"] } -tokio-util = { version = "0.7.13", optional = true } +tokio-util = { version = "0.7.13", features = ["codec"], optional = true } toml = "0.8.19" tracing = { version = "0.1.37", features = ["log"] } tracing-subscriber = { version = "0.3.17", features = ["tracing-log"] } tracing-appender = "0.2.2" -nom = "7.1.3" -const_format = "0.2.33" diff --git a/src/lib.rs b/src/lib.rs index f32d383..ce7661b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ /// Server configuration and cli options. pub mod config; +/// Network operations. +pub(crate) mod net; /// The Minecraft protocol implemented in a network-agnostic way. pub mod protocol; /// A proxy server. diff --git a/src/main.rs b/src/main.rs index af92fbd..070b0a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ pub fn main() { .clone() .unwrap_or(PathBuf::from(DEFAULT_LOG_DIR)); let log_path = Path::new(&log_path); - let file_writer = tracing_appender::rolling::daily(&log_path, "log"); + let file_writer = tracing_appender::rolling::daily(log_path, "log"); let (file_writer, _guard) = tracing_appender::non_blocking(file_writer); tracing_subscriber::registry() diff --git a/src/net/codec.rs b/src/net/codec.rs new file mode 100644 index 0000000..060d5d1 --- /dev/null +++ b/src/net/codec.rs @@ -0,0 +1,95 @@ +use crate::protocol::{ + packets::{Packet, PacketDirection}, + parsing::Parsable, + types::VarInt, + ClientState, +}; +use std::io::{Error, ErrorKind}; +use tokio_util::{ + bytes::{Buf, BytesMut}, + codec::{Decoder, Encoder}, +}; + +#[derive(Clone, Copy, Debug)] +pub struct PacketCodec { + pub client_state: ClientState, + pub packet_direction: PacketDirection, +} +impl PacketCodec { + pub fn new(client_state: ClientState, packet_direction: PacketDirection) -> PacketCodec { + PacketCodec { + client_state, + packet_direction, + } + } +} +impl Default for PacketCodec { + fn default() -> Self { + PacketCodec { + client_state: ClientState::Handshake, + packet_direction: PacketDirection::Serverbound, + } + } +} +impl Decoder for PacketCodec { + type Item = Packet; + type Error = std::io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Packet::parse(self.client_state, self.packet_direction, src) { + Ok((rest, packet)) => { + let bytes_consumed = src.len() - rest.len(); + src.advance(bytes_consumed); + + if let Some(next_state) = packet.state_change() { + self.client_state = next_state; + } + + Ok(Some(packet)) + } + Err(nom::Err::Incomplete(_)) => { + // Try to read the packet length. + match VarInt::parse_usize(src) { + Ok((_, packet_length)) => { + src.reserve(packet_length + 64); + Ok(None) + } + Err(nom::Err::Incomplete(_)) => { + src.reserve(5); + Ok(None) + } + Err(_) => Err(Error::new(ErrorKind::InvalidData, "Nom parsing error")), + } + } + Err(nom::Err::Error(_)) | Err(nom::Err::Failure(_)) => { + Err(Error::new(ErrorKind::InvalidData, "Nom parsing error")) + } + } + } +} +impl Encoder for PacketCodec { + type Error = std::io::Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + let mut out = vec![]; + let (packet_id, packet_body) = item.serialize(); + out.extend(packet_id.serialize().to_vec()); + out.extend(packet_body); + let packet_len = VarInt::from(out.len()); + dst.extend(packet_len.serialize()); + dst.extend(out); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + #[test] + fn packet_decoder_works() { + unimplemented!() + } + #[test] + fn packet_encoder_works() { + unimplemented!() + } +} diff --git a/src/net/connection.rs b/src/net/connection.rs new file mode 100644 index 0000000..5b978a5 --- /dev/null +++ b/src/net/connection.rs @@ -0,0 +1,79 @@ +use super::codec::PacketCodec; +use crate::protocol::{ + packets::{self, Packet, PacketDirection}, + types::Chat, + ClientState, +}; +use futures::{stream::StreamExt, SinkExt}; +use std::time::{Duration, Instant}; +use tokio::{io::BufStream, net::TcpStream}; +use tokio_util::codec::{Decoder, Framed}; +use tracing::trace; + +#[derive(Debug)] +pub struct Connection { + /// The `Connection`'s unique id. + pub id: u128, + stream: Framed, PacketCodec>, + last_received_data_time: Instant, + last_sent_data_time: Instant, +} +impl Connection { + fn new(id: u128, receiving_direction: PacketDirection, stream: TcpStream) -> Self { + let codec = PacketCodec::new(ClientState::Handshake, receiving_direction); + + Connection { + id, + stream: codec.framed(BufStream::new(stream)), + last_received_data_time: Instant::now(), + last_sent_data_time: Instant::now(), + } + } + pub fn new_client(id: u128, stream: TcpStream) -> Self { + Self::new(id, PacketDirection::Serverbound, stream) + } + pub fn new_server(id: u128, stream: TcpStream) -> Self { + Self::new(id, PacketDirection::Clientbound, stream) + } + pub fn client_state(&self) -> ClientState { + self.stream.codec().client_state + } + pub fn received_elapsed(&self) -> Duration { + self.last_received_data_time.elapsed() + } + pub fn sent_elapsed(&self) -> Duration { + self.last_sent_data_time.elapsed() + } + pub async fn read_packet(&mut self) -> Option> { + self.last_received_data_time = Instant::now(); + self.stream.next().await + } + pub async fn send_packet>(&mut self, packet: P) -> Result<(), std::io::Error> { + let packet: Packet = packet.into(); + self.stream.send(packet).await + } + pub async fn disconnect(mut self, reason: Option) -> Result<(), std::io::Error> { + trace!("Connection disconnected (id {})", self.id); + use packets::{login::clientbound::LoginDisconnect, play::clientbound::PlayDisconnect}; + + let reason = reason.unwrap_or(serde_json::json!({ + "text": "You have been disconnected!" + })); + + match self.client_state() { + ClientState::Disconnected | ClientState::Handshake | ClientState::Status => { + // Impossible to send a disconnect in these states. + } + ClientState::Login => { + let _ = self.send_packet(LoginDisconnect { reason }).await; + } + ClientState::Play => { + let _ = self.send_packet(PlayDisconnect { reason }).await; + } + } + + self.stream.flush().await?; + self.stream.codec_mut().client_state = ClientState::Disconnected; + Ok(()) + } +} diff --git a/src/net/listener.rs b/src/net/listener.rs new file mode 100644 index 0000000..04ff000 --- /dev/null +++ b/src/net/listener.rs @@ -0,0 +1,101 @@ +use super::connection::Connection; +use crate::protocol::types::Chat; +use std::{ + collections::HashMap, + sync::{Arc, Weak}, +}; +use tokio::{ + net::{TcpListener, ToSocketAddrs}, + sync::RwLock, +}; +use tokio_util::sync::CancellationToken; +use tracing::{error, trace}; + +pub type Callback = dyn Fn(u128, Arc>) + Send; + +#[derive(Clone, Debug)] +pub struct NetworkListener { + running: CancellationToken, + clients: Arc>>>>, +} +impl NetworkListener { + pub async fn new( + bind_address: A, + running: CancellationToken, + callback: Option>, + ) -> Result { + let listener = TcpListener::bind(bind_address) + .await + .inspect_err(|_| error!("Could not bind to given address."))?; + let clients = Arc::new(RwLock::new(HashMap::new())); + + let r = running.clone(); + let c = clients.clone(); + tokio::spawn(async move { + trace!("Starting listener task"); + let mut client_id = 0u128; + + loop { + tokio::select! { + _ = r.cancelled() => { + break; + } + result = listener.accept() => { + if let Ok((stream, _)) = result { + trace!("Listener task got connection (id {})", client_id); + let client = Arc::new(RwLock::new(Connection::new_client(client_id, stream))); + c.write().await.insert(client_id, client.clone()); + if let Some(ref callback) = callback { + callback(client_id, client); + } + client_id += 1; + } + } + } + } + }); + + Ok(NetworkListener { running, clients }) + } + pub async fn get_client(&self, id: u128) -> Option>> { + self.clients.read().await.get(&id).map(Arc::downgrade) + } + pub async fn disconnect_client( + &self, + id: u128, + reason: Option, + ) -> Result, ()> { + // Remove the client from the hashmap. + let client = self.clients.write().await.remove(&id).ok_or(())?; + let client: Connection = Arc::into_inner(client) + .expect("only one reference") + .into_inner(); + // let mut client = client.write().await; + // Send a disconnect packet. + Ok(client.disconnect(reason).await) + } + pub async fn shutdown(self, reason: Option) -> Result<(), std::io::Error> { + self.running.cancel(); + + let reason = reason.unwrap_or(serde_json::json!({ + "text": "You have been disconnected!" + })); + + let disconnections = self + .clients + .write() + .await + .drain() + .map(|(_, c)| c) + .map(|c| Arc::into_inner(c).expect("only one reference").into_inner()) + .map(|c| c.disconnect(Some(reason.clone()))) + .collect::>(); + + // We don't actually care if the disconnections succeed, + // the connection is going to be dropped anyway. + let _disconnections: Vec> = + futures::future::join_all(disconnections).await; + + Ok(()) + } +} diff --git a/src/net/mod.rs b/src/net/mod.rs new file mode 100644 index 0000000..ff2a2a7 --- /dev/null +++ b/src/net/mod.rs @@ -0,0 +1,5 @@ +#![allow(dead_code)] + +pub mod codec; +pub mod connection; +pub mod listener; diff --git a/src/protocol/packets.rs b/src/protocol/packets.rs index 520ecb9..71fa0a5 100644 --- a/src/protocol/packets.rs +++ b/src/protocol/packets.rs @@ -104,6 +104,15 @@ macro_rules! packets { )*)*)* } } + pub fn state_change(&self) -> Option { + match self { + Packet::Handshake(handshake) => Some(handshake.next_state), + Packet::LoginSuccess(_) => Some(ClientState::Play), + Packet::LoginDisconnect(_) => Some(ClientState::Disconnected), + Packet::PlayDisconnect(_) => Some(ClientState::Disconnected), + _ => None, + } + } } $(pub mod $state { diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index cda73e7..6a51a53 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,13 +1,15 @@ pub mod config; -use crate::config::Config; +use crate::{config::Config, net::listener::NetworkListener}; use config::ProxyConfig; use tokio::net::ToSocketAddrs; use tokio_util::sync::CancellationToken; use tracing::{info, trace}; #[derive(Debug)] -pub struct Proxy {} +pub struct Proxy { + _network_listener: NetworkListener, +} impl Proxy { /// Start the proxy. #[tracing::instrument] @@ -53,12 +55,18 @@ impl Proxy { } #[tracing::instrument] async fn new( - _bind_address: A, + bind_address: A, ) -> (Proxy, CancellationToken) { trace!("Proxy::new()"); - let running = CancellationToken::new(); - let proxy = Proxy {}; + + let network_listener = NetworkListener::new(bind_address, running.child_token(), None) + .await + .expect("listener to bind properly"); + + let proxy = Proxy { + _network_listener: network_listener, + }; (proxy, running) } diff --git a/src/server/config.rs b/src/server/config.rs index 83edb9f..6f8d0ad 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -39,8 +39,7 @@ impl ServerConfig { } pub fn load_args(&mut self) { self.server_icon = ServerArgs::instance() - .map(|s| s.server_icon.clone()) - .flatten() + .and_then(|s| s.server_icon.clone()) .unwrap_or(PathBuf::from(DEFAULT_SERVER_ICON)); self.load_icon(); } @@ -83,6 +82,7 @@ impl ServerConfig { pub struct ServerArgs { pub server_icon: Option, } +#[allow(clippy::derivable_impls)] impl Default for ServerArgs { fn default() -> Self { ServerArgs { server_icon: None } @@ -104,6 +104,7 @@ impl ServerArgs { .default_value(DEFAULT_SERVER_ICON), ) } + #[allow(clippy::field_reassign_with_default)] pub fn parse(m: clap::ArgMatches) -> Self { let mut server_args = ServerArgs::default(); server_args.server_icon = m.get_one("server-icon").cloned();