module Foreign.Marshal.Array.Guarded.Debug ( create, alloca, ) where import qualified Foreign.Marshal.Array.Guarded.Plain as Plain import Foreign.Marshal.Array (mallocArray, allocaArray, pokeArray, copyArray, advancePtr) import Foreign.Marshal.Alloc (free) import Foreign.Storable (Storable, peekByteOff, sizeOf) import Foreign.Concurrent (newForeignPtr) import Foreign.ForeignPtr (ForeignPtr) import Foreign.Ptr (Ptr, castPtr) import Control.Monad (when) import Data.Foldable (for_) import Data.Word (Word8) {- | Array creation with additional immutability check, electrical fence and pollution of uncleaned memory. The function checks that the array is not altered anymore after creation. -} create :: (Storable a) => Int -> (Ptr a -> IO b) -> IO (ForeignPtr a, b) create = flip asTypeOf Plain.create $ \size f -> do let border = 64 let fullSize = size + 2*border ptrApre <- mallocArray fullSize ptrsA@(_ptrApre, ptrA, _ptrApost) <- fillAll border size ptrApre result <- f ptrA checkAll border ptrsA ptrB <- mallocArray size copyArray ptrB ptrA size fmap (flip (,) result) $ newForeignPtr ptrA $ do for_ (take (arraySize ptrA size) [0..]) $ \i -> do a <- peekByteOff ptrA i b <- peekByteOff ptrB i when (a/=(b::Word8)) $ error $ "immutable array was altered at byte position " ++ show i trash fullSize ptrApre free ptrApre free ptrB alloca :: (Storable a) => Int -> (Ptr a -> IO b) -> IO b alloca = flip asTypeOf Plain.alloca $ \size f -> do let border = 64 let fullSize = size + 2*border allocaArray fullSize $ \ptrPre -> do ptrs@(_ptrPre, ptr, _ptrPost) <- fillAll border size ptrPre result <- f ptr checkAll border ptrs trash fullSize ptrPre return result fillAll :: (Storable a) => Int -> Int -> Ptr a -> IO (Ptr a, Ptr a, Ptr a) fillAll border size ptrPre = do let ptr = advancePtr ptrPre border let ptrPost = advancePtr ptr size fill ptrPre border [0xAB,0xAD,0xCA,0xFE] fill ptr size [0xDE,0xAD,0xF0,0x0D] fill ptrPost border [0xAB,0xAD,0xCA,0xFE] return (ptrPre, ptr, ptrPost) checkAll :: (Storable a) => Int -> (Ptr a, Ptr a, Ptr a) -> IO () checkAll border (ptrPre, _ptr, ptrPost) = do check "leading" ptrPre border [0xAB,0xAD,0xCA,0xFE] check "trailing" ptrPost border [0xAB,0xAD,0xCA,0xFE] trash :: (Storable a) => Int -> Ptr a -> IO () trash fullSize ptrPre = fill ptrPre fullSize [0xDE,0xAD,0xBE,0xEF] {-# INLINE fill #-} fill :: (Storable a) => Ptr a -> Int -> [Word8] -> IO () fill ptr n bytes = pokeArray (castPtr ptr) $ take (arraySize ptr n) $ cycle bytes {-# INLINE check #-} check :: (Storable a) => String -> Ptr a -> Int -> [Word8] -> IO () check name ptr n bytes = for_ (take (arraySize ptr n) $ zip [0..] $ cycle bytes) $ \(i,b) -> do a <- peekByteOff ptr i when (a/=(b::Word8)) $ error $ "damaged " ++ name ++ " fence at position " ++ show i arraySize :: (Storable a) => Ptr a -> Int -> Int arraySize ptr n = arraySizeAux ptr n $ error "arraySize: undefined element" {- | Correct size computation should also respect padding caused by alignment. However, mallocArray uses this simple arithmetic. -} arraySizeAux :: (Storable a) => Ptr a -> Int -> a -> Int arraySizeAux _ n a = n * sizeOf a