module Botan.Utility
( constantTimeCompare
, scrubMemory
, scrub
, scrubArray
, scrubForeignPtr
, scrubForeignPtrArray
, scrubByteString
, HexCase(..)
, hexEncode
, hexDecode
, base64Encode
, base64Decode
) where

import System.IO.Unsafe

import Data.ByteString.Unsafe as ByteString
import Data.ByteString.Internal as ByteString

import Botan.Low.Utility (HexEncodingFlags(..), pattern HexUpperCase, pattern HexLowerCase)
import qualified Botan.Low.Utility as Low
import qualified Botan.Bindings.Utility as Bindings

import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Storable

import Botan.Error
import Botan.Prelude

-- | Returns 0 if x[0..len] == y[0..len], -1 otherwise.
constantTimeCompare :: ByteString -> ByteString -> Int -> Bool
constantTimeCompare :: ByteString -> ByteString -> Int -> Bool
constantTimeCompare = (ByteString -> ByteString -> Int -> IO Bool)
-> ByteString -> ByteString -> Int -> Bool
forall a b c d. (a -> b -> c -> IO d) -> a -> b -> c -> d
unsafePerformIO3 ByteString -> ByteString -> Int -> IO Bool
Low.constantTimeCompare

-- TODO: randomizeMemory and variants?

scrubMemory :: (MonadIO m) => Ptr a -> Int -> m ()
scrubMemory :: forall (m :: * -> *) a. MonadIO m => Ptr a -> Int -> m ()
scrubMemory Ptr a
ptr Int
sz = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> Int -> IO ()
forall a. Ptr a -> Int -> IO ()
Low.scrubMem Ptr a
ptr Int
sz

scrub :: (MonadIO m, Storable a) => Ptr a -> m ()
scrub :: forall (m :: * -> *) a. (MonadIO m, Storable a) => Ptr a -> m ()
scrub Ptr a
ptr = Ptr a -> Int -> m ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> Int -> m ()
scrubMemory Ptr a
ptr (Ptr a -> Int
forall a. Storable a => a -> Int
sizeOf Ptr a
ptr)

scrubArray :: (MonadIO m, Storable a) => Int -> Ptr a -> m ()
scrubArray :: forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
Int -> Ptr a -> m ()
scrubArray Int
n Ptr a
ptr  = Ptr a -> Int -> m ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> Int -> m ()
scrubMemory Ptr a
ptr (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Ptr a -> Int
forall a. Storable a => a -> Int
sizeOf Ptr a
ptr)

scrubForeignPtr :: (MonadIO m, Storable a) => ForeignPtr a -> m ()
scrubForeignPtr :: forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
ForeignPtr a -> m ()
scrubForeignPtr ForeignPtr a
fptr = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr Ptr a -> IO ()
forall (m :: * -> *) a. (MonadIO m, Storable a) => Ptr a -> m ()
scrub

scrubForeignPtrArray :: (MonadIO m, Storable a) => Int -> ForeignPtr a -> m ()
scrubForeignPtrArray :: forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
Int -> ForeignPtr a -> m ()
scrubForeignPtrArray Int
n ForeignPtr a
fptr = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr (Int -> Ptr a -> IO ()
forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
Int -> Ptr a -> m ()
scrubArray Int
n)

-- TODO: Rename scrubByteStringImmediately?
scrubByteString :: (MonadIO m) => ByteString -> m ()
scrubByteString :: forall (m :: * -> *). MonadIO m => ByteString -> m ()
scrubByteString ByteString
bs = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
bs ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (Ptr CChar -> Int -> IO ()) -> CStringLen -> IO ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Ptr CChar -> Int -> IO ()) -> CStringLen -> IO ())
-> (Ptr CChar -> Int -> IO ()) -> CStringLen -> IO ()
forall a b. (a -> b) -> a -> b
$ (Int -> Ptr CChar -> IO ()) -> Ptr CChar -> Int -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Ptr CChar -> IO ()
forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
Int -> Ptr a -> m ()
scrubArray

-- TODO: Attach a scrubbing finalizer
-- This will require freeing the finalizer funptr from inside itself.
-- SEE: https://mail.haskell.org/pipermail/glasgow-haskell-users/2006-March/009910.html
-- scrubByteStringFinalizer :: (MonadIO m) => ByteString -> m ()
-- scrubByteStringFinalizer bs = liftIO $ addForeignPtrFinalizer _ fptr where
--     (fptr,_,_) = ByteString.toForeignPtr bs

data HexCase = Upper | Lower

hexEncodingFlags :: HexCase -> HexEncodingFlags
hexEncodingFlags :: HexCase -> HexEncodingFlags
hexEncodingFlags HexCase
Upper = HexEncodingFlags
Low.HexUpperCase
hexEncodingFlags HexCase
Lower = HexEncodingFlags
Low.HexLowerCase

-- TODO: Discuss ergonomics of flipping argument order
hexEncode :: ByteString -> HexCase -> Text
hexEncode :: ByteString -> HexCase -> Text
hexEncode ByteString
bs HexCase
c = IO Text -> Text
forall a. IO a -> a
unsafePerformIO (IO Text -> Text) -> IO Text -> Text
forall a b. (a -> b) -> a -> b
$ ByteString -> HexEncodingFlags -> IO Text
Low.hexEncode ByteString
bs (HexCase -> HexEncodingFlags
hexEncodingFlags HexCase
c)
{-# NOINLINE hexEncode #-}

-- | "Hex decode some data"
hexDecode :: Text -> ByteString
hexDecode :: Text -> ByteString
hexDecode = (Text -> IO ByteString) -> Text -> ByteString
forall a b. (a -> IO b) -> a -> b
unsafePerformIO1 Text -> IO ByteString
Low.hexDecode

base64Encode :: ByteString -> Text
base64Encode :: ByteString -> Text
base64Encode = (ByteString -> IO Text) -> ByteString -> Text
forall a b. (a -> IO b) -> a -> b
unsafePerformIO1 ByteString -> IO Text
Low.base64Encode

base64Decode :: Text -> ByteString
base64Decode :: Text -> ByteString
base64Decode = (Text -> IO ByteString) -> Text -> ByteString
forall a b. (a -> IO b) -> a -> b
unsafePerformIO1 Text -> IO ByteString
Low.base64Decode