{-# LANGUAGE MagicHash #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context (
withBLAS
) where
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Control.Monad.State
import Control.Concurrent.MVar
import Data.IntMap.Strict ( IntMap )
import System.IO.Unsafe
import qualified Data.IntMap.Strict as IM
import qualified Foreign.CUDA.Driver.Context as CUDA
import qualified Foreign.CUDA.BLAS as BLAS
import GHC.Ptr
import GHC.Base
import Prelude hiding ( lookup )
withBLAS :: (BLAS.Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS k = do
lc <- gets (deviceContext . ptxContext)
h <- liftIO $
withLifetime lc $ \ctx ->
modifyMVar handles $ \im ->
let key = toKey ctx in
case IM.lookup key im of
Nothing -> do
h <- BLAS.create
l <- newLifetime h
BLAS.setAtomicsMode h BLAS.Allowed
addFinalizer lc $ modifyMVar handles (\im' -> return (IM.delete key im', ()))
addFinalizer l $ BLAS.destroy h
return ( IM.insert key l im, l )
Just h -> return (im, h)
withLifetime' h k
toKey :: CUDA.Context -> IM.Key
toKey (CUDA.Context (Ptr addr#)) = I# (addr2Int# addr#)
{-# NOINLINE handles #-}
handles :: MVar (IntMap (Lifetime BLAS.Handle))
handles = unsafePerformIO $ newMVar IM.empty