{-# LANGUAGE CPP                 #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.State
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.State (

  evalPTX,
  createTargetForDevice, createTargetFromContext,

  Pool(..),
  withPool, unsafeWithPool,
  defaultTarget,
  defaultTargetPool,

) where

import Data.Array.Accelerate.Error

import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.PTX.Target
import qualified Data.Array.Accelerate.LLVM.PTX.Array.Table         as MT
import qualified Data.Array.Accelerate.LLVM.PTX.Context             as CT
import qualified Data.Array.Accelerate.LLVM.PTX.Debug               as Debug
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Stream      as ST
import qualified Data.Array.Accelerate.LLVM.PTX.Link.Cache          as LC
import qualified Data.Array.Accelerate.LLVM.PTX.Pool                as Pool

import Control.Exception                                            ( try, catch )
import Data.Maybe                                                   ( fromMaybe, catMaybes )
import System.Environment                                           ( lookupEnv )
import System.IO.Unsafe                                             ( unsafePerformIO, unsafeInterleaveIO )
import Text.Printf                                                  ( printf )
import Text.Read                                                    ( readMaybe )
import Foreign.CUDA.Driver.Error
import qualified Foreign.CUDA.Driver                                as CUDA
import qualified Foreign.CUDA.Driver.Context                        as Context


-- | Execute a PTX computation
--
evalPTX :: PTX -> LLVM PTX a -> IO a
evalPTX :: PTX -> LLVM PTX a -> IO a
evalPTX PTX
ptx LLVM PTX a
acc =
  Context -> IO a -> IO a
forall a. Context -> IO a -> IO a
CT.withContext (PTX -> Context
ptxContext PTX
ptx) (PTX -> LLVM PTX a -> IO a
forall t a. t -> LLVM t a -> IO a
evalLLVM PTX
ptx LLVM PTX a
acc)
  IO a -> (CUDAException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
  \CUDAException
e -> String -> IO a
forall a. HasCallStack => String -> a
internalError (CUDAException -> String
forall a. Show a => a -> String
show (CUDAException
e :: CUDAException))


-- | Create a new PTX execution target for the given device
--
createTargetForDevice
    :: CUDA.Device
    -> CUDA.DeviceProperties
    -> [CUDA.ContextFlag]
    -> IO PTX
createTargetForDevice :: Device -> DeviceProperties -> [ContextFlag] -> IO PTX
createTargetForDevice Device
dev DeviceProperties
prp [ContextFlag]
flags = do
  Context
raw <- Device -> [ContextFlag] -> IO Context
CUDA.create Device
dev [ContextFlag]
flags
  PTX
ptx <- Device -> DeviceProperties -> Context -> IO PTX
createTarget Device
dev DeviceProperties
prp Context
raw
  Context
_   <- IO Context
CUDA.pop
  PTX -> IO PTX
forall (m :: * -> *) a. Monad m => a -> m a
return PTX
ptx


-- | Create a PTX execute target for the given device context
--
createTargetFromContext
    :: CUDA.Context
    -> IO PTX
createTargetFromContext :: Context -> IO PTX
createTargetFromContext Context
raw = do
  Device
dev <- IO Device
Context.device
  DeviceProperties
prp <- Device -> IO DeviceProperties
CUDA.props Device
dev
  Device -> DeviceProperties -> Context -> IO PTX
createTarget Device
dev DeviceProperties
prp Context
raw


-- | Create a PTX execution target
--
createTarget
    :: CUDA.Device
    -> CUDA.DeviceProperties
    -> CUDA.Context
    -> IO PTX
createTarget :: Device -> DeviceProperties -> Context -> IO PTX
createTarget Device
dev DeviceProperties
prp Context
raw = do
  Context
ctx <- Device -> DeviceProperties -> Context -> IO Context
CT.raw Device
dev DeviceProperties
prp Context
raw
  MemoryTable
mt  <- Context -> IO MemoryTable
MT.new Context
ctx
  LinkCache FunctionTable ObjectCode
lc  <- IO (LinkCache FunctionTable ObjectCode)
forall f o. IO (LinkCache f o)
LC.new
  Reservoir
st  <- Context -> IO Reservoir
ST.new Context
ctx
  PTX -> IO PTX
forall (m :: * -> *) a. Monad m => a -> m a
return (PTX -> IO PTX) -> PTX -> IO PTX
forall a b. (a -> b) -> a -> b
$! Context
-> MemoryTable
-> LinkCache FunctionTable ObjectCode
-> Reservoir
-> PTX
PTX Context
ctx MemoryTable
mt LinkCache FunctionTable ObjectCode
lc Reservoir
st


-- Shared execution contexts
-- -------------------------

-- In order to implement runN, we need to keep track of all available contexts,
-- as well as the managed resource pool.
--
data Pool a = Pool
    { Pool a -> Pool a
managed   :: {-# UNPACK #-} !(Pool.Pool a)
    , Pool a -> [a]
unmanaged :: [a]
    }

-- Evaluate a thing given an execution context from the default pool
--
withPool :: Pool a -> (a -> IO b) -> IO b
withPool :: Pool a -> (a -> IO b) -> IO b
withPool Pool a
p = Pool a -> (a -> IO b) -> IO b
forall a b. Pool a -> (a -> IO b) -> IO b
Pool.with (Pool a -> Pool a
forall a. Pool a -> Pool a
managed Pool a
p)

unsafeWithPool :: Pool a -> (a -> b) -> b
unsafeWithPool :: Pool a -> (a -> b) -> b
unsafeWithPool Pool a
p = Pool a -> (a -> b) -> b
forall a b. Pool a -> (a -> b) -> b
Pool.unsafeWith (Pool a -> Pool a
forall a. Pool a -> Pool a
managed Pool a
p)


-- Top-level mutable state
-- -----------------------
--
-- It is important to keep some information alive for the entire run of the
-- program, not just a single execution. These tokens use 'unsafePerformIO' to
-- ensure they are executed only once, and reused for subsequent invocations.
--

-- | Select a device from the default pool.
--
{-# NOINLINE defaultTarget #-}
defaultTarget :: PTX
defaultTarget :: PTX
defaultTarget = [PTX] -> PTX
forall a. [a] -> a
head (Pool PTX -> [PTX]
forall a. Pool a -> [a]
unmanaged Pool PTX
defaultTargetPool)

-- | Create a shared resource pool of the available CUDA devices.
--
-- This globally shared resource pool is auto-initialised on startup. It will
-- consist of every currently available device, or those specified by the value
-- of the environment variable @ACCELERATE_LLVM_PTX_DEVICES@ (as a list of
-- device ordinals).
--
{-# NOINLINE defaultTargetPool #-}
defaultTargetPool :: Pool PTX
defaultTargetPool :: Pool PTX
defaultTargetPool = IO (Pool PTX) -> Pool PTX
forall a. IO a -> a
unsafePerformIO (IO (Pool PTX) -> Pool PTX) -> IO (Pool PTX) -> Pool PTX
forall a b. (a -> b) -> a -> b
$! do
  Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_gc String
"gc: initialise default PTX pool"
  [InitFlag] -> IO ()
CUDA.initialise []

  -- Figure out which GPUs we should put into the execution pool
  --
  Int
ngpu  <- IO Int
CUDA.count
  Maybe [Int]
menv  <- (String -> Maybe [Int]
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe [Int]) -> Maybe String -> Maybe [Int]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (Maybe String -> Maybe [Int])
-> IO (Maybe String) -> IO (Maybe [Int])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"ACCELERATE_LLVM_PTX_DEVICES"

  let ids :: [Int]
ids = [Int] -> Maybe [Int] -> [Int]
forall a. a -> Maybe a -> a
fromMaybe [Int
0..Int
ngpuInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] Maybe [Int]
menv

      -- Spin up the GPU at the given ordinal.
      --
      boot :: Int -> IO (Maybe PTX)
      boot :: Int -> IO (Maybe PTX)
boot Int
i = IO (Maybe PTX) -> IO (Maybe PTX)
forall a. IO a -> IO a
unsafeInterleaveIO (IO (Maybe PTX) -> IO (Maybe PTX))
-> IO (Maybe PTX) -> IO (Maybe PTX)
forall a b. (a -> b) -> a -> b
$ do
        Device
dev <- Int -> IO Device
CUDA.device Int
i
        DeviceProperties
prp <- Device -> IO DeviceProperties
CUDA.props Device
dev
        Either CUDAException PTX
r   <- IO PTX -> IO (Either CUDAException PTX)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO PTX -> IO (Either CUDAException PTX))
-> IO PTX -> IO (Either CUDAException PTX)
forall a b. (a -> b) -> a -> b
$ Device -> DeviceProperties -> [ContextFlag] -> IO PTX
createTargetForDevice Device
dev DeviceProperties
prp [ContextFlag
CUDA.SchedAuto]
        case Either CUDAException PTX
r of
          Right PTX
ptx               -> Maybe PTX -> IO (Maybe PTX)
forall (m :: * -> *) a. Monad m => a -> m a
return (PTX -> Maybe PTX
forall a. a -> Maybe a
Just PTX
ptx)
          Left (CUDAException
e::CUDAException) -> do
            Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_gc (String -> Int -> String -> String
forall r. PrintfType r => String -> r
printf String
"gc: failed to initialise device %d: %s" Int
i (CUDAException -> String
forall a. Show a => a -> String
show CUDAException
e))
            Maybe PTX -> IO (Maybe PTX)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PTX
forall a. Maybe a
Nothing

  -- Create the pool from the available devices, which get spun-up lazily as
  -- required (due to the implementation of the Pool, we will look ahead by one
  -- each time one device is requested).
  --
  [PTX]
devices <- [Maybe PTX] -> [PTX]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe PTX] -> [PTX]) -> IO [Maybe PTX] -> IO [PTX]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO (Maybe PTX)) -> [Int] -> IO [Maybe PTX]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Int -> IO (Maybe PTX)
boot [Int]
ids
  if [PTX] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PTX]
devices
    then String -> IO (Pool PTX)
forall a. HasCallStack => String -> a
error String
"No CUDA-capable devices are available"
    else Pool PTX -> [PTX] -> Pool PTX
forall a. Pool a -> [a] -> Pool a
Pool (Pool PTX -> [PTX] -> Pool PTX)
-> IO (Pool PTX) -> IO ([PTX] -> Pool PTX)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PTX] -> IO (Pool PTX)
forall a. [a] -> IO (Pool a)
Pool.create [PTX]
devices
              IO ([PTX] -> Pool PTX) -> IO [PTX] -> IO (Pool PTX)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [PTX] -> IO [PTX]
forall (m :: * -> *) a. Monad m => a -> m a
return [PTX]
devices