{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}

-- | Like "Data.TLS.PThread", but this also exports internal functionality
-- not exposed in the public interface.
--
-- There are no API guaranteees whatsoever for this module, so use it with
-- with caution.
module Data.TLS.PThread.Internal where

import Control.Monad
import Control.Exception
import Data.IORef
import Foreign.Ptr
import Foreign.StablePtr
import Foreign.Storable(Storable(sizeOf))

#if !(MIN_VERSION_base(4,8,0))
import Data.Word (Word)
#endif

#include "../TLS_Sig.hs"
--------------------------------------------------------------------------------

type Key = Word

foreign import ccall unsafe
   get_pthread_key_size :: Int

foreign import ccall unsafe
   pthread_key_create :: Ptr Key -> Ptr () -> IO Int

foreign import ccall unsafe
   easy_make_pthread_key :: IO Key

foreign import ccall unsafe
   pthread_getspecific :: Key -> IO (StablePtr a)

foreign import ccall unsafe
   pthread_setspecific :: Key -> StablePtr a -> IO Int

foreign import ccall unsafe
   pthread_key_delete :: Key -> IO Int

check_error :: ()
check_error =
--  if get_pthread_key_size == sizeOf(0::Word)
 if get_pthread_key_size <= sizeOf(0::Word)
 then ()
 else error $ "Data.TLS.PThread: internal invariant broken!  Expected pthread_key_t to be word-sized!\n"
             ++"Instead it was: "++show get_pthread_key_size


{-# INLINE setspecific #-}
setspecific :: Key -> StablePtr a -> IO ()
setspecific k p = do
    code <- pthread_setspecific k p
    unless (code == 0) (error $ "pthread_setspecific returned error code: "++show code)

{-# INLINE delete #-}
delete :: Key -> IO ()
delete k = do
    code <- pthread_key_delete k
--    putStrLn $ "KEY DELETED: "++show k
    unless (code == 0) (error $ "pthread_key_delete returned error code: "++show code)
    return ()


--------------------------------------------------------------------------------

-- | A thread-local variable of type `a`.
data TLS a = TLS { key       :: {-# UNPACK #-} !Key
                 , mknew     :: !(IO a)
                 , allCopies :: {-# UNPACK #-} !(IORef [StablePtr a]) }

mkTLS new = do
  evaluate check_error
  key  <- easy_make_pthread_key
--   putStrLn $ "KEY CREATED: "++show key
  allC <- newIORef []
  return $! TLS key new allC

getTLS TLS{key,mknew,allCopies} = do
  p <- pthread_getspecific key
  if castStablePtrToPtr p == nullPtr then do
    a <- mknew
    sp <- newStablePtr a
    setspecific key sp
    atomicModifyIORef' allCopies (\l -> (sp:l,()))
    return a
   else
    deRefStablePtr p

allTLS TLS{allCopies} = do
    ls <- readIORef allCopies
    mapM deRefStablePtr ls

forEachTLS_ tls fn = do
  ls <- allTLS tls
  forM_ ls fn

freeAllTLS TLS{key,allCopies} = do
    ls <- readIORef allCopies
    delete key
    mapM_ freeStablePtr ls