{-# LANGUAGE ScopedTypeVariables    #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Windll
-- Copyright   :  (c) Tamar Christina 2009 - 2010
-- License     :  BSD3
-- 
-- Maintainer  :  tamar@zhox.com
-- Stability   :  experimental
-- Portability :  portable
--
-- These re-exported allocation functions allows us to
-- keep track of memory allocations within our program.
-- We can later analyze this to find memory leaks
--
-----------------------------------------------------------------------------

module WinDll.Debug.Alloc
 ( alloca
 , malloc
 , realloc
 , record
 , recordM
 , freeDefault
 , freeUnknown
 , free
 ) where

import WinDll.Debug.Records
import WinDll.Debug.Stack ( Stack, Stackable(..) )

import qualified Foreign.Marshal.Alloc as F
import Foreign ( Storable(..), Ptr)

import Data.Time.Clock ( getCurrentTime, UTCTime )
import Data.Time.Format ( formatTime )
import System.Locale

-- | Re-export of the Foreign.Marshal.Alloc.alloca function taking an explicit stack
alloca :: forall a b st. (Storable a, Stackable st) => st -> (Ptr a -> IO b) -> IO b
alloca stack fn = do ptr  <- F.alloca (return :: Ptr a -> IO (Ptr a))
                     val  <- F.alloca fn
                     time <- getCurrentTime
                     let start = read (show ptr)
                         size  = sizeOf (undefined :: a)
                         mem   = MemAlloc { memFun   = Alloc
                                          , memStack = toStack stack
                                          , memStart = start
                                          , memStop  = Just (start + size)
                                          , memSize  = Just size
                                          , memTime  = formatTime defaultTimeLocale "%s" time
                                          }
                     writeMemAlloc mem
                     return val
                     
-- | Re-export of the Foreign.Marshal.Alloc.malloc function taking an explicit stack
malloc :: forall a st. (Storable a, Stackable st) => st -> IO (Ptr a)
malloc stack = do ptr  <- F.malloc
                  time <- getCurrentTime
                  let start = read (show ptr)
                      size  = sizeOf (undefined :: a)
                      mem   = MemAlloc { memFun   = Malloc
                                       , memStack = toStack stack
                                       , memStart = start
                                       , memStop  = Just (start + size)
                                       , memSize  = Just size
                                       , memTime  = formatTime defaultTimeLocale "%s" time
                                       }
                  writeMemAlloc mem
                  return ptr

-- | Re-export of the Foreign.Marshal.Alloc.realloc function taking an explicit stack                  
realloc :: forall a b st. (Storable b, Stackable st) => st -> Ptr a -> IO (Ptr b)
realloc stack ptr = do ptr' <- F.realloc ptr
                       time <- getCurrentTime
                       let start = read (show ptr)
                           size  = sizeOf (undefined :: b)
                           mem   = MemAlloc { memFun   = ReAlloc
                                            , memStack = toStack stack
                                            , memStart = start
                                            , memStop  = Just (start + size)
                                            , memSize  = Just size
                                            , memTime  = formatTime defaultTimeLocale "%s" time
                                            }
                       writeMemAlloc mem
                       return ptr'
                       
-- | This function record the allocation information about a Ptr we couldn't override it's allocation function for.
record :: forall a st m. (Storable a, Stackable st, Show (m a)) => st -> m a -> IO ()
record stack ptr = do time <- getCurrentTime
                      let start = read (show ptr)
                          size  = sizeOf (undefined :: a)
                          mem   = MemAlloc { memFun   = Record
                                           , memStack = toStack stack
                                           , memStart = start
                                           , memStop  = Just (start + size)
                                           , memSize  = Just size
                                           , memTime  = formatTime defaultTimeLocale "%s" time
                                           }
                      writeMemAlloc mem
         
-- | This function mirrors /record/ in its function. Except that the value it's given might not be a ptr value
recordM :: forall a st b. (Storable a, Stackable st) => st -> b -> (b -> IO (Ptr a)) -> IO (Ptr a)
recordM stack val mk = do ptr <- mk val
                          record stack ptr
                          return ptr
                       
-- | Free a ptr that is not a normal Ptr type, (e.g. ForeignPtr etc) 
freeDefault :: forall a st m. (Storable a, Stackable st, Show (m a)) => st -> m a -> (m a -> IO ()) -> IO ()
freeDefault stack ptr fn = do time <- getCurrentTime
                              let start = read (show ptr)
                                  size  = sizeOf (undefined :: a)
                                  mem   = MemFree { memStack = toStack stack
                                                  , memStart = start
                                                  , memSize  = Just size
                                                  , memTime  = formatTime defaultTimeLocale "%s" time
                                                  }
                              writeMemAlloc mem
                              fn ptr                
                              
-- | Free a ptr that is not a normal Ptr type, (e.g. ForeignPtr etc), 
--   however it does not record the size of the value held by the pointer.
--   as such, it has no need for the Storable requirement
freeUnknown :: forall a st m. (Stackable st, Show (m a)) => st -> m a -> (m a -> IO ()) -> IO ()
freeUnknown stack ptr fn = do time <- getCurrentTime
                              let start = read (show ptr)
                                  mem   = MemFree { memStack = toStack stack
                                                  , memStart = start
                                                  , memSize  = Nothing
                                                  , memTime  = formatTime defaultTimeLocale "%s" time
                                                  }
                              writeMemAlloc mem
                              fn ptr
                       
-- | Re-export of the Foreign.Marshal.Alloc.free function taking an explicit stack  
free :: forall a st. (Storable a, Stackable st) => st -> Ptr a -> IO ()
free stack ptr = freeDefault stack ptr F.free