From 4cc58fbf81cbe2bb85dbfb12f62f6e67e4313561 Mon Sep 17 00:00:00 2001 From: Garen Tyler Date: Fri, 6 Jun 2025 14:42:41 -0600 Subject: [PATCH] Update server to use new streams --- src/net/connection/downstream/manager.rs | 13 +- src/net/connection/downstream/mod.rs | 38 ++- src/net/connection/mod.rs | 13 +- src/proxy/mod.rs | 2 +- src/server/error.rs | 19 +- src/server/mod.rs | 318 ++++------------------- src/server/net.rs | 246 ------------------ 7 files changed, 133 insertions(+), 516 deletions(-) delete mode 100644 src/server/net.rs diff --git a/src/net/connection/downstream/manager.rs b/src/net/connection/downstream/manager.rs index 2ea55ee..c418725 100644 --- a/src/net/connection/downstream/manager.rs +++ b/src/net/connection/downstream/manager.rs @@ -1,5 +1,8 @@ use crate::{ - net::{connection::DownstreamConnection, error::Error}, + net::{ + connection::{DownstreamConnection, DownstreamConnectionState}, + error::Error, + }, protocol::{types::Chat, ClientState}, }; use std::{collections::HashMap, time::Duration}; @@ -83,6 +86,8 @@ impl DownstreamConnectionManager { Ok(join_handle) } + /// Receive new connections and remove disconnected clients. + /// Reading packets from clients is handled elsewhere. pub async fn update(&mut self) -> Result<(), Error> { // Receive new clients from the sender. loop { @@ -124,8 +129,10 @@ impl DownstreamConnectionManager { // Remove disconnected clients. let before = self.clients.len(); - self.clients - .retain(|_id, c| c.client_state() != ClientState::Disconnected); + self.clients.retain(|_id, c| { + c.client_state() != DownstreamConnectionState::Disconnected + && c.inner_state() != ClientState::Disconnected + }); let after = self.clients.len(); if before - after > 0 { trace!("Removed {} disconnected clients", before - after); diff --git a/src/net/connection/downstream/mod.rs b/src/net/connection/downstream/mod.rs index ddb4249..78b6e02 100644 --- a/src/net/connection/downstream/mod.rs +++ b/src/net/connection/downstream/mod.rs @@ -40,6 +40,42 @@ impl DownstreamConnection { state: DownstreamConnectionState::Handshake, } } + pub fn client_state(&self) -> DownstreamConnectionState { + self.state + } + pub fn client_state_mut(&mut self) -> &mut DownstreamConnectionState { + &mut self.state + } + pub fn inner_state(&self) -> ClientState { + self.inner.client_state() + } + pub fn inner_state_mut(&mut self) -> &mut ClientState { + self.inner.client_state_mut() + } + pub async fn handle_handshake(&mut self) -> Result<(), Error> { + use packets::handshake::serverbound::Handshake; + + let handshake = self + .read_specific_packet::() + .await + .ok_or(Error::Unexpected)??; + + match handshake.next_state { + ClientState::Status => { + *self.client_state_mut() = DownstreamConnectionState::StatusRequest; + *self.inner_state_mut() = ClientState::Status; + } + ClientState::Login => todo!(), + _ => { + self.disconnect(Some( + serde_json::json!({ "text": "Received invalid handshake." }), + )) + .await?; + } + } + + Ok(()) + } pub async fn handle_status_ping(&mut self, online_player_count: usize) -> Result<(), Error> { // The state just changed from Handshake to Status. use base64::Engine; @@ -102,7 +138,7 @@ impl DownstreamConnection { // })); if let Some(reason) = reason { - match self.client_state() { + match self.inner_state() { ClientState::Disconnected | ClientState::Handshake | ClientState::Status => { // Impossible to send a disconnect in these states. } diff --git a/src/net/connection/mod.rs b/src/net/connection/mod.rs index ecc7ce1..8aec84b 100644 --- a/src/net/connection/mod.rs +++ b/src/net/connection/mod.rs @@ -3,7 +3,9 @@ mod downstream; /// Connections where we're the client. mod upstream; -pub use downstream::{manager::DownstreamConnectionManager, DownstreamConnection}; +pub use downstream::{ + manager::DownstreamConnectionManager, DownstreamConnection, DownstreamConnectionState, +}; pub use upstream::UpstreamConnection; use crate::{ @@ -76,6 +78,15 @@ impl GenericConnection { packet } + pub async fn read_specific_packet>(&mut self) -> Option> { + self.read_packet() + .await + .map(|packet| match packet.map(P::try_from) { + Ok(Ok(p)) => Ok(p), + Ok(Err(_)) => Err(Error::Unexpected), + Err(e) => Err(e), + }) + } pub async fn send_packet>(&mut self, packet: P) -> Result<(), Error> { let packet: Packet = packet.into(); trace!("Sending packet to connection {}: {:?}", self.id, packet); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index f9fa3a1..2dafdc9 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -141,7 +141,7 @@ impl App for Proxy { client.send_packet(packet).await.map_err(Error::Network)?; } if let Some(next_state) = next_state { - *client.client_state_mut() = next_state; + *client.inner_state_mut() = next_state; } } Err(e) => { diff --git a/src/server/error.rs b/src/server/error.rs index c2fd3da..07610ff 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -1,7 +1,20 @@ +pub use crate::net::error::Error as NetworkError; +pub use std::io::Error as IoError; +pub use tokio::task::JoinError as TaskError; + /// This type represents all possible errors that can occur when running the server. #[allow(dead_code)] -#[derive(thiserror::Error, Clone, Debug, PartialEq)] +#[derive(thiserror::Error, Debug)] pub enum Error { - #[error("the server is not running")] - NotRunning, + #[error(transparent)] + Io(IoError), + #[error(transparent)] + Task(TaskError), + #[error(transparent)] + Network(NetworkError), +} +impl From for Error { + fn from(err: NetworkError) -> Self { + Error::Network(err) + } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 93fb3ae..d5c3b7a 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -2,137 +2,26 @@ pub mod config; /// When managing the server encounters errors. pub mod error; -/// Network operations. -pub mod net; -use crate::config::Config; -use crate::protocol::types::Uuid; -use crate::protocol::ClientState; -use crate::App; -use config::ServerConfig; -use net::{NetworkClient, NetworkClientState}; -use std::sync::Arc; -use tokio::net::{TcpListener, ToSocketAddrs}; -use tokio::{sync::RwLock, task::JoinHandle}; +use crate::{ + config::Config, + net::connection::{DownstreamConnectionManager, DownstreamConnectionState}, + server::{config::ServerConfig, error::Error}, + App, +}; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; -use tracing::{error, trace}; /// The main state and logic of the program. #[derive(Debug)] pub struct Server { - clients: Arc>>, - net_tasks_handle: JoinHandle<()>, -} -impl Server { - #[tracing::instrument] - async fn create_network_tasks( - bind_address: A, - network_clients: Arc>>, - running: CancellationToken, - ) { - // Start a task to receive new clients. - trace!("Creating listener task"); - let nc = network_clients.clone(); - let r = running.clone(); - let listener_task = tokio::spawn(async move { - trace!("Listener task created"); - let Ok(listener) = TcpListener::bind(bind_address).await else { - error!("Could not bind to given address, shutting down."); - std::process::exit(1); - }; - - let mut client_id = 0u128; - loop { - tokio::select! { - _ = r.cancelled() => { - trace!("Listener task received shutdown"); - break; - } - result = listener.accept() => { - if let Ok((stream, _)) = result { - trace!("Listener task got client (id {})", client_id); - nc.write().await.push(NetworkClient::new(client_id, stream)); - client_id += 1; - } else { - trace!("Listener task failed to accept client"); - } - } - } - } - }); - - // Start a task to update existing clients' packet queues. - trace!("Creating network task"); - let nc = network_clients.clone(); - let r = running.clone(); - let packet_task = tokio::spawn(async move { - trace!("Network task created"); - loop { - // Start tasks to read/write to clients concurrently. - tokio::select! { - _ = r.cancelled() => { - trace!("Network task received shutdown"); - break; - } - mut nc = nc.write() => { - trace!("Network task updating clients"); - let tasks: Vec> = nc - .drain(..) - .map(|mut client: NetworkClient| { - tokio::spawn(async move { - let _ = client.read_packets().await; - if client.send_queued_packets().await.is_err() { - client - .disconnect(Some(serde_json::json!({ "text": "Error writing packets." }))) - .await; - } - client - }) - }) - .collect(); - *nc = Vec::with_capacity(tasks.len()); - for task in tasks { - nc.push(task.await.unwrap()); - } - trace!("Network task updated clients"); - } - } - } - }); - - // Start a task to remove disconnected clients. - trace!("Creating disconnection task"); - let nc = network_clients.clone(); - let r = running.clone(); - let disconnection_task = tokio::spawn(async move { - trace!("Disconnection task created"); - loop { - tokio::select! { - _ = r.cancelled() => { - trace!("Disconnection task received shutdown"); - break; - } - mut nc = nc.write() => { - let before = nc.len(); - nc.retain(|client| client.state != NetworkClientState::Disconnected); - let after = nc.len(); - trace!("Disconnection task removed {} clients", before - after); - } - } - } - }); - - // Join the tasks on shutdown. - listener_task.await.expect("Listener task crashed"); - packet_task.await.expect("Packet task crashed"); - disconnection_task - .await - .expect("Disconnection task crashed"); - } + running: CancellationToken, + connections: DownstreamConnectionManager, + listener: JoinHandle<()>, } #[async_trait::async_trait] impl App for Server { - type Error = error::Error; + type Error = Error; fn startup_message() -> String { let config = Config::instance(); @@ -147,162 +36,69 @@ impl App for Server { let config = Config::instance(); let bind_address = format!("0.0.0.0:{}", config.server.port); - let clients = Arc::new(RwLock::new(vec![])); - let net_tasks_handle = tokio::spawn(Self::create_network_tasks( - bind_address, - clients.clone(), - running.child_token(), - )); + // No limit on connections. + let connections = DownstreamConnectionManager::new(None); + let listener = connections + .spawn_listener(bind_address, running.child_token()) + .await + .map_err(Error::Network)?; Ok(Server { - clients, - net_tasks_handle, + running, + connections, + listener, }) } #[tracing::instrument] async fn update(&mut self) -> Result<(), Self::Error> { - let mut clients = self.clients.write().await; - - // Handle packets from the clients. - let online_players = clients - .iter() - .filter(|client| matches!(client.state, NetworkClientState::Play)) + let online_player_count = self + .connections + .clients() + .filter(|c| matches!(c.client_state(), DownstreamConnectionState::Play)) .count(); - 'clients: for client in clients.iter_mut() { - use crate::protocol::packets; - 'packets: while !client.incoming_packet_queue.is_empty() { - // client.read_packet() - // None: The client doesn't have any more packets. - // Some(Err(_)): The client read an unexpected packet. TODO: Handle this error. - // Some(Ok(_)): The client read the expected packet. - match client.state.clone() { - NetworkClientState::Handshake => { - use packets::handshake::serverbound::Handshake; + // Receive new connections and remove disconnected ones. + self.connections.update().await?; - let handshake = match client.read_packet::() { - None => continue 'packets, - Some(Err(_)) => continue 'clients, - Some(Ok(handshake)) => handshake, - }; + // Read packets from each connection. + // Handle handshake connections. + let _ = futures::future::join_all( + self.connections + .clients_mut() + .filter(|c| matches!(c.client_state(), DownstreamConnectionState::Handshake)) + .map(|c| c.handle_handshake()), + ) + .await; - if handshake.next_state == ClientState::Status { - client.state = NetworkClientState::Status { - received_request: false, - received_ping: false, - }; - } else if handshake.next_state == ClientState::Login { - client.state = NetworkClientState::Login { - received_start: (false, None), - }; - } else { - client - .disconnect(Some( - serde_json::json!({ "text": "Received invalid SH00Handshake packet" }), - )) - .await; - } - } - // Status !received_request: Read SS00StatusRequest and respond with CS00StatusResponse - NetworkClientState::Status { - received_request, - received_ping, - } if !received_request => { - use packets::status::{ - clientbound::StatusResponse, serverbound::StatusRequest, - }; + // Handle status connections. + let _ = futures::future::join_all( + self.connections + .clients_mut() + .filter(|c| matches!(c.client_state(), DownstreamConnectionState::StatusRequest)) + .map(|c| c.handle_status_ping(online_player_count)), + ) + .await; - let _status_request = match client.read_packet::() { - None => continue 'packets, - Some(Err(_)) => continue 'clients, - Some(Ok(p)) => p, - }; - client.state = NetworkClientState::Status { - received_request: true, - received_ping, - }; - let config = Config::instance(); - use base64::Engine; - client.queue_packet(StatusResponse { - response: serde_json::json!({ - "version": { - "name": config.global.game_version, - "protocol": config.global.protocol_version - }, - "players": { - "max": config.server.max_players, - "online": online_players, - "sample": [] - }, - "description": { - "text": config.server.motd - }, - "favicon": format!("data:image/png;base64,{}", base64::engine::general_purpose::STANDARD_NO_PAD.encode(&config.server.server_icon_bytes)), - "enforcesSecureChat": true - }), - }); - } - // Status !received_ping: Read SS00StatusRequest and respond with CS00StatusResponse - NetworkClientState::Status { received_ping, .. } if !received_ping => { - use packets::status::{ - clientbound::PingResponse, serverbound::PingRequest, - }; - - let ping = match client.read_packet::() { - None => continue 'packets, - Some(Err(_)) => continue 'clients, - Some(Ok(p)) => p, - }; - client.queue_packet(PingResponse { - payload: ping.payload, - }); - client.state = NetworkClientState::Disconnected; - } - NetworkClientState::Status { .. } => unreachable!(), - NetworkClientState::Login { received_start, .. } if !received_start.0 => { - use packets::login::{clientbound::*, serverbound::*}; - - let login_start = match client.read_packet::() { - None => continue 'packets, - Some(Err(_)) => continue 'clients, - Some(Ok(p)) => p, - }; - // TODO: Authenticate the user. - // TODO: Get the user from the stored database. - // TODO: Encryption/compression. - client.queue_packet(LoginSuccess { - uuid: login_start.uuid.unwrap_or(Uuid::nil()), - username: login_start.name.clone(), - properties: vec![], - }); - client.state = NetworkClientState::Login { - received_start: (true, Some(login_start)), - }; - } - NetworkClientState::Login { .. } => unreachable!(), - NetworkClientState::Play => unimplemented!(), - NetworkClientState::Disconnected => unimplemented!(), - } - // If continue was not - break 'packets; - } - } + // Handle login connections. + // Handle play connection packets. + // Process world updates. + // Send out play connection updates. Ok(()) } #[tracing::instrument] async fn shutdown(self) -> Result<(), Self::Error> { - // Close the concurrent tasks. - let _ = self.net_tasks_handle.await; + // Ensure any child tasks have been shut down. + self.running.cancel(); - // Send disconnect messages to the clients. - for client in self.clients.write().await.iter_mut() { - client - .disconnect(Some( - serde_json::json!({ "text": "The server is shutting down." }), - )) - .await; - } + let _ = self.listener.await.map_err(Error::Task)?; + let _ = self + .connections + .shutdown(Some( + serde_json::json!({ "text": "The server is shutting down." }), + )) + .await + .map_err(Error::Network)?; Ok(()) } diff --git a/src/server/net.rs b/src/server/net.rs deleted file mode 100644 index b4b68f0..0000000 --- a/src/server/net.rs +++ /dev/null @@ -1,246 +0,0 @@ -use crate::protocol::{ - packets::{self, Packet, PacketDirection}, - parsing::Parsable, - ClientState, -}; -use std::{collections::VecDeque, sync::Arc, time::Instant}; -use tokio::io::AsyncWriteExt; -use tokio::{net::TcpStream, sync::RwLock}; -use tracing::{debug, trace, warn}; - -/// Similar to `composition_protocol::ClientState`, -/// but contains more useful data for managing the client's state. -#[derive(Clone, PartialEq, Debug)] -pub(crate) enum NetworkClientState { - /// A client has established a connection with the server. - /// - /// See `composition_protocol::ClientState::Handshake` for more details. - Handshake, - /// The client sent `SH00Handshake` with `next_state = ClientState::Status` - /// and is performing [server list ping](https://wiki.vg/Server_List_Ping). - Status { - /// When the server receives `SS00StatusRequest`, this is set - /// to `true` and the server should send `CS00StatusResponse`. - received_request: bool, - /// When the server receives `SS01PingRequest`, this is set - /// to `true` and the server should send `CS01PingResponse` - /// and set the connection state to `Disconnected`. - received_ping: bool, - }, - /// The client sent `SH00Handshake` with `next_state = ClientState::Login` - /// and is attempting to join the server. - Login { - received_start: (bool, Option), - }, - /// The server sent `CL02LoginSuccess` and transitioned to `Play`. - #[allow(dead_code)] - Play, - /// The client has disconnected. - /// - /// No packets should be sent or received, - /// and the `NetworkClient` should be queued for removal. - Disconnected, -} -impl From for ClientState { - fn from(value: NetworkClientState) -> Self { - match value { - NetworkClientState::Handshake => ClientState::Handshake, - NetworkClientState::Status { .. } => ClientState::Status, - NetworkClientState::Login { .. } => ClientState::Login, - NetworkClientState::Play => ClientState::Play, - NetworkClientState::Disconnected => ClientState::Disconnected, - } - } -} -impl AsRef for NetworkClientState { - fn as_ref(&self) -> &ClientState { - match self { - NetworkClientState::Handshake => &ClientState::Handshake, - NetworkClientState::Status { .. } => &ClientState::Status, - NetworkClientState::Login { .. } => &ClientState::Login, - NetworkClientState::Play => &ClientState::Play, - NetworkClientState::Disconnected => &ClientState::Disconnected, - } - } -} - -/// A wrapper around the raw `TcpStream` that abstracts away reading/writing packets and bytes. -#[derive(Debug, Clone)] -pub(crate) struct NetworkClient { - /// The `NetworkClient`'s unique id. - pub id: u128, - pub state: NetworkClientState, - stream: Arc>, - /// Data gets appended to the back as it gets read, - /// and popped from the front as it gets parsed into packets. - incoming_data: VecDeque, - /// Packets get appended to the back as they get read, - /// and popped from the front as they get handled. - pub incoming_packet_queue: VecDeque, - /// Keeps track of the last time the client sent data. - /// - /// This is useful for removing clients that have timed out. - pub last_received_data_time: Instant, - /// Packets get appended to the back and get popped from the front as they get sent. - pub outgoing_packet_queue: VecDeque, -} -impl NetworkClient { - #[tracing::instrument] - pub fn new(id: u128, stream: TcpStream) -> NetworkClient { - NetworkClient { - id, - state: NetworkClientState::Handshake, - stream: Arc::new(RwLock::new(stream)), - incoming_data: VecDeque::new(), - incoming_packet_queue: VecDeque::new(), - last_received_data_time: Instant::now(), - outgoing_packet_queue: VecDeque::new(), - } - } - #[tracing::instrument] - async fn read_data(&mut self) -> tokio::io::Result<()> { - trace!("NetworkClient.read_data() id {}", self.id); - let stream = self.stream.read().await; - - // Try to read 4kb at a time until there is no more data. - loop { - let mut buf = [0; 4096]; - - let num_bytes = match stream.try_read(&mut buf) { - Ok(0) => break, - Ok(n) => n, - Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => { - break; - } - Err(e) => return Err(e), - }; - - debug!("Read {} bytes from client {}", num_bytes, self.id); - - self.last_received_data_time = Instant::now(); - self.incoming_data.extend(&buf[..num_bytes]); - } - - trace!("NetworkClient.read_data() end id {}", self.id); - Ok(()) - } - // TODO: Stream compression/encryption. - #[tracing::instrument] - pub async fn read_packets(&mut self) -> crate::protocol::Result<()> { - trace!("NetworkClient.read_packet() id {}", self.id); - - if self.read_data().await.is_err() { - self.disconnect(None).await; - return Err(crate::protocol::Error::Disconnected); - } - - self.incoming_data.make_contiguous(); - let (mut data, &[..]) = self.incoming_data.as_slices(); - - let mut bytes_consumed = 0; - while !data.is_empty() { - let p = Packet::parse( - self.state.clone().into(), - PacketDirection::Serverbound, - data, - ); - trace!("{} got {:?}", self.id, p); - match p { - Ok((d, packet)) => { - debug!("Got packet {:?} from client {}", packet, self.id); - bytes_consumed += data.len() - d.len(); - data = d; - self.incoming_packet_queue.push_back(packet); - } - Err(nom::Err::Incomplete(_)) => break, - Err(_) => { - // Remove the valid bytes before this packet. - self.incoming_data = self.incoming_data.split_off(bytes_consumed); - return Err(crate::protocol::Error::Parsing); - } - } - } - - // Remove the bytes we just read. - self.incoming_data = self.incoming_data.split_off(bytes_consumed); - - Ok(()) - } - // None: There was no packet to read. - // Some(Err(())): The packet was the wrong type. - // Some(Ok(_)): The packet was successfully read. - #[tracing::instrument] - pub fn read_packet>( - &mut self, - ) -> Option> { - if let Some(generic_packet) = self.incoming_packet_queue.pop_back() { - if let Ok(packet) = TryInto::

::try_into(generic_packet.clone()) { - Some(Ok(packet)) - } else { - self.incoming_packet_queue.push_back(generic_packet.clone()); - Some(Err(generic_packet)) - } - } else { - None - } - } - #[tracing::instrument] - pub fn queue_packet>(&mut self, packet: P) { - self.outgoing_packet_queue.push_back(packet.into()); - } - #[tracing::instrument] - pub async fn send_queued_packets(&mut self) -> crate::protocol::Result<()> { - let packets: Vec<_> = self.outgoing_packet_queue.drain(..).collect(); - for packet in packets { - self.send_packet(packet) - .await - .map_err(|_| crate::protocol::Error::Disconnected)?; - } - Ok(()) - } - #[tracing::instrument] - pub async fn send_packet>( - &self, - packet: P, - ) -> tokio::io::Result<()> { - let packet: Packet = packet.into(); - - debug!("Sending packet {:?} to client {}", packet, self.id); - let (packet_id, mut packet_body) = packet.serialize(); - let mut packet_id = packet_id.serialize(); - - // TODO: Stream compression/encryption. - - let mut b = vec![]; - b.append(&mut packet_id); - b.append(&mut packet_body); - - // bytes: packet length as varint, packet id as varint, packet body - let bytes = Parsable::serialize(&b); - - self.stream.write().await.write_all(&bytes).await?; - Ok(()) - } - #[tracing::instrument] - pub async fn disconnect(&mut self, reason: Option) { - use packets::{login::clientbound::LoginDisconnect, play::clientbound::PlayDisconnect}; - - let reason = reason.unwrap_or(serde_json::json!({ - "text": "You have been disconnected!" - })); - - match self.state.as_ref() { - ClientState::Disconnected | ClientState::Handshake | ClientState::Status => { - // Impossible to send a disconnect in these states. - } - ClientState::Login => { - let _ = self.send_packet(LoginDisconnect { reason }).await; - } - ClientState::Play => { - let _ = self.send_packet(PlayDisconnect { reason }).await; - } - } - - self.state = NetworkClientState::Disconnected; - } -}