Update server to use new streams
This commit is contained in:
parent
deff480665
commit
4cc58fbf81
@ -1,5 +1,8 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
net::{connection::DownstreamConnection, error::Error},
|
net::{
|
||||||
|
connection::{DownstreamConnection, DownstreamConnectionState},
|
||||||
|
error::Error,
|
||||||
|
},
|
||||||
protocol::{types::Chat, ClientState},
|
protocol::{types::Chat, ClientState},
|
||||||
};
|
};
|
||||||
use std::{collections::HashMap, time::Duration};
|
use std::{collections::HashMap, time::Duration};
|
||||||
@ -83,6 +86,8 @@ impl DownstreamConnectionManager {
|
|||||||
|
|
||||||
Ok(join_handle)
|
Ok(join_handle)
|
||||||
}
|
}
|
||||||
|
/// Receive new connections and remove disconnected clients.
|
||||||
|
/// Reading packets from clients is handled elsewhere.
|
||||||
pub async fn update(&mut self) -> Result<(), Error> {
|
pub async fn update(&mut self) -> Result<(), Error> {
|
||||||
// Receive new clients from the sender.
|
// Receive new clients from the sender.
|
||||||
loop {
|
loop {
|
||||||
@ -124,8 +129,10 @@ impl DownstreamConnectionManager {
|
|||||||
|
|
||||||
// Remove disconnected clients.
|
// Remove disconnected clients.
|
||||||
let before = self.clients.len();
|
let before = self.clients.len();
|
||||||
self.clients
|
self.clients.retain(|_id, c| {
|
||||||
.retain(|_id, c| c.client_state() != ClientState::Disconnected);
|
c.client_state() != DownstreamConnectionState::Disconnected
|
||||||
|
&& c.inner_state() != ClientState::Disconnected
|
||||||
|
});
|
||||||
let after = self.clients.len();
|
let after = self.clients.len();
|
||||||
if before - after > 0 {
|
if before - after > 0 {
|
||||||
trace!("Removed {} disconnected clients", before - after);
|
trace!("Removed {} disconnected clients", before - after);
|
||||||
|
@ -40,6 +40,42 @@ impl DownstreamConnection {
|
|||||||
state: DownstreamConnectionState::Handshake,
|
state: DownstreamConnectionState::Handshake,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn client_state(&self) -> DownstreamConnectionState {
|
||||||
|
self.state
|
||||||
|
}
|
||||||
|
pub fn client_state_mut(&mut self) -> &mut DownstreamConnectionState {
|
||||||
|
&mut self.state
|
||||||
|
}
|
||||||
|
pub fn inner_state(&self) -> ClientState {
|
||||||
|
self.inner.client_state()
|
||||||
|
}
|
||||||
|
pub fn inner_state_mut(&mut self) -> &mut ClientState {
|
||||||
|
self.inner.client_state_mut()
|
||||||
|
}
|
||||||
|
pub async fn handle_handshake(&mut self) -> Result<(), Error> {
|
||||||
|
use packets::handshake::serverbound::Handshake;
|
||||||
|
|
||||||
|
let handshake = self
|
||||||
|
.read_specific_packet::<Handshake>()
|
||||||
|
.await
|
||||||
|
.ok_or(Error::Unexpected)??;
|
||||||
|
|
||||||
|
match handshake.next_state {
|
||||||
|
ClientState::Status => {
|
||||||
|
*self.client_state_mut() = DownstreamConnectionState::StatusRequest;
|
||||||
|
*self.inner_state_mut() = ClientState::Status;
|
||||||
|
}
|
||||||
|
ClientState::Login => todo!(),
|
||||||
|
_ => {
|
||||||
|
self.disconnect(Some(
|
||||||
|
serde_json::json!({ "text": "Received invalid handshake." }),
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
pub async fn handle_status_ping(&mut self, online_player_count: usize) -> Result<(), Error> {
|
pub async fn handle_status_ping(&mut self, online_player_count: usize) -> Result<(), Error> {
|
||||||
// The state just changed from Handshake to Status.
|
// The state just changed from Handshake to Status.
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
@ -102,7 +138,7 @@ impl DownstreamConnection {
|
|||||||
// }));
|
// }));
|
||||||
|
|
||||||
if let Some(reason) = reason {
|
if let Some(reason) = reason {
|
||||||
match self.client_state() {
|
match self.inner_state() {
|
||||||
ClientState::Disconnected | ClientState::Handshake | ClientState::Status => {
|
ClientState::Disconnected | ClientState::Handshake | ClientState::Status => {
|
||||||
// Impossible to send a disconnect in these states.
|
// Impossible to send a disconnect in these states.
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,9 @@ mod downstream;
|
|||||||
/// Connections where we're the client.
|
/// Connections where we're the client.
|
||||||
mod upstream;
|
mod upstream;
|
||||||
|
|
||||||
pub use downstream::{manager::DownstreamConnectionManager, DownstreamConnection};
|
pub use downstream::{
|
||||||
|
manager::DownstreamConnectionManager, DownstreamConnection, DownstreamConnectionState,
|
||||||
|
};
|
||||||
pub use upstream::UpstreamConnection;
|
pub use upstream::UpstreamConnection;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -76,6 +78,15 @@ impl GenericConnection {
|
|||||||
|
|
||||||
packet
|
packet
|
||||||
}
|
}
|
||||||
|
pub async fn read_specific_packet<P: TryFrom<Packet>>(&mut self) -> Option<Result<P, Error>> {
|
||||||
|
self.read_packet()
|
||||||
|
.await
|
||||||
|
.map(|packet| match packet.map(P::try_from) {
|
||||||
|
Ok(Ok(p)) => Ok(p),
|
||||||
|
Ok(Err(_)) => Err(Error::Unexpected),
|
||||||
|
Err(e) => Err(e),
|
||||||
|
})
|
||||||
|
}
|
||||||
pub async fn send_packet<P: Into<Packet>>(&mut self, packet: P) -> Result<(), Error> {
|
pub async fn send_packet<P: Into<Packet>>(&mut self, packet: P) -> Result<(), Error> {
|
||||||
let packet: Packet = packet.into();
|
let packet: Packet = packet.into();
|
||||||
trace!("Sending packet to connection {}: {:?}", self.id, packet);
|
trace!("Sending packet to connection {}: {:?}", self.id, packet);
|
||||||
|
@ -141,7 +141,7 @@ impl App for Proxy {
|
|||||||
client.send_packet(packet).await.map_err(Error::Network)?;
|
client.send_packet(packet).await.map_err(Error::Network)?;
|
||||||
}
|
}
|
||||||
if let Some(next_state) = next_state {
|
if let Some(next_state) = next_state {
|
||||||
*client.client_state_mut() = next_state;
|
*client.inner_state_mut() = next_state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -1,7 +1,20 @@
|
|||||||
|
pub use crate::net::error::Error as NetworkError;
|
||||||
|
pub use std::io::Error as IoError;
|
||||||
|
pub use tokio::task::JoinError as TaskError;
|
||||||
|
|
||||||
/// This type represents all possible errors that can occur when running the server.
|
/// This type represents all possible errors that can occur when running the server.
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(thiserror::Error, Clone, Debug, PartialEq)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
#[error("the server is not running")]
|
#[error(transparent)]
|
||||||
NotRunning,
|
Io(IoError),
|
||||||
|
#[error(transparent)]
|
||||||
|
Task(TaskError),
|
||||||
|
#[error(transparent)]
|
||||||
|
Network(NetworkError),
|
||||||
|
}
|
||||||
|
impl From<NetworkError> for Error {
|
||||||
|
fn from(err: NetworkError) -> Self {
|
||||||
|
Error::Network(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,137 +2,26 @@
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
/// When managing the server encounters errors.
|
/// When managing the server encounters errors.
|
||||||
pub mod error;
|
pub mod error;
|
||||||
/// Network operations.
|
|
||||||
pub mod net;
|
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::{
|
||||||
use crate::protocol::types::Uuid;
|
config::Config,
|
||||||
use crate::protocol::ClientState;
|
net::connection::{DownstreamConnectionManager, DownstreamConnectionState},
|
||||||
use crate::App;
|
server::{config::ServerConfig, error::Error},
|
||||||
use config::ServerConfig;
|
App,
|
||||||
use net::{NetworkClient, NetworkClientState};
|
};
|
||||||
use std::sync::Arc;
|
use tokio::task::JoinHandle;
|
||||||
use tokio::net::{TcpListener, ToSocketAddrs};
|
|
||||||
use tokio::{sync::RwLock, task::JoinHandle};
|
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, trace};
|
|
||||||
|
|
||||||
/// The main state and logic of the program.
|
/// The main state and logic of the program.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
clients: Arc<RwLock<Vec<NetworkClient>>>,
|
|
||||||
net_tasks_handle: JoinHandle<()>,
|
|
||||||
}
|
|
||||||
impl Server {
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn create_network_tasks<A: 'static + ToSocketAddrs + Send + std::fmt::Debug>(
|
|
||||||
bind_address: A,
|
|
||||||
network_clients: Arc<RwLock<Vec<NetworkClient>>>,
|
|
||||||
running: CancellationToken,
|
running: CancellationToken,
|
||||||
) {
|
connections: DownstreamConnectionManager,
|
||||||
// Start a task to receive new clients.
|
listener: JoinHandle<()>,
|
||||||
trace!("Creating listener task");
|
|
||||||
let nc = network_clients.clone();
|
|
||||||
let r = running.clone();
|
|
||||||
let listener_task = tokio::spawn(async move {
|
|
||||||
trace!("Listener task created");
|
|
||||||
let Ok(listener) = TcpListener::bind(bind_address).await else {
|
|
||||||
error!("Could not bind to given address, shutting down.");
|
|
||||||
std::process::exit(1);
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut client_id = 0u128;
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
_ = r.cancelled() => {
|
|
||||||
trace!("Listener task received shutdown");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
result = listener.accept() => {
|
|
||||||
if let Ok((stream, _)) = result {
|
|
||||||
trace!("Listener task got client (id {})", client_id);
|
|
||||||
nc.write().await.push(NetworkClient::new(client_id, stream));
|
|
||||||
client_id += 1;
|
|
||||||
} else {
|
|
||||||
trace!("Listener task failed to accept client");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Start a task to update existing clients' packet queues.
|
|
||||||
trace!("Creating network task");
|
|
||||||
let nc = network_clients.clone();
|
|
||||||
let r = running.clone();
|
|
||||||
let packet_task = tokio::spawn(async move {
|
|
||||||
trace!("Network task created");
|
|
||||||
loop {
|
|
||||||
// Start tasks to read/write to clients concurrently.
|
|
||||||
tokio::select! {
|
|
||||||
_ = r.cancelled() => {
|
|
||||||
trace!("Network task received shutdown");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
mut nc = nc.write() => {
|
|
||||||
trace!("Network task updating clients");
|
|
||||||
let tasks: Vec<JoinHandle<NetworkClient>> = nc
|
|
||||||
.drain(..)
|
|
||||||
.map(|mut client: NetworkClient| {
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let _ = client.read_packets().await;
|
|
||||||
if client.send_queued_packets().await.is_err() {
|
|
||||||
client
|
|
||||||
.disconnect(Some(serde_json::json!({ "text": "Error writing packets." })))
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
client
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
*nc = Vec::with_capacity(tasks.len());
|
|
||||||
for task in tasks {
|
|
||||||
nc.push(task.await.unwrap());
|
|
||||||
}
|
|
||||||
trace!("Network task updated clients");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Start a task to remove disconnected clients.
|
|
||||||
trace!("Creating disconnection task");
|
|
||||||
let nc = network_clients.clone();
|
|
||||||
let r = running.clone();
|
|
||||||
let disconnection_task = tokio::spawn(async move {
|
|
||||||
trace!("Disconnection task created");
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
_ = r.cancelled() => {
|
|
||||||
trace!("Disconnection task received shutdown");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
mut nc = nc.write() => {
|
|
||||||
let before = nc.len();
|
|
||||||
nc.retain(|client| client.state != NetworkClientState::Disconnected);
|
|
||||||
let after = nc.len();
|
|
||||||
trace!("Disconnection task removed {} clients", before - after);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Join the tasks on shutdown.
|
|
||||||
listener_task.await.expect("Listener task crashed");
|
|
||||||
packet_task.await.expect("Packet task crashed");
|
|
||||||
disconnection_task
|
|
||||||
.await
|
|
||||||
.expect("Disconnection task crashed");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl App for Server {
|
impl App for Server {
|
||||||
type Error = error::Error;
|
type Error = Error;
|
||||||
|
|
||||||
fn startup_message() -> String {
|
fn startup_message() -> String {
|
||||||
let config = Config::instance();
|
let config = Config::instance();
|
||||||
@ -147,162 +36,69 @@ impl App for Server {
|
|||||||
let config = Config::instance();
|
let config = Config::instance();
|
||||||
let bind_address = format!("0.0.0.0:{}", config.server.port);
|
let bind_address = format!("0.0.0.0:{}", config.server.port);
|
||||||
|
|
||||||
let clients = Arc::new(RwLock::new(vec![]));
|
// No limit on connections.
|
||||||
let net_tasks_handle = tokio::spawn(Self::create_network_tasks(
|
let connections = DownstreamConnectionManager::new(None);
|
||||||
bind_address,
|
let listener = connections
|
||||||
clients.clone(),
|
.spawn_listener(bind_address, running.child_token())
|
||||||
running.child_token(),
|
.await
|
||||||
));
|
.map_err(Error::Network)?;
|
||||||
|
|
||||||
Ok(Server {
|
Ok(Server {
|
||||||
clients,
|
running,
|
||||||
net_tasks_handle,
|
connections,
|
||||||
|
listener,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
async fn update(&mut self) -> Result<(), Self::Error> {
|
async fn update(&mut self) -> Result<(), Self::Error> {
|
||||||
let mut clients = self.clients.write().await;
|
let online_player_count = self
|
||||||
|
.connections
|
||||||
// Handle packets from the clients.
|
.clients()
|
||||||
let online_players = clients
|
.filter(|c| matches!(c.client_state(), DownstreamConnectionState::Play))
|
||||||
.iter()
|
|
||||||
.filter(|client| matches!(client.state, NetworkClientState::Play))
|
|
||||||
.count();
|
.count();
|
||||||
'clients: for client in clients.iter_mut() {
|
|
||||||
use crate::protocol::packets;
|
|
||||||
|
|
||||||
'packets: while !client.incoming_packet_queue.is_empty() {
|
// Receive new connections and remove disconnected ones.
|
||||||
// client.read_packet()
|
self.connections.update().await?;
|
||||||
// None: The client doesn't have any more packets.
|
|
||||||
// Some(Err(_)): The client read an unexpected packet. TODO: Handle this error.
|
|
||||||
// Some(Ok(_)): The client read the expected packet.
|
|
||||||
match client.state.clone() {
|
|
||||||
NetworkClientState::Handshake => {
|
|
||||||
use packets::handshake::serverbound::Handshake;
|
|
||||||
|
|
||||||
let handshake = match client.read_packet::<Handshake>() {
|
// Read packets from each connection.
|
||||||
None => continue 'packets,
|
// Handle handshake connections.
|
||||||
Some(Err(_)) => continue 'clients,
|
let _ = futures::future::join_all(
|
||||||
Some(Ok(handshake)) => handshake,
|
self.connections
|
||||||
};
|
.clients_mut()
|
||||||
|
.filter(|c| matches!(c.client_state(), DownstreamConnectionState::Handshake))
|
||||||
if handshake.next_state == ClientState::Status {
|
.map(|c| c.handle_handshake()),
|
||||||
client.state = NetworkClientState::Status {
|
)
|
||||||
received_request: false,
|
|
||||||
received_ping: false,
|
|
||||||
};
|
|
||||||
} else if handshake.next_state == ClientState::Login {
|
|
||||||
client.state = NetworkClientState::Login {
|
|
||||||
received_start: (false, None),
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
client
|
|
||||||
.disconnect(Some(
|
|
||||||
serde_json::json!({ "text": "Received invalid SH00Handshake packet" }),
|
|
||||||
))
|
|
||||||
.await;
|
.await;
|
||||||
}
|
|
||||||
}
|
|
||||||
// Status !received_request: Read SS00StatusRequest and respond with CS00StatusResponse
|
|
||||||
NetworkClientState::Status {
|
|
||||||
received_request,
|
|
||||||
received_ping,
|
|
||||||
} if !received_request => {
|
|
||||||
use packets::status::{
|
|
||||||
clientbound::StatusResponse, serverbound::StatusRequest,
|
|
||||||
};
|
|
||||||
|
|
||||||
let _status_request = match client.read_packet::<StatusRequest>() {
|
// Handle status connections.
|
||||||
None => continue 'packets,
|
let _ = futures::future::join_all(
|
||||||
Some(Err(_)) => continue 'clients,
|
self.connections
|
||||||
Some(Ok(p)) => p,
|
.clients_mut()
|
||||||
};
|
.filter(|c| matches!(c.client_state(), DownstreamConnectionState::StatusRequest))
|
||||||
client.state = NetworkClientState::Status {
|
.map(|c| c.handle_status_ping(online_player_count)),
|
||||||
received_request: true,
|
)
|
||||||
received_ping,
|
.await;
|
||||||
};
|
|
||||||
let config = Config::instance();
|
|
||||||
use base64::Engine;
|
|
||||||
client.queue_packet(StatusResponse {
|
|
||||||
response: serde_json::json!({
|
|
||||||
"version": {
|
|
||||||
"name": config.global.game_version,
|
|
||||||
"protocol": config.global.protocol_version
|
|
||||||
},
|
|
||||||
"players": {
|
|
||||||
"max": config.server.max_players,
|
|
||||||
"online": online_players,
|
|
||||||
"sample": []
|
|
||||||
},
|
|
||||||
"description": {
|
|
||||||
"text": config.server.motd
|
|
||||||
},
|
|
||||||
"favicon": format!("data:image/png;base64,{}", base64::engine::general_purpose::STANDARD_NO_PAD.encode(&config.server.server_icon_bytes)),
|
|
||||||
"enforcesSecureChat": true
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// Status !received_ping: Read SS00StatusRequest and respond with CS00StatusResponse
|
|
||||||
NetworkClientState::Status { received_ping, .. } if !received_ping => {
|
|
||||||
use packets::status::{
|
|
||||||
clientbound::PingResponse, serverbound::PingRequest,
|
|
||||||
};
|
|
||||||
|
|
||||||
let ping = match client.read_packet::<PingRequest>() {
|
// Handle login connections.
|
||||||
None => continue 'packets,
|
// Handle play connection packets.
|
||||||
Some(Err(_)) => continue 'clients,
|
// Process world updates.
|
||||||
Some(Ok(p)) => p,
|
// Send out play connection updates.
|
||||||
};
|
|
||||||
client.queue_packet(PingResponse {
|
|
||||||
payload: ping.payload,
|
|
||||||
});
|
|
||||||
client.state = NetworkClientState::Disconnected;
|
|
||||||
}
|
|
||||||
NetworkClientState::Status { .. } => unreachable!(),
|
|
||||||
NetworkClientState::Login { received_start, .. } if !received_start.0 => {
|
|
||||||
use packets::login::{clientbound::*, serverbound::*};
|
|
||||||
|
|
||||||
let login_start = match client.read_packet::<LoginStart>() {
|
|
||||||
None => continue 'packets,
|
|
||||||
Some(Err(_)) => continue 'clients,
|
|
||||||
Some(Ok(p)) => p,
|
|
||||||
};
|
|
||||||
// TODO: Authenticate the user.
|
|
||||||
// TODO: Get the user from the stored database.
|
|
||||||
// TODO: Encryption/compression.
|
|
||||||
client.queue_packet(LoginSuccess {
|
|
||||||
uuid: login_start.uuid.unwrap_or(Uuid::nil()),
|
|
||||||
username: login_start.name.clone(),
|
|
||||||
properties: vec![],
|
|
||||||
});
|
|
||||||
client.state = NetworkClientState::Login {
|
|
||||||
received_start: (true, Some(login_start)),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
NetworkClientState::Login { .. } => unreachable!(),
|
|
||||||
NetworkClientState::Play => unimplemented!(),
|
|
||||||
NetworkClientState::Disconnected => unimplemented!(),
|
|
||||||
}
|
|
||||||
// If continue was not
|
|
||||||
break 'packets;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
async fn shutdown(self) -> Result<(), Self::Error> {
|
async fn shutdown(self) -> Result<(), Self::Error> {
|
||||||
// Close the concurrent tasks.
|
// Ensure any child tasks have been shut down.
|
||||||
let _ = self.net_tasks_handle.await;
|
self.running.cancel();
|
||||||
|
|
||||||
// Send disconnect messages to the clients.
|
let _ = self.listener.await.map_err(Error::Task)?;
|
||||||
for client in self.clients.write().await.iter_mut() {
|
let _ = self
|
||||||
client
|
.connections
|
||||||
.disconnect(Some(
|
.shutdown(Some(
|
||||||
serde_json::json!({ "text": "The server is shutting down." }),
|
serde_json::json!({ "text": "The server is shutting down." }),
|
||||||
))
|
))
|
||||||
.await;
|
.await
|
||||||
}
|
.map_err(Error::Network)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,246 +0,0 @@
|
|||||||
use crate::protocol::{
|
|
||||||
packets::{self, Packet, PacketDirection},
|
|
||||||
parsing::Parsable,
|
|
||||||
ClientState,
|
|
||||||
};
|
|
||||||
use std::{collections::VecDeque, sync::Arc, time::Instant};
|
|
||||||
use tokio::io::AsyncWriteExt;
|
|
||||||
use tokio::{net::TcpStream, sync::RwLock};
|
|
||||||
use tracing::{debug, trace, warn};
|
|
||||||
|
|
||||||
/// Similar to `composition_protocol::ClientState`,
|
|
||||||
/// but contains more useful data for managing the client's state.
|
|
||||||
#[derive(Clone, PartialEq, Debug)]
|
|
||||||
pub(crate) enum NetworkClientState {
|
|
||||||
/// A client has established a connection with the server.
|
|
||||||
///
|
|
||||||
/// See `composition_protocol::ClientState::Handshake` for more details.
|
|
||||||
Handshake,
|
|
||||||
/// The client sent `SH00Handshake` with `next_state = ClientState::Status`
|
|
||||||
/// and is performing [server list ping](https://wiki.vg/Server_List_Ping).
|
|
||||||
Status {
|
|
||||||
/// When the server receives `SS00StatusRequest`, this is set
|
|
||||||
/// to `true` and the server should send `CS00StatusResponse`.
|
|
||||||
received_request: bool,
|
|
||||||
/// When the server receives `SS01PingRequest`, this is set
|
|
||||||
/// to `true` and the server should send `CS01PingResponse`
|
|
||||||
/// and set the connection state to `Disconnected`.
|
|
||||||
received_ping: bool,
|
|
||||||
},
|
|
||||||
/// The client sent `SH00Handshake` with `next_state = ClientState::Login`
|
|
||||||
/// and is attempting to join the server.
|
|
||||||
Login {
|
|
||||||
received_start: (bool, Option<packets::login::serverbound::LoginStart>),
|
|
||||||
},
|
|
||||||
/// The server sent `CL02LoginSuccess` and transitioned to `Play`.
|
|
||||||
#[allow(dead_code)]
|
|
||||||
Play,
|
|
||||||
/// The client has disconnected.
|
|
||||||
///
|
|
||||||
/// No packets should be sent or received,
|
|
||||||
/// and the `NetworkClient` should be queued for removal.
|
|
||||||
Disconnected,
|
|
||||||
}
|
|
||||||
impl From<NetworkClientState> for ClientState {
|
|
||||||
fn from(value: NetworkClientState) -> Self {
|
|
||||||
match value {
|
|
||||||
NetworkClientState::Handshake => ClientState::Handshake,
|
|
||||||
NetworkClientState::Status { .. } => ClientState::Status,
|
|
||||||
NetworkClientState::Login { .. } => ClientState::Login,
|
|
||||||
NetworkClientState::Play => ClientState::Play,
|
|
||||||
NetworkClientState::Disconnected => ClientState::Disconnected,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl AsRef<ClientState> for NetworkClientState {
|
|
||||||
fn as_ref(&self) -> &ClientState {
|
|
||||||
match self {
|
|
||||||
NetworkClientState::Handshake => &ClientState::Handshake,
|
|
||||||
NetworkClientState::Status { .. } => &ClientState::Status,
|
|
||||||
NetworkClientState::Login { .. } => &ClientState::Login,
|
|
||||||
NetworkClientState::Play => &ClientState::Play,
|
|
||||||
NetworkClientState::Disconnected => &ClientState::Disconnected,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A wrapper around the raw `TcpStream` that abstracts away reading/writing packets and bytes.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub(crate) struct NetworkClient {
|
|
||||||
/// The `NetworkClient`'s unique id.
|
|
||||||
pub id: u128,
|
|
||||||
pub state: NetworkClientState,
|
|
||||||
stream: Arc<RwLock<TcpStream>>,
|
|
||||||
/// Data gets appended to the back as it gets read,
|
|
||||||
/// and popped from the front as it gets parsed into packets.
|
|
||||||
incoming_data: VecDeque<u8>,
|
|
||||||
/// Packets get appended to the back as they get read,
|
|
||||||
/// and popped from the front as they get handled.
|
|
||||||
pub incoming_packet_queue: VecDeque<Packet>,
|
|
||||||
/// Keeps track of the last time the client sent data.
|
|
||||||
///
|
|
||||||
/// This is useful for removing clients that have timed out.
|
|
||||||
pub last_received_data_time: Instant,
|
|
||||||
/// Packets get appended to the back and get popped from the front as they get sent.
|
|
||||||
pub outgoing_packet_queue: VecDeque<Packet>,
|
|
||||||
}
|
|
||||||
impl NetworkClient {
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub fn new(id: u128, stream: TcpStream) -> NetworkClient {
|
|
||||||
NetworkClient {
|
|
||||||
id,
|
|
||||||
state: NetworkClientState::Handshake,
|
|
||||||
stream: Arc::new(RwLock::new(stream)),
|
|
||||||
incoming_data: VecDeque::new(),
|
|
||||||
incoming_packet_queue: VecDeque::new(),
|
|
||||||
last_received_data_time: Instant::now(),
|
|
||||||
outgoing_packet_queue: VecDeque::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn read_data(&mut self) -> tokio::io::Result<()> {
|
|
||||||
trace!("NetworkClient.read_data() id {}", self.id);
|
|
||||||
let stream = self.stream.read().await;
|
|
||||||
|
|
||||||
// Try to read 4kb at a time until there is no more data.
|
|
||||||
loop {
|
|
||||||
let mut buf = [0; 4096];
|
|
||||||
|
|
||||||
let num_bytes = match stream.try_read(&mut buf) {
|
|
||||||
Ok(0) => break,
|
|
||||||
Ok(n) => n,
|
|
||||||
Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(e) => return Err(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
debug!("Read {} bytes from client {}", num_bytes, self.id);
|
|
||||||
|
|
||||||
self.last_received_data_time = Instant::now();
|
|
||||||
self.incoming_data.extend(&buf[..num_bytes]);
|
|
||||||
}
|
|
||||||
|
|
||||||
trace!("NetworkClient.read_data() end id {}", self.id);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
// TODO: Stream compression/encryption.
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub async fn read_packets(&mut self) -> crate::protocol::Result<()> {
|
|
||||||
trace!("NetworkClient.read_packet() id {}", self.id);
|
|
||||||
|
|
||||||
if self.read_data().await.is_err() {
|
|
||||||
self.disconnect(None).await;
|
|
||||||
return Err(crate::protocol::Error::Disconnected);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.incoming_data.make_contiguous();
|
|
||||||
let (mut data, &[..]) = self.incoming_data.as_slices();
|
|
||||||
|
|
||||||
let mut bytes_consumed = 0;
|
|
||||||
while !data.is_empty() {
|
|
||||||
let p = Packet::parse(
|
|
||||||
self.state.clone().into(),
|
|
||||||
PacketDirection::Serverbound,
|
|
||||||
data,
|
|
||||||
);
|
|
||||||
trace!("{} got {:?}", self.id, p);
|
|
||||||
match p {
|
|
||||||
Ok((d, packet)) => {
|
|
||||||
debug!("Got packet {:?} from client {}", packet, self.id);
|
|
||||||
bytes_consumed += data.len() - d.len();
|
|
||||||
data = d;
|
|
||||||
self.incoming_packet_queue.push_back(packet);
|
|
||||||
}
|
|
||||||
Err(nom::Err::Incomplete(_)) => break,
|
|
||||||
Err(_) => {
|
|
||||||
// Remove the valid bytes before this packet.
|
|
||||||
self.incoming_data = self.incoming_data.split_off(bytes_consumed);
|
|
||||||
return Err(crate::protocol::Error::Parsing);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the bytes we just read.
|
|
||||||
self.incoming_data = self.incoming_data.split_off(bytes_consumed);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
// None: There was no packet to read.
|
|
||||||
// Some(Err(())): The packet was the wrong type.
|
|
||||||
// Some(Ok(_)): The packet was successfully read.
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub fn read_packet<P: std::fmt::Debug + TryFrom<Packet>>(
|
|
||||||
&mut self,
|
|
||||||
) -> Option<std::result::Result<P, Packet>> {
|
|
||||||
if let Some(generic_packet) = self.incoming_packet_queue.pop_back() {
|
|
||||||
if let Ok(packet) = TryInto::<P>::try_into(generic_packet.clone()) {
|
|
||||||
Some(Ok(packet))
|
|
||||||
} else {
|
|
||||||
self.incoming_packet_queue.push_back(generic_packet.clone());
|
|
||||||
Some(Err(generic_packet))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub fn queue_packet<P: std::fmt::Debug + Into<Packet>>(&mut self, packet: P) {
|
|
||||||
self.outgoing_packet_queue.push_back(packet.into());
|
|
||||||
}
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub async fn send_queued_packets(&mut self) -> crate::protocol::Result<()> {
|
|
||||||
let packets: Vec<_> = self.outgoing_packet_queue.drain(..).collect();
|
|
||||||
for packet in packets {
|
|
||||||
self.send_packet(packet)
|
|
||||||
.await
|
|
||||||
.map_err(|_| crate::protocol::Error::Disconnected)?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub async fn send_packet<P: std::fmt::Debug + Into<Packet>>(
|
|
||||||
&self,
|
|
||||||
packet: P,
|
|
||||||
) -> tokio::io::Result<()> {
|
|
||||||
let packet: Packet = packet.into();
|
|
||||||
|
|
||||||
debug!("Sending packet {:?} to client {}", packet, self.id);
|
|
||||||
let (packet_id, mut packet_body) = packet.serialize();
|
|
||||||
let mut packet_id = packet_id.serialize();
|
|
||||||
|
|
||||||
// TODO: Stream compression/encryption.
|
|
||||||
|
|
||||||
let mut b = vec![];
|
|
||||||
b.append(&mut packet_id);
|
|
||||||
b.append(&mut packet_body);
|
|
||||||
|
|
||||||
// bytes: packet length as varint, packet id as varint, packet body
|
|
||||||
let bytes = Parsable::serialize(&b);
|
|
||||||
|
|
||||||
self.stream.write().await.write_all(&bytes).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub async fn disconnect(&mut self, reason: Option<crate::protocol::types::Chat>) {
|
|
||||||
use packets::{login::clientbound::LoginDisconnect, play::clientbound::PlayDisconnect};
|
|
||||||
|
|
||||||
let reason = reason.unwrap_or(serde_json::json!({
|
|
||||||
"text": "You have been disconnected!"
|
|
||||||
}));
|
|
||||||
|
|
||||||
match self.state.as_ref() {
|
|
||||||
ClientState::Disconnected | ClientState::Handshake | ClientState::Status => {
|
|
||||||
// Impossible to send a disconnect in these states.
|
|
||||||
}
|
|
||||||
ClientState::Login => {
|
|
||||||
let _ = self.send_packet(LoginDisconnect { reason }).await;
|
|
||||||
}
|
|
||||||
ClientState::Play => {
|
|
||||||
let _ = self.send_packet(PlayDisconnect { reason }).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.state = NetworkClientState::Disconnected;
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user