diff --git a/docker-compose.yml b/docker-compose.yml index 42d803f..315eff1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -32,9 +32,13 @@ services: context: . dockerfile: Dockerfile target: dev - command: [ "run -- proxy -u reference" ] + depends_on: + reference: + condition: service_healthy + restart: true + command: [ "run -- proxy -U reference -l trace" ] ports: - - "25566:25565" + - "25566:25566" volumes: - .:/app - .git:/app/.git diff --git a/reference.bin b/reference.bin new file mode 100644 index 0000000..3d0e7a8 --- /dev/null +++ b/reference.bin @@ -0,0 +1,14 @@ +00000000 10 00 ff 05 09 6c 6f 63 61 6c 68 6f 73 74 63 dd .....loc alhostc. +00000010 01 . +00000011 01 00 .. + 00000000 8c 01 00 89 01 7b 22 76 65 72 73 69 6f 6e 22 3a .....{"v ersion": + 00000010 7b 22 6e 61 6d 65 22 3a 22 31 2e 32 31 2e 34 22 {"name": "1.21.4" + 00000020 2c 22 70 72 6f 74 6f 63 6f 6c 22 3a 37 36 39 7d ,"protoc ol":769} + 00000030 2c 22 65 6e 66 6f 72 63 65 73 53 65 63 75 72 65 ,"enforc esSecure + 00000040 43 68 61 74 22 3a 74 72 75 65 2c 22 64 65 73 63 Chat":tr ue,"desc + 00000050 72 69 70 74 69 6f 6e 22 3a 22 41 20 4d 69 6e 65 ription" :"A Mine + 00000060 63 72 61 66 74 20 53 65 72 76 65 72 22 2c 22 70 craft Se rver","p + 00000070 6c 61 79 65 72 73 22 3a 7b 22 6d 61 78 22 3a 32 layers": {"max":2 + 00000080 30 2c 22 6f 6e 6c 69 6e 65 22 3a 30 7d 7d 0,"onlin e":0}} +00000013 09 01 00 00 00 00 00 3b 3b 51 .......; ;Q + 0000008E 09 01 00 00 00 00 00 3b 3b 51 .......; ;Q diff --git a/src/lib.rs b/src/lib.rs index b624e43..bbd3a65 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ use config::Subcommand; use once_cell::sync::OnceCell; use std::time::Instant; use tokio_util::sync::CancellationToken; -use tracing::info; +use tracing::{info, error}; pub const PROTOCOL_VERSION: i32 = 762; pub const GAME_VERSION: &str = "1.19.4"; @@ -62,8 +62,12 @@ pub(crate) trait App: Sized { break; } r = app.update() => { - if r.is_err() { - break; + match r { + Ok(_) => {}, + Err(e) => { + error!("{:?}", e); + break; + } } } } diff --git a/src/net/codec.rs b/src/net/codec.rs index 060d5d1..c9fd077 100644 --- a/src/net/codec.rs +++ b/src/net/codec.rs @@ -4,11 +4,12 @@ use crate::protocol::{ types::VarInt, ClientState, }; -use std::io::{Error, ErrorKind}; use tokio_util::{ bytes::{Buf, BytesMut}, codec::{Decoder, Encoder}, }; +use super::error::Error; +use tracing::trace; #[derive(Clone, Copy, Debug)] pub struct PacketCodec { @@ -33,7 +34,7 @@ impl Default for PacketCodec { } impl Decoder for PacketCodec { type Item = Packet; - type Error = std::io::Error; + type Error = Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match Packet::parse(self.client_state, self.packet_direction, src) { @@ -58,17 +59,21 @@ impl Decoder for PacketCodec { src.reserve(5); Ok(None) } - Err(_) => Err(Error::new(ErrorKind::InvalidData, "Nom parsing error")), + Err(_) => Err(Error::Parsing), } } - Err(nom::Err::Error(_)) | Err(nom::Err::Failure(_)) => { - Err(Error::new(ErrorKind::InvalidData, "Nom parsing error")) + Err(nom::Err::Error(e)) => { + trace!("parsing error: {:02X?}", e.input); + Err(Error::Parsing) + } + Err(nom::Err::Failure(_)) => { + Err(Error::Parsing) } } } } impl Encoder for PacketCodec { - type Error = std::io::Error; + type Error = Error; fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { let mut out = vec![]; diff --git a/src/net/connection.rs b/src/net/connection.rs index 78f9e71..9bce62a 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -1,4 +1,4 @@ -use super::codec::PacketCodec; +use super::{codec::PacketCodec, error::Error}; use crate::protocol::{ packets::{self, Packet, PacketDirection}, types::Chat, @@ -20,6 +20,7 @@ use tracing::{error, trace}; #[derive(Debug)] pub struct ConnectionManager { + max_clients: Option, clients: HashMap, channel: ( mpsc::UnboundedSender, @@ -27,8 +28,9 @@ pub struct ConnectionManager { ), } impl ConnectionManager { - pub fn new() -> ConnectionManager { + pub fn new(max_clients: Option) -> ConnectionManager { ConnectionManager { + max_clients, clients: HashMap::new(), channel: mpsc::unbounded_channel(), } @@ -39,11 +41,17 @@ impl ConnectionManager { pub fn client_mut(&mut self, id: u128) -> Option<&mut Connection> { self.clients.get_mut(&id) } + pub fn clients(&self) -> impl Iterator { + self.clients.iter().map(|(_id, c)| c) + } + pub fn clients_mut(&mut self) -> impl Iterator { + self.clients.iter_mut().map(|(_id, c)| c) + } pub async fn spawn_listener( &self, bind_address: A, running: CancellationToken, - ) -> Result, std::io::Error> + ) -> Result, Error> where A: 'static + ToSocketAddrs + Send + std::fmt::Debug, { @@ -51,6 +59,7 @@ impl ConnectionManager { let fmt_addr = format!("{:?}", bind_address); let listener = TcpListener::bind(bind_address) .await + .map_err(Error::Io) .inspect_err(|_| error!("Could not bind to {}.", fmt_addr))?; let sender = self.channel.0.clone(); @@ -81,25 +90,49 @@ impl ConnectionManager { Ok(join_handle) } - pub fn update(&mut self) -> Result<(), std::io::Error> { - use std::io::{Error, ErrorKind}; - + pub async fn update(&mut self) -> Result<(), Error> { // Receive new clients from the sender. loop { match self.channel.1.try_recv() { Ok(connection) => { let id = connection.id(); - self.clients.insert(id, connection); - } - Err(mpsc::error::TryRecvError::Disconnected) => { - return Err(Error::new( - ErrorKind::BrokenPipe, - "all senders disconnected", - )) + + match self.max_clients { + Some(max) => { + if self.clients.len() >= max { + let _ = connection.disconnect(None).await; + } else { + self.clients.insert(id, connection); + } + } + None => { + self.clients.insert(id, connection); + }, + } } + Err(mpsc::error::TryRecvError::Disconnected) => return Err(Error::ConnectionChannelDisconnnection), Err(mpsc::error::TryRecvError::Empty) => break, }; } + + // Disconnect any clients that have timed out. + // We don't actually care if the disconnections succeed, + // the connection is going to be dropped anyway. + let _ = futures::future::join_all({ + // Workaround until issue #59618 hash_extract_if gets stabilized. + let ids = self.clients.iter() + .filter_map(|(id, c)| { + if c.received_elapsed() > Duration::from_secs(10) { + Some(*id) + } else { + None + } + }) + .collect::>(); + ids.into_iter() + .map(|id| self.clients.remove(&id).unwrap()) + .map(|client| client.disconnect(None)) + }).await; // Remove disconnected clients. self.clients @@ -110,11 +143,11 @@ impl ConnectionManager { &mut self, id: u128, reason: Option, - ) -> Option> { + ) -> Option> { let client = self.clients.remove(&id)?; Some(client.disconnect(reason).await) } - pub async fn shutdown(mut self, reason: Option) -> Result<(), std::io::Error> { + pub async fn shutdown(mut self, reason: Option) -> Result<(), Error> { let reason = reason.unwrap_or(serde_json::json!({ "text": "You have been disconnected!" })); @@ -128,8 +161,7 @@ impl ConnectionManager { // We don't actually care if the disconnections succeed, // the connection is going to be dropped anyway. - let _disconnections: Vec> = - futures::future::join_all(disconnections).await; + let _disconnections = futures::future::join_all(disconnections).await; Ok(()) } @@ -172,15 +204,15 @@ impl Connection { pub fn sent_elapsed(&self) -> Duration { self.last_sent_data_time.elapsed() } - pub async fn read_packet(&mut self) -> Option> { + pub async fn read_packet(&mut self) -> Option> { self.last_received_data_time = Instant::now(); self.stream.next().await } - pub async fn send_packet>(&mut self, packet: P) -> Result<(), std::io::Error> { + pub async fn send_packet>(&mut self, packet: P) -> Result<(), Error> { let packet: Packet = packet.into(); self.stream.send(packet).await } - pub async fn disconnect(mut self, reason: Option) -> Result<(), std::io::Error> { + pub async fn disconnect(mut self, reason: Option) -> Result<(), Error> { trace!("Connection disconnected (id {})", self.id); use packets::{login::clientbound::LoginDisconnect, play::clientbound::PlayDisconnect}; diff --git a/src/net/error.rs b/src/net/error.rs new file mode 100644 index 0000000..b9840c5 --- /dev/null +++ b/src/net/error.rs @@ -0,0 +1,18 @@ +pub use std::io::Error as IoError; + +/// This type represents all possible errors that can occur in the network. +#[allow(dead_code)] +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error(transparent)] + Io(IoError), + #[error("There was an error parsing data")] + Parsing, + #[error("Internal channel disconnected")] + ConnectionChannelDisconnnection, +} +impl From for Error { + fn from(value: std::io::Error) -> Self { + Error::Io(value) + } +} diff --git a/src/net/mod.rs b/src/net/mod.rs index 91ff74e..9eda11f 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -2,3 +2,4 @@ pub mod codec; pub mod connection; +pub mod error; diff --git a/src/proxy/error.rs b/src/proxy/error.rs new file mode 100644 index 0000000..4fb59e2 --- /dev/null +++ b/src/proxy/error.rs @@ -0,0 +1,15 @@ +pub use std::io::Error as IoError; +pub use tokio::task::JoinError as TaskError; +pub use crate::net::error::Error as NetworkError; + +/// This type represents all possible errors that can occur when running the proxy. +#[allow(dead_code)] +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error(transparent)] + Io(IoError), + #[error(transparent)] + Task(TaskError), + #[error(transparent)] + Network(NetworkError), +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index a477723..5abd58b 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,21 +1,26 @@ pub mod config; +pub mod error; +use crate::net::connection::Connection; use crate::App; use crate::{config::Config, net::connection::ConnectionManager}; use config::ProxyConfig; +use tokio::net::TcpStream; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; -use tracing::info; +use tracing::{info, trace, error, debug}; +use error::{Error, NetworkError}; #[derive(Debug)] pub struct Proxy { running: CancellationToken, connections: ConnectionManager, listener: JoinHandle<()>, + upstream: Connection, } #[async_trait::async_trait] impl App for Proxy { - type Error = (); + type Error = Error; fn startup_message() -> String { let config = Config::instance(); @@ -29,35 +34,87 @@ impl App for Proxy { async fn new(running: CancellationToken) -> Result { let config = Config::instance(); let bind_address = format!("0.0.0.0:{}", config.proxy.port); - - let connections = ConnectionManager::new(); + + // Only allow one client to join at a time. + let connections = ConnectionManager::new(Some(1)); let listener = connections .spawn_listener(bind_address, running.child_token()) .await - .map_err(|_| ())?; + .map_err(Error::Network)?; - info!( - "Upstream server: {}:{}", - config.proxy.upstream_host, config.proxy.upstream_port - ); + let upstream_address = format!("{}:{}", config.proxy.upstream_host, config.proxy.upstream_port); + info!("Upstream server: {}", upstream_address); + let upstream = TcpStream::connect(upstream_address).await.map_err(Error::Io)?; + let upstream = Connection::new_server(0, upstream); Ok(Proxy { running, connections, listener, + upstream, }) } #[tracing::instrument] async fn update(&mut self) -> Result<(), Self::Error> { - todo!() + let _ = self.connections.update().await.map_err(Error::Network)?; + + let Some(client) = self.connections.clients_mut().take(1).next() else { + return Ok(()); + }; + + let mut client_parsing_error = false; + + // At the same time, try to read packets from the server and client. + // Forward the packet onto the other. + tokio::select! { + packet = client.read_packet() => { + if let Some(packet) = packet { + match packet { + Ok(packet) => { + trace!("Got packet from client: {:?}", packet); + self.upstream.send_packet(packet).await.map_err(Error::Network)?; + } + Err(NetworkError::Parsing) => { + debug!("Got invalid data from client (id {})", client.id()); + client_parsing_error = true; + } + Err(e) => return Err(Error::Network(e)), + } + } + } + packet = self.upstream.read_packet() => { + if let Some(packet) = packet { + match packet { + Ok(packet) => { + trace!("Got packet from upstream: {:?}", packet); + client.send_packet(packet).await.map_err(Error::Network)?; + } + Err(NetworkError::Parsing) => { + error!("Got invalid data from upstream"); + return Err(Error::Network(NetworkError::Parsing)); + }, + Err(e) => return Err(Error::Network(e)), + } + } + } + } + + if client_parsing_error { + let id = client.id(); + // Drop the &mut Connection + let _ = client; + let _ = self.connections.disconnect(id, Some(serde_json::json!({ "text": "Received malformed data." }))).await; + } + + Ok(()) } #[tracing::instrument] async fn shutdown(self) -> Result<(), Self::Error> { // Ensure any child tasks have been shut down. self.running.cancel(); - let _ = self.listener.await.map_err(|_| ())?; - let _ = self.connections.shutdown(None).await.map_err(|_| ())?; + let _ = self.listener.await.map_err(Error::Task)?; + let _ = self.connections.shutdown(None).await.map_err(Error::Network)?; Ok(()) }