{-# LANGUAGE ScopedTypeVariables #-}
module Network.TLS.Util
        ( sub
        , takelast
        , partition3
        , partition6
        , fromJust
        , (&&!)
        , bytesEq
        , fmapEither
        , catchException
        , forEitherM
        , mapChunks_
        , getChunks
        , Saved
        , saveMVar
        , restoreMVar
        ) where

import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import Network.TLS.Imports

import Control.Exception (SomeException)
import Control.Concurrent.Async
import Control.Concurrent.MVar

sub :: ByteString -> Int -> Int -> Maybe ByteString
sub :: ByteString -> Int -> Int -> Maybe ByteString
sub ByteString
b Int
offset Int
len
    | ByteString -> Int
B.length ByteString
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len = Maybe ByteString
forall a. Maybe a
Nothing
    | Bool
otherwise                 = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Int
len (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((ByteString, ByteString) -> ByteString)
-> (ByteString, ByteString) -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
offset ByteString
b

takelast :: Int -> ByteString -> Maybe ByteString
takelast :: Int -> ByteString -> Maybe ByteString
takelast Int
i ByteString
b
    | ByteString -> Int
B.length ByteString
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i = ByteString -> Int -> Int -> Maybe ByteString
sub ByteString
b (ByteString -> Int
B.length ByteString
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) Int
i
    | Bool
otherwise       = Maybe ByteString
forall a. Maybe a
Nothing

partition3 :: ByteString -> (Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 :: ByteString
-> (Int, Int, Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 ByteString
bytes (Int
d1,Int
d2,Int
d3)
    | (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) [Int]
l             = Maybe (ByteString, ByteString, ByteString)
forall a. Maybe a
Nothing
    | [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
B.length ByteString
bytes = Maybe (ByteString, ByteString, ByteString)
forall a. Maybe a
Nothing
    | Bool
otherwise               = (ByteString, ByteString, ByteString)
-> Maybe (ByteString, ByteString, ByteString)
forall a. a -> Maybe a
Just (ByteString
p1,ByteString
p2,ByteString
p3)
        where l :: [Int]
l        = [Int
d1,Int
d2,Int
d3]
              (ByteString
p1, ByteString
r1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d1 ByteString
bytes
              (ByteString
p2, ByteString
r2) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d2 ByteString
r1
              (ByteString
p3, ByteString
_)  = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d3 ByteString
r2

partition6 :: ByteString -> (Int,Int,Int,Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString, ByteString, ByteString, ByteString)
partition6 :: ByteString
-> (Int, Int, Int, Int, Int, Int)
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
partition6 ByteString
bytes (Int
d1,Int
d2,Int
d3,Int
d4,Int
d5,Int
d6) = if ByteString -> Int
B.length ByteString
bytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
s then Maybe
  (ByteString, ByteString, ByteString, ByteString, ByteString,
   ByteString)
forall a. Maybe a
Nothing else (ByteString, ByteString, ByteString, ByteString, ByteString,
 ByteString)
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
forall a. a -> Maybe a
Just (ByteString
p1,ByteString
p2,ByteString
p3,ByteString
p4,ByteString
p5,ByteString
p6)
  where s :: Int
s        = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int
d1,Int
d2,Int
d3,Int
d4,Int
d5,Int
d6]
        (ByteString
p1, ByteString
r1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d1 ByteString
bytes
        (ByteString
p2, ByteString
r2) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d2 ByteString
r1
        (ByteString
p3, ByteString
r3) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d3 ByteString
r2
        (ByteString
p4, ByteString
r4) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d4 ByteString
r3
        (ByteString
p5, ByteString
r5) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d5 ByteString
r4
        (ByteString
p6, ByteString
_)  = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d6 ByteString
r5

fromJust :: String -> Maybe a -> a
fromJust :: String -> Maybe a -> a
fromJust String
what Maybe a
Nothing  = String -> a
forall a. HasCallStack => String -> a
error (String
"fromJust " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
what String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": Nothing") -- yuck
fromJust String
_    (Just a
x) = a
x

-- | This is a strict version of &&.
(&&!) :: Bool -> Bool -> Bool
Bool
True  &&! :: Bool -> Bool -> Bool
&&! Bool
True  = Bool
True
Bool
True  &&! Bool
False = Bool
False
Bool
False &&! Bool
True  = Bool
False
Bool
False &&! Bool
False = Bool
False

-- | verify that 2 bytestrings are equals.
-- it's a non lazy version, that will compare every bytes.
-- arguments with different length will bail out early
bytesEq :: ByteString -> ByteString -> Bool
bytesEq :: ByteString -> ByteString -> Bool
bytesEq = ByteString -> ByteString -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
BA.constEq

fmapEither :: (a -> b) -> Either l a -> Either l b
fmapEither :: (a -> b) -> Either l a -> Either l b
fmapEither a -> b
f = (a -> b) -> Either l a -> Either l b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f

catchException :: IO a -> (SomeException -> IO a) -> IO a
catchException :: IO a -> (SomeException -> IO a) -> IO a
catchException IO a
action SomeException -> IO a
handler = IO a
-> (Async a -> IO (Either SomeException a))
-> IO (Either SomeException a)
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync IO a
action Async a -> IO (Either SomeException a)
forall a. Async a -> IO (Either SomeException a)
waitCatch IO (Either SomeException a)
-> (Either SomeException a -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (SomeException -> IO a)
-> (a -> IO a) -> Either SomeException a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> IO a
handler a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return

forEitherM :: Monad m => [a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM :: [a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM []     a -> m (Either l b)
_ = Either l [b] -> m (Either l [b])
forall (m :: * -> *) a. Monad m => a -> m a
return ([b] -> Either l [b]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
forEitherM (a
x:[a]
xs) a -> m (Either l b)
f = a -> m (Either l b)
f a
x m (Either l b)
-> (Either l b -> m (Either l [b])) -> m (Either l [b])
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either l b -> m (Either l [b])
doTail
  where
    doTail :: Either l b -> m (Either l [b])
doTail (Right b
b) = ([b] -> [b]) -> Either l [b] -> Either l [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (b
b b -> [b] -> [b]
forall a. a -> [a] -> [a]
:) (Either l [b] -> Either l [b])
-> m (Either l [b]) -> m (Either l [b])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> (a -> m (Either l b)) -> m (Either l [b])
forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [a]
xs a -> m (Either l b)
f
    doTail (Left l
e)  = Either l [b] -> m (Either l [b])
forall (m :: * -> *) a. Monad m => a -> m a
return (l -> Either l [b]
forall a b. a -> Either a b
Left l
e)

mapChunks_ :: Monad m
           => Maybe Int -> (B.ByteString -> m a) -> B.ByteString -> m ()
mapChunks_ :: Maybe Int -> (ByteString -> m a) -> ByteString -> m ()
mapChunks_ Maybe Int
len ByteString -> m a
f = (ByteString -> m a) -> [ByteString] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ByteString -> m a
f ([ByteString] -> m ())
-> (ByteString -> [ByteString]) -> ByteString -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
len

getChunks :: Maybe Int -> B.ByteString -> [B.ByteString]
getChunks :: Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
Nothing    = (ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [])
getChunks (Just Int
len) = ByteString -> [ByteString]
go
  where
    go :: ByteString -> [ByteString]
go ByteString
bs | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
len =
              let (ByteString
chunk, ByteString
remain) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
len ByteString
bs
               in ByteString
chunk ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: ByteString -> [ByteString]
go ByteString
remain
          | Bool
otherwise = [ByteString
bs]

-- | An opaque newtype wrapper to prevent from poking inside content that has
-- been saved.
newtype Saved a = Saved a

-- | Save the content of an 'MVar' to restore it later.
saveMVar :: MVar a -> IO (Saved a)
saveMVar :: MVar a -> IO (Saved a)
saveMVar MVar a
ref = a -> Saved a
forall a. a -> Saved a
Saved (a -> Saved a) -> IO a -> IO (Saved a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar a -> IO a
forall a. MVar a -> IO a
readMVar MVar a
ref

-- | Restore the content of an 'MVar' to a previous saved value and return the
-- content that has just been replaced.
restoreMVar :: MVar a -> Saved a -> IO (Saved a)
restoreMVar :: MVar a -> Saved a -> IO (Saved a)
restoreMVar MVar a
ref (Saved a
val) = a -> Saved a
forall a. a -> Saved a
Saved (a -> Saved a) -> IO a -> IO (Saved a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar a -> a -> IO a
forall a. MVar a -> a -> IO a
swapMVar MVar a
ref a
val