From 0cbfe045e360c4f4f4b63f7b20c1cd92f8c6c568 Mon Sep 17 00:00:00 2001 From: Garen Tyler Date: Mon, 9 Dec 2024 01:24:33 -0700 Subject: [PATCH] Working proxy! --- src/net/codec.rs | 12 ----------- src/net/connection.rs | 8 +++++++ src/protocol/packets.rs | 46 ++++++++++++++++++++++++++++++++++++++--- src/proxy/mod.rs | 40 ++++++++++++++++++++++++++++++----- 4 files changed, 86 insertions(+), 20 deletions(-) diff --git a/src/net/codec.rs b/src/net/codec.rs index c9fd077..874c6ee 100644 --- a/src/net/codec.rs +++ b/src/net/codec.rs @@ -86,15 +86,3 @@ impl Encoder for PacketCodec { Ok(()) } } - -#[cfg(test)] -mod tests { - #[test] - fn packet_decoder_works() { - unimplemented!() - } - #[test] - fn packet_encoder_works() { - unimplemented!() - } -} diff --git a/src/net/connection.rs b/src/net/connection.rs index 9bce62a..84e00da 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -135,8 +135,13 @@ impl ConnectionManager { }).await; // Remove disconnected clients. + let before = self.clients.len(); self.clients .retain(|_id, c| c.client_state() != ClientState::Disconnected); + let after = self.clients.len(); + if before - after > 0 { + trace!("Removed {} disconnected clients", before - after); + } Ok(()) } pub async fn disconnect( @@ -198,6 +203,9 @@ impl Connection { pub fn client_state(&self) -> ClientState { self.stream.codec().client_state } + pub fn client_state_mut(&mut self) -> &mut ClientState { + &mut self.stream.codec_mut().client_state + } pub fn received_elapsed(&self) -> Duration { self.last_received_data_time.elapsed() } diff --git a/src/protocol/packets.rs b/src/protocol/packets.rs index 71fa0a5..e01fd5d 100644 --- a/src/protocol/packets.rs +++ b/src/protocol/packets.rs @@ -45,8 +45,8 @@ macro_rules! packets { if client_state == ClientState::Disconnected { return nom::combinator::fail(input); } - let (input, packet_len) = VarInt::parse_usize(input)?; + let (input, packet_body) = take(packet_len)(input)?; let (packet_body, packet_id) = verify(VarInt::parse, |v| { match client_state { $(ClientState::$state_name => { @@ -61,8 +61,7 @@ macro_rules! packets { })* ClientState::Disconnected => false, } - })(input)?; - let (input, packet_body) = take(packet_len)(packet_body)?; + })(packet_body)?; let (_, packet) = Packet::body_parser(client_state, direction, packet_id)(packet_body)?; Ok((input, packet)) } @@ -110,6 +109,7 @@ macro_rules! packets { Packet::LoginSuccess(_) => Some(ClientState::Play), Packet::LoginDisconnect(_) => Some(ClientState::Disconnected), Packet::PlayDisconnect(_) => Some(ClientState::Disconnected), + Packet::PingResponse(_) => Some(ClientState::Disconnected), _ => None, } } @@ -239,3 +239,43 @@ packets!( } } ); + +#[cfg(test)] +mod tests { + use crate::protocol::{packets::handshake::serverbound::Handshake, types::VarInt, ClientState}; + use super::{Packet, PacketDirection}; + + fn get_handshake() -> (Handshake, &'static [u8]) { + ( + Handshake { + protocol_version: VarInt::from(767), + host: String::from("localhost"), + port: 25565, + next_state: ClientState::Status, + }, + &[ + // Packet length + 0x10, + // Packet ID + 0x00, + // protocol_version: VarInt + 0xff, 0x05, + // host: String + 0x09, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74, + // port: u16 + 0x63, 0xdd, + // next_state: ClientState (VarInt) + 0x01, + ] + ) + } + + #[test] + fn packet_parsing_works() { + let (handshake, handshake_bytes) = get_handshake(); + + let (rest, packet) = Packet::parse(ClientState::Handshake, PacketDirection::Serverbound, handshake_bytes).unwrap(); + assert_eq!(packet, Packet::Handshake(handshake)); + assert!(rest.is_empty()); + } +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 5abd58b..0902a1b 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -2,6 +2,8 @@ pub mod config; pub mod error; use crate::net::connection::Connection; +use crate::protocol::packets::Packet; +use crate::protocol::ClientState; use crate::App; use crate::{config::Config, net::connection::ConnectionManager}; use config::ProxyConfig; @@ -16,8 +18,25 @@ pub struct Proxy { running: CancellationToken, connections: ConnectionManager, listener: JoinHandle<()>, + upstream_address: String, upstream: Connection, } +impl Proxy { + pub async fn connect_upstream(upstream_address: &str) -> Result { + let upstream = TcpStream::connect(upstream_address).await.map_err(Error::Io)?; + Ok(Connection::new_server(0, upstream)) + } + pub fn rewrite_packet(packet: Packet) -> Packet { + match packet { + Packet::StatusResponse(mut status) => { + let new_description = ProxyConfig::default().version.clone(); + *status.response.as_object_mut().unwrap().get_mut("description").unwrap() = serde_json::Value::String(new_description); + Packet::StatusResponse(status) + } + p => p, + } + } +} #[async_trait::async_trait] impl App for Proxy { type Error = Error; @@ -30,7 +49,6 @@ impl App for Proxy { config.proxy.port ) } - #[tracing::instrument] async fn new(running: CancellationToken) -> Result { let config = Config::instance(); let bind_address = format!("0.0.0.0:{}", config.proxy.port); @@ -44,14 +62,14 @@ impl App for Proxy { let upstream_address = format!("{}:{}", config.proxy.upstream_host, config.proxy.upstream_port); info!("Upstream server: {}", upstream_address); - let upstream = TcpStream::connect(upstream_address).await.map_err(Error::Io)?; - let upstream = Connection::new_server(0, upstream); + let upstream = Proxy::connect_upstream(&upstream_address).await?; Ok(Proxy { running, connections, listener, upstream, + upstream_address, }) } #[tracing::instrument] @@ -72,7 +90,11 @@ impl App for Proxy { match packet { Ok(packet) => { trace!("Got packet from client: {:?}", packet); - self.upstream.send_packet(packet).await.map_err(Error::Network)?; + let next_state = packet.state_change(); + self.upstream.send_packet(Proxy::rewrite_packet(packet)).await.map_err(Error::Network)?; + if let Some(next_state) = next_state { + *self.upstream.client_state_mut() = next_state; + } } Err(NetworkError::Parsing) => { debug!("Got invalid data from client (id {})", client.id()); @@ -87,7 +109,11 @@ impl App for Proxy { match packet { Ok(packet) => { trace!("Got packet from upstream: {:?}", packet); - client.send_packet(packet).await.map_err(Error::Network)?; + let next_state = packet.state_change(); + client.send_packet(Proxy::rewrite_packet(packet)).await.map_err(Error::Network)?; + if let Some(next_state) = next_state { + *client.client_state_mut() = next_state; + } } Err(NetworkError::Parsing) => { error!("Got invalid data from upstream"); @@ -105,6 +131,10 @@ impl App for Proxy { let _ = client; let _ = self.connections.disconnect(id, Some(serde_json::json!({ "text": "Received malformed data." }))).await; } + if self.upstream.client_state() == ClientState::Disconnected { + // Start a new connection with the upstream server. + self.upstream = Proxy::connect_upstream(&self.upstream_address).await?; + } Ok(()) }