-- |
-- Module      : Network.TLS.Wire
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- the Wire module is a specialized marshalling/unmarshalling package related to the TLS protocol.
-- all multibytes values are written as big endian.
--
module Network.TLS.Wire
    ( Get
    , GetResult(..)
    , GetContinuation
    , runGet
    , runGetErr
    , runGetMaybe
    , tryGet
    , remaining
    , getWord8
    , getWords8
    , getWord16
    , getWords16
    , getWord24
    , getWord32
    , getWord64
    , getBytes
    , getOpaque8
    , getOpaque16
    , getOpaque24
    , getInteger16
    , getBigNum16
    , getList
    , processBytes
    , isEmpty
    , Put
    , runPut
    , putWord8
    , putWords8
    , putWord16
    , putWords16
    , putWord24
    , putWord32
    , putWord64
    , putBytes
    , putOpaque8
    , putOpaque16
    , putOpaque24
    , putInteger16
    , putBigNum16
    , encodeWord16
    , encodeWord32
    , encodeWord64
    ) where

import Data.Serialize.Get hiding (runGet)
import qualified Data.Serialize.Get as G
import Data.Serialize.Put
import qualified Data.ByteString as B
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Util.Serialization

type GetContinuation a = ByteString -> GetResult a
data GetResult a =
      GotError TLSError
    | GotPartial (GetContinuation a)
    | GotSuccess a
    | GotSuccessRemaining a ByteString

runGet :: String -> Get a -> ByteString -> GetResult a
runGet :: String -> Get a -> ByteString -> GetResult a
runGet String
lbl Get a
f = Result a -> GetResult a
forall a. Result a -> GetResult a
toGetResult (Result a -> GetResult a)
-> (ByteString -> Result a) -> ByteString -> GetResult a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get a -> ByteString -> Result a
forall a. Get a -> ByteString -> Result a
G.runGetPartial (String -> Get a -> Get a
forall a. String -> Get a -> Get a
label String
lbl Get a
f)
  where toGetResult :: Result a -> GetResult a
toGetResult (G.Fail String
err ByteString
_)    = TLSError -> GetResult a
forall a. TLSError -> GetResult a
GotError (String -> TLSError
Error_Packet_Parsing String
err)
        toGetResult (G.Partial ByteString -> Result a
cont)  = GetContinuation a -> GetResult a
forall a. GetContinuation a -> GetResult a
GotPartial (Result a -> GetResult a
toGetResult (Result a -> GetResult a)
-> (ByteString -> Result a) -> GetContinuation a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Result a
cont)
        toGetResult (G.Done a
r ByteString
bsLeft)
            | ByteString -> Bool
B.null ByteString
bsLeft = a -> GetResult a
forall a. a -> GetResult a
GotSuccess a
r
            | Bool
otherwise     = a -> GetContinuation a
forall a. a -> GetContinuation a
GotSuccessRemaining a
r ByteString
bsLeft

runGetErr :: String -> Get a -> ByteString -> Either TLSError a
runGetErr :: String -> Get a -> ByteString -> Either TLSError a
runGetErr String
lbl Get a
getter ByteString
b = GetResult a -> Either TLSError a
forall b. GetResult b -> Either TLSError b
toSimple (GetResult a -> Either TLSError a)
-> GetResult a -> Either TLSError a
forall a b. (a -> b) -> a -> b
$ String -> Get a -> ByteString -> GetResult a
forall a. String -> Get a -> ByteString -> GetResult a
runGet String
lbl Get a
getter ByteString
b
  where toSimple :: GetResult b -> Either TLSError b
toSimple (GotError TLSError
err) = TLSError -> Either TLSError b
forall a b. a -> Either a b
Left TLSError
err
        toSimple (GotPartial GetContinuation b
_) = TLSError -> Either TLSError b
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing (String
lbl String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": parsing error: partial packet"))
        toSimple (GotSuccessRemaining b
_ ByteString
_) = TLSError -> Either TLSError b
forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing (String
lbl String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": parsing error: remaining bytes"))
        toSimple (GotSuccess b
r) = b -> Either TLSError b
forall a b. b -> Either a b
Right b
r

runGetMaybe :: Get a -> ByteString -> Maybe a
runGetMaybe :: Get a -> ByteString -> Maybe a
runGetMaybe Get a
f = (String -> Maybe a) -> (a -> Maybe a) -> Either String a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> String -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just (Either String a -> Maybe a)
-> (ByteString -> Either String a) -> ByteString -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
G.runGet Get a
f

tryGet :: Get a -> ByteString -> Maybe a
tryGet :: Get a -> ByteString -> Maybe a
tryGet Get a
f = (String -> Maybe a) -> (a -> Maybe a) -> Either String a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> String -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just (Either String a -> Maybe a)
-> (ByteString -> Either String a) -> ByteString -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
G.runGet Get a
f

getWords8 :: Get [Word8]
getWords8 :: Get [Word8]
getWords8 = Get Word8
getWord8 Get Word8 -> (Word8 -> Get [Word8]) -> Get [Word8]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word8
lenb -> Int -> Get Word8 -> Get [Word8]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
lenb) Get Word8
getWord8

getWord16 :: Get Word16
getWord16 :: Get Word16
getWord16 = Get Word16
getWord16be

getWords16 :: Get [Word16]
getWords16 :: Get [Word16]
getWords16 = Get Word16
getWord16 Get Word16 -> (Word16 -> Get [Word16]) -> Get [Word16]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word16
lenb -> Int -> Get Word16 -> Get [Word16]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
lenb Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Get Word16
getWord16

getWord24 :: Get Int
getWord24 :: Get Int
getWord24 = do
    Int
a <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int
b <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int
c <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> Get Word8 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int -> Get Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Get Int) -> Int -> Get Int
forall a b. (a -> b) -> a -> b
$ (Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
16) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. (Int
b Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
c

getWord32 :: Get Word32
getWord32 :: Get Word32
getWord32 = Get Word32
getWord32be

getWord64 :: Get Word64
getWord64 :: Get Word64
getWord64 = Get Word64
getWord64be

getOpaque8 :: Get ByteString
getOpaque8 :: Get ByteString
getOpaque8 = Get Word8
getWord8 Get Word8 -> (Word8 -> Get ByteString) -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes (Int -> Get ByteString)
-> (Word8 -> Int) -> Word8 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

getOpaque16 :: Get ByteString
getOpaque16 :: Get ByteString
getOpaque16 = Get Word16
getWord16 Get Word16 -> (Word16 -> Get ByteString) -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes (Int -> Get ByteString)
-> (Word16 -> Int) -> Word16 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

getOpaque24 :: Get ByteString
getOpaque24 :: Get ByteString
getOpaque24 = Get Int
getWord24 Get Int -> (Int -> Get ByteString) -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes

getInteger16 :: Get Integer
getInteger16 :: Get Integer
getInteger16 = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
os2ip (ByteString -> Integer) -> Get ByteString -> Get Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getOpaque16

getBigNum16 :: Get BigNum
getBigNum16 :: Get BigNum
getBigNum16 = ByteString -> BigNum
BigNum (ByteString -> BigNum) -> Get ByteString -> Get BigNum
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getOpaque16

getList :: Int -> Get (Int, a) -> Get [a]
getList :: Int -> Get (Int, a) -> Get [a]
getList Int
totalLen Get (Int, a)
getElement = Int -> Get [a] -> Get [a]
forall a. Int -> Get a -> Get a
isolate Int
totalLen (Int -> Get [a]
getElements Int
totalLen)
  where getElements :: Int -> Get [a]
getElements Int
len
            | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0     = String -> Get [a]
forall a. HasCallStack => String -> a
error String
"list consumed too much data. should never happen with isolate."
            | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = [a] -> Get [a]
forall (m :: * -> *) a. Monad m => a -> m a
return []
            | Bool
otherwise   = Get (Int, a)
getElement Get (Int, a) -> ((Int, a) -> Get [a]) -> Get [a]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Int
elementLen, a
a) -> (:) a
a ([a] -> [a]) -> Get [a] -> Get [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Get [a]
getElements (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
elementLen)

processBytes :: Int -> Get a -> Get a
processBytes :: Int -> Get a -> Get a
processBytes Int
i Get a
f = Int -> Get a -> Get a
forall a. Int -> Get a -> Get a
isolate Int
i Get a
f

putWords8 :: [Word8] -> Put
putWords8 :: [Word8] -> Put
putWords8 [Word8]
l = do
    Putter Word8
putWord8 Putter Word8 -> Putter Word8
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Word8] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
l)
    Putter Word8 -> [Word8] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Putter Word8
putWord8 [Word8]
l

putWord16 :: Word16 -> Put
putWord16 :: Word16 -> Put
putWord16 = Word16 -> Put
putWord16be

putWord32 :: Word32 -> Put
putWord32 :: Word32 -> Put
putWord32 = Word32 -> Put
putWord32be

putWord64 :: Word64 -> Put
putWord64 :: Word64 -> Put
putWord64 = Word64 -> Put
putWord64be

putWords16 :: [Word16] -> Put
putWords16 :: [Word16] -> Put
putWords16 [Word16]
l = do
    Word16 -> Put
putWord16 (Word16 -> Put) -> Word16 -> Put
forall a b. (a -> b) -> a -> b
$ Word16
2 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
* Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Word16] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word16]
l)
    (Word16 -> Put) -> [Word16] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Word16 -> Put
putWord16 [Word16]
l

putWord24 :: Int -> Put
putWord24 :: Int -> Put
putWord24 Int
i = do
    let a :: Word8
a = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0xff)
    let b :: Word8
b = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
i Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0xff)
    let c :: Word8
c = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0xff)
    Putter Word8 -> [Word8] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Putter Word8
putWord8 [Word8
a,Word8
b,Word8
c]

putBytes :: ByteString -> Put
putBytes :: ByteString -> Put
putBytes = ByteString -> Put
putByteString

putOpaque8 :: ByteString -> Put
putOpaque8 :: ByteString -> Put
putOpaque8 ByteString
b = Putter Word8
putWord8 (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putOpaque16 :: ByteString -> Put
putOpaque16 :: ByteString -> Put
putOpaque16 ByteString
b = Word16 -> Put
putWord16 (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word16) -> Int -> Word16
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putOpaque24 :: ByteString -> Put
putOpaque24 :: ByteString -> Put
putOpaque24 ByteString
b = Int -> Put
putWord24 (ByteString -> Int
B.length ByteString
b) Put -> Put -> Put
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putInteger16 :: Integer -> Put
putInteger16 :: Integer -> Put
putInteger16 = ByteString -> Put
putOpaque16 (ByteString -> Put) -> (Integer -> ByteString) -> Integer -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ByteString
forall ba. ByteArray ba => Integer -> ba
i2osp

putBigNum16 :: BigNum -> Put
putBigNum16 :: BigNum -> Put
putBigNum16 (BigNum ByteString
b) = ByteString -> Put
putOpaque16 ByteString
b

encodeWord16 :: Word16 -> ByteString
encodeWord16 :: Word16 -> ByteString
encodeWord16 = Put -> ByteString
runPut (Put -> ByteString) -> (Word16 -> Put) -> Word16 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Put
putWord16

encodeWord32 :: Word32 -> ByteString
encodeWord32 :: Word32 -> ByteString
encodeWord32 = Put -> ByteString
runPut (Put -> ByteString) -> (Word32 -> Put) -> Word32 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Put
putWord32

encodeWord64 :: Word64 -> ByteString
encodeWord64 :: Word64 -> ByteString
encodeWord64 = Put -> ByteString
runPut (Put -> ByteString) -> (Word64 -> Put) -> Word64 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Put
putWord64be