diff --git a/Cargo.lock b/Cargo.lock index f294530..c111603 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -236,6 +236,7 @@ dependencies = [ "rsa", "serde", "serde_json", + "spki", "thiserror 2.0.12", "tokio", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index 2d99ce3..9e75262 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,3 +44,4 @@ der = { version = "0.7.10", features = ["alloc", "derive"] } aes = "0.8.4" cfb8 = { version = "0.8.1", features = ["alloc"] } generic-array = "0.14.7" +spki = { version = "0.7.3", features = ["std"] } diff --git a/src/net/connection/downstream/mod.rs b/src/net/connection/downstream/mod.rs index 78b6e02..a2ad248 100644 --- a/src/net/connection/downstream/mod.rs +++ b/src/net/connection/downstream/mod.rs @@ -55,17 +55,17 @@ impl DownstreamConnection { pub async fn handle_handshake(&mut self) -> Result<(), Error> { use packets::handshake::serverbound::Handshake; - let handshake = self - .read_specific_packet::() - .await - .ok_or(Error::Unexpected)??; + let handshake = self.read_specific_packet::().await?; match handshake.next_state { ClientState::Status => { *self.client_state_mut() = DownstreamConnectionState::StatusRequest; *self.inner_state_mut() = ClientState::Status; } - ClientState::Login => todo!(), + ClientState::Login => { + *self.client_state_mut() = DownstreamConnectionState::LoginStart; + *self.inner_state_mut() = ClientState::Login; + } _ => { self.disconnect(Some( serde_json::json!({ "text": "Received invalid handshake." }), @@ -79,14 +79,13 @@ impl DownstreamConnection { pub async fn handle_status_ping(&mut self, online_player_count: usize) -> Result<(), Error> { // The state just changed from Handshake to Status. use base64::Engine; - use packets::status::clientbound::{PingResponse, StatusResponse}; + use packets::status::{ + clientbound::{PingResponse, StatusResponse}, + serverbound::{PingRequest, StatusRequest}, + }; // Read the status request packet. - let Packet::StatusRequest(_status_request) = - self.read_packet().await.ok_or(Error::Unexpected)?? - else { - return Err(Error::Unexpected); - }; + let _status_request = self.read_specific_packet::().await?; // Send the status response packet. let config = Config::instance(); @@ -110,19 +109,101 @@ impl DownstreamConnection { }).await?; // Read the ping request packet. - let Packet::PingRequest(ping_request) = - self.read_packet().await.ok_or(Error::Unexpected)?? - else { - return Err(Error::Unexpected); - }; + let payload = self.read_specific_packet::().await?.payload; // Send the ping response packet. - self.send_packet(PingResponse { - payload: ping_request.payload, + self.send_packet(PingResponse { payload }).await?; + + self.disconnect(None).await?; + + Ok(()) + } + pub async fn handle_login(&mut self) -> Result<(), Error> { + // The state just changed from Handshake to Login. + use packets::login::{clientbound::LoginSuccess, serverbound::LoginStart}; + + // Read login start packet. + let login_start = self.read_specific_packet::().await?; + + // Enable encryption and authenticate with Mojang. + self.enable_encryption().await?; + + // Enable compression. + self.enable_compression().await?; + + // Send login success packet. + self.send_packet(LoginSuccess { + // Generate a random UUID if none was provided. + uuid: login_start.uuid.unwrap_or(uuid::Uuid::new_v4()), + username: login_start.name, + properties: vec![], }) .await?; - self.disconnect(None).await + Ok(()) + } + pub async fn enable_encryption(&mut self) -> Result<(), Error> { + use crate::protocol::encryption::*; + use packets::login::{clientbound::EncryptionRequest, serverbound::EncryptionResponse}; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + assert!(matches!(self.inner_state(), ClientState::Login)); + + // RSA keys were generated on startup. + let config = Config::instance(); + let (public_key, private_key) = &config.rsa_key_pair; + tracing::trace!( + "{}", + public_key + .serialize() + .iter() + .map(|b| format!("{b:02X?}")) + .collect::>() + .join("") + ); + + // Generate a verify token. + let mut rng = StdRng::from_entropy(); + let verify_token: [u8; 16] = rng.gen(); + + // Send the encryption request packet. + self.send_packet(EncryptionRequest { + server_id: "".into(), + public_key: public_key.serialize(), + verify_token: verify_token.to_vec(), + // TODO: Implement Mojang authentication. + use_mojang_authentication: false, + }) + .await?; + + // Read the encryption response packet. + let encryption_response = self.read_specific_packet::().await?; + + // Verify the response. + let decrypted_verify_token = private_key + .decrypt(Pkcs1v15Encrypt, &encryption_response.verify_token) + .expect("failed to decrypt verify token"); + if decrypted_verify_token != verify_token { + return Err(Error::Invalid); + } + + // Decrypt the shared secret. + let shared_secret = private_key + .decrypt(Pkcs1v15Encrypt, &encryption_response.shared_secret) + .expect("failed to decrypt shared secret"); + + // Enable encryption on the connection. + let encryptor = + Aes128Cfb8Encryptor::new((&(*shared_secret)).into(), (&(*shared_secret)).into()); + let decryptor = + Aes128Cfb8Decryptor::new((&(*shared_secret)).into(), (&(*shared_secret)).into()); + self.inner.stream.codec_mut().aes_cipher = Some((encryptor, decryptor, 0)); + + Ok(()) + } + pub async fn enable_compression(&mut self) -> Result<(), Error> { + // TODO: Implement compression. + Ok(()) } pub async fn read_packet(&mut self) -> Option> { self.inner.read_packet().await diff --git a/src/net/connection/mod.rs b/src/net/connection/mod.rs index 8aec84b..6641b61 100644 --- a/src/net/connection/mod.rs +++ b/src/net/connection/mod.rs @@ -78,14 +78,9 @@ impl GenericConnection { packet } - pub async fn read_specific_packet>(&mut self) -> Option> { - self.read_packet() - .await - .map(|packet| match packet.map(P::try_from) { - Ok(Ok(p)) => Ok(p), - Ok(Err(_)) => Err(Error::Unexpected), - Err(e) => Err(e), - }) + pub async fn read_specific_packet>(&mut self) -> Result { + let packet = self.read_packet().await.ok_or(Error::Unexpected)??; + P::try_from(packet).map_err(|_| Error::Unexpected) } pub async fn send_packet>(&mut self, packet: P) -> Result<(), Error> { let packet: Packet = packet.into(); diff --git a/src/net/connection/upstream.rs b/src/net/connection/upstream.rs index 9f74b20..30e76f4 100644 --- a/src/net/connection/upstream.rs +++ b/src/net/connection/upstream.rs @@ -25,6 +25,15 @@ impl UpstreamConnection { match packet { Packet::EncryptionRequest(ref packet) => { // Extract the public key from the packet. + tracing::trace!( + "{}", + packet + .public_key + .iter() + .map(|b| format!("{b:02X?}")) + .collect::>() + .join("") + ); let public_key = rsa::RsaPublicKey::parse(&packet.public_key) .expect("Failed to parse RSA public key from packet") .1; diff --git a/src/net/error.rs b/src/net/error.rs index 6c8bd72..eb60d6c 100644 --- a/src/net/error.rs +++ b/src/net/error.rs @@ -12,6 +12,8 @@ pub enum Error { Unexpected, #[error("Internal channel disconnected")] ConnectionChannelDisconnnection, + #[error("Invalid response")] + Invalid, } impl From for Error { fn from(value: std::io::Error) -> Self { diff --git a/src/protocol/encryption.rs b/src/protocol/encryption.rs index 6ade3c2..cee9f49 100644 --- a/src/protocol/encryption.rs +++ b/src/protocol/encryption.rs @@ -1,7 +1,5 @@ -use der::{ - asn1::{AnyRef, ObjectIdentifier}, - Decode, DecodeValue, Encode, EncodeValue, Header, Reader, Sequence, Tag, -}; +use der::Encode; +use spki::{DecodePublicKey, SubjectPublicKeyInfo}; pub use crate::protocol::parsing::Parsable; pub use aes::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit}; @@ -9,85 +7,19 @@ pub use generic_array::{ typenum::{UInt, UTerm, B1}, GenericArray, }; -pub use rsa::{RsaPrivateKey, RsaPublicKey}; +pub use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey}; pub type Aes128Cfb8Encryptor = cfb8::Encryptor; pub type Aes128Cfb8Decryptor = cfb8::Decryptor; pub type GenericCFB8BlockArray = GenericArray>; impl Parsable for RsaPublicKey { fn parse(data: &[u8]) -> nom::IResult<&[u8], Self> { - let spki = SubjectPublicKeyInfo::from_der(data).unwrap(); - - let modulus = rsa::BigUint::from_bytes_be(spki.subject_public_key.modulus.as_bytes()); - let exponent = - rsa::BigUint::from_bytes_be(spki.subject_public_key.public_exponent.as_bytes()); - - Ok((&[], RsaPublicKey::new(modulus, exponent).unwrap())) + Ok((&[], RsaPublicKey::from_public_key_der(data).unwrap())) } fn serialize(&self) -> Vec { - use rsa::traits::PublicKeyParts; - let algorithm = PublicKeyAlgorithm::default(); - let subject_public_key = SubjectPublicKey { - modulus: der::asn1::Int::new(&self.n().to_bytes_be()).unwrap(), - public_exponent: der::asn1::Int::new(&self.e().to_bytes_be()).unwrap(), - }; - let spki = SubjectPublicKeyInfo { - algorithm, - subject_public_key, - }; - let mut buf = Vec::new(); - spki.encode(&mut buf).unwrap(); - buf + SubjectPublicKeyInfo::from_key(self.clone()) + .unwrap() + .to_der() + .unwrap() } } - -// Custom decode implementation for SubjectPublicKeyInfo. -#[derive(Debug, Clone, PartialEq, Eq)] -struct SubjectPublicKeyInfo<'a> { - algorithm: PublicKeyAlgorithm<'a>, - subject_public_key: SubjectPublicKey, -} -impl<'a> DecodeValue<'a> for SubjectPublicKeyInfo<'a> { - fn decode_value>(reader: &mut R, _header: Header) -> der::Result { - let algorithm = reader.decode()?; - let spk_der: der::asn1::BitString = reader.decode()?; - let spk_der = spk_der.as_bytes().unwrap(); - let subject_public_key = SubjectPublicKey::from_der(spk_der).unwrap(); - - Ok(Self { - algorithm, - subject_public_key, - }) - } -} -impl EncodeValue for SubjectPublicKeyInfo<'_> { - fn value_len(&self) -> der::Result { - self.algorithm.value_len()? + self.subject_public_key.value_len()? - } - fn encode_value(&self, writer: &mut impl der::Writer) -> der::Result<()> { - self.algorithm.encode_value(writer)?; - self.subject_public_key.encode_value(writer)?; - Ok(()) - } -} -impl<'a> Sequence<'a> for SubjectPublicKeyInfo<'a> {} - -#[derive(Debug, Clone, PartialEq, Eq, Sequence)] -struct PublicKeyAlgorithm<'a> { - pub algorithm: ObjectIdentifier, - pub parameters: Option>, -} -impl Default for PublicKeyAlgorithm<'_> { - fn default() -> Self { - Self { - algorithm: ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1"), - parameters: Some(AnyRef::new(Tag::Null, &[]).unwrap()), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Sequence)] -struct SubjectPublicKey { - pub modulus: der::asn1::Int, - pub public_exponent: der::asn1::Int, -} diff --git a/src/server/error.rs b/src/server/error.rs index 07610ff..e66ac07 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -7,14 +7,9 @@ pub use tokio::task::JoinError as TaskError; #[derive(thiserror::Error, Debug)] pub enum Error { #[error(transparent)] - Io(IoError), + Io(#[from] IoError), #[error(transparent)] - Task(TaskError), + Task(#[from] TaskError), #[error(transparent)] - Network(NetworkError), -} -impl From for Error { - fn from(err: NetworkError) -> Self { - Error::Network(err) - } + Network(#[from] NetworkError), } diff --git a/src/server/mod.rs b/src/server/mod.rs index d5c3b7a..ad05ef8 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -16,7 +16,7 @@ use tokio_util::sync::CancellationToken; #[derive(Debug)] pub struct Server { running: CancellationToken, - connections: DownstreamConnectionManager, + pub connections: DownstreamConnectionManager, listener: JoinHandle<()>, } #[async_trait::async_trait] @@ -80,6 +80,14 @@ impl App for Server { .await; // Handle login connections. + let _ = futures::future::join_all( + self.connections + .clients_mut() + .filter(|c| matches!(c.client_state(), DownstreamConnectionState::LoginStart)) + .map(|c| c.handle_login()), + ) + .await; + // Handle play connection packets. // Process world updates. // Send out play connection updates.