{-# LANGUAGE CPP, MagicHash, UnboxedTuples #-}
  This module assumes that you are familiar with the funcitons
  of the same name in Foreign.Marshall.Alloc, Foreign.Marshal.AllocArray
  and in Foreign.ForeignPtr.

  The functions are generalized to work in any monad which is an
  instance of MonadCatchIO.
module Control.Monad.CatchIO.Foreign
    ( alloca
    , allocaBytes
    , allocaArray
    , allocaArray0
    , withForeignPtr

-- this is a bit dodgy - we use the MIN_VERSION macro to
-- see if we're compiled against transformers or mtl
#ifdef MIN_VERSION_transformers
import Control.Monad.IO.Class(liftIO)
import Control.Monad.Trans (liftIO)

import Control.Monad.CatchIO

import qualified Foreign as F
import Foreign (Ptr,ForeignPtr, sizeOf, alignment, Storable)

import qualified Data.Primitive as P
import GHC.Exts
import GHC.IOBase hiding (liftIO)

alloca :: (F.Storable a, MonadCatchIO m) => (Ptr a -> m b) -> m b
alloca  = doAlloca undefined
    doAlloca       :: (MonadCatchIO m', Storable a') => a' -> (Ptr a' -> m' b') -> m' b'
    doAlloca dummy  = allocaBytesAligned (F.sizeOf dummy) (F.alignment dummy)

allocaBytes :: (MonadCatchIO m) => Int -> (Ptr a -> m b) -> m b
allocaBytes size = bracket (liftIO $ F.mallocBytes size) (liftIO . F.free)

allocaBytesAligned :: (MonadCatchIO m) => Int -> Int -> (Ptr a -> m b) -> m b
allocaBytesAligned size alignment k
 = do
  ba <- liftIO $ P.newAlignedPinnedByteArray size alignment >>= P.unsafeFreezeByteArray
  r <- k $ case P.byteArrayContents ba of
             P.Addr addr# -> Ptr addr#
  liftIO $ touch ba
  return r

touch :: a -> IO ()
touch a = IO $ \s -> case touch# a s of
                       s' -> (# s', () #)

allocaBytesAligned size _ = allocaBytes size -- wrong, but the FFI doesn't offer anything else

allocaArray :: (F.Storable a, MonadCatchIO m) => Int -> (Ptr a -> m b) -> m b
allocaArray  = doAlloca undefined
    doAlloca            :: (Storable a', MonadCatchIO m') => a' -> Int -> (Ptr a' -> m' b') -> m' b'
    doAlloca dummy size  = allocaBytesAligned (size * sizeOf dummy) (alignment dummy)

allocaArray0 :: (F.Storable a, MonadCatchIO m) => Int -> (Ptr a -> m b) -> m b
allocaArray0 size = allocaArray (size + 1)

withForeignPtr :: (MonadCatchIO m) => ForeignPtr a -> (Ptr a -> m b) -> m b
withForeignPtr fo io
    = do r <- io (F.unsafeForeignPtrToPtr fo)
         liftIO $ F.touchForeignPtr fo
         return r