{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CApiFFI #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Sel.Internal where

import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.Base16.Types as Base16
import qualified Data.ByteString.Base16 as Base16
import qualified Data.ByteString.Internal as BS
import Data.Kind (Type)
import Foreign (Ptr, castForeignPtr)
import Foreign.C.Types (CInt (CInt), CSize (CSize))
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import LibSodium.Bindings.SecureMemory (sodiumFree, sodiumMalloc)

-- | This calls to C's @memcmp@ function, used in lieu of
-- libsodium's @memcmp@ in cases when the return code is necessary.
foreign import capi unsafe "string.h memcmp"
  memcmp :: Ptr a -> Ptr b -> CSize -> IO CInt

-- | Compare if the contents of two @ForeignPtr@s are equal.
foreignPtrEq :: ForeignPtr a -> ForeignPtr a -> CSize -> IO Bool
foreignPtrEq :: forall a. ForeignPtr a -> ForeignPtr a -> CSize -> IO Bool
foreignPtrEq ForeignPtr a
fptr1 ForeignPtr a
fptr2 CSize
size =
  ForeignPtr a -> (Ptr a -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr1 ((Ptr a -> IO Bool) -> IO Bool) -> (Ptr a -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr a
p ->
    ForeignPtr a -> (Ptr a -> IO Bool) -> IO Bool
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr2 ((Ptr a -> IO Bool) -> IO Bool) -> (Ptr a -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr a
q ->
      do
        CInt
result <- Ptr a -> Ptr a -> CSize -> IO CInt
forall a b. Ptr a -> Ptr b -> CSize -> IO CInt
memcmp Ptr a
p Ptr a
q CSize
size
        Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ CInt
0 CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
result

-- | Compare the contents of two @ForeignPtr@s using lexicographical ordering.
foreignPtrOrd :: ForeignPtr a -> ForeignPtr a -> CSize -> IO Ordering
foreignPtrOrd :: forall a. ForeignPtr a -> ForeignPtr a -> CSize -> IO Ordering
foreignPtrOrd ForeignPtr a
fptr1 ForeignPtr a
fptr2 CSize
size =
  ForeignPtr a -> (Ptr a -> IO Ordering) -> IO Ordering
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr1 ((Ptr a -> IO Ordering) -> IO Ordering)
-> (Ptr a -> IO Ordering) -> IO Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr a
p ->
    ForeignPtr a -> (Ptr a -> IO Ordering) -> IO Ordering
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr2 ((Ptr a -> IO Ordering) -> IO Ordering)
-> (Ptr a -> IO Ordering) -> IO Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr a
q ->
      do
        CInt
result <- Ptr a -> Ptr a -> CSize -> IO CInt
forall a b. Ptr a -> Ptr b -> CSize -> IO CInt
memcmp Ptr a
p Ptr a
q CSize
size
        Ordering -> IO Ordering
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ordering -> IO Ordering) -> Ordering -> IO Ordering
forall a b. (a -> b) -> a -> b
$
          if
            | CInt
result CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0 -> Ordering
EQ
            | CInt
result CInt -> CInt -> Bool
forall a. Ord a => a -> a -> Bool
< CInt
0 -> Ordering
LT
            | Bool
otherwise -> Ordering
GT

foreignPtrShow :: ForeignPtr a -> CSize -> String
foreignPtrShow :: forall a. ForeignPtr a -> CSize -> String
foreignPtrShow ForeignPtr a
fptr CSize
size =
  ByteString -> String
BS.unpackChars (ByteString -> String)
-> (ByteString -> ByteString) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Base16 ByteString -> ByteString
forall a. Base16 a -> a
Base16.extractBase16 (Base16 ByteString -> ByteString)
-> (ByteString -> Base16 ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Base16 ByteString
Base16.encodeBase16' (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$
    ForeignPtr Word8 -> Int -> Int -> ByteString
BS.fromForeignPtr (ForeignPtr a -> ForeignPtr Word8
forall a b. ForeignPtr a -> ForeignPtr b
Foreign.castForeignPtr ForeignPtr a
fptr) Int
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral @CSize @Int CSize
size)

-- | Securely allocate an amount of memory with 'sodiumMalloc' and pass
-- a pointer to the region to the provided action.
-- The region is deallocated with 'sodiumFree' afterwards.
-- Do not try to jailbreak the pointer outside of the action,
-- this will not be pleasant.
allocateWith
  :: forall (a :: Type) (b :: Type) (m :: Type -> Type)
   . MonadIO m
  => CSize
  -- ^ Amount of memory to allocate
  -> (Ptr a -> m b)
  -- ^ Action to perform on the memory
  -> m b
allocateWith :: forall a b (m :: * -> *).
MonadIO m =>
CSize -> (Ptr a -> m b) -> m b
allocateWith CSize
size Ptr a -> m b
action = do
  !Ptr a
ptr <- IO (Ptr a) -> m (Ptr a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Ptr a) -> m (Ptr a)) -> IO (Ptr a) -> m (Ptr a)
forall a b. (a -> b) -> a -> b
$ CSize -> IO (Ptr a)
forall a. CSize -> IO (Ptr a)
sodiumMalloc CSize
size
  !b
result <- Ptr a -> m b
action Ptr a
ptr
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> IO ()
forall a. Ptr a -> IO ()
sodiumFree Ptr a
ptr
  b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
result