-- SPDX-FileCopyrightText: 2021 Serokell <https://serokell.io/>
--
-- SPDX-License-Identifier: MPL-2.0

{-# LANGUAGE ConstraintKinds, RankNTypes #-}

-- | The sensitive data type internals.
module Data.SensitiveBytes.Internal
  ( withSecureMemory
  , WithSecureMemory
  , SodiumInitialised
  , SecureMemoryInitException

  , SensitiveBytes (..)
  , allocate
  , free
  , unsafePtr
  , resized

  , withSensitiveBytes
  , SensitiveBytesAllocException
  ) where

import Control.Exception.Safe (Exception, MonadMask, bracket, throwIO)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.ByteArray (ByteArrayAccess (length, withByteArray))
import Data.Reflection (Given, give, given)
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import Libsodium (sodium_free, sodium_init, sodium_malloc, sodium_memzero)


-- | A trivial proof that @sodium_init@ has been called.
data SodiumInitialised = SodiumInitialised

-- | A constraint for functions that require access to secure memory.
-- The only way to satisfy it is to call 'withSecureMemory'.
type WithSecureMemory = Given SodiumInitialised


-- | This function performs the initialisation steps
-- required for allocating data in secure memory regions.
--
-- The basic usage is to call this function and provide to it
-- a block of code that will be allocating memory for sensitive
-- data. The type of 'withSensitiveBytes' is such that it can
-- only be called withing such a code block.
--
-- Ideally, you should call 'withSecureMemory' only once and deal
-- with all your sensitive data within this single code block,
-- however it is not a requirement – you can call it as many
-- times as you wish and the only downside to doing so is that
-- it will incur a tiny performance penalty.
--
-- In some rare circumstances this function secure memory initialisation
-- may fail, in which case this function will throw
-- 'SecureMemoryInitException'.
withSecureMemory
  :: forall m r. MonadIO m
  => (WithSecureMemory => m r)  -- ^ Action to perform.
  -> m r
withSecureMemory :: (WithSecureMemory => m r) -> m r
withSecureMemory WithSecureMemory => m r
act = do
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IO CInt
sodium_init IO CInt -> (CInt -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    CInt
0 ->
      -- Ok
      () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    CInt
1 ->
      -- Already initialised, ok
      () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    CInt
_ ->
      -- sodium_init failed, not good
      SecureMemoryInitException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SecureMemoryInitException
SodiumInitFailed
  SodiumInitialised -> (WithSecureMemory => m r) -> m r
forall a r. a -> (Given a => r) -> r
give SodiumInitialised
SodiumInitialised WithSecureMemory => m r
act

-- | Exception thrown by 'withSecureMemory'.
data SecureMemoryInitException
  = SodiumInitFailed  -- ^ libsodium failed to initialise.

instance Show SecureMemoryInitException where
  show :: SecureMemoryInitException -> String
show SecureMemoryInitException
SodiumInitFailed =
    String
"Failed to initialise a secure memory region"

instance Exception SecureMemoryInitException


-- | Bytes that will be allocated in a secure memory location
-- such that they will never be moved by the garbage collector
-- and, hopefully, never swapped out to the disk (if the
-- operating system supports this kind of protection).
data SensitiveBytes s = SensitiveBytes
  { SensitiveBytes s -> Int
allocSize :: Int  -- ^ Size of the allocated buffer.
  , SensitiveBytes s -> Int
dataSize :: Int  -- ^ Size of the actual data stored.
  , SensitiveBytes s -> Ptr ()
bufPtr :: Ptr ()  -- ^ Buffer pointer.
  }

instance ByteArrayAccess (SensitiveBytes s) where
  length :: SensitiveBytes s -> Int
length SensitiveBytes{ Int
dataSize :: Int
dataSize :: forall k (s :: k). SensitiveBytes s -> Int
dataSize } = Int
dataSize
  withByteArray :: SensitiveBytes s -> (Ptr p -> IO a) -> IO a
withByteArray SensitiveBytes{ Ptr ()
bufPtr :: Ptr ()
bufPtr :: forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr } Ptr p -> IO a
act = Ptr p -> IO a
act (Ptr () -> Ptr p
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
bufPtr)

-- | Get the underlying data pointer.
--
-- This function is unsafe, because it discards the second-order context
-- and thus can allow the pointer to escape its scope and be used after free.
unsafePtr :: SensitiveBytes s -> Ptr ()
unsafePtr :: SensitiveBytes s -> Ptr ()
unsafePtr = SensitiveBytes s -> Ptr ()
forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr


-- | Allocate bytes in a protected memory region.
--
-- Just as regular @malloc@, this function can fail, for example,
-- if there is not enough memory. In this case, it will throw
-- 'SensitiveBytesAllocException'.
allocate
  :: forall s m. (MonadIO m, WithSecureMemory)
  => Int  -- ^ Size of the array (in bytes).
  -> m (SensitiveBytes s)
allocate :: Int -> m (SensitiveBytes s)
allocate Int
size = m (SensitiveBytes s) -> WithSecureMemory => m (SensitiveBytes s)
forall r. r -> WithSecureMemory => r
requiringSecureMemory (IO (SensitiveBytes s) -> m (SensitiveBytes s)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (SensitiveBytes s)
forall k (s :: k). IO (SensitiveBytes s)
act)
  where
    act :: IO (SensitiveBytes s)
act = do
      Ptr ()
res <- (Any ::: CSize) -> IO (Ptr ())
forall k (size :: k) a. (Any ::: CSize) -> IO (Ptr a)
sodium_malloc (Int -> Any ::: CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size)
      if Ptr ()
res Ptr () -> Ptr () -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr ()
forall a. Ptr a
nullPtr then
        SensitiveBytesAllocException -> IO (SensitiveBytes s)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SensitiveBytesAllocException
SodiumMallocFailed
      else
        SensitiveBytes s -> IO (SensitiveBytes s)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SensitiveBytes s -> IO (SensitiveBytes s))
-> SensitiveBytes s -> IO (SensitiveBytes s)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Ptr () -> SensitiveBytes s
forall k (s :: k). Int -> Int -> Ptr () -> SensitiveBytes s
SensitiveBytes Int
size Int
size Ptr ()
res

-- | Free bytes previously allocated in a protected memory region.
free
  :: forall s m. (MonadIO m, WithSecureMemory)
  => SensitiveBytes s
  -> m ()
free :: SensitiveBytes s -> m ()
free SensitiveBytes{ Ptr ()
bufPtr :: Ptr ()
bufPtr :: forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr } = m () -> WithSecureMemory => m ()
forall r. r -> WithSecureMemory => r
requiringSecureMemory (m () -> WithSecureMemory => m ())
-> m () -> WithSecureMemory => m ()
forall a b. (a -> b) -> a -> b
$
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Ptr () -> IO ()
forall k (addr :: k) x. (addr ::: Ptr x) -> IO ()
sodium_free Ptr ()
bufPtr

-- | Zero-out memory.
memzero
  :: forall s m. (MonadIO m)
  => SensitiveBytes s
  -> m ()
memzero :: SensitiveBytes s -> m ()
memzero SensitiveBytes{ Int
allocSize :: Int
allocSize :: forall k (s :: k). SensitiveBytes s -> Int
allocSize, Ptr ()
bufPtr :: Ptr ()
bufPtr :: forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr } =
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Ptr () -> (Any ::: CSize) -> IO ()
forall k1 k2 (pnt :: k1) x (len :: k2).
(pnt ::: Ptr x) -> (Any ::: CSize) -> IO ()
sodium_memzero Ptr ()
bufPtr (Int -> Any ::: CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
allocSize)

-- | Rewrite the recorded size of the data.
--
-- This is a very dangerous internal-only function. It is essentially
-- a hack that allows other functions exported from this library to
-- efficiently read data of unknown size by first allocating a large buffer
-- and then tweaking the 'ByteArrayAccess' instance to return the size that
-- is smaller than what was actually allocated.
resized
  :: forall s. ()
  => Int  -- ^ New data size.
  -> SensitiveBytes s  -- ^ What to resize.
  -> SensitiveBytes s
resized :: Int -> SensitiveBytes s -> SensitiveBytes s
resized Int
newSize sb :: SensitiveBytes s
sb@SensitiveBytes{ Int
allocSize :: Int
allocSize :: forall k (s :: k). SensitiveBytes s -> Int
allocSize }
  | Int
newSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
allocSize = SensitiveBytes s
sb{ dataSize :: Int
dataSize = Int
newSize }
  | Bool
otherwise = String -> SensitiveBytes s
forall a. HasCallStack => String -> a
error String
"SensitiveBytes.Internal.resized: the new size is too large"


-- | Allocate a byte array in a secure memory region.
--
-- This function guarantees that:
--
-- 1. The garbage collector will not touch the allocated memory and
--    will not try to copy the sensitive data.
-- 2. The memory will be zeroed-out and freed as soon as the computation
--    finishes.
--
-- Additionally, it will try its best (subject to the support from
-- the operating system) to do the following:
--
-- 1. Allocate the buffer at the end of a page and make sure that the
--    following page is not mapped, so trying to access past the end of
--    the buffer will crash the program.
-- 2. Place a canary immediately before the buffer, check that it was not
--    modified before deallocating the memory, and crash the program otherwise.
-- 3. @mlock@ the memory to make sure it will not be paged to the disk.
-- 4. Ask the operating system not to include this memory in core dumps.
--
-- Just as with regular @malloc@, allocation can fail, for example,
-- if there is not enough memory. In this case, the function will throw
-- 'SensitiveBytesAllocException'.
withSensitiveBytes
  :: forall s m r. (MonadIO m, MonadMask m, WithSecureMemory)
  => Int  -- ^ Size of the array (in bytes).
  -> (SensitiveBytes s -> m r)  -- ^ Action to perform with memory allocated.
  -> m r
-- TODO: libsodium docs also say something about the allocated size being
-- a multiple of the required alignment, but it is not clear what the
-- implications are (I added a test, just in case).
withSensitiveBytes :: Int -> (SensitiveBytes s -> m r) -> m r
withSensitiveBytes Int
size = m (SensitiveBytes s)
-> (SensitiveBytes s -> m ()) -> (SensitiveBytes s -> m r) -> m r
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Int -> m (SensitiveBytes s)
forall k (s :: k) (m :: * -> *).
(MonadIO m, WithSecureMemory) =>
Int -> m (SensitiveBytes s)
allocate Int
size) SensitiveBytes s -> m ()
forall k (f :: * -> *) (s :: k).
MonadIO f =>
SensitiveBytes s -> f ()
finalise
  where
    -- OK, this is weird, but libsodium has a whole bunch of ifdefs that
    -- control the logic of @sodium_free@ and, for some reason, if it does
    -- not @HAVE_ALIGNED_MALLOC@, it will not zero-out the memory.
    -- Cool story, but this makes no sense, so we zero-out it ourselves
    -- in case we are on such a system.
    finalise :: SensitiveBytes s -> f ()
finalise SensitiveBytes s
sb = SensitiveBytes s -> f ()
forall k (s :: k) (m :: * -> *).
MonadIO m =>
SensitiveBytes s -> m ()
memzero SensitiveBytes s
sb f () -> f () -> f ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> SensitiveBytes s -> f ()
forall k (s :: k) (m :: * -> *).
(MonadIO m, WithSecureMemory) =>
SensitiveBytes s -> m ()
free SensitiveBytes s
sb

-- | Exception thrown by 'withSensitiveBytes'.
data SensitiveBytesAllocException
  = SodiumMallocFailed  -- ^ @sodium_malloc@ returned NULL.

instance Show SensitiveBytesAllocException where
  show :: SensitiveBytesAllocException -> String
show SensitiveBytesAllocException
SodiumMallocFailed =
    String
"Failed to allocate secure memory"

instance Exception SensitiveBytesAllocException



-- | An internal helper that fakes needing "WithSecureMemory".
--
-- It is a complete no-op and exists only to silence the unused constraint
-- warning. Hopefully, it will get optimised away every time.
requiringSecureMemory :: r -> (WithSecureMemory => r)
requiringSecureMemory :: r -> WithSecureMemory => r
requiringSecureMemory r
act = (\SodiumInitialised
_ -> r
act) (SodiumInitialised
forall a. Given a => a
given :: SodiumInitialised)