{-# LANGUAGE CPP               #-}
{-# LANGUAGE EmptyDataDecls    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications  #-}
{-# LANGUAGE TypeFamilies      #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Target
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Target (

  module Data.Array.Accelerate.LLVM.Target,
  module Data.Array.Accelerate.LLVM.PTX.Target,

) where

-- llvm-hs
import LLVM.AST.AddrSpace
import LLVM.AST.DataLayout
import LLVM.Target                                                  hiding ( Target )
import qualified LLVM.Target                                        as LLVM
import qualified LLVM.Relocation                                    as R
import qualified LLVM.CodeModel                                     as CM
import qualified LLVM.CodeGenOpt                                    as CGO

-- accelerate
import Data.Array.Accelerate.Error

import Data.Array.Accelerate.LLVM.Extra
import Data.Array.Accelerate.LLVM.Target

import Data.Array.Accelerate.LLVM.PTX.Array.Table                   ( MemoryTable )
import Data.Array.Accelerate.LLVM.PTX.Context                       ( Context, deviceProperties )
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream.Reservoir      ( Reservoir )
import Data.Array.Accelerate.LLVM.PTX.Link.Cache                    ( KernelTable )

-- CUDA
import Foreign.CUDA.Analysis.Device

-- standard library
import Data.ByteString                                              ( ByteString )
import Data.ByteString.Short                                        ( ShortByteString )
import Data.String
import Debug.Trace
import System.IO.Unsafe
import Text.Printf
import qualified Data.Map                                           as Map
import qualified Data.Set                                           as Set


-- | The PTX execution target for NVIDIA GPUs.
--
-- The execution target carries state specific for the current execution
-- context. The data here --- device memory and execution streams --- are
-- implicitly tied to this CUDA execution context.
--
-- Don't store anything here that is independent of the context, for example
-- state related to [persistent] kernel caching should _not_ go here.
--
data PTX = PTX {
    PTX -> Context
ptxContext                  :: {-# UNPACK #-} !Context
  , PTX -> MemoryTable
ptxMemoryTable              :: {-# UNPACK #-} !MemoryTable
  , PTX -> KernelTable
ptxKernelTable              :: {-# UNPACK #-} !KernelTable
  , PTX -> Reservoir
ptxStreamReservoir          :: {-# UNPACK #-} !Reservoir
  }

instance Target PTX where
  targetTriple :: Maybe ShortByteString
targetTriple     = ShortByteString -> Maybe ShortByteString
forall a. a -> Maybe a
Just ShortByteString
HasCallStack => ShortByteString
ptxTargetTriple
#if ACCELERATE_USE_NVVM
  targetDataLayout = Nothing              -- see note: [NVVM and target data layout]
#else
  targetDataLayout :: Maybe DataLayout
targetDataLayout = DataLayout -> Maybe DataLayout
forall a. a -> Maybe a
Just DataLayout
ptxDataLayout
#endif


-- | Extract the properties of the device the current PTX execution state is
-- executing on.
--
ptxDeviceProperties :: PTX -> DeviceProperties
ptxDeviceProperties :: PTX -> DeviceProperties
ptxDeviceProperties = Context -> DeviceProperties
deviceProperties (Context -> DeviceProperties)
-> (PTX -> Context) -> PTX -> DeviceProperties
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PTX -> Context
ptxContext


-- | A description of the various data layout properties that may be used during
-- optimisation. For CUDA the following data layouts are supported:
--
-- 32-bit:
--   e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64
--
-- 64-bit:
--   e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64
--
-- Thus, only the size of the pointer layout changes depending on the host
-- architecture.
--
ptxDataLayout :: DataLayout
ptxDataLayout :: DataLayout
ptxDataLayout = DataLayout :: Endianness
-> Maybe Mangling
-> Maybe Word32
-> Map AddrSpace (Word32, AlignmentInfo)
-> Map (AlignType, Word32) AlignmentInfo
-> AlignmentInfo
-> Maybe (Set Word32)
-> DataLayout
DataLayout
  { endianness :: Endianness
endianness          = Endianness
LittleEndian
  , mangling :: Maybe Mangling
mangling            = Maybe Mangling
forall a. Maybe a
Nothing
  , aggregateLayout :: AlignmentInfo
aggregateLayout     = Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
0 Word32
64
  , stackAlignment :: Maybe Word32
stackAlignment      = Maybe Word32
forall a. Maybe a
Nothing
  , pointerLayouts :: Map AddrSpace (Word32, AlignmentInfo)
pointerLayouts      = [(AddrSpace, (Word32, AlignmentInfo))]
-> Map AddrSpace (Word32, AlignmentInfo)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
      [ (Word32 -> AddrSpace
AddrSpace Word32
0, (Word32
wordSize, Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
wordSize Word32
wordSize)) ]
  , typeLayouts :: Map (AlignType, Word32) AlignmentInfo
typeLayouts         = [((AlignType, Word32), AlignmentInfo)]
-> Map (AlignType, Word32) AlignmentInfo
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([((AlignType, Word32), AlignmentInfo)]
 -> Map (AlignType, Word32) AlignmentInfo)
-> [((AlignType, Word32), AlignmentInfo)]
-> Map (AlignType, Word32) AlignmentInfo
forall a b. (a -> b) -> a -> b
$
      [ ((AlignType
IntegerAlign, Word32
1), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
8 Word32
8) ] [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
forall a. [a] -> [a] -> [a]
++
      [ ((AlignType
IntegerAlign, Word32
i), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
i Word32
i) | Word32
i <- [Word32
8,Word32
16,Word32
32,Word32
64]] [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
forall a. [a] -> [a] -> [a]
++
      [ ((AlignType
VectorAlign,  Word32
v), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
v Word32
v) | Word32
v <- [Word32
16,Word32
32,Word32
64,Word32
128]] [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
forall a. [a] -> [a] -> [a]
++
      [ ((AlignType
FloatAlign,   Word32
f), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
f Word32
f) | Word32
f <- [Word32
32,Word32
64] ]
  , nativeSizes :: Maybe (Set Word32)
nativeSizes         = Set Word32 -> Maybe (Set Word32)
forall a. a -> Maybe a
Just (Set Word32 -> Maybe (Set Word32))
-> Set Word32 -> Maybe (Set Word32)
forall a b. (a -> b) -> a -> b
$ [Word32] -> Set Word32
forall a. Ord a => [a] -> Set a
Set.fromList [ Word32
16,Word32
32,Word32
64 ]
  }
  where
    wordSize :: Word32
wordSize = Int -> Word32
forall a. (HasCallStack, Bits a) => a -> Word32
bitSize (Int
forall a. HasCallStack => a
undefined :: Int)


-- | String that describes the target host.
--
ptxTargetTriple :: HasCallStack => ShortByteString
ptxTargetTriple :: ShortByteString
ptxTargetTriple =
  case Int -> Word32
forall a. (HasCallStack, Bits a) => a -> Word32
bitSize (Int
forall a. HasCallStack => a
undefined::Int) of
    Word32
32  -> ShortByteString
"nvptx-nvidia-cuda"
    Word32
64  -> ShortByteString
"nvptx64-nvidia-cuda"
    Word32
_   -> String -> ShortByteString
forall a. HasCallStack => String -> a
internalError String
"I don't know what architecture I am"


-- | Bracket creation and destruction of the NVVM TargetMachine.
--
withPTXTargetMachine
    :: HasCallStack
    => DeviceProperties
    -> (TargetMachine -> IO a)
    -> IO a
withPTXTargetMachine :: DeviceProperties -> (TargetMachine -> IO a) -> IO a
withPTXTargetMachine DeviceProperties
dev TargetMachine -> IO a
go =
  let (ByteString
sm, ByteString
isa) = Compute -> (ByteString, ByteString)
ptxTargetVersion (DeviceProperties -> Compute
computeCapability DeviceProperties
dev)
  in
  (TargetOptions -> IO a) -> IO a
forall a. (TargetOptions -> IO a) -> IO a
withTargetOptions ((TargetOptions -> IO a) -> IO a)
-> (TargetOptions -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \TargetOptions
options -> do
    Target
-> ShortByteString
-> ByteString
-> Map CPUFeature Bool
-> TargetOptions
-> Model
-> Model
-> Level
-> (TargetMachine -> IO a)
-> IO a
forall a.
Target
-> ShortByteString
-> ByteString
-> Map CPUFeature Bool
-> TargetOptions
-> Model
-> Model
-> Level
-> (TargetMachine -> IO a)
-> IO a
withTargetMachine
      Target
ptxTarget
      ShortByteString
HasCallStack => ShortByteString
ptxTargetTriple
      ByteString
sm                                    -- CPU
      (CPUFeature -> Bool -> Map CPUFeature Bool
forall k a. k -> a -> Map k a
Map.singleton (ByteString -> CPUFeature
CPUFeature ByteString
isa) Bool
True) -- CPU features
      TargetOptions
options                               -- target options
      Model
R.Default                             -- relocation model
      Model
CM.Default                            -- code model
      Level
CGO.Default                           -- optimisation level
      TargetMachine -> IO a
go

-- Compile using the earliest version of the SM target PTX ISA supported by
-- the given compute device and this version of LLVM.
--
-- Note that we require at least ptx40 for some libnvvm device functions.
--
-- See table NVPTX supported processors:
--
--   https://github.com/llvm-mirror/llvm/blob/master/lib/Target/NVPTX/NVPTX.td
--
-- PTX ISA verison history:
--
--   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#release-notes
--
ptxTargetVersion :: Compute -> (ByteString, ByteString)
ptxTargetVersion :: Compute -> (ByteString, ByteString)
ptxTargetVersion compute :: Compute
compute@(Compute Int
m Int
n)
#if MIN_VERSION_llvm_hs(8,0,0)
  | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
7 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
5    = (ByteString
"sm_75", ByteString
"ptx63")
#endif
#if MIN_VERSION_llvm_hs(7,0,0)
  | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
7 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2    = (ByteString
"sm_72", ByteString
"ptx61")
#endif
#if MIN_VERSION_llvm_hs(6,0,0)
  | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
7              = (ByteString
"sm_70", ByteString
"ptx60")
#endif
  | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>  Int
6              = (ByteString
"sm_62", ByteString
"ptx50")  -- fallthrough
  --
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
6 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2    = (ByteString
"sm_62", ByteString
"ptx50")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
6 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1    = (ByteString
"sm_61", ByteString
"ptx50")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
6              = (ByteString
"sm_60", ByteString
"ptx50")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3    = (ByteString
"sm_53", ByteString
"ptx42")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2    = (ByteString
"sm_52", ByteString
"ptx41")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5              = (ByteString
"sm_50", ByteString
"ptx40")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
7    = (ByteString
"sm_37", ByteString
"ptx41")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5    = (ByteString
"sm_35", ByteString
"ptx40")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2    = (ByteString
"sm_32", ByteString
"ptx40")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3              = (ByteString
"sm_30", ByteString
"ptx40")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1    = (ByteString
"sm_21", ByteString
"ptx40")
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2              = (ByteString
"sm_20", ByteString
"ptx40")
  --
  | Bool
otherwise
  = String -> (ByteString, ByteString) -> (ByteString, ByteString)
forall a. String -> a -> a
trace String
warning (String -> ByteString
forall a. IsString a => String -> a
fromString (String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"sm_%d%d" Int
m Int
n), ByteString
"ptx40")
  where
    warning :: String
warning = [String] -> String
unlines [ String
"*** Warning: Unhandled CUDA device compute capability: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Compute -> String
forall a. Show a => a -> String
show Compute
compute
                      , String
"*** Please submit a bug report at https://github.com/AccelerateHS/accelerate/issues" ]

-- | The NVPTX target for this host.
--
-- The top-level 'unsafePerformIO' is so that 'initializeAllTargets' is run once
-- per program execution (although that might not be necessary?)
--
{-# NOINLINE ptxTarget #-}
ptxTarget :: LLVM.Target
ptxTarget :: Target
ptxTarget = IO Target -> Target
forall a. IO a -> a
unsafePerformIO (IO Target -> Target) -> IO Target -> Target
forall a b. (a -> b) -> a -> b
$ do
  IO ()
initializeAllTargets
  (Target, ShortByteString) -> Target
forall a b. (a, b) -> a
fst ((Target, ShortByteString) -> Target)
-> IO (Target, ShortByteString) -> IO Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Maybe ShortByteString
-> ShortByteString -> IO (Target, ShortByteString)
lookupTarget Maybe ShortByteString
forall a. Maybe a
Nothing ShortByteString
HasCallStack => ShortByteString
ptxTargetTriple