diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index d2ef210..1d4852b 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -21,6 +21,7 @@ pub mod parsing; pub mod types; pub use error::{Error, Result}; +use types::VarInt; /// Enum representation of the connection's current state. /// @@ -45,3 +46,23 @@ pub enum ClientState { /// The client has disconnected, and the connection struct should be removed. No packets should be sent or received. Disconnected, } +impl parsing::Parsable for ClientState { + fn parse(data: &[u8]) -> nom::IResult<&[u8], Self> + where + Self: Sized, + { + nom::combinator::map_res(VarInt::parse, |next_state: VarInt| match *next_state { + 1 => Ok(ClientState::Status), + 2 => Ok(ClientState::Login), + _ => Err(()), + })(data) + } + fn serialize(&self) -> Vec { + let byte = match &self { + &ClientState::Status => 1, + &ClientState::Login => 2, + _ => 0, + }; + vec![byte] + } +} diff --git a/src/protocol/packets/serverbound/handshake.rs b/src/protocol/packets/serverbound/handshake.rs index 1042744..f1d4102 100644 --- a/src/protocol/packets/serverbound/handshake.rs +++ b/src/protocol/packets/serverbound/handshake.rs @@ -1,5 +1,4 @@ use crate::protocol::{types::VarInt, ClientState}; -use nom::combinator::map_res; #[derive(Clone, Debug, PartialEq)] pub struct SH00Handshake { @@ -17,12 +16,8 @@ crate::protocol::packets::packet!( let (data, protocol_version) = VarInt::parse(data)?; let (data, server_address) = String::parse(data)?; let (data, server_port) = u16::parse(data)?; + let (data, next_state) = ClientState::parse(data)?; // let (data, next_state) = VarInt::parse(data)?; - let (data, next_state) = map_res(VarInt::parse, |next_state: VarInt| match *next_state { - 1 => Ok(ClientState::Status), - 2 => Ok(ClientState::Login), - _ => Err(()), - })(data)?; Ok(( data, @@ -39,14 +34,7 @@ crate::protocol::packets::packet!( output.extend(packet.protocol_version.serialize()); output.extend(packet.server_address.serialize()); output.extend(packet.server_port.serialize()); - output.extend( - VarInt::from(match packet.next_state { - ClientState::Status => 0x01, - ClientState::Login => 0x02, - _ => panic!("invalid SH00Handshake next_state"), - }) - .serialize(), - ); + output.extend(packet.next_state.serialize()); output } ); diff --git a/src/protocol/parsing.rs b/src/protocol/parsing.rs index a9d52a1..c756a0a 100644 --- a/src/protocol/parsing.rs +++ b/src/protocol/parsing.rs @@ -2,7 +2,7 @@ pub use nom::IResult; use nom::{ bytes::streaming::{take, take_while_m_n}, combinator::map_res, - number::streaming as nom_nums, + number::streaming as nom_nums, Parser, }; /// Implementation of the protocol's VarInt type. @@ -154,10 +154,17 @@ impl Parsable for VarInt { fn parse(data: &[u8]) -> IResult<&[u8], Self> { let mut output = 0u32; - let (rest, bytes) = take_while_m_n(1, 5, |byte| byte & 0x80 == 0x80)(data)?; - for (i, &b) in bytes.iter().enumerate() { + // 0-4 bytes with the most significant bit set, + // followed by one with the bit unset. + let start_parser = take_while_m_n(0, 4, |byte| byte & 0x80 == 0x80); + let end_parser = take_while_m_n(1, 1, |byte| byte & 0x80 != 0x80); + let mut parser = start_parser.and(end_parser); + let (rest, (start, end)) = parser.parse(data)?; + + for (i, &b) in start.iter().enumerate() { output |= ((b & 0x7f) as u32) << (7 * i); } + output |= ((end[0] & 0x7f) as u32) << (7 * start.len()); Ok((rest, VarInt(output as i32))) } #[tracing::instrument] @@ -354,6 +361,8 @@ mod tests { for (value, bytes) in get_varints() { assert_eq!(value, *VarInt::parse(&bytes).unwrap().1); } + // Check if the VarInt is too long (>5 bytes). + assert!(VarInt::parse(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x08]).is_err()); } #[test] fn serialize_varint_works() {