module Data.SecureMem
( SecureMem
, secureMemGetSize
, secureMemCopy
, ToSecureMem(..)
, allocateSecureMem
, createSecureMem
, unsafeCreateSecureMem
, finalizeSecureMem
, withSecureMemPtr
, withSecureMemPtrSz
, withSecureMemCopy
, secureMemFromByteString
, secureMemFromByteable
) where
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr
import Data.Word (Word8)
import Data.Monoid
import Control.Applicative
import Data.Byteable
import Data.ByteString (ByteString)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as B
import qualified Data.Memory.PtrMethods as B (memSet)
import qualified Data.ByteString.Internal as BS
#if MIN_VERSION_base(4,4,0)
import System.IO.Unsafe (unsafeDupablePerformIO)
#else
import System.IO.Unsafe (unsafePerformIO)
#endif
pureIO :: IO a -> a
#if MIN_VERSION_base(4,4,0)
pureIO = unsafeDupablePerformIO
#else
pureIO = unsafePerformIO
#endif
newtype SecureMem = SecureMem ScrubbedBytes
secureMemGetSize :: SecureMem -> Int
secureMemGetSize (SecureMem scrubbedBytes) = B.length scrubbedBytes
secureMemEq :: SecureMem -> SecureMem -> Bool
secureMemEq (SecureMem sm1) (SecureMem sm2) = sm1 == sm2
secureMemAppend :: SecureMem -> SecureMem -> SecureMem
secureMemAppend (SecureMem s1) (SecureMem s2) = SecureMem (s1 `mappend` s2)
secureMemConcat :: [SecureMem] -> SecureMem
secureMemConcat = SecureMem . mconcat . map unSecureMem
where unSecureMem (SecureMem sb) = sb
secureMemCopy :: SecureMem -> IO SecureMem
secureMemCopy (SecureMem src) =
SecureMem `fmap` B.copy src (\_ -> return ())
withSecureMemCopy :: SecureMem -> (Ptr Word8 -> IO ()) -> IO SecureMem
withSecureMemCopy (SecureMem src) f = SecureMem `fmap` B.copy src f
instance Show SecureMem where
show _ = "<secure-mem>"
instance Byteable SecureMem where
toBytes = secureMemToByteString
byteableLength = secureMemGetSize
withBytePtr = withSecureMemPtr
instance Eq SecureMem where
(==) = secureMemEq
instance Monoid SecureMem where
mempty = unsafeCreateSecureMem 0 (\_ -> return ())
mappend = secureMemAppend
mconcat = secureMemConcat
class ToSecureMem a where
toSecureMem :: a -> SecureMem
instance ToSecureMem SecureMem where
toSecureMem a = a
instance ToSecureMem ByteString where
toSecureMem bs = secureMemFromByteString bs
allocateSecureMem :: Int -> IO SecureMem
allocateSecureMem sz = SecureMem <$> B.create sz (\_ -> return ())
createSecureMem :: Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem sz f = SecureMem `fmap` B.create sz f
unsafeCreateSecureMem :: Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem sz f = pureIO (createSecureMem sz f)
withSecureMemPtr :: SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr (SecureMem sm) f = B.withByteArray sm f
withSecureMemPtrSz :: SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (SecureMem sm) f = B.withByteArray sm (f (B.length sm))
finalizeSecureMem :: SecureMem -> IO ()
finalizeSecureMem (SecureMem sb) = B.withByteArray sb $ \p ->
B.memSet p 0 (B.length sb)
secureMemToByteString :: SecureMem -> ByteString
secureMemToByteString sm =
BS.unsafeCreate sz $ \dst ->
withSecureMemPtr sm $ \src ->
BS.memcpy dst src (fromIntegral sz)
where !sz = secureMemGetSize sm
secureMemFromByteString :: ByteString -> SecureMem
secureMemFromByteString b = pureIO $ do
sm <- allocateSecureMem len
withSecureMemPtr sm $ \dst -> withBytestringPtr $ \src -> BS.memcpy dst src (fromIntegral len)
return sm
where (fp, off, !len) = BS.toForeignPtr b
withBytestringPtr f = withForeignPtr fp $ \p -> f (p `plusPtr` off)
secureMemFromByteable :: Byteable b => b -> SecureMem
secureMemFromByteable bs = pureIO $ do
sm <- allocateSecureMem len
withSecureMemPtr sm $ \dst -> withBytePtr bs $ \src -> BS.memcpy dst src (fromIntegral len)
return sm
where len = byteableLength bs