use std::{
    net::{AddrParseError, Ipv4Addr, Ipv6Addr, SocketAddr},
    num::ParseIntError,
    str::Utf8Error,
};
use bytes::Buf;
use thiserror::Error;
#[derive(Debug, Clone)]
pub enum ProxyProtocolV1Info {
    Tcp {
        source: SocketAddr,
        destination: SocketAddr,
    },
    Udp {
        source: SocketAddr,
        destination: SocketAddr,
    },
    Unknown,
}
#[derive(Error, Debug)]
#[error("Invalid proxy protocol header")]
pub enum ParseError {
    #[error("Not enough bytes provided")]
    NotEnoughBytes,
    NoCrLf,
    NoProxyPreamble,
    NoProtocol,
    InvalidProtocol,
    NoSourceAddress,
    NoDestinationAddress,
    NoSourcePort,
    NoDestinationPort,
    TooManyFields,
    InvalidUtf8(#[from] Utf8Error),
    InvalidAddress(#[from] AddrParseError),
    InvalidPort(#[from] ParseIntError),
}
impl ParseError {
    pub const fn not_enough_bytes(&self) -> bool {
        matches!(self, &Self::NotEnoughBytes)
    }
}
impl ProxyProtocolV1Info {
    #[allow(clippy::too_many_lines)]
    pub(super) fn parse<B>(buf: &mut B) -> Result<Self, ParseError>
    where
        B: Buf + AsRef<[u8]>,
    {
        use ParseError as E;
        if buf.remaining() < 15 {
            return Err(E::NotEnoughBytes);
        }
        let Some(crlf) = buf
            .as_ref()
            .windows(2)
            .take(108)
            .position(|needle| needle == [0x0D, 0x0A])
        else {
            return if buf.remaining() < 108 {
                Err(E::NotEnoughBytes)
            } else {
                Err(E::NoCrLf)
            };
        };
        let bytes = &buf.as_ref()[..crlf];
        let mut it = bytes.splitn(6, |c| c == &b' ');
        if it.next() != Some(b"PROXY") {
            return Err(E::NoProxyPreamble);
        }
        let result = match it.next() {
            Some(b"TCP4") => {
                let source_address: Ipv4Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
                let destination_address: Ipv4Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
                let source_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
                let destination_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
                if it.next().is_some() {
                    return Err(E::TooManyFields);
                }
                let source = (source_address, source_port).into();
                let destination = (destination_address, destination_port).into();
                Self::Tcp {
                    source,
                    destination,
                }
            }
            Some(b"TCP6") => {
                let source_address: Ipv6Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
                let destination_address: Ipv6Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
                let source_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
                let destination_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
                if it.next().is_some() {
                    return Err(E::TooManyFields);
                }
                let source = (source_address, source_port).into();
                let destination = (destination_address, destination_port).into();
                Self::Tcp {
                    source,
                    destination,
                }
            }
            Some(b"UDP4") => {
                let source_address: Ipv4Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
                let destination_address: Ipv4Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
                let source_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
                let destination_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
                if it.next().is_some() {
                    return Err(E::TooManyFields);
                }
                let source = (source_address, source_port).into();
                let destination = (destination_address, destination_port).into();
                Self::Udp {
                    source,
                    destination,
                }
            }
            Some(b"UDP6") => {
                let source_address: Ipv6Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoSourceAddress)?)?.parse()?;
                let destination_address: Ipv6Addr =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationAddress)?)?.parse()?;
                let source_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoSourcePort)?)?.parse()?;
                let destination_port: u16 =
                    std::str::from_utf8(it.next().ok_or(E::NoDestinationPort)?)?.parse()?;
                if it.next().is_some() {
                    return Err(E::TooManyFields);
                }
                let source = (source_address, source_port).into();
                let destination = (destination_address, destination_port).into();
                Self::Udp {
                    source,
                    destination,
                }
            }
            Some(b"UNKNOWN") => Self::Unknown,
            Some(_) => return Err(E::InvalidProtocol),
            None => return Err(E::NoProtocol),
        };
        buf.advance(crlf + 2);
        Ok(result)
    }
    #[must_use]
    pub fn is_ipv4(&self) -> bool {
        match self {
            Self::Udp {
                source,
                destination,
            }
            | Self::Tcp {
                source,
                destination,
            } => source.is_ipv4() && destination.is_ipv4(),
            Self::Unknown => false,
        }
    }
    #[must_use]
    pub fn is_ipv6(&self) -> bool {
        match self {
            Self::Udp {
                source,
                destination,
            }
            | Self::Tcp {
                source,
                destination,
            } => source.is_ipv6() && destination.is_ipv6(),
            Self::Unknown => false,
        }
    }
    #[must_use]
    pub const fn is_tcp(&self) -> bool {
        matches!(self, Self::Tcp { .. })
    }
    #[must_use]
    pub const fn is_udp(&self) -> bool {
        matches!(self, Self::Udp { .. })
    }
    #[must_use]
    pub const fn is_unknown(&self) -> bool {
        matches!(self, Self::Unknown)
    }
    #[must_use]
    pub const fn source(&self) -> Option<&SocketAddr> {
        match self {
            Self::Udp { source, .. } | Self::Tcp { source, .. } => Some(source),
            Self::Unknown => None,
        }
    }
    #[must_use]
    pub const fn destination(&self) -> Option<&SocketAddr> {
        match self {
            Self::Udp { destination, .. } | Self::Tcp { destination, .. } => Some(destination),
            Self::Unknown => None,
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_parse() {
        let mut buf =
            b"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\nhello world".as_slice();
        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
        assert_eq!(buf, b"hello world");
        assert!(info.is_tcp());
        assert!(!info.is_udp());
        assert!(!info.is_unknown());
        assert!(info.is_ipv4());
        assert!(!info.is_ipv6());
        let mut buf =
            b"PROXY TCP6 ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
            .as_slice();
        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
        assert_eq!(buf, b"hello world");
        assert!(info.is_tcp());
        assert!(!info.is_udp());
        assert!(!info.is_unknown());
        assert!(!info.is_ipv4());
        assert!(info.is_ipv6());
        let mut buf = b"PROXY UNKNOWN\r\nhello world".as_slice();
        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
        assert_eq!(buf, b"hello world");
        assert!(!info.is_tcp());
        assert!(!info.is_udp());
        assert!(info.is_unknown());
        assert!(!info.is_ipv4());
        assert!(!info.is_ipv6());
        let mut buf =
            b"PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\nhello world"
            .as_slice();
        let info = ProxyProtocolV1Info::parse(&mut buf).unwrap();
        assert_eq!(buf, b"hello world");
        assert!(!info.is_tcp());
        assert!(!info.is_udp());
        assert!(info.is_unknown());
        assert!(!info.is_ipv4());
        assert!(!info.is_ipv6());
    }
}