Update server to use new streams

This commit is contained in:
Garen Tyler 2025-06-06 14:42:41 -06:00
parent deff480665
commit 4cc58fbf81
Signed by: garentyler
SSH Key Fingerprint: SHA256:G4ke7blZMdpWPbkescyZ7IQYE4JAtwpI85YoJdq+S7U
7 changed files with 133 additions and 516 deletions

View File

@ -1,5 +1,8 @@
use crate::{ use crate::{
net::{connection::DownstreamConnection, error::Error}, net::{
connection::{DownstreamConnection, DownstreamConnectionState},
error::Error,
},
protocol::{types::Chat, ClientState}, protocol::{types::Chat, ClientState},
}; };
use std::{collections::HashMap, time::Duration}; use std::{collections::HashMap, time::Duration};
@ -83,6 +86,8 @@ impl DownstreamConnectionManager {
Ok(join_handle) Ok(join_handle)
} }
/// Receive new connections and remove disconnected clients.
/// Reading packets from clients is handled elsewhere.
pub async fn update(&mut self) -> Result<(), Error> { pub async fn update(&mut self) -> Result<(), Error> {
// Receive new clients from the sender. // Receive new clients from the sender.
loop { loop {
@ -124,8 +129,10 @@ impl DownstreamConnectionManager {
// Remove disconnected clients. // Remove disconnected clients.
let before = self.clients.len(); let before = self.clients.len();
self.clients self.clients.retain(|_id, c| {
.retain(|_id, c| c.client_state() != ClientState::Disconnected); c.client_state() != DownstreamConnectionState::Disconnected
&& c.inner_state() != ClientState::Disconnected
});
let after = self.clients.len(); let after = self.clients.len();
if before - after > 0 { if before - after > 0 {
trace!("Removed {} disconnected clients", before - after); trace!("Removed {} disconnected clients", before - after);

View File

@ -40,6 +40,42 @@ impl DownstreamConnection {
state: DownstreamConnectionState::Handshake, state: DownstreamConnectionState::Handshake,
} }
} }
pub fn client_state(&self) -> DownstreamConnectionState {
self.state
}
pub fn client_state_mut(&mut self) -> &mut DownstreamConnectionState {
&mut self.state
}
pub fn inner_state(&self) -> ClientState {
self.inner.client_state()
}
pub fn inner_state_mut(&mut self) -> &mut ClientState {
self.inner.client_state_mut()
}
pub async fn handle_handshake(&mut self) -> Result<(), Error> {
use packets::handshake::serverbound::Handshake;
let handshake = self
.read_specific_packet::<Handshake>()
.await
.ok_or(Error::Unexpected)??;
match handshake.next_state {
ClientState::Status => {
*self.client_state_mut() = DownstreamConnectionState::StatusRequest;
*self.inner_state_mut() = ClientState::Status;
}
ClientState::Login => todo!(),
_ => {
self.disconnect(Some(
serde_json::json!({ "text": "Received invalid handshake." }),
))
.await?;
}
}
Ok(())
}
pub async fn handle_status_ping(&mut self, online_player_count: usize) -> Result<(), Error> { pub async fn handle_status_ping(&mut self, online_player_count: usize) -> Result<(), Error> {
// The state just changed from Handshake to Status. // The state just changed from Handshake to Status.
use base64::Engine; use base64::Engine;
@ -102,7 +138,7 @@ impl DownstreamConnection {
// })); // }));
if let Some(reason) = reason { if let Some(reason) = reason {
match self.client_state() { match self.inner_state() {
ClientState::Disconnected | ClientState::Handshake | ClientState::Status => { ClientState::Disconnected | ClientState::Handshake | ClientState::Status => {
// Impossible to send a disconnect in these states. // Impossible to send a disconnect in these states.
} }

View File

@ -3,7 +3,9 @@ mod downstream;
/// Connections where we're the client. /// Connections where we're the client.
mod upstream; mod upstream;
pub use downstream::{manager::DownstreamConnectionManager, DownstreamConnection}; pub use downstream::{
manager::DownstreamConnectionManager, DownstreamConnection, DownstreamConnectionState,
};
pub use upstream::UpstreamConnection; pub use upstream::UpstreamConnection;
use crate::{ use crate::{
@ -76,6 +78,15 @@ impl GenericConnection {
packet packet
} }
pub async fn read_specific_packet<P: TryFrom<Packet>>(&mut self) -> Option<Result<P, Error>> {
self.read_packet()
.await
.map(|packet| match packet.map(P::try_from) {
Ok(Ok(p)) => Ok(p),
Ok(Err(_)) => Err(Error::Unexpected),
Err(e) => Err(e),
})
}
pub async fn send_packet<P: Into<Packet>>(&mut self, packet: P) -> Result<(), Error> { pub async fn send_packet<P: Into<Packet>>(&mut self, packet: P) -> Result<(), Error> {
let packet: Packet = packet.into(); let packet: Packet = packet.into();
trace!("Sending packet to connection {}: {:?}", self.id, packet); trace!("Sending packet to connection {}: {:?}", self.id, packet);

View File

@ -141,7 +141,7 @@ impl App for Proxy {
client.send_packet(packet).await.map_err(Error::Network)?; client.send_packet(packet).await.map_err(Error::Network)?;
} }
if let Some(next_state) = next_state { if let Some(next_state) = next_state {
*client.client_state_mut() = next_state; *client.inner_state_mut() = next_state;
} }
} }
Err(e) => { Err(e) => {

View File

@ -1,7 +1,20 @@
pub use crate::net::error::Error as NetworkError;
pub use std::io::Error as IoError;
pub use tokio::task::JoinError as TaskError;
/// This type represents all possible errors that can occur when running the server. /// This type represents all possible errors that can occur when running the server.
#[allow(dead_code)] #[allow(dead_code)]
#[derive(thiserror::Error, Clone, Debug, PartialEq)] #[derive(thiserror::Error, Debug)]
pub enum Error { pub enum Error {
#[error("the server is not running")] #[error(transparent)]
NotRunning, Io(IoError),
#[error(transparent)]
Task(TaskError),
#[error(transparent)]
Network(NetworkError),
}
impl From<NetworkError> for Error {
fn from(err: NetworkError) -> Self {
Error::Network(err)
}
} }

View File

@ -2,137 +2,26 @@
pub mod config; pub mod config;
/// When managing the server encounters errors. /// When managing the server encounters errors.
pub mod error; pub mod error;
/// Network operations.
pub mod net;
use crate::config::Config; use crate::{
use crate::protocol::types::Uuid; config::Config,
use crate::protocol::ClientState; net::connection::{DownstreamConnectionManager, DownstreamConnectionState},
use crate::App; server::{config::ServerConfig, error::Error},
use config::ServerConfig; App,
use net::{NetworkClient, NetworkClientState}; };
use std::sync::Arc; use tokio::task::JoinHandle;
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::{sync::RwLock, task::JoinHandle};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{error, trace};
/// The main state and logic of the program. /// The main state and logic of the program.
#[derive(Debug)] #[derive(Debug)]
pub struct Server { pub struct Server {
clients: Arc<RwLock<Vec<NetworkClient>>>,
net_tasks_handle: JoinHandle<()>,
}
impl Server {
#[tracing::instrument]
async fn create_network_tasks<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
bind_address: A,
network_clients: Arc<RwLock<Vec<NetworkClient>>>,
running: CancellationToken, running: CancellationToken,
) { connections: DownstreamConnectionManager,
// Start a task to receive new clients. listener: JoinHandle<()>,
trace!("Creating listener task");
let nc = network_clients.clone();
let r = running.clone();
let listener_task = tokio::spawn(async move {
trace!("Listener task created");
let Ok(listener) = TcpListener::bind(bind_address).await else {
error!("Could not bind to given address, shutting down.");
std::process::exit(1);
};
let mut client_id = 0u128;
loop {
tokio::select! {
_ = r.cancelled() => {
trace!("Listener task received shutdown");
break;
}
result = listener.accept() => {
if let Ok((stream, _)) = result {
trace!("Listener task got client (id {})", client_id);
nc.write().await.push(NetworkClient::new(client_id, stream));
client_id += 1;
} else {
trace!("Listener task failed to accept client");
}
}
}
}
});
// Start a task to update existing clients' packet queues.
trace!("Creating network task");
let nc = network_clients.clone();
let r = running.clone();
let packet_task = tokio::spawn(async move {
trace!("Network task created");
loop {
// Start tasks to read/write to clients concurrently.
tokio::select! {
_ = r.cancelled() => {
trace!("Network task received shutdown");
break;
}
mut nc = nc.write() => {
trace!("Network task updating clients");
let tasks: Vec<JoinHandle<NetworkClient>> = nc
.drain(..)
.map(|mut client: NetworkClient| {
tokio::spawn(async move {
let _ = client.read_packets().await;
if client.send_queued_packets().await.is_err() {
client
.disconnect(Some(serde_json::json!({ "text": "Error writing packets." })))
.await;
}
client
})
})
.collect();
*nc = Vec::with_capacity(tasks.len());
for task in tasks {
nc.push(task.await.unwrap());
}
trace!("Network task updated clients");
}
}
}
});
// Start a task to remove disconnected clients.
trace!("Creating disconnection task");
let nc = network_clients.clone();
let r = running.clone();
let disconnection_task = tokio::spawn(async move {
trace!("Disconnection task created");
loop {
tokio::select! {
_ = r.cancelled() => {
trace!("Disconnection task received shutdown");
break;
}
mut nc = nc.write() => {
let before = nc.len();
nc.retain(|client| client.state != NetworkClientState::Disconnected);
let after = nc.len();
trace!("Disconnection task removed {} clients", before - after);
}
}
}
});
// Join the tasks on shutdown.
listener_task.await.expect("Listener task crashed");
packet_task.await.expect("Packet task crashed");
disconnection_task
.await
.expect("Disconnection task crashed");
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl App for Server { impl App for Server {
type Error = error::Error; type Error = Error;
fn startup_message() -> String { fn startup_message() -> String {
let config = Config::instance(); let config = Config::instance();
@ -147,162 +36,69 @@ impl App for Server {
let config = Config::instance(); let config = Config::instance();
let bind_address = format!("0.0.0.0:{}", config.server.port); let bind_address = format!("0.0.0.0:{}", config.server.port);
let clients = Arc::new(RwLock::new(vec![])); // No limit on connections.
let net_tasks_handle = tokio::spawn(Self::create_network_tasks( let connections = DownstreamConnectionManager::new(None);
bind_address, let listener = connections
clients.clone(), .spawn_listener(bind_address, running.child_token())
running.child_token(), .await
)); .map_err(Error::Network)?;
Ok(Server { Ok(Server {
clients, running,
net_tasks_handle, connections,
listener,
}) })
} }
#[tracing::instrument] #[tracing::instrument]
async fn update(&mut self) -> Result<(), Self::Error> { async fn update(&mut self) -> Result<(), Self::Error> {
let mut clients = self.clients.write().await; let online_player_count = self
.connections
// Handle packets from the clients. .clients()
let online_players = clients .filter(|c| matches!(c.client_state(), DownstreamConnectionState::Play))
.iter()
.filter(|client| matches!(client.state, NetworkClientState::Play))
.count(); .count();
'clients: for client in clients.iter_mut() {
use crate::protocol::packets;
'packets: while !client.incoming_packet_queue.is_empty() { // Receive new connections and remove disconnected ones.
// client.read_packet() self.connections.update().await?;
// None: The client doesn't have any more packets.
// Some(Err(_)): The client read an unexpected packet. TODO: Handle this error.
// Some(Ok(_)): The client read the expected packet.
match client.state.clone() {
NetworkClientState::Handshake => {
use packets::handshake::serverbound::Handshake;
let handshake = match client.read_packet::<Handshake>() { // Read packets from each connection.
None => continue 'packets, // Handle handshake connections.
Some(Err(_)) => continue 'clients, let _ = futures::future::join_all(
Some(Ok(handshake)) => handshake, self.connections
}; .clients_mut()
.filter(|c| matches!(c.client_state(), DownstreamConnectionState::Handshake))
if handshake.next_state == ClientState::Status { .map(|c| c.handle_handshake()),
client.state = NetworkClientState::Status { )
received_request: false,
received_ping: false,
};
} else if handshake.next_state == ClientState::Login {
client.state = NetworkClientState::Login {
received_start: (false, None),
};
} else {
client
.disconnect(Some(
serde_json::json!({ "text": "Received invalid SH00Handshake packet" }),
))
.await; .await;
}
}
// Status !received_request: Read SS00StatusRequest and respond with CS00StatusResponse
NetworkClientState::Status {
received_request,
received_ping,
} if !received_request => {
use packets::status::{
clientbound::StatusResponse, serverbound::StatusRequest,
};
let _status_request = match client.read_packet::<StatusRequest>() { // Handle status connections.
None => continue 'packets, let _ = futures::future::join_all(
Some(Err(_)) => continue 'clients, self.connections
Some(Ok(p)) => p, .clients_mut()
}; .filter(|c| matches!(c.client_state(), DownstreamConnectionState::StatusRequest))
client.state = NetworkClientState::Status { .map(|c| c.handle_status_ping(online_player_count)),
received_request: true, )
received_ping, .await;
};
let config = Config::instance();
use base64::Engine;
client.queue_packet(StatusResponse {
response: serde_json::json!({
"version": {
"name": config.global.game_version,
"protocol": config.global.protocol_version
},
"players": {
"max": config.server.max_players,
"online": online_players,
"sample": []
},
"description": {
"text": config.server.motd
},
"favicon": format!("data:image/png;base64,{}", base64::engine::general_purpose::STANDARD_NO_PAD.encode(&config.server.server_icon_bytes)),
"enforcesSecureChat": true
}),
});
}
// Status !received_ping: Read SS00StatusRequest and respond with CS00StatusResponse
NetworkClientState::Status { received_ping, .. } if !received_ping => {
use packets::status::{
clientbound::PingResponse, serverbound::PingRequest,
};
let ping = match client.read_packet::<PingRequest>() { // Handle login connections.
None => continue 'packets, // Handle play connection packets.
Some(Err(_)) => continue 'clients, // Process world updates.
Some(Ok(p)) => p, // Send out play connection updates.
};
client.queue_packet(PingResponse {
payload: ping.payload,
});
client.state = NetworkClientState::Disconnected;
}
NetworkClientState::Status { .. } => unreachable!(),
NetworkClientState::Login { received_start, .. } if !received_start.0 => {
use packets::login::{clientbound::*, serverbound::*};
let login_start = match client.read_packet::<LoginStart>() {
None => continue 'packets,
Some(Err(_)) => continue 'clients,
Some(Ok(p)) => p,
};
// TODO: Authenticate the user.
// TODO: Get the user from the stored database.
// TODO: Encryption/compression.
client.queue_packet(LoginSuccess {
uuid: login_start.uuid.unwrap_or(Uuid::nil()),
username: login_start.name.clone(),
properties: vec![],
});
client.state = NetworkClientState::Login {
received_start: (true, Some(login_start)),
};
}
NetworkClientState::Login { .. } => unreachable!(),
NetworkClientState::Play => unimplemented!(),
NetworkClientState::Disconnected => unimplemented!(),
}
// If continue was not
break 'packets;
}
}
Ok(()) Ok(())
} }
#[tracing::instrument] #[tracing::instrument]
async fn shutdown(self) -> Result<(), Self::Error> { async fn shutdown(self) -> Result<(), Self::Error> {
// Close the concurrent tasks. // Ensure any child tasks have been shut down.
let _ = self.net_tasks_handle.await; self.running.cancel();
// Send disconnect messages to the clients. let _ = self.listener.await.map_err(Error::Task)?;
for client in self.clients.write().await.iter_mut() { let _ = self
client .connections
.disconnect(Some( .shutdown(Some(
serde_json::json!({ "text": "The server is shutting down." }), serde_json::json!({ "text": "The server is shutting down." }),
)) ))
.await; .await
} .map_err(Error::Network)?;
Ok(()) Ok(())
} }

View File

@ -1,246 +0,0 @@
use crate::protocol::{
packets::{self, Packet, PacketDirection},
parsing::Parsable,
ClientState,
};
use std::{collections::VecDeque, sync::Arc, time::Instant};
use tokio::io::AsyncWriteExt;
use tokio::{net::TcpStream, sync::RwLock};
use tracing::{debug, trace, warn};
/// Similar to `composition_protocol::ClientState`,
/// but contains more useful data for managing the client's state.
#[derive(Clone, PartialEq, Debug)]
pub(crate) enum NetworkClientState {
/// A client has established a connection with the server.
///
/// See `composition_protocol::ClientState::Handshake` for more details.
Handshake,
/// The client sent `SH00Handshake` with `next_state = ClientState::Status`
/// and is performing [server list ping](https://wiki.vg/Server_List_Ping).
Status {
/// When the server receives `SS00StatusRequest`, this is set
/// to `true` and the server should send `CS00StatusResponse`.
received_request: bool,
/// When the server receives `SS01PingRequest`, this is set
/// to `true` and the server should send `CS01PingResponse`
/// and set the connection state to `Disconnected`.
received_ping: bool,
},
/// The client sent `SH00Handshake` with `next_state = ClientState::Login`
/// and is attempting to join the server.
Login {
received_start: (bool, Option<packets::login::serverbound::LoginStart>),
},
/// The server sent `CL02LoginSuccess` and transitioned to `Play`.
#[allow(dead_code)]
Play,
/// The client has disconnected.
///
/// No packets should be sent or received,
/// and the `NetworkClient` should be queued for removal.
Disconnected,
}
impl From<NetworkClientState> for ClientState {
fn from(value: NetworkClientState) -> Self {
match value {
NetworkClientState::Handshake => ClientState::Handshake,
NetworkClientState::Status { .. } => ClientState::Status,
NetworkClientState::Login { .. } => ClientState::Login,
NetworkClientState::Play => ClientState::Play,
NetworkClientState::Disconnected => ClientState::Disconnected,
}
}
}
impl AsRef<ClientState> for NetworkClientState {
fn as_ref(&self) -> &ClientState {
match self {
NetworkClientState::Handshake => &ClientState::Handshake,
NetworkClientState::Status { .. } => &ClientState::Status,
NetworkClientState::Login { .. } => &ClientState::Login,
NetworkClientState::Play => &ClientState::Play,
NetworkClientState::Disconnected => &ClientState::Disconnected,
}
}
}
/// A wrapper around the raw `TcpStream` that abstracts away reading/writing packets and bytes.
#[derive(Debug, Clone)]
pub(crate) struct NetworkClient {
/// The `NetworkClient`'s unique id.
pub id: u128,
pub state: NetworkClientState,
stream: Arc<RwLock<TcpStream>>,
/// Data gets appended to the back as it gets read,
/// and popped from the front as it gets parsed into packets.
incoming_data: VecDeque<u8>,
/// Packets get appended to the back as they get read,
/// and popped from the front as they get handled.
pub incoming_packet_queue: VecDeque<Packet>,
/// Keeps track of the last time the client sent data.
///
/// This is useful for removing clients that have timed out.
pub last_received_data_time: Instant,
/// Packets get appended to the back and get popped from the front as they get sent.
pub outgoing_packet_queue: VecDeque<Packet>,
}
impl NetworkClient {
#[tracing::instrument]
pub fn new(id: u128, stream: TcpStream) -> NetworkClient {
NetworkClient {
id,
state: NetworkClientState::Handshake,
stream: Arc::new(RwLock::new(stream)),
incoming_data: VecDeque::new(),
incoming_packet_queue: VecDeque::new(),
last_received_data_time: Instant::now(),
outgoing_packet_queue: VecDeque::new(),
}
}
#[tracing::instrument]
async fn read_data(&mut self) -> tokio::io::Result<()> {
trace!("NetworkClient.read_data() id {}", self.id);
let stream = self.stream.read().await;
// Try to read 4kb at a time until there is no more data.
loop {
let mut buf = [0; 4096];
let num_bytes = match stream.try_read(&mut buf) {
Ok(0) => break,
Ok(n) => n,
Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {
break;
}
Err(e) => return Err(e),
};
debug!("Read {} bytes from client {}", num_bytes, self.id);
self.last_received_data_time = Instant::now();
self.incoming_data.extend(&buf[..num_bytes]);
}
trace!("NetworkClient.read_data() end id {}", self.id);
Ok(())
}
// TODO: Stream compression/encryption.
#[tracing::instrument]
pub async fn read_packets(&mut self) -> crate::protocol::Result<()> {
trace!("NetworkClient.read_packet() id {}", self.id);
if self.read_data().await.is_err() {
self.disconnect(None).await;
return Err(crate::protocol::Error::Disconnected);
}
self.incoming_data.make_contiguous();
let (mut data, &[..]) = self.incoming_data.as_slices();
let mut bytes_consumed = 0;
while !data.is_empty() {
let p = Packet::parse(
self.state.clone().into(),
PacketDirection::Serverbound,
data,
);
trace!("{} got {:?}", self.id, p);
match p {
Ok((d, packet)) => {
debug!("Got packet {:?} from client {}", packet, self.id);
bytes_consumed += data.len() - d.len();
data = d;
self.incoming_packet_queue.push_back(packet);
}
Err(nom::Err::Incomplete(_)) => break,
Err(_) => {
// Remove the valid bytes before this packet.
self.incoming_data = self.incoming_data.split_off(bytes_consumed);
return Err(crate::protocol::Error::Parsing);
}
}
}
// Remove the bytes we just read.
self.incoming_data = self.incoming_data.split_off(bytes_consumed);
Ok(())
}
// None: There was no packet to read.
// Some(Err(())): The packet was the wrong type.
// Some(Ok(_)): The packet was successfully read.
#[tracing::instrument]
pub fn read_packet<P: std::fmt::Debug + TryFrom<Packet>>(
&mut self,
) -> Option<std::result::Result<P, Packet>> {
if let Some(generic_packet) = self.incoming_packet_queue.pop_back() {
if let Ok(packet) = TryInto::<P>::try_into(generic_packet.clone()) {
Some(Ok(packet))
} else {
self.incoming_packet_queue.push_back(generic_packet.clone());
Some(Err(generic_packet))
}
} else {
None
}
}
#[tracing::instrument]
pub fn queue_packet<P: std::fmt::Debug + Into<Packet>>(&mut self, packet: P) {
self.outgoing_packet_queue.push_back(packet.into());
}
#[tracing::instrument]
pub async fn send_queued_packets(&mut self) -> crate::protocol::Result<()> {
let packets: Vec<_> = self.outgoing_packet_queue.drain(..).collect();
for packet in packets {
self.send_packet(packet)
.await
.map_err(|_| crate::protocol::Error::Disconnected)?;
}
Ok(())
}
#[tracing::instrument]
pub async fn send_packet<P: std::fmt::Debug + Into<Packet>>(
&self,
packet: P,
) -> tokio::io::Result<()> {
let packet: Packet = packet.into();
debug!("Sending packet {:?} to client {}", packet, self.id);
let (packet_id, mut packet_body) = packet.serialize();
let mut packet_id = packet_id.serialize();
// TODO: Stream compression/encryption.
let mut b = vec![];
b.append(&mut packet_id);
b.append(&mut packet_body);
// bytes: packet length as varint, packet id as varint, packet body
let bytes = Parsable::serialize(&b);
self.stream.write().await.write_all(&bytes).await?;
Ok(())
}
#[tracing::instrument]
pub async fn disconnect(&mut self, reason: Option<crate::protocol::types::Chat>) {
use packets::{login::clientbound::LoginDisconnect, play::clientbound::PlayDisconnect};
let reason = reason.unwrap_or(serde_json::json!({
"text": "You have been disconnected!"
}));
match self.state.as_ref() {
ClientState::Disconnected | ClientState::Handshake | ClientState::Status => {
// Impossible to send a disconnect in these states.
}
ClientState::Login => {
let _ = self.send_packet(LoginDisconnect { reason }).await;
}
ClientState::Play => {
let _ = self.send_packet(PlayDisconnect { reason }).await;
}
}
self.state = NetworkClientState::Disconnected;
}
}