{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Packet.Decode (
    decodePacket
  , decodePackets
  , decodeCryptPackets
  , decodeStatelessResetToken
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Short as Short
import qualified UnliftIO.Exception as E

import Network.QUIC.Imports
import Network.QUIC.Packet.Header
import Network.QUIC.Types

----------------------------------------------------------------

-- Server uses this.
decodeCryptPackets :: ByteString -> IO [(CryptPacket,EncryptionLevel,Int)]
decodeCryptPackets :: ByteString -> IO [(CryptPacket, EncryptionLevel, Int)]
decodeCryptPackets ByteString
bs0 = [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO [PacketI]
decodePackets ByteString
bs0
  where
    unwrap :: [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap (PacketIC CryptPacket
c EncryptionLevel
l Int
s:[PacketI]
xs) = (CryptPacket
c,EncryptionLevel
l,Int
s) forall a. a -> [a] -> [a]
: [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap [PacketI]
xs
    unwrap (PacketI
_:[PacketI]
xs)              = [PacketI] -> [(CryptPacket, EncryptionLevel, Int)]
unwrap [PacketI]
xs
    unwrap []                  = []

-- Client uses this.
decodePackets :: ByteString -> IO [PacketI]
decodePackets :: ByteString -> IO [PacketI]
decodePackets ByteString
bs0 = forall {c}. ByteString -> ([PacketI] -> c) -> IO c
loop ByteString
bs0 forall a. a -> a
id
  where
    loop :: ByteString -> ([PacketI] -> c) -> IO c
loop ByteString
"" [PacketI] -> c
build = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [PacketI] -> c
build [] -- fixme
    loop ByteString
bs [PacketI] -> c
build = do
        (PacketI
pkt, ByteString
rest) <- ByteString -> IO (PacketI, ByteString)
decodePacket ByteString
bs
        ByteString -> ([PacketI] -> c) -> IO c
loop ByteString
rest ([PacketI] -> c
build forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PacketI
pkt forall a. a -> [a] -> [a]
:))

decodePacket :: ByteString -> IO (PacketI, ByteString)
decodePacket :: ByteString -> IO (PacketI, ByteString)
decodePacket ByteString
bs = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
E.handle forall {m :: * -> *} {b}.
(Monad m, IsString b) =>
BufferOverrun -> m (PacketI, b)
handler forall a b. (a -> b) -> a -> b
$ forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf -> do
    forall a. Readable a => a -> IO ()
save ReadBuffer
rbuf
    Flags Protected
proFlags <- forall a. Word8 -> Flags a
Flags forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    let short :: Bool
short = Flags Protected -> Bool
isShort Flags Protected
proFlags
    PacketI
pkt <- ReadBuffer -> Flags Protected -> Bool -> IO PacketI
decode ReadBuffer
rbuf Flags Protected
proFlags Bool
short
    Int
siz <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    let rest :: ByteString
rest = Int -> ByteString -> ByteString
BS.drop Int
siz ByteString
bs
    forall (m :: * -> *) a. Monad m => a -> m a
return (PacketI
pkt, ByteString
rest)
  where
    decode :: ReadBuffer -> Flags Protected -> Bool -> IO PacketI
decode ReadBuffer
rbuf Flags Protected
_proFlags Bool
True = do
        Header
header <- CID -> Header
Short forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShortByteString -> CID
makeCID forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString ReadBuffer
rbuf Int
myCIDLength
        CryptPacket
cpkt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeShortCrypt ByteString
bs ReadBuffer
rbuf
        Int
siz <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
cpkt EncryptionLevel
RTT1Level Int
siz
    decode ReadBuffer
rbuf Flags Protected
proFlags Bool
False = do
        (Version
ver, CID
dCID, CID
sCID) <- ReadBuffer -> IO (Version, CID, CID)
decodeLongHeader ReadBuffer
rbuf
        case Version
ver of
          Version
Negotiation      -> do
              ReadBuffer -> CID -> CID -> IO PacketI
decodeVersionNegotiationPacket ReadBuffer
rbuf CID
dCID CID
sCID
          Version
_                -> case Version -> Flags Protected -> LongHeaderPacketType
decodeLongHeaderPacketType Version
ver Flags Protected
proFlags of
            LongHeaderPacketType
RetryPacketType     -> do
                ReadBuffer
-> Flags Protected -> Version -> CID -> CID -> IO PacketI
decodeRetryPacket ReadBuffer
rbuf Flags Protected
proFlags Version
ver CID
dCID CID
sCID
            LongHeaderPacketType
RTT0PacketType      -> do
                let header :: Header
header = Version -> CID -> CID -> Header
RTT0 Version
ver CID
dCID CID
sCID
                CryptPacket
cpkt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf
                Int
siz <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
                forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
cpkt EncryptionLevel
RTT0Level Int
siz
            LongHeaderPacketType
InitialPacketType   -> do
                Int
tokenLen <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Int64
decodeInt' ReadBuffer
rbuf
                ByteString
token <- forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf Int
tokenLen
                let header :: Header
header = Version -> CID -> CID -> ByteString -> Header
Initial Version
ver CID
dCID CID
sCID ByteString
token
                CryptPacket
cpkt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf
                Int
siz <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
                forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
cpkt EncryptionLevel
InitialLevel Int
siz
            LongHeaderPacketType
HandshakePacketType -> do
                let header :: Header
header = Version -> CID -> CID -> Header
Handshake Version
ver CID
dCID CID
sCID
                CryptPacket
crypt <- Header -> Crypt -> CryptPacket
CryptPacket Header
header forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf
                Int
siz <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
                forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ CryptPacket -> EncryptionLevel -> Int -> PacketI
PacketIC CryptPacket
crypt EncryptionLevel
HandshakeLevel Int
siz
    handler :: BufferOverrun -> m (PacketI, b)
handler BufferOverrun
BufferOverrun = forall (m :: * -> *) a. Monad m => a -> m a
return (BrokenPacket -> PacketI
PacketIB BrokenPacket
BrokenPacket,b
"")

makeShortCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeShortCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeShortCrypt ByteString
bs ReadBuffer
rbuf = do
    Int
len <- forall a. Readable a => a -> IO Int
remainingSize ReadBuffer
rbuf
    Int
here <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    forall a. Readable a => a -> Int -> IO ()
ff ReadBuffer
rbuf Int
len
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> Int -> Maybe MigrationInfo -> Crypt
Crypt Int
here ByteString
bs Int
0 forall a. Maybe a
Nothing

makeLongCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt :: ByteString -> ReadBuffer -> IO Crypt
makeLongCrypt ByteString
bs ReadBuffer
rbuf = do
    Int
len <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Int64
decodeInt' ReadBuffer
rbuf
    Int
here <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    forall a. Readable a => a -> Int -> IO ()
ff ReadBuffer
rbuf Int
len
    let pkt :: ByteString
pkt = Int -> ByteString -> ByteString
BS.take (Int
here forall a. Num a => a -> a -> a
+ Int
len) ByteString
bs
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> Int -> Maybe MigrationInfo -> Crypt
Crypt Int
here ByteString
pkt Int
0 forall a. Maybe a
Nothing

----------------------------------------------------------------

decodeLongHeader :: ReadBuffer -> IO (Version, CID, CID)
decodeLongHeader :: ReadBuffer -> IO (Version, CID, CID)
decodeLongHeader ReadBuffer
rbuf  = do
    Version
ver     <- Word32 -> Version
Version forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word32
read32 ReadBuffer
rbuf
    Int
dcidlen <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    CID
dCID    <- ShortByteString -> CID
makeCID forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString ReadBuffer
rbuf Int
dcidlen
    Int
scidlen <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    CID
sCID    <- ShortByteString -> CID
makeCID forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString ReadBuffer
rbuf Int
scidlen
    forall (m :: * -> *) a. Monad m => a -> m a
return (Version
ver, CID
dCID, CID
sCID)

decodeVersionNegotiationPacket :: ReadBuffer -> CID -> CID -> IO PacketI
decodeVersionNegotiationPacket :: ReadBuffer -> CID -> CID -> IO PacketI
decodeVersionNegotiationPacket ReadBuffer
rbuf CID
dCID CID
sCID = do
    Int
siz <- forall a. Readable a => a -> IO Int
remainingSize ReadBuffer
rbuf
    [Version]
vers <- forall {t} {a}.
(Ord t, Num t) =>
t -> ([a] -> [Version]) -> IO [Version]
decodeVersions Int
siz forall a. a -> a
id
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ VersionNegotiationPacket -> PacketI
PacketIV forall a b. (a -> b) -> a -> b
$ CID -> CID -> [Version] -> VersionNegotiationPacket
VersionNegotiationPacket CID
dCID CID
sCID [Version]
vers
  where
    decodeVersions :: t -> ([a] -> [Version]) -> IO [Version]
decodeVersions t
siz [a] -> [Version]
vers
      | t
siz forall a. Ord a => a -> a -> Bool
>= t
4  = do
            Version
ver <- Word32 -> Version
Version forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word32
read32 ReadBuffer
rbuf
            t -> ([a] -> [Version]) -> IO [Version]
decodeVersions (t
siz forall a. Num a => a -> a -> a
- t
4) ((Version
ver forall a. a -> [a] -> [a]
:) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [Version]
vers)
      | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [a] -> [Version]
vers []

decodeRetryPacket :: ReadBuffer -> Flags Protected -> Version -> CID -> CID -> IO PacketI
decodeRetryPacket :: ReadBuffer
-> Flags Protected -> Version -> CID -> CID -> IO PacketI
decodeRetryPacket ReadBuffer
rbuf Flags Protected
_proFlags Version
version CID
dCID CID
sCID = do
    Int
rsiz <- forall a. Readable a => a -> IO Int
remainingSize ReadBuffer
rbuf
    ByteString
token <- forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf (Int
rsiz forall a. Num a => a -> a -> a
- Int
16)
    Int
siz <- forall a. Readable a => a -> IO Int
savingSize ReadBuffer
rbuf
    ByteString
pseudo <- forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
negate Int
siz
    ByteString
tag <- forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf Int
16
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ RetryPacket -> PacketI
PacketIR forall a b. (a -> b) -> a -> b
$ Version
-> CID
-> CID
-> ByteString
-> Either CID (ByteString, ByteString)
-> RetryPacket
RetryPacket Version
version CID
dCID CID
sCID ByteString
token (forall a b. b -> Either a b
Right (ByteString
pseudo,ByteString
tag))

----------------------------------------------------------------

decodeStatelessResetToken :: ByteString -> Maybe StatelessResetToken
decodeStatelessResetToken :: ByteString -> Maybe StatelessResetToken
decodeStatelessResetToken ByteString
bs
  | Int
len forall a. Ord a => a -> a -> Bool
< Int
21  = forall a. Maybe a
Nothing
  | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ ShortByteString -> StatelessResetToken
StatelessResetToken forall a b. (a -> b) -> a -> b
$ ByteString -> ShortByteString
Short.toShort ByteString
token
  where
    len :: Int
len = ByteString -> Int
BS.length ByteString
bs
    (ByteString
_,ByteString
token) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
len forall a. Num a => a -> a -> a
- Int
16) ByteString
bs