Create basic proxy (almost working)

This commit is contained in:
Garen Tyler 2024-12-09 00:17:23 -07:00
parent adf4f37536
commit 7f29ac3011
Signed by: garentyler
SSH Key Fingerprint: SHA256:G4ke7blZMdpWPbkescyZ7IQYE4JAtwpI85YoJdq+S7U
9 changed files with 193 additions and 43 deletions

View File

@ -32,9 +32,13 @@ services:
context: .
dockerfile: Dockerfile
target: dev
command: [ "run -- proxy -u reference" ]
depends_on:
reference:
condition: service_healthy
restart: true
command: [ "run -- proxy -U reference -l trace" ]
ports:
- "25566:25565"
- "25566:25566"
volumes:
- .:/app
- .git:/app/.git

14
reference.bin Normal file
View File

@ -0,0 +1,14 @@
00000000 10 00 ff 05 09 6c 6f 63 61 6c 68 6f 73 74 63 dd .....loc alhostc.
00000010 01 .
00000011 01 00 ..
00000000 8c 01 00 89 01 7b 22 76 65 72 73 69 6f 6e 22 3a .....{"v ersion":
00000010 7b 22 6e 61 6d 65 22 3a 22 31 2e 32 31 2e 34 22 {"name": "1.21.4"
00000020 2c 22 70 72 6f 74 6f 63 6f 6c 22 3a 37 36 39 7d ,"protoc ol":769}
00000030 2c 22 65 6e 66 6f 72 63 65 73 53 65 63 75 72 65 ,"enforc esSecure
00000040 43 68 61 74 22 3a 74 72 75 65 2c 22 64 65 73 63 Chat":tr ue,"desc
00000050 72 69 70 74 69 6f 6e 22 3a 22 41 20 4d 69 6e 65 ription" :"A Mine
00000060 63 72 61 66 74 20 53 65 72 76 65 72 22 2c 22 70 craft Se rver","p
00000070 6c 61 79 65 72 73 22 3a 7b 22 6d 61 78 22 3a 32 layers": {"max":2
00000080 30 2c 22 6f 6e 6c 69 6e 65 22 3a 30 7d 7d 0,"onlin e":0}}
00000013 09 01 00 00 00 00 00 3b 3b 51 .......; ;Q
0000008E 09 01 00 00 00 00 00 3b 3b 51 .......; ;Q

View File

@ -18,7 +18,7 @@ use config::Subcommand;
use once_cell::sync::OnceCell;
use std::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::{info, error};
pub const PROTOCOL_VERSION: i32 = 762;
pub const GAME_VERSION: &str = "1.19.4";
@ -62,8 +62,12 @@ pub(crate) trait App: Sized {
break;
}
r = app.update() => {
if r.is_err() {
break;
match r {
Ok(_) => {},
Err(e) => {
error!("{:?}", e);
break;
}
}
}
}

View File

@ -4,11 +4,12 @@ use crate::protocol::{
types::VarInt,
ClientState,
};
use std::io::{Error, ErrorKind};
use tokio_util::{
bytes::{Buf, BytesMut},
codec::{Decoder, Encoder},
};
use super::error::Error;
use tracing::trace;
#[derive(Clone, Copy, Debug)]
pub struct PacketCodec {
@ -33,7 +34,7 @@ impl Default for PacketCodec {
}
impl Decoder for PacketCodec {
type Item = Packet;
type Error = std::io::Error;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match Packet::parse(self.client_state, self.packet_direction, src) {
@ -58,17 +59,21 @@ impl Decoder for PacketCodec {
src.reserve(5);
Ok(None)
}
Err(_) => Err(Error::new(ErrorKind::InvalidData, "Nom parsing error")),
Err(_) => Err(Error::Parsing),
}
}
Err(nom::Err::Error(_)) | Err(nom::Err::Failure(_)) => {
Err(Error::new(ErrorKind::InvalidData, "Nom parsing error"))
Err(nom::Err::Error(e)) => {
trace!("parsing error: {:02X?}", e.input);
Err(Error::Parsing)
}
Err(nom::Err::Failure(_)) => {
Err(Error::Parsing)
}
}
}
}
impl Encoder<Packet> for PacketCodec {
type Error = std::io::Error;
type Error = Error;
fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
let mut out = vec![];

View File

@ -1,4 +1,4 @@
use super::codec::PacketCodec;
use super::{codec::PacketCodec, error::Error};
use crate::protocol::{
packets::{self, Packet, PacketDirection},
types::Chat,
@ -20,6 +20,7 @@ use tracing::{error, trace};
#[derive(Debug)]
pub struct ConnectionManager {
max_clients: Option<usize>,
clients: HashMap<u128, Connection>,
channel: (
mpsc::UnboundedSender<Connection>,
@ -27,8 +28,9 @@ pub struct ConnectionManager {
),
}
impl ConnectionManager {
pub fn new() -> ConnectionManager {
pub fn new(max_clients: Option<usize>) -> ConnectionManager {
ConnectionManager {
max_clients,
clients: HashMap::new(),
channel: mpsc::unbounded_channel(),
}
@ -39,11 +41,17 @@ impl ConnectionManager {
pub fn client_mut(&mut self, id: u128) -> Option<&mut Connection> {
self.clients.get_mut(&id)
}
pub fn clients(&self) -> impl Iterator<Item = &Connection> {
self.clients.iter().map(|(_id, c)| c)
}
pub fn clients_mut(&mut self) -> impl Iterator<Item = &mut Connection> {
self.clients.iter_mut().map(|(_id, c)| c)
}
pub async fn spawn_listener<A>(
&self,
bind_address: A,
running: CancellationToken,
) -> Result<JoinHandle<()>, std::io::Error>
) -> Result<JoinHandle<()>, Error>
where
A: 'static + ToSocketAddrs + Send + std::fmt::Debug,
{
@ -51,6 +59,7 @@ impl ConnectionManager {
let fmt_addr = format!("{:?}", bind_address);
let listener = TcpListener::bind(bind_address)
.await
.map_err(Error::Io)
.inspect_err(|_| error!("Could not bind to {}.", fmt_addr))?;
let sender = self.channel.0.clone();
@ -81,25 +90,49 @@ impl ConnectionManager {
Ok(join_handle)
}
pub fn update(&mut self) -> Result<(), std::io::Error> {
use std::io::{Error, ErrorKind};
pub async fn update(&mut self) -> Result<(), Error> {
// Receive new clients from the sender.
loop {
match self.channel.1.try_recv() {
Ok(connection) => {
let id = connection.id();
self.clients.insert(id, connection);
}
Err(mpsc::error::TryRecvError::Disconnected) => {
return Err(Error::new(
ErrorKind::BrokenPipe,
"all senders disconnected",
))
match self.max_clients {
Some(max) => {
if self.clients.len() >= max {
let _ = connection.disconnect(None).await;
} else {
self.clients.insert(id, connection);
}
}
None => {
self.clients.insert(id, connection);
},
}
}
Err(mpsc::error::TryRecvError::Disconnected) => return Err(Error::ConnectionChannelDisconnnection),
Err(mpsc::error::TryRecvError::Empty) => break,
};
}
// Disconnect any clients that have timed out.
// We don't actually care if the disconnections succeed,
// the connection is going to be dropped anyway.
let _ = futures::future::join_all({
// Workaround until issue #59618 hash_extract_if gets stabilized.
let ids = self.clients.iter()
.filter_map(|(id, c)| {
if c.received_elapsed() > Duration::from_secs(10) {
Some(*id)
} else {
None
}
})
.collect::<Vec<_>>();
ids.into_iter()
.map(|id| self.clients.remove(&id).unwrap())
.map(|client| client.disconnect(None))
}).await;
// Remove disconnected clients.
self.clients
@ -110,11 +143,11 @@ impl ConnectionManager {
&mut self,
id: u128,
reason: Option<Chat>,
) -> Option<Result<(), std::io::Error>> {
) -> Option<Result<(), Error>> {
let client = self.clients.remove(&id)?;
Some(client.disconnect(reason).await)
}
pub async fn shutdown(mut self, reason: Option<Chat>) -> Result<(), std::io::Error> {
pub async fn shutdown(mut self, reason: Option<Chat>) -> Result<(), Error> {
let reason = reason.unwrap_or(serde_json::json!({
"text": "You have been disconnected!"
}));
@ -128,8 +161,7 @@ impl ConnectionManager {
// We don't actually care if the disconnections succeed,
// the connection is going to be dropped anyway.
let _disconnections: Vec<Result<(), std::io::Error>> =
futures::future::join_all(disconnections).await;
let _disconnections = futures::future::join_all(disconnections).await;
Ok(())
}
@ -172,15 +204,15 @@ impl Connection {
pub fn sent_elapsed(&self) -> Duration {
self.last_sent_data_time.elapsed()
}
pub async fn read_packet(&mut self) -> Option<Result<Packet, std::io::Error>> {
pub async fn read_packet(&mut self) -> Option<Result<Packet, Error>> {
self.last_received_data_time = Instant::now();
self.stream.next().await
}
pub async fn send_packet<P: Into<Packet>>(&mut self, packet: P) -> Result<(), std::io::Error> {
pub async fn send_packet<P: Into<Packet>>(&mut self, packet: P) -> Result<(), Error> {
let packet: Packet = packet.into();
self.stream.send(packet).await
}
pub async fn disconnect(mut self, reason: Option<Chat>) -> Result<(), std::io::Error> {
pub async fn disconnect(mut self, reason: Option<Chat>) -> Result<(), Error> {
trace!("Connection disconnected (id {})", self.id);
use packets::{login::clientbound::LoginDisconnect, play::clientbound::PlayDisconnect};

18
src/net/error.rs Normal file
View File

@ -0,0 +1,18 @@
pub use std::io::Error as IoError;
/// This type represents all possible errors that can occur in the network.
#[allow(dead_code)]
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error(transparent)]
Io(IoError),
#[error("There was an error parsing data")]
Parsing,
#[error("Internal channel disconnected")]
ConnectionChannelDisconnnection,
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::Io(value)
}
}

View File

@ -2,3 +2,4 @@
pub mod codec;
pub mod connection;
pub mod error;

15
src/proxy/error.rs Normal file
View File

@ -0,0 +1,15 @@
pub use std::io::Error as IoError;
pub use tokio::task::JoinError as TaskError;
pub use crate::net::error::Error as NetworkError;
/// This type represents all possible errors that can occur when running the proxy.
#[allow(dead_code)]
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error(transparent)]
Io(IoError),
#[error(transparent)]
Task(TaskError),
#[error(transparent)]
Network(NetworkError),
}

View File

@ -1,21 +1,26 @@
pub mod config;
pub mod error;
use crate::net::connection::Connection;
use crate::App;
use crate::{config::Config, net::connection::ConnectionManager};
use config::ProxyConfig;
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::{info, trace, error, debug};
use error::{Error, NetworkError};
#[derive(Debug)]
pub struct Proxy {
running: CancellationToken,
connections: ConnectionManager,
listener: JoinHandle<()>,
upstream: Connection,
}
#[async_trait::async_trait]
impl App for Proxy {
type Error = ();
type Error = Error;
fn startup_message() -> String {
let config = Config::instance();
@ -29,35 +34,87 @@ impl App for Proxy {
async fn new(running: CancellationToken) -> Result<Self, Self::Error> {
let config = Config::instance();
let bind_address = format!("0.0.0.0:{}", config.proxy.port);
let connections = ConnectionManager::new();
// Only allow one client to join at a time.
let connections = ConnectionManager::new(Some(1));
let listener = connections
.spawn_listener(bind_address, running.child_token())
.await
.map_err(|_| ())?;
.map_err(Error::Network)?;
info!(
"Upstream server: {}:{}",
config.proxy.upstream_host, config.proxy.upstream_port
);
let upstream_address = format!("{}:{}", config.proxy.upstream_host, config.proxy.upstream_port);
info!("Upstream server: {}", upstream_address);
let upstream = TcpStream::connect(upstream_address).await.map_err(Error::Io)?;
let upstream = Connection::new_server(0, upstream);
Ok(Proxy {
running,
connections,
listener,
upstream,
})
}
#[tracing::instrument]
async fn update(&mut self) -> Result<(), Self::Error> {
todo!()
let _ = self.connections.update().await.map_err(Error::Network)?;
let Some(client) = self.connections.clients_mut().take(1).next() else {
return Ok(());
};
let mut client_parsing_error = false;
// At the same time, try to read packets from the server and client.
// Forward the packet onto the other.
tokio::select! {
packet = client.read_packet() => {
if let Some(packet) = packet {
match packet {
Ok(packet) => {
trace!("Got packet from client: {:?}", packet);
self.upstream.send_packet(packet).await.map_err(Error::Network)?;
}
Err(NetworkError::Parsing) => {
debug!("Got invalid data from client (id {})", client.id());
client_parsing_error = true;
}
Err(e) => return Err(Error::Network(e)),
}
}
}
packet = self.upstream.read_packet() => {
if let Some(packet) = packet {
match packet {
Ok(packet) => {
trace!("Got packet from upstream: {:?}", packet);
client.send_packet(packet).await.map_err(Error::Network)?;
}
Err(NetworkError::Parsing) => {
error!("Got invalid data from upstream");
return Err(Error::Network(NetworkError::Parsing));
},
Err(e) => return Err(Error::Network(e)),
}
}
}
}
if client_parsing_error {
let id = client.id();
// Drop the &mut Connection
let _ = client;
let _ = self.connections.disconnect(id, Some(serde_json::json!({ "text": "Received malformed data." }))).await;
}
Ok(())
}
#[tracing::instrument]
async fn shutdown(self) -> Result<(), Self::Error> {
// Ensure any child tasks have been shut down.
self.running.cancel();
let _ = self.listener.await.map_err(|_| ())?;
let _ = self.connections.shutdown(None).await.map_err(|_| ())?;
let _ = self.listener.await.map_err(Error::Task)?;
let _ = self.connections.shutdown(None).await.map_err(Error::Network)?;
Ok(())
}