use std::{
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::{
    rustls::{
        pki_types::CertificateDer, ProtocolVersion, ServerConfig, ServerConnection,
        SupportedCipherSuite,
    },
    TlsAcceptor,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TlsStreamInfo {
    pub protocol_version: ProtocolVersion,
    pub negotiated_cipher_suite: SupportedCipherSuite,
    pub sni_hostname: Option<String>,
    pub alpn_protocol: Option<Vec<u8>>,
    pub peer_certificates: Option<Vec<CertificateDer<'static>>>,
}
impl TlsStreamInfo {
    #[must_use]
    pub fn is_alpn_h2(&self) -> bool {
        matches!(self.alpn_protocol.as_deref(), Some(b"h2"))
    }
}
pin_project_lite::pin_project! {
    #[project = MaybeTlsStreamProj]
    pub enum MaybeTlsStream<T> {
        Secure {
            #[pin]
            stream: tokio_rustls::server::TlsStream<T>
        },
        Insecure {
            #[pin]
            stream: T,
        },
    }
}
impl<T> MaybeTlsStream<T> {
    pub fn get_ref(&self) -> &T {
        match self {
            Self::Secure { stream } => stream.get_ref().0,
            Self::Insecure { stream } => stream,
        }
    }
    pub fn get_tls_connection(&self) -> Option<&ServerConnection> {
        match self {
            Self::Secure { stream } => Some(stream.get_ref().1),
            Self::Insecure { .. } => None,
        }
    }
    pub fn tls_info(&self) -> Option<TlsStreamInfo> {
        let conn = self.get_tls_connection()?;
        let protocol_version = conn
            .protocol_version()
            .expect("TLS handshake is not done yet");
        let negotiated_cipher_suite = conn
            .negotiated_cipher_suite()
            .expect("TLS handshake is not done yet");
        let sni_hostname = conn.server_name().map(ToOwned::to_owned);
        let alpn_protocol = conn.alpn_protocol().map(ToOwned::to_owned);
        let peer_certificates = conn.peer_certificates().map(|certs| {
            certs
                .iter()
                .cloned()
                .map(CertificateDer::into_owned)
                .collect()
        });
        Some(TlsStreamInfo {
            protocol_version,
            negotiated_cipher_suite,
            sni_hostname,
            alpn_protocol,
            peer_certificates,
        })
    }
}
impl<T> AsyncRead for MaybeTlsStream<T>
where
    T: AsyncRead + AsyncWrite + Unpin,
{
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &mut ReadBuf,
    ) -> Poll<std::io::Result<()>> {
        match self.project() {
            MaybeTlsStreamProj::Secure { stream } => stream.poll_read(cx, buf),
            MaybeTlsStreamProj::Insecure { stream } => stream.poll_read(cx, buf),
        }
    }
}
impl<T> AsyncWrite for MaybeTlsStream<T>
where
    T: AsyncRead + AsyncWrite + Unpin,
{
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        match self.project() {
            MaybeTlsStreamProj::Secure { stream } => stream.poll_write(cx, buf),
            MaybeTlsStreamProj::Insecure { stream } => stream.poll_write(cx, buf),
        }
    }
    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[std::io::IoSlice<'_>],
    ) -> Poll<Result<usize, std::io::Error>> {
        match self.project() {
            MaybeTlsStreamProj::Secure { stream } => stream.poll_write_vectored(cx, bufs),
            MaybeTlsStreamProj::Insecure { stream } => stream.poll_write_vectored(cx, bufs),
        }
    }
    fn is_write_vectored(&self) -> bool {
        match self {
            Self::Secure { stream } => stream.is_write_vectored(),
            Self::Insecure { stream } => stream.is_write_vectored(),
        }
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.project() {
            MaybeTlsStreamProj::Secure { stream } => stream.poll_flush(cx),
            MaybeTlsStreamProj::Insecure { stream } => stream.poll_flush(cx),
        }
    }
    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        match self.project() {
            MaybeTlsStreamProj::Secure { stream } => stream.poll_shutdown(cx),
            MaybeTlsStreamProj::Insecure { stream } => stream.poll_shutdown(cx),
        }
    }
}
#[derive(Clone)]
pub struct MaybeTlsAcceptor {
    tls_config: Option<Arc<ServerConfig>>,
}
impl MaybeTlsAcceptor {
    #[must_use]
    pub fn new(tls_config: Option<Arc<ServerConfig>>) -> Self {
        Self { tls_config }
    }
    #[must_use]
    pub fn new_secure(tls_config: Arc<ServerConfig>) -> Self {
        Self {
            tls_config: Some(tls_config),
        }
    }
    #[must_use]
    pub fn new_insecure() -> Self {
        Self { tls_config: None }
    }
    #[must_use]
    pub const fn is_secure(&self) -> bool {
        self.tls_config.is_some()
    }
    pub async fn accept<T>(&self, stream: T) -> Result<MaybeTlsStream<T>, std::io::Error>
    where
        T: AsyncRead + AsyncWrite + Unpin,
    {
        match &self.tls_config {
            Some(config) => {
                let acceptor = TlsAcceptor::from(config.clone());
                let stream = acceptor.accept(stream).await?;
                Ok(MaybeTlsStream::Secure { stream })
            }
            None => Ok(MaybeTlsStream::Insecure { stream }),
        }
    }
}