{-# LANGUAGE BangPatterns #-}

module Crypto.Saltine.Internal.Util (
    module Crypto.Saltine.Internal.Util
  , withCString
  , allocaBytes
)
where

import Data.ByteString          (ByteString)
import Data.ByteString.Unsafe
import Data.Monoid
import Foreign.C
import Foreign.Marshal.Alloc    (mallocBytes, allocaBytes)
import Foreign.Ptr
import GHC.Word                 (Word8)
import System.IO.Unsafe

import qualified Data.ByteString       as S
import qualified Data.ByteString.Char8 as S8

-- | Returns @Nothing@ if the subtraction would result in an
-- underflow or a negative number.
safeSubtract :: (Ord a, Num a) => a -> a -> Maybe a
a
x safeSubtract :: forall a. (Ord a, Num a) => a -> a -> Maybe a
`safeSubtract` a
y = if a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
x then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y)

-- | @snd . cycleSucc@ computes the 'succ' of a 'Bounded', 'Eq' 'Enum'
-- with wraparound. The @fst . cycleSuc@ is whether the wraparound
-- occurred (i.e. @fst . cycleSucc == (== maxBound)@).
cycleSucc :: (Bounded a, Enum a, Eq a) => a -> (Bool, a)
cycleSucc :: forall a. (Bounded a, Enum a, Eq a) => a -> (Bool, a)
cycleSucc a
a = (Bool
top, if Bool
top then a
forall a. Bounded a => a
minBound else a -> a
forall a. Enum a => a -> a
succ a
a)
  where top :: Bool
top = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
maxBound

-- | Treats a 'ByteString' as a little endian bitstring and increments
-- it.
nudgeBS :: ByteString -> ByteString
nudgeBS :: ByteString -> ByteString
nudgeBS ByteString
i = (ByteString, Maybe (Bool, ByteString)) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, Maybe (Bool, ByteString)) -> ByteString)
-> (ByteString, Maybe (Bool, ByteString)) -> ByteString
forall a b. (a -> b) -> a -> b
$ Int
-> ((Bool, ByteString) -> Maybe (Word8, (Bool, ByteString)))
-> (Bool, ByteString)
-> (ByteString, Maybe (Bool, ByteString))
forall a.
Int -> (a -> Maybe (Word8, a)) -> a -> (ByteString, Maybe a)
S.unfoldrN (ByteString -> Int
S.length ByteString
i) (Bool, ByteString) -> Maybe (Word8, (Bool, ByteString))
go (Bool
True, ByteString
i) where
  go :: (Bool, ByteString) -> Maybe (Word8, (Bool, ByteString))
go (Bool
toSucc, ByteString
bs) = do
    (Word8
hd, ByteString
tl)      <- ByteString -> Maybe (Word8, ByteString)
S.uncons ByteString
bs
    let (Bool
top, Word8
hd') = Word8 -> (Bool, Word8)
forall a. (Bounded a, Enum a, Eq a) => a -> (Bool, a)
cycleSucc Word8
hd

    if   Bool
toSucc
    then (Word8, (Bool, ByteString)) -> Maybe (Word8, (Bool, ByteString))
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
hd', (Bool
top, ByteString
tl))
    else (Word8, (Bool, ByteString)) -> Maybe (Word8, (Bool, ByteString))
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8
hd, (Bool
top Bool -> Bool -> Bool
&& Bool
toSucc, ByteString
tl))

-- | Computes the orbit of a endomorphism... in a very brute force
-- manner. Exists just for the below property.
--
-- prop> length . orbit nudgeBS . S.pack . replicate 0 == (256^)
orbit :: Eq a => (a -> a) -> a -> [a]
orbit :: forall a. Eq a => (a -> a) -> a -> [a]
orbit a -> a
f a
a0 = a -> [a]
orbit' (a -> a
f a
a0) where
  orbit' :: a -> [a]
orbit' a
a = if a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a0 then [a
a0] else a
a a -> [a] -> [a]
forall a. a -> [a] -> [a]
: a -> [a]
orbit' (a -> a
f a
a)

-- | 0-pad a 'ByteString'
pad :: Int -> ByteString -> ByteString
pad :: Int -> ByteString -> ByteString
pad Int
n = ByteString -> ByteString -> ByteString
forall a. Monoid a => a -> a -> a
mappend (Int -> Word8 -> ByteString
S.replicate Int
n Word8
0)

-- | Remove a 0-padding from a 'ByteString'
unpad :: Int -> ByteString -> ByteString
unpad :: Int -> ByteString -> ByteString
unpad = Int -> ByteString -> ByteString
S.drop

-- | Converts a C-convention errno to an Either
handleErrno :: CInt -> (a -> Either String a)
handleErrno :: forall a. CInt -> a -> Either String a
handleErrno CInt
err a
a = case CInt
err of
  CInt
0  -> a -> Either String a
forall a b. b -> Either a b
Right a
a
  -1 -> String -> Either String a
forall a b. a -> Either a b
Left String
"failed"
  CInt
n  -> String -> Either String a
forall a b. a -> Either a b
Left (String
"unexpected error code: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
n)

unsafeDidSucceed :: IO CInt -> Bool
unsafeDidSucceed :: IO CInt -> Bool
unsafeDidSucceed = CInt -> Bool
forall {a}. (Eq a, Num a) => a -> Bool
go (CInt -> Bool) -> (IO CInt -> CInt) -> IO CInt -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO CInt -> CInt
forall a. IO a -> a
unsafePerformIO
  where go :: a -> Bool
go a
0 = Bool
True
        go a
_ = Bool
False

withCStrings :: [String] -> ([CString] -> IO a) -> IO a
withCStrings :: forall a. [String] -> ([CString] -> IO a) -> IO a
withCStrings = (String
 -> (([CString] -> IO a) -> IO a) -> ([CString] -> IO a) -> IO a)
-> (([CString] -> IO a) -> IO a)
-> [String]
-> ([CString] -> IO a)
-> IO a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\String
v ([CString] -> IO a) -> IO a
kk -> \[CString] -> IO a
k -> (String -> (CString -> IO a) -> IO a
forall a. String -> (CString -> IO a) -> IO a
withCString String
v) (\CString
a -> ([CString] -> IO a) -> IO a
kk (\[CString]
as -> [CString] -> IO a
k (CString
aCString -> [CString] -> [CString]
forall a. a -> [a] -> [a]
:[CString]
as)))) (([CString] -> IO a) -> [CString] -> IO a
forall a b. (a -> b) -> a -> b
$ [])

withCStringLens :: [String] -> ([CStringLen] -> IO a) -> IO a
withCStringLens :: forall a. [String] -> ([CStringLen] -> IO a) -> IO a
withCStringLens = (String
 -> (([CStringLen] -> IO a) -> IO a)
 -> ([CStringLen] -> IO a)
 -> IO a)
-> (([CStringLen] -> IO a) -> IO a)
-> [String]
-> ([CStringLen] -> IO a)
-> IO a
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\String
v ([CStringLen] -> IO a) -> IO a
kk -> \[CStringLen] -> IO a
k -> (String -> (CStringLen -> IO a) -> IO a
forall a. String -> (CStringLen -> IO a) -> IO a
withCStringLen String
v) (\CStringLen
a -> ([CStringLen] -> IO a) -> IO a
kk (\[CStringLen]
as -> [CStringLen] -> IO a
k (CStringLen
aCStringLen -> [CStringLen] -> [CStringLen]
forall a. a -> [a] -> [a]
:[CStringLen]
as)))) (([CStringLen] -> IO a) -> [CStringLen] -> IO a
forall a b. (a -> b) -> a -> b
$ [])

-- | Convenience function for accessing constant C strings
constByteStrings :: [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings :: forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings =
  (ByteString
 -> (([CStringLen] -> IO b) -> IO b)
 -> ([CStringLen] -> IO b)
 -> IO b)
-> (([CStringLen] -> IO b) -> IO b)
-> [ByteString]
-> ([CStringLen] -> IO b)
-> IO b
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\ByteString
v ([CStringLen] -> IO b) -> IO b
kk -> \[CStringLen] -> IO b
k -> (ByteString -> (CStringLen -> IO b) -> IO b
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
v) (\CStringLen
a -> ([CStringLen] -> IO b) -> IO b
kk (\[CStringLen]
as -> [CStringLen] -> IO b
k (CStringLen
aCStringLen -> [CStringLen] -> [CStringLen]
forall a. a -> [a] -> [a]
:[CStringLen]
as)))) (([CStringLen] -> IO b) -> [CStringLen] -> IO b
forall a b. (a -> b) -> a -> b
$ [])

-- | Slightly safer cousin to 'buildUnsafeByteString' that remains in the
-- 'IO' monad.
buildUnsafeByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' :: forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n CString -> IO b
k = do
  CString
ph  <- Int -> IO CString
forall a. Int -> IO (Ptr a)
mallocBytes Int
n
  ByteString
bs  <- CStringLen -> IO ByteString
unsafePackMallocCStringLen (CString
ph, Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
  b
out <- ByteString -> (CString -> IO b) -> IO b
forall a. ByteString -> (CString -> IO a) -> IO a
unsafeUseAsCString ByteString
bs CString -> IO b
k
  (b, ByteString) -> IO (b, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
out, ByteString
bs)


-- | Sometimes we have to deal with variable-length strings
buildUnsafeVariableByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeVariableByteString' :: forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeVariableByteString' Int
n CString -> IO b
k = do
  CString
ph  <- Int -> IO CString
forall a. Int -> IO (Ptr a)
mallocBytes Int
n
  b
out <- CString -> IO b
k CString
ph
  ByteString
bs  <- CString -> IO ByteString
unsafePackMallocCString CString
ph
  (b, ByteString) -> IO (b, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
out, ByteString
bs)

buildUnsafeVariableByteString :: Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeVariableByteString :: forall b. Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeVariableByteString Int
n = IO (b, ByteString) -> (b, ByteString)
forall a. IO a -> a
unsafePerformIO (IO (b, ByteString) -> (b, ByteString))
-> ((CString -> IO b) -> IO (b, ByteString))
-> (CString -> IO b)
-> (b, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO b) -> IO (b, ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeVariableByteString' Int
n

-- | Extremely unsafe function, use with utmost care! Builds a new
-- ByteString using a ccall which is given access to the raw underlying
-- pointer. Overwrites are UNCHECKED and 'unsafePerformIO' is used so
-- it's difficult to predict the timing of the 'ByteString' creation.
buildUnsafeByteString :: Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString :: forall b. Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
n = IO (b, ByteString) -> (b, ByteString)
forall a. IO a -> a
unsafePerformIO (IO (b, ByteString) -> (b, ByteString))
-> ((CString -> IO b) -> IO (b, ByteString))
-> (CString -> IO b)
-> (b, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO b) -> IO (b, ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n

-- | Build a sized random 'ByteString' using Sodium's bindings to
-- @/dev/urandom@.
randomByteString :: Int -> IO ByteString
randomByteString :: Int -> IO ByteString
randomByteString Int
n =
  ((), ByteString) -> ByteString
forall a b. (a, b) -> b
snd (((), ByteString) -> ByteString)
-> IO ((), ByteString) -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (CString -> IO ()) -> IO ((), ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
n (CString -> CInt -> IO ()
`c_randombytes_buf` Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

-- | To prevent a dependency on package 'errors'
hush :: Either s a -> Maybe a
hush :: forall s a. Either s a -> Maybe a
hush = (s -> Maybe a) -> (a -> Maybe a) -> Either s a -> Maybe a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe a -> s -> Maybe a
forall a b. a -> b -> a
const Maybe a
forall a. Maybe a
Nothing) a -> Maybe a
forall a. a -> Maybe a
Just

foreign import ccall "randombytes_buf"
  c_randombytes_buf :: Ptr CChar -> CInt -> IO ()

-- | Constant time memory comparison
foreign import ccall unsafe "sodium_memcmp"
  c_sodium_memcmp
    :: Ptr CChar -- a
    -> Ptr CChar -- b
    -> CInt   -- Length
    -> IO CInt

foreign import ccall unsafe "sodium_malloc"
  c_sodium_malloc
    :: CSize -> IO (Ptr a)

foreign import ccall unsafe "sodium_free"
  c_sodium_free
    :: Ptr Word8 -> IO ()

-- | Not sure yet what to use this for
buildUnsafeScrubbedByteString' :: Int -> (Ptr CChar -> IO b) -> IO (b,ByteString)
buildUnsafeScrubbedByteString' :: forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeScrubbedByteString' Int
n CString -> IO b
k = do
    Ptr Word8
p <- CSize -> IO (Ptr Word8)
forall a. CSize -> IO (Ptr a)
c_sodium_malloc (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)

    ByteString
bs <- Ptr Word8 -> Int -> IO () -> IO ByteString
unsafePackCStringFinalizer Ptr Word8
p Int
n (Ptr Word8 -> IO ()
c_sodium_free Ptr Word8
p)
    b
out <- ByteString -> (CString -> IO b) -> IO b
forall a. ByteString -> (CString -> IO a) -> IO a
unsafeUseAsCString ByteString
bs CString -> IO b
k
    (b, ByteString) -> IO (b, ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
out,ByteString
bs)

-- | Not sure yet what to use this for
buildUnsafeScrubbedByteString :: Int -> (Ptr CChar -> IO b) -> (b,ByteString)
buildUnsafeScrubbedByteString :: forall b. Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeScrubbedByteString Int
n = IO (b, ByteString) -> (b, ByteString)
forall a. IO a -> a
unsafePerformIO (IO (b, ByteString) -> (b, ByteString))
-> ((CString -> IO b) -> IO (b, ByteString))
-> (CString -> IO b)
-> (b, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO b) -> IO (b, ByteString)
forall b. Int -> (CString -> IO b) -> IO (b, ByteString)
buildUnsafeScrubbedByteString' Int
n

-- | Constant-time comparison
compare :: ByteString -> ByteString -> Bool
compare :: ByteString -> ByteString -> Bool
compare ByteString
a ByteString
b =
    (ByteString -> Int
S.length ByteString
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int
S.length ByteString
b) Bool -> Bool -> Bool
&& IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO ([ByteString] -> ([CStringLen] -> IO Bool) -> IO Bool
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
a, ByteString
b] (([CStringLen] -> IO Bool) -> IO Bool)
-> ([CStringLen] -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \
        [(CString
bsa, Int
_), (CString
bsb,Int
_)] ->
            (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0) (CInt -> Bool) -> IO CInt -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CString -> CString -> CInt -> IO CInt
c_sodium_memcmp CString
bsa CString
bsb (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
S.length ByteString
a))

-- | bin2hex conversion for showing various binary types
foreign import ccall unsafe "sodium_bin2hex"
  c_sodium_bin2hex
    :: Ptr CChar            -- Target zone
    -> CInt                 -- Max. length of target string (must be min. bin_len * 2 + 1)
    -> Ptr CChar            -- Source
    -> CInt                 -- Source length
    -> IO (Ptr CChar)

bin2hex :: ByteString -> String
bin2hex :: ByteString -> String
bin2hex ByteString
bs = let tlen :: Int
tlen = ByteString -> Int
S.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in
    ByteString -> String
S8.unpack (ByteString -> String)
-> ((CString -> IO CString) -> ByteString)
-> (CString -> IO CString)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasCallStack => ByteString -> ByteString
ByteString -> ByteString
S8.init (ByteString -> ByteString)
-> ((CString -> IO CString) -> ByteString)
-> (CString -> IO CString)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((CString, ByteString) -> ByteString)
-> ((CString -> IO CString) -> (CString, ByteString))
-> (CString -> IO CString)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (CString -> IO CString) -> (CString, ByteString)
forall b. Int -> (CString -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
tlen ((CString -> IO CString) -> String)
-> (CString -> IO CString) -> String
forall a b. (a -> b) -> a -> b
$ \CString
t ->
        [ByteString] -> ([CStringLen] -> IO CString) -> IO CString
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
bs] (([CStringLen] -> IO CString) -> IO CString)
-> ([CStringLen] -> IO CString) -> IO CString
forall a b. (a -> b) -> a -> b
$ \
            [(CString
pbs, Int
_)] ->
                CString -> CInt -> CString -> CInt -> IO CString
c_sodium_bin2hex CString
t (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
tlen) CString
pbs (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
S.length ByteString
bs)

uncurry3 :: (a -> b -> c -> d) -> ((a, b, c) -> d)
uncurry3 :: forall a b c d. (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 a -> b -> c -> d
f ~(a
a,b
b,c
c) = a -> b -> c -> d
f a
a b
b c
c

uncurry5 :: (a -> b -> c -> d -> e -> f) -> ((a, b, c, d, e) -> f)
uncurry5 :: forall a b c d e f.
(a -> b -> c -> d -> e -> f) -> (a, b, c, d, e) -> f
uncurry5 a -> b -> c -> d -> e -> f
f ~(a
a,b
b,c
c,d
d,e
e) = a -> b -> c -> d -> e -> f
f a
a b
b c
c d
d e
e

(!&&!) :: Bool -> Bool -> Bool
!&&! :: Bool -> Bool -> Bool
(!&&!) !Bool
a !Bool
b = Bool
a Bool -> Bool -> Bool
&& Bool
b

(!||!) :: Bool -> Bool -> Bool
!||! :: Bool -> Bool -> Bool
(!||!) !Bool
a !Bool
b = Bool
a Bool -> Bool -> Bool
|| Bool
b