{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.QUIC.Packet.Decrypt (
decryptCrypt
) where
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import Foreign.Ptr
import Network.ByteOrder
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Imports
import Network.QUIC.Packet.Frame
import Network.QUIC.Packet.Header
import Network.QUIC.Packet.Number
import Network.QUIC.Types
decryptCrypt :: Connection -> Buffer -> BufferSize -> Crypt -> EncryptionLevel -> IO (Maybe Plain)
decryptCrypt :: Connection
-> Buffer
-> BufferSize
-> Crypt
-> EncryptionLevel
-> IO (Maybe Plain)
decryptCrypt Connection
conn Buffer
decBuf BufferSize
_bufsiz Crypt{BufferSize
ByteString
cryptMarks :: Crypt -> BufferSize
cryptPacket :: Crypt -> ByteString
cryptPktNumOffset :: Crypt -> BufferSize
cryptMarks :: BufferSize
cryptPacket :: ByteString
cryptPktNumOffset :: BufferSize
..} EncryptionLevel
lvl = do
Cipher
cipher <- Connection -> EncryptionLevel -> IO Cipher
getCipher Connection
conn EncryptionLevel
lvl
Protector
protector <- Connection -> EncryptionLevel -> IO Protector
getProtector Connection
conn EncryptionLevel
lvl
let proFlags :: Flags a
proFlags = Word8 -> Flags a
forall a. Word8 -> Flags a
Flags (ByteString
cryptPacket ByteString -> BufferSize -> Word8
`BS.index` BufferSize
0)
sampleOffset :: BufferSize
sampleOffset = BufferSize
cryptPktNumOffset BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
4
sampleLen :: BufferSize
sampleLen = Cipher -> BufferSize
sampleLength Cipher
cipher
sample :: Sample
sample = ByteString -> Sample
Sample (ByteString -> Sample) -> ByteString -> Sample
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.take BufferSize
sampleLen (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.drop BufferSize
sampleOffset ByteString
cryptPacket
makeMask :: Sample -> Mask
makeMask = Protector -> Sample -> Mask
unprotect Protector
protector
Mask ByteString
mask = Sample -> Mask
makeMask Sample
sample
case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
mask of
Maybe (Word8, ByteString)
Nothing -> Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Plain
forall a. Maybe a
Nothing
Just (Word8
mask1,ByteString
mask2) -> do
let rawFlags :: Flags Raw
rawFlags@(Flags Word8
flags) = Flags Protected -> Word8 -> Flags Raw
unprotectFlags Flags Protected
forall a. Flags a
proFlags Word8
mask1
epnLen :: BufferSize
epnLen = Flags Raw -> BufferSize
decodePktNumLength Flags Raw
rawFlags
epn :: ByteString
epn = BufferSize -> ByteString -> ByteString
BS.take BufferSize
epnLen (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.drop BufferSize
cryptPktNumOffset ByteString
cryptPacket
bytePN :: ByteString
bytePN = ByteString -> ByteString -> ByteString
bsXOR ByteString
mask2 ByteString
epn
headerLen :: BufferSize
headerLen = BufferSize
cryptPktNumOffset BufferSize -> BufferSize -> BufferSize
forall a. Num a => a -> a -> a
+ BufferSize
epnLen
(ByteString
proHeader, ByteString
ciphertext) = BufferSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufferSize
headerLen ByteString
cryptPacket
ilen :: BufferSize
ilen = ByteString -> BufferSize
BS.length ByteString
ciphertext
BufferSize
peerPN <- if EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level then Connection -> IO BufferSize
getPeerPacketNumber Connection
conn else BufferSize -> IO BufferSize
forall (m :: * -> *) a. Monad m => a -> m a
return BufferSize
0
let pn :: BufferSize
pn = BufferSize -> EncodedPacketNumber -> BufferSize -> BufferSize
decodePacketNumber BufferSize
peerPN (ByteString -> EncodedPacketNumber
toEncodedPacketNumber ByteString
bytePN) BufferSize
epnLen
ByteString
header <- BufferSize -> (Buffer -> IO ()) -> IO ByteString
BS.create BufferSize
headerLen ((Buffer -> IO ()) -> IO ByteString)
-> (Buffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Buffer
p -> do
IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy Buffer
p ByteString
proHeader
Word8 -> Buffer -> BufferSize -> IO ()
poke8 Word8
flags Buffer
p BufferSize
0
IO Buffer -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Buffer -> IO ()) -> IO Buffer -> IO ()
forall a b. (a -> b) -> a -> b
$ Buffer -> ByteString -> IO Buffer
copy (Buffer
p Buffer -> BufferSize -> Buffer
forall a b. Ptr a -> BufferSize -> Ptr b
`plusPtr` BufferSize
cryptPktNumOffset) (ByteString -> IO Buffer) -> ByteString -> IO Buffer
forall a b. (a -> b) -> a -> b
$ BufferSize -> ByteString -> ByteString
BS.take BufferSize
epnLen ByteString
bytePN
let keyPhase :: Bool
keyPhase | EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level = Word8
flags Word8 -> BufferSize -> Bool
forall a. Bits a => a -> BufferSize -> Bool
`testBit` BufferSize
2
| Bool
otherwise = Bool
False
Coder
coder <- Connection -> EncryptionLevel -> Bool -> IO Coder
getCoder Connection
conn EncryptionLevel
lvl Bool
keyPhase
BufferSize
siz <- ByteString -> (Buffer -> IO BufferSize) -> IO BufferSize
forall a. ByteString -> (Buffer -> IO a) -> IO a
withByteString ByteString
ciphertext ((Buffer -> IO BufferSize) -> IO BufferSize)
-> (Buffer -> IO BufferSize) -> IO BufferSize
forall a b. (a -> b) -> a -> b
$ \Buffer
ibuf ->
ByteString -> (Buffer -> IO BufferSize) -> IO BufferSize
forall a. ByteString -> (Buffer -> IO a) -> IO a
withByteString ByteString
header ((Buffer -> IO BufferSize) -> IO BufferSize)
-> (Buffer -> IO BufferSize) -> IO BufferSize
forall a b. (a -> b) -> a -> b
$ \Buffer
abuf -> do
let ilen' :: BufferSize
ilen' = BufferSize -> BufferSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufferSize
ilen
alen' :: BufferSize
alen' = BufferSize -> BufferSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufferSize
headerLen
Coder
-> Buffer
-> BufferSize
-> Buffer
-> BufferSize
-> BufferSize
-> Buffer
-> IO BufferSize
decrypt Coder
coder Buffer
ibuf BufferSize
ilen' Buffer
abuf BufferSize
alen' BufferSize
pn Buffer
decBuf
let rrMask :: Word8
rrMask | EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT1Level = Word8
0x18
| Bool
otherwise = Word8
0x0c
marks :: BufferSize
marks | Word8
flags Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
rrMask Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0 = BufferSize
defaultPlainMarks
| Bool
otherwise = BufferSize -> BufferSize
setIllegalReservedBits BufferSize
defaultPlainMarks
if BufferSize
siz BufferSize -> BufferSize -> Bool
forall a. Ord a => a -> a -> Bool
< BufferSize
0 then
Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Plain
forall a. Maybe a
Nothing
else do
Maybe [Frame]
mframes <- Buffer -> BufferSize -> IO (Maybe [Frame])
decodeFrames Buffer
decBuf BufferSize
siz
case Maybe [Frame]
mframes of
Maybe [Frame]
Nothing -> do
let marks' :: BufferSize
marks' = BufferSize -> BufferSize
setUnknownFrame BufferSize
marks
Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Plain -> IO (Maybe Plain))
-> Maybe Plain -> IO (Maybe Plain)
forall a b. (a -> b) -> a -> b
$ Plain -> Maybe Plain
forall a. a -> Maybe a
Just (Plain -> Maybe Plain) -> Plain -> Maybe Plain
forall a b. (a -> b) -> a -> b
$ Flags Raw -> BufferSize -> [Frame] -> BufferSize -> Plain
Plain Flags Raw
rawFlags BufferSize
pn [] BufferSize
marks'
Just [Frame]
frames -> do
let marks' :: BufferSize
marks' | [Frame] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Frame]
frames = BufferSize -> BufferSize
setNoFrames BufferSize
marks
| Bool
otherwise = BufferSize
marks
Maybe Plain -> IO (Maybe Plain)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Plain -> IO (Maybe Plain))
-> Maybe Plain -> IO (Maybe Plain)
forall a b. (a -> b) -> a -> b
$ Plain -> Maybe Plain
forall a. a -> Maybe a
Just (Plain -> Maybe Plain) -> Plain -> Maybe Plain
forall a b. (a -> b) -> a -> b
$ Flags Raw -> BufferSize -> [Frame] -> BufferSize -> Plain
Plain Flags Raw
rawFlags BufferSize
pn [Frame]
frames BufferSize
marks'
toEncodedPacketNumber :: ByteString -> EncodedPacketNumber
toEncodedPacketNumber :: ByteString -> EncodedPacketNumber
toEncodedPacketNumber ByteString
bs = (EncodedPacketNumber -> Word8 -> EncodedPacketNumber)
-> EncodedPacketNumber -> [Word8] -> EncodedPacketNumber
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\EncodedPacketNumber
b Word8
a -> EncodedPacketNumber
b EncodedPacketNumber -> EncodedPacketNumber -> EncodedPacketNumber
forall a. Num a => a -> a -> a
* EncodedPacketNumber
256 EncodedPacketNumber -> EncodedPacketNumber -> EncodedPacketNumber
forall a. Num a => a -> a -> a
+ Word8 -> EncodedPacketNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
a) EncodedPacketNumber
0 ([Word8] -> EncodedPacketNumber) -> [Word8] -> EncodedPacketNumber
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
BS.unpack ByteString
bs