{-# LANGUAGE MagicHash       #-}
{-# LANGUAGE RecordWildCards #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Context
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Context (

  Context(..),
  new, raw, withContext,

) where

import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.LLVM.PTX.Analysis.Device
import qualified Data.Array.Accelerate.LLVM.PTX.Debug           as Debug

import qualified Foreign.CUDA.Driver.Device                     as CUDA
import qualified Foreign.CUDA.Driver.Context                    as CUDA

import Control.Concurrent
import Control.Exception
import Control.Monad
import Data.Hashable
import Text.PrettyPrint
import Prelude                                                  hiding ( (<>) )

import GHC.Base                                                 ( Int(..), addr2Int#, )
import GHC.Ptr                                                  ( Ptr(..) )


-- | An execution context, which is tied to a specific device and CUDA execution
-- context.
--
data Context = Context {
    Context -> DeviceProperties
deviceProperties    :: {-# UNPACK #-} !CUDA.DeviceProperties        -- information on hardware resources
  , Context -> Lifetime Context
deviceContext       :: {-# UNPACK #-} !(Lifetime CUDA.Context)      -- device execution context
  }

instance Eq Context where
  Context
c1 == :: Context -> Context -> Bool
== Context
c2 = Context -> Lifetime Context
deviceContext Context
c1 Lifetime Context -> Lifetime Context -> Bool
forall a. Eq a => a -> a -> Bool
== Context -> Lifetime Context
deviceContext Context
c2

instance Hashable Context where
  hashWithSalt :: Int -> Context -> Int
hashWithSalt Int
salt =
    let
        ptrToInt :: Ptr a -> Int
        ptrToInt :: Ptr a -> Int
ptrToInt (Ptr Addr#
addr#) = Int# -> Int
I# (Addr# -> Int#
addr2Int# Addr#
addr#)
    in
    Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (Int -> Int) -> (Context -> Int) -> Context -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr () -> Int
forall a. Ptr a -> Int
ptrToInt (Ptr () -> Int) -> (Context -> Ptr ()) -> Context -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> Ptr ()
CUDA.useContext (Context -> Ptr ()) -> (Context -> Context) -> Context -> Ptr ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lifetime Context -> Context
forall a. Lifetime a -> a
unsafeGetValue (Lifetime Context -> Context)
-> (Context -> Lifetime Context) -> Context -> Context
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> Lifetime Context
deviceContext


-- | Create a new CUDA execution context
--
new :: CUDA.Device
    -> CUDA.DeviceProperties
    -> [CUDA.ContextFlag]
    -> IO Context
new :: Device -> DeviceProperties -> [ContextFlag] -> IO Context
new Device
dev DeviceProperties
prp [ContextFlag]
flags = do
  Context
ctx <- Device -> DeviceProperties -> Context -> IO Context
raw Device
dev DeviceProperties
prp (Context -> IO Context) -> IO Context -> IO Context
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Device -> [ContextFlag] -> IO Context
CUDA.create Device
dev [ContextFlag]
flags
  Context
_   <- IO Context
CUDA.pop
  Context -> IO Context
forall (m :: * -> *) a. Monad m => a -> m a
return Context
ctx

-- | Wrap a raw CUDA execution context
--
raw :: CUDA.Device
    -> CUDA.DeviceProperties
    -> CUDA.Context
    -> IO Context
raw :: Device -> DeviceProperties -> Context -> IO Context
raw Device
dev DeviceProperties
prp Context
ctx = do
  Lifetime Context
lft <- Context -> IO (Lifetime Context)
forall a. a -> IO (Lifetime a)
newLifetime Context
ctx
  Lifetime Context -> IO () -> IO ()
forall a. Lifetime a -> IO () -> IO ()
addFinalizer Lifetime Context
lft (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    String -> IO ()
message (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"finalise context " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Context -> String
showContext Context
ctx
    Context -> IO ()
CUDA.destroy Context
ctx

  -- The kernels don't use much shared memory, so for devices that support it
  -- prefer using those memory banks as an L1 cache.
  --
  -- TLM: Is this a good idea? For example, external libraries such as cuBLAS
  -- rely heavily on shared memory and thus this could adversely affect
  -- performance. Perhaps we should use 'setCacheConfigFun' for individual
  -- functions which might benefit from this.
  --
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DeviceProperties -> Compute
CUDA.computeCapability DeviceProperties
prp Compute -> Compute -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Int -> Compute
CUDA.Compute Int
2 Int
0)
       (Cache -> IO ()
CUDA.setCache Cache
CUDA.PreferL1)

  -- Display information about the selected device
  Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_phases (Device -> DeviceProperties -> String
deviceInfo Device
dev DeviceProperties
prp)

  Context -> IO Context
forall (m :: * -> *) a. Monad m => a -> m a
return (Context -> IO Context) -> Context -> IO Context
forall a b. (a -> b) -> a -> b
$! DeviceProperties -> Lifetime Context -> Context
Context DeviceProperties
prp Lifetime Context
lft


-- | Push the context onto the CPUs thread stack of current contexts and execute
-- some operation.
--
{-# INLINE withContext #-}
withContext :: Context -> IO a -> IO a
withContext :: Context -> IO a -> IO a
withContext Context{Lifetime Context
DeviceProperties
deviceContext :: Lifetime Context
deviceProperties :: DeviceProperties
deviceContext :: Context -> Lifetime Context
deviceProperties :: Context -> DeviceProperties
..} IO a
action
  = IO a -> IO a
forall a. IO a -> IO a
runInBoundThread
  (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ Lifetime Context -> (Context -> IO a) -> IO a
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Lifetime Context
deviceContext ((Context -> IO a) -> IO a) -> (Context -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
      IO () -> IO () -> IO a -> IO a
forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (Context -> IO ()
push Context
ctx) IO ()
pop IO a
action

{-# INLINE push #-}
push :: CUDA.Context -> IO ()
push :: Context -> IO ()
push Context
ctx = do
  String -> IO ()
message (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"push context: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Context -> String
showContext Context
ctx
  Context -> IO ()
CUDA.push Context
ctx

{-# INLINE pop #-}
pop :: IO ()
pop :: IO ()
pop = do
  Context
ctx <- IO Context
CUDA.pop
  String -> IO ()
message (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"pop context: "  String -> String -> String
forall a. [a] -> [a] -> [a]
++ Context -> String
showContext Context
ctx


-- Debugging
-- ---------

-- Nicely format a summary of the selected CUDA device, example:
--
-- Device 0: GeForce 9600M GT (compute capability 1.1), 4 multiprocessors @ 1.25GHz (32 cores), 512MB global memory
--
deviceInfo :: CUDA.Device -> CUDA.DeviceProperties -> String
deviceInfo :: Device -> DeviceProperties -> String
deviceInfo Device
dev DeviceProperties
prp = Doc -> String
render (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$
  Doc
devID Doc -> Doc -> Doc
<> Doc
colon Doc -> Doc -> Doc
<+> Doc
name Doc -> Doc -> Doc
<+> Doc -> Doc
parens Doc
compute
        Doc -> Doc -> Doc
<> Doc
comma Doc -> Doc -> Doc
<+> Doc
processors Doc -> Doc -> Doc
<+> Doc
at Doc -> Doc -> Doc
<+> String -> Doc
text String
clock Doc -> Doc -> Doc
<+> Doc -> Doc
parens Doc
cores
        Doc -> Doc -> Doc
<> Doc
comma Doc -> Doc -> Doc
<+> Doc
memory
  where
    name :: Doc
name        = String -> Doc
text (DeviceProperties -> String
CUDA.deviceName DeviceProperties
prp)
    compute :: Doc
compute     = String -> Doc
text String
"compute capability" Doc -> Doc -> Doc
<+> String -> Doc
text (Compute -> String
forall a. Show a => a -> String
show (Compute -> String) -> Compute -> String
forall a b. (a -> b) -> a -> b
$ DeviceProperties -> Compute
CUDA.computeCapability DeviceProperties
prp)
    devID :: Doc
devID       = String -> Doc
text String
"device" Doc -> Doc -> Doc
<+> Int -> Doc
int (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> CInt -> Int
forall a b. (a -> b) -> a -> b
$ Device -> CInt
CUDA.useDevice Device
dev)
    processors :: Doc
processors  = Int -> Doc
int (DeviceProperties -> Int
CUDA.multiProcessorCount DeviceProperties
prp)                              Doc -> Doc -> Doc
<+> String -> Doc
text String
"multiprocessors"
    cores :: Doc
cores       = Int -> Doc
int (DeviceProperties -> Int
CUDA.multiProcessorCount DeviceProperties
prp Int -> Int -> Int
forall a. Num a => a -> a -> a
* DeviceProperties -> Int
coresPerMultiProcessor DeviceProperties
prp) Doc -> Doc -> Doc
<+> String -> Doc
text String
"cores"
    memory :: Doc
memory      = String -> Doc
text String
mem Doc -> Doc -> Doc
<+> String -> Doc
text String
"global memory"
    --
    clock :: String
clock       = Maybe Int -> Double -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> a -> String -> String
Debug.showFFloatSIBase (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
2) Double
1000 (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ DeviceProperties -> Int
CUDA.clockRate DeviceProperties
prp Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000 :: Double) String
"Hz"
    mem :: String
mem         = Maybe Int -> Double -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> a -> String -> String
Debug.showFFloatSIBase (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0) Double
1024 (Int64 -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Double) -> Int64 -> Double
forall a b. (a -> b) -> a -> b
$ DeviceProperties -> Int64
CUDA.totalGlobalMem DeviceProperties
prp   :: Double) String
"B"
    at :: Doc
at          = Char -> Doc
char Char
'@'
    -- reset       = zeroWidthText "\r"


{-# INLINE trace #-}
trace :: String -> IO a -> IO a
trace :: String -> IO a -> IO a
trace String
msg IO a
next = do
  Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_gc (String
"gc: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg)
  IO a
next

{-# INLINE message #-}
message :: String -> IO ()
message :: String -> IO ()
message String
s = String
s String -> IO () -> IO ()
forall a. String -> IO a -> IO a
`trace` () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

{-# INLINE showContext #-}
showContext :: CUDA.Context -> String
showContext :: Context -> String
showContext (CUDA.Context Ptr ()
c) = Ptr () -> String
forall a. Show a => a -> String
show Ptr ()
c