{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.Packet.Decrypt (
    decryptCrypt
  ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import Foreign.Ptr
import Network.ByteOrder

import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Imports
import Network.QUIC.Packet.Frame
import Network.QUIC.Packet.Header
import Network.QUIC.Packet.Number
import Network.QUIC.Types

----------------------------------------------------------------

decryptCrypt :: Connection -> Buffer -> BufferSize -> Crypt -> EncryptionLevel -> IO (Maybe Plain)
decryptCrypt :: Connection
-> Buffer
-> BufferSize
-> Crypt
-> EncryptionLevel
-> IO (Maybe Plain)
decryptCrypt Connection
conn Buffer
decBuf BufferSize
_bufsiz Crypt{BufferSize
ByteString
cryptMarks :: Crypt -> BufferSize
cryptPacket :: Crypt -> ByteString
cryptPktNumOffset :: Crypt -> BufferSize
cryptMarks :: BufferSize
cryptPacket :: ByteString
cryptPktNumOffset :: BufferSize
..} EncryptionLevel
lvl = do -- fixme: bufsiz is not used
    Cipher
cipher <- Connection -> EncryptionLevel -> IO Cipher
getCipher Connection
conn EncryptionLevel
lvl
    Protector
protector <- Connection -> EncryptionLevel -> IO Protector
getProtector Connection
conn EncryptionLevel
lvl
    let proFlags :: Flags a
proFlags = Word8 -> Flags a
forall a. Word8 -> Flags a
Flags (ByteString
cryptPacket ByteString -> BufferSize -> Word8
`BS.index` BufferSize
0)
        sampleOffset :: BufferSize
sampleOffset = BufferSize
cryptPktNumOffset BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
4
        sampleLen :: BufferSize
sampleLen = Cipher -> BufferSize
sampleLength Cipher
cipher
        sample :: Sample
sample = ByteString -> Sample
Sample (ByteString -> Sample) -> ByteString -> Sample
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.take BufferSize
sampleLen (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.drop BufferSize
sampleOffset ByteString
cryptPacket
        makeMask :: Sample -> Mask
makeMask = Protector -> Sample -> Mask
unprotect Protector
protector
        Mask ByteString
mask = Sample -> Mask
makeMask Sample
sample
    case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
mask of
      Maybe (Word8, ByteString)
Nothing -> Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Plain
forall a. Maybe a
Nothing
      Just (Word8
mask1,ByteString
mask2) -> do
        let rawFlags :: Flags Raw
rawFlags@(Flags Word8
flags) = Flags Protected -> Word8 -> Flags Raw
unprotectFlags Flags Protected
forall a. Flags a
proFlags Word8
mask1
            epnLen :: BufferSize
epnLen = Flags Raw -> BufferSize
decodePktNumLength Flags Raw
rawFlags
            epn :: ByteString
epn = BufferSize -> ByteString -> ByteString
BS.take BufferSize
epnLen (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.drop BufferSize
cryptPktNumOffset ByteString
cryptPacket
            bytePN :: ByteString
bytePN = ByteString -> ByteString -> ByteString
bsXOR ByteString
mask2 ByteString
epn
            headerLen :: BufferSize
headerLen = BufferSize
cryptPktNumOffset BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
epnLen
            (ByteString
proHeader, ByteString
ciphertext) = BufferSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufferSize
headerLen ByteString
cryptPacket
            ilen :: BufferSize
ilen = ByteString -> BufferSize
BS.length ByteString
ciphertext
        BufferSize
peerPN <- if EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level then Connection -> IO BufferSize
getPeerPacketNumber Connection
conn else BufferSize -> IO BufferSize
forall (m :: * -> *) a. Monad m => a -> m a
return BufferSize
0
        let pn :: BufferSize
pn = BufferSize -> EncodedPacketNumber -> BufferSize -> BufferSize
decodePacketNumber BufferSize
peerPN (ByteString -> EncodedPacketNumber
toEncodedPacketNumber ByteString
bytePN) BufferSize
epnLen
        ByteString
header <- BufferSize -> (Buffer -> IO ()) -> IO ByteString
BS.create BufferSize
headerLen ((Buffer -> IO ()) -> IO ByteString)
-> (Buffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Buffer
p -> do
            IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy Buffer
p ByteString
proHeader
            Word8 -> Buffer -> BufferSize -> IO ()
poke8 Word8
flags Buffer
p BufferSize
0
            IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy (Buffer
p Buffer -> BufferSize -> Buffer
forall a b. Ptr a -> BufferSize -> Ptr b
`plusPtr` BufferSize
cryptPktNumOffset) (ByteString -> IO Buffer) -> ByteString -> IO Buffer
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.take BufferSize
epnLen ByteString
bytePN
        let keyPhase :: Bool
keyPhase | EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level = Word8
flags Word8 -> BufferSize -> Bool
forall a. Bits a => a -> BufferSize -> Bool
`testBit` BufferSize
2
                     | Bool
otherwise        = Bool
False
        Coder
coder <- Connection -> EncryptionLevel -> Bool -> IO Coder
getCoder Connection
conn EncryptionLevel
lvl Bool
keyPhase
        BufferSize
siz <- ByteString -> (Buffer -> IO BufferSize) -> IO BufferSize
forall a. ByteString -> (Buffer -> IO a) -> IO a
withByteString ByteString
ciphertext ((Buffer -> IO BufferSize) -> IO BufferSize)
-> (Buffer -> IO BufferSize) -> IO BufferSize
forall a b. (a -> b) -> a -> b
$ \Buffer
ibuf ->
                   ByteString -> (Buffer -> IO BufferSize) -> IO BufferSize
forall a. ByteString -> (Buffer -> IO a) -> IO a
withByteString ByteString
header ((Buffer -> IO BufferSize) -> IO BufferSize)
-> (Buffer -> IO BufferSize) -> IO BufferSize
forall a b. (a -> b) -> a -> b
$ \Buffer
abuf -> do
            let ilen' :: BufferSize
ilen' = BufferSize -> BufferSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufferSize
ilen
                alen' :: BufferSize
alen' = BufferSize -> BufferSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufferSize
headerLen
            Coder
-> Buffer
-> BufferSize
-> Buffer
-> BufferSize
-> BufferSize
-> Buffer
-> IO BufferSize
decrypt Coder
coder Buffer
ibuf BufferSize
ilen' Buffer
abuf BufferSize
alen' BufferSize
pn Buffer
decBuf
        let rrMask :: Word8
rrMask | EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level = Word8
0x18
                   | Bool
otherwise        = Word8
0x0c
            marks :: BufferSize
marks | Word8
flags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
rrMask Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0 = BufferSize
defaultPlainMarks
                  | Bool
otherwise             = BufferSize -> BufferSize
setIllegalReservedBits BufferSize
defaultPlainMarks
        if BufferSize
siz BufferSize -> BufferSize -> Bool
forall a. Ord a => a -> a -> Bool
< BufferSize
0 then
            Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Plain
forall a. Maybe a
Nothing
          else do
            Maybe [Frame]
mframes <- Buffer -> BufferSize -> IO (Maybe [Frame])
decodeFrames Buffer
decBuf BufferSize
siz
            case Maybe [Frame]
mframes of
              Maybe [Frame]
Nothing -> do
                  let marks' :: BufferSize
marks' = BufferSize -> BufferSize
setUnknownFrame BufferSize
marks
                  Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Plain -> IO (Maybe Plain))
-> Maybe Plain -> IO (Maybe Plain)
forall a b. (a -> b) -> a -> b
$ Plain -> Maybe Plain
forall a. a -> Maybe a
Just (Plain -> Maybe Plain) -> Plain -> Maybe Plain
forall a b. (a -> b) -> a -> b
$ Flags Raw -> BufferSize -> [Frame] -> BufferSize -> Plain
Plain Flags Raw
rawFlags BufferSize
pn [] BufferSize
marks'
              Just [Frame]
frames -> do
                  let marks' :: BufferSize
marks' | [Frame] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Frame]
frames = BufferSize -> BufferSize
setNoFrames BufferSize
marks
                             | Bool
otherwise   = BufferSize
marks
                  Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Plain -> IO (Maybe Plain))
-> Maybe Plain -> IO (Maybe Plain)
forall a b. (a -> b) -> a -> b
$ Plain -> Maybe Plain
forall a. a -> Maybe a
Just (Plain -> Maybe Plain) -> Plain -> Maybe Plain
forall a b. (a -> b) -> a -> b
$ Flags Raw -> BufferSize -> [Frame] -> BufferSize -> Plain
Plain Flags Raw
rawFlags BufferSize
pn [Frame]
frames BufferSize
marks'

toEncodedPacketNumber :: ByteString -> EncodedPacketNumber
toEncodedPacketNumber :: ByteString -> EncodedPacketNumber
toEncodedPacketNumber ByteString
bs = (EncodedPacketNumber -> Word8 -> EncodedPacketNumber)
-> EncodedPacketNumber -> [Word8] -> EncodedPacketNumber
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\EncodedPacketNumber
b Word8
a -> EncodedPacketNumber
b EncodedPacketNumber -> EncodedPacketNumber -> EncodedPacketNumber
forall a. Num a => a -> a -> a
* EncodedPacketNumber
256 EncodedPacketNumber -> EncodedPacketNumber -> EncodedPacketNumber
forall a. Num a => a -> a -> a
+ Word8 -> EncodedPacketNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
a) EncodedPacketNumber
0 ([Word8] -> EncodedPacketNumber) -> [Word8] -> EncodedPacketNumber
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
BS.unpack ByteString
bs