-- | The Wire module is a specialized marshalling/unmarshalling
-- package related to the TLS protocol.  All multibytes values are
-- written as big endian.
module Network.TLS.Wire (
    Get,
    GetResult (..),
    GetContinuation,
    runGet,
    runGetErr,
    runGetMaybe,
    tryGet,
    remaining,
    getWord8,
    getWords8,
    getWord16,
    getWords16,
    getWord24,
    getWord32,
    getWord64,
    getBytes,
    getOpaque8,
    getOpaque16,
    getOpaque24,
    getInteger16,
    getBigNum16,
    getList,
    processBytes,
    isEmpty,
    Put,
    runPut,
    putWord8,
    putWords8,
    putWord16,
    putWords16,
    putWord24,
    putWord32,
    putWord64,
    putBytes,
    putOpaque8,
    putOpaque16,
    putOpaque24,
    putInteger16,
    putBigNum16,
    encodeWord16,
    encodeWord32,
    encodeWord64,
) where

import qualified Data.ByteString as B
import Data.Serialize.Get hiding (runGet)
import qualified Data.Serialize.Get as G
import Data.Serialize.Put
import Network.TLS.Imports
import Network.TLS.Struct
import Network.TLS.Util.Serialization

type GetContinuation a = ByteString -> GetResult a
data GetResult a
    = GotError TLSError
    | GotPartial (GetContinuation a)
    | GotSuccess a
    | GotSuccessRemaining a ByteString

runGet :: String -> Get a -> ByteString -> GetResult a
runGet :: forall a. String -> Get a -> ByteString -> GetResult a
runGet String
lbl Get a
f = Result a -> GetResult a
forall {a}. Result a -> GetResult a
toGetResult (Result a -> GetResult a)
-> (ByteString -> Result a) -> ByteString -> GetResult a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get a -> ByteString -> Result a
forall a. Get a -> ByteString -> Result a
G.runGetPartial (String -> Get a -> Get a
forall a. String -> Get a -> Get a
label String
lbl Get a
f)
  where
    toGetResult :: Result a -> GetResult a
toGetResult (G.Fail String
err ByteString
_) = TLSError -> GetResult a
forall a. TLSError -> GetResult a
GotError (String -> TLSError
Error_Packet_Parsing String
err)
    toGetResult (G.Partial ByteString -> Result a
cont) = GetContinuation a -> GetResult a
forall a. GetContinuation a -> GetResult a
GotPartial (Result a -> GetResult a
toGetResult (Result a -> GetResult a)
-> (ByteString -> Result a) -> GetContinuation a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Result a
cont)
    toGetResult (G.Done a
r ByteString
bsLeft)
        | ByteString -> Bool
B.null ByteString
bsLeft = a -> GetResult a
forall a. a -> GetResult a
GotSuccess a
r
        | Bool
otherwise = a -> GetContinuation a
forall a. a -> GetContinuation a
GotSuccessRemaining a
r ByteString
bsLeft

runGetErr :: String -> Get a -> ByteString -> Either TLSError a
runGetErr :: forall a. String -> Get a -> ByteString -> Either TLSError a
runGetErr String
lbl Get a
getter ByteString
b = GetResult a -> Either TLSError a
forall {b}. GetResult b -> Either TLSError b
toSimple (GetResult a -> Either TLSError a)
-> GetResult a -> Either TLSError a
forall a b. (a -> b) -> a -> b
$ String -> Get a -> ByteString -> GetResult a
forall a. String -> Get a -> ByteString -> GetResult a
runGet String
lbl Get a
getter ByteString
b
  where
    toSimple :: GetResult b -> Either TLSError b
toSimple (GotError TLSError
err) = TLSError -> Either TLSError b
forall a b. a -> Either a b
Left TLSError
err
    toSimple (GotPartial GetContinuation b
_) = TLSError -> Either TLSError b
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing (String
lbl String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": parsing error: partial packet"))
    toSimple (GotSuccessRemaining b
_ ByteString
_) = TLSError -> Either TLSError b
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing (String
lbl String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": parsing error: remaining bytes"))
    toSimple (GotSuccess b
r) = b -> Either TLSError b
forall a b. b -> Either a b
Right b
r

runGetMaybe :: Get a -> ByteString -> Maybe a
runGetMaybe :: forall a. Get a -> ByteString -> Maybe a
runGetMaybe Get a
f = (String -> Maybe a) -> (a -> Maybe a) -> Either String a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> String -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just (Either String a -> Maybe a)
-> (ByteString -> Either String a) -> ByteString -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
G.runGet Get a
f

tryGet :: Get a -> ByteString -> Maybe a
tryGet :: forall a. Get a -> ByteString -> Maybe a
tryGet Get a
f = (String -> Maybe a) -> (a -> Maybe a) -> Either String a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> String -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just (Either String a -> Maybe a)
-> (ByteString -> Either String a) -> ByteString -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
G.runGet Get a
f

getWords8 :: Get [Word8]
getWords8 :: Get [Word8]
getWords8 = Get Word8
getWord8 Get Word8 -> (Word8 -> Get [Word8]) -> Get [Word8]
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word8
lenb -> Int -> Get Word8 -> Get [Word8]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
lenb) Get Word8
getWord8

getWord16 :: Get Word16
getWord16 :: Get Word16
getWord16 = Get Word16
getWord16be

getWords16 :: Get [Word16]
getWords16 :: Get [Word16]
getWords16 = Get Word16
getWord16 Get Word16 -> (Word16 -> Get [Word16]) -> Get [Word16]
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word16
lenb -> Int -> Get Word16 -> Get [Word16]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
lenb Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Get Word16
getWord16

getWord24 :: Get Int
getWord24 :: Get Int
getWord24 = do
    Int
a <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int
b <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int
c <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int -> Get Int
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Get Int) -> Int -> Get Int
forall a b. (a -> b) -> a -> b
$ (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
16) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
b Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
c

getWord32 :: Get Word32
getWord32 :: Get Word32
getWord32 = Get Word32
getWord32be

getWord64 :: Get Word64
getWord64 :: Get Word64
getWord64 = Get Word64
getWord64be

getOpaque8 :: Get ByteString
getOpaque8 :: Get ByteString
getOpaque8 = Get Word8
getWord8 Get Word8 -> (Word8 -> Get ByteString) -> Get ByteString
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes (Int -> Get ByteString)
-> (Word8 -> Int) -> Word8 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

getOpaque16 :: Get ByteString
getOpaque16 :: Get ByteString
getOpaque16 = Get Word16
getWord16 Get Word16 -> (Word16 -> Get ByteString) -> Get ByteString
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes (Int -> Get ByteString)
-> (Word16 -> Int) -> Word16 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

getOpaque24 :: Get ByteString
getOpaque24 :: Get ByteString
getOpaque24 = Get Int
getWord24 Get Int -> (Int -> Get ByteString) -> Get ByteString
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes

getInteger16 :: Get Integer
getInteger16 :: Get Integer
getInteger16 = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (ByteString -> Integer) -> Get ByteString -> Get Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getOpaque16

getBigNum16 :: Get BigNum
getBigNum16 :: Get BigNum
getBigNum16 = ByteString -> BigNum
BigNum (ByteString -> BigNum) -> Get ByteString -> Get BigNum
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getOpaque16

getList :: Int -> Get (Int, a) -> Get [a]
getList :: forall a. Int -> Get (Int, a) -> Get [a]
getList Int
totalLen Get (Int, a)
getElement = Int -> Get [a] -> Get [a]
forall a. Int -> Get a -> Get a
isolate Int
totalLen (Int -> Get [a]
getElements Int
totalLen)
  where
    getElements :: Int -> Get [a]
getElements Int
len
        | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 =
            String -> Get [a]
forall a. HasCallStack => String -> a
error String
"list consumed too much data. should never happen with isolate."
        | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [a] -> Get [a]
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return []
        | Bool
otherwise =
            Get (Int, a)
getElement Get (Int, a) -> ((Int, a) -> Get [a]) -> Get [a]
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Int
elementLen, a
a) -> (:) a
a ([a] -> [a]) -> Get [a] -> Get [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Get [a]
getElements (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
elementLen)

processBytes :: Int -> Get a -> Get a
processBytes :: forall a. Int -> Get a -> Get a
processBytes Int
i Get a
f = Int -> Get a -> Get a
forall a. Int -> Get a -> Get a
isolate Int
i Get a
f

putWords8 :: [Word8] -> Put
putWords8 :: [Word8] -> Put
putWords8 [Word8]
l = do
    Putter Word8
putWord8 Putter Word8 -> Putter Word8
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Word8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
l)
    Putter Word8 -> [Word8] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Putter Word8
putWord8 [Word8]
l

putWord16 :: Word16 -> Put
putWord16 :: Word16 -> Put
putWord16 = Word16 -> Put
putWord16be

putWord32 :: Word32 -> Put
putWord32 :: Word32 -> Put
putWord32 = Word32 -> Put
putWord32be

putWord64 :: Word64 -> Put
putWord64 :: Word64 -> Put
putWord64 = Word64 -> Put
putWord64be

putWords16 :: [Word16] -> Put
putWords16 :: [Word16] -> Put
putWords16 [Word16]
l = do
    Word16 -> Put
putWord16 (Word16 -> Put) -> Word16 -> Put
forall a b. (a -> b) -> a -> b
$ Word16
2 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
* Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Word16] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word16]
l)
    (Word16 -> Put) -> [Word16] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Word16 -> Put
putWord16 [Word16]
l

putWord24 :: Int -> Put
putWord24 :: Int -> Put
putWord24 Int
i = do
    let a :: Word8
a = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0xff)
    let b :: Word8
b = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0xff)
    let c :: Word8
c = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0xff)
    Putter Word8 -> [Word8] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Putter Word8
putWord8 [Word8
a, Word8
b, Word8
c]

putBytes :: ByteString -> Put
putBytes :: ByteString -> Put
putBytes = ByteString -> Put
putByteString

putOpaque8 :: ByteString -> Put
putOpaque8 :: ByteString -> Put
putOpaque8 ByteString
b = Putter Word8
putWord8 (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putOpaque16 :: ByteString -> Put
putOpaque16 :: ByteString -> Put
putOpaque16 ByteString
b = Word16 -> Put
putWord16 (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word16) -> Int -> Word16
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putOpaque24 :: ByteString -> Put
putOpaque24 :: ByteString -> Put
putOpaque24 ByteString
b = Int -> Put
putWord24 (ByteString -> Int
B.length ByteString
b) Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putInteger16 :: Integer -> Put
putInteger16 :: Integer -> Put
putInteger16 = ByteString -> Put
putOpaque16 (ByteString -> Put) -> (Integer -> ByteString) -> Integer -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ByteString
forall ba. ByteArray ba => Integer -> ba
i2osp

putBigNum16 :: BigNum -> Put
putBigNum16 :: BigNum -> Put
putBigNum16 (BigNum ByteString
b) = ByteString -> Put
putOpaque16 ByteString
b

encodeWord16 :: Word16 -> ByteString
encodeWord16 :: Word16 -> ByteString
encodeWord16 = Put -> ByteString
runPut (Put -> ByteString) -> (Word16 -> Put) -> Word16 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Put
putWord16

encodeWord32 :: Word32 -> ByteString
encodeWord32 :: Word32 -> ByteString
encodeWord32 = Put -> ByteString
runPut (Put -> ByteString) -> (Word32 -> Put) -> Word32 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Put
putWord32

encodeWord64 :: Word64 -> ByteString
encodeWord64 :: Word64 -> ByteString
encodeWord64 = Put -> ByteString
runPut (Put -> ByteString) -> (Word64 -> Put) -> Word64 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Put
putWord64be