{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TypeApplications  #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice.TH
-- Copyright   : [2017..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice.TH (

  nvvmReflectModule, nvvmReflectBitcode,
  libdeviceBitcode,

) where

import qualified LLVM.AST                                           as AST
import qualified LLVM.AST.Attribute                                 as AST
import qualified LLVM.AST.Global                                    as AST.G
import qualified LLVM.Context                                       as LLVM
import qualified LLVM.Module                                        as LLVM

import LLVM.AST.Type.Downcast
import LLVM.AST.Type.Representation

import Data.Array.Accelerate.Error
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.Target

import Foreign.CUDA.Analysis
import qualified Foreign.CUDA.Driver                                as CUDA
#if MIN_VERSION_nvvm(0,10,0)
import Foreign.NVVM.Path
#else
import Foreign.CUDA.Path
#endif

import Data.ByteString                                              ( ByteString )
import Data.ByteString.Short                                        ( ShortByteString )
import Data.FileEmbed
import Data.List
import Data.Maybe
import Language.Haskell.TH.Syntax                                   ( Q, TExp )
import System.Directory
import System.FilePath
import Text.Printf
import qualified Data.ByteString.Short                              as BS
import qualified Data.ByteString.Short.Char8                        as S8
import qualified Data.ByteString.Short.Extra                        as BS
import qualified Language.Haskell.TH                                as TH
import qualified Language.Haskell.TH.Syntax                         as TH


-- This is a hacky module that can be linked against in order to provide the
-- same functionality as running the NVVMReflect pass.
--
-- Note: [NVVM Reflect Pass]
--
-- To accommodate various math-related compiler flags that can affect code
-- generation of libdevice code, the library code depends on a special LLVM IR
-- pass (NVVMReflect) to handle conditional compilation within LLVM IR. This
-- pass looks for calls to the @__nvvm_reflect function and replaces them with
-- constants based on the defined reflection parameters.
--
-- libdevice currently uses the following reflection parameters to control code
-- generation:
--
--   * __CUDA_FTZ={0,1}     fast math that flushes denormals to zero
--
-- Since this is currently the only reflection parameter supported, and that we
-- prefer correct results over pure speed, we do not flush denormals to zero. If
-- the list of supported parameters ever changes, we may need to re-evaluate
-- this implementation.
--
nvvmReflectModule :: AST.Module
nvvmReflectModule :: Module
nvvmReflectModule =
  Module :: ShortByteString
-> ShortByteString
-> Maybe DataLayout
-> Maybe ShortByteString
-> [Definition]
-> Module
AST.Module
    { moduleName :: ShortByteString
AST.moduleName            = ShortByteString
"nvvm-reflect"
    , moduleSourceFileName :: ShortByteString
AST.moduleSourceFileName  = ShortByteString
BS.empty
    , moduleDataLayout :: Maybe DataLayout
AST.moduleDataLayout      = Target PTX => Maybe DataLayout
forall t. Target t => Maybe DataLayout
targetDataLayout @PTX
    , moduleTargetTriple :: Maybe ShortByteString
AST.moduleTargetTriple    = Target PTX => Maybe ShortByteString
forall t. Target t => Maybe ShortByteString
targetTriple @PTX
    , moduleDefinitions :: [Definition]
AST.moduleDefinitions     = [Global -> Definition
AST.GlobalDefinition (Global -> Definition) -> Global -> Definition
forall a b. (a -> b) -> a -> b
$ Global
AST.G.functionDefaults
      { name :: Name
AST.G.name                = ShortByteString -> Name
AST.Name ShortByteString
"__nvvm_reflect"
      , returnType :: Type
AST.G.returnType          = IntegralType Int32 -> Type
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast (IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType :: IntegralType Int32)
      , parameters :: ([Parameter], Bool)
AST.G.parameters          = ( [ScalarType Int8 -> Name (Ptr Int8) -> Parameter
forall t. ScalarType t -> Name (Ptr t) -> Parameter
ptrParameter ScalarType Int8
forall a. IsScalar a => ScalarType a
scalarType (Word -> Name (Ptr Int8)
forall a. Word -> Name a
UnName Word
0 :: Name (Ptr Int8))], Bool
False )
      , functionAttributes :: [Either GroupID FunctionAttribute]
AST.G.functionAttributes  = (FunctionAttribute -> Either GroupID FunctionAttribute)
-> [FunctionAttribute] -> [Either GroupID FunctionAttribute]
forall a b. (a -> b) -> [a] -> [b]
map FunctionAttribute -> Either GroupID FunctionAttribute
forall a b. b -> Either a b
Right [FunctionAttribute
AST.NoUnwind, FunctionAttribute
AST.ReadNone, FunctionAttribute
AST.AlwaysInline]
      , basicBlocks :: [BasicBlock]
AST.G.basicBlocks         = []
      }]
    }


-- Lower the given NVVM Reflect module into bitcode.
--
nvvmReflectBitcode :: AST.Module -> Q (TExp (ShortByteString, ByteString))
nvvmReflectBitcode :: Module -> Q (TExp (ShortByteString, ByteString))
nvvmReflectBitcode Module
mdl = do
  let name :: ShortByteString
name = ShortByteString
"__nvvm_reflect"
  --
  ByteString
bs <- IO ByteString -> Q ByteString
forall a. IO a -> Q a
TH.runIO (IO ByteString -> Q ByteString) -> IO ByteString -> Q ByteString
forall a b. (a -> b) -> a -> b
$ (Context -> IO ByteString) -> IO ByteString
forall a. (Context -> IO a) -> IO a
LLVM.withContext ((Context -> IO ByteString) -> IO ByteString)
-> (Context -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
                     Context -> Module -> (Module -> IO ByteString) -> IO ByteString
forall a. Context -> Module -> (Module -> IO a) -> IO a
LLVM.withModuleFromAST Context
ctx Module
mdl Module -> IO ByteString
LLVM.moduleLLVMAssembly
  Q Exp -> Q (TExp (ShortByteString, ByteString))
forall a. Q Exp -> Q (TExp a)
TH.unsafeTExpCoerce (Q Exp -> Q (TExp (ShortByteString, ByteString)))
-> Q Exp -> Q (TExp (ShortByteString, ByteString))
forall a b. (a -> b) -> a -> b
$ [Q Exp] -> Q Exp
TH.tupE [ Q (TExp ShortByteString) -> Q Exp
forall a. Q (TExp a) -> Q Exp
TH.unTypeQ (ShortByteString -> Q (TExp ShortByteString)
BS.liftSBS ShortByteString
name)
                                , ByteString -> Q Exp
bsToExp ByteString
bs ]


-- Load the libdevice bitcode file for the given compute architecture. The name
-- of the bitcode files follows the format:
--
--   libdevice.compute_XX.YY.bc
--
-- Where XX represents the compute capability, and YY represents a version(?) We
-- search the libdevice PATH for all files of the appropriate compute capability
-- and load the "most recent" (by sort order).
--
libdeviceBitcode :: HasCallStack => Compute -> Q (TExp (ShortByteString, ByteString))
libdeviceBitcode :: Compute -> Q (TExp (ShortByteString, ByteString))
libdeviceBitcode Compute
compute = do
  let basename :: [Char]
basename
        | Int
CUDA.libraryVersion Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
9000
        , Compute Int
m Int
n <- Compute
compute     = [Char] -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"libdevice.compute_%d%d" Int
m Int
n
        | Bool
otherwise                  = [Char]
"libdevice"
      --
      err :: a
err     = [Char] -> a
forall a. HasCallStack => [Char] -> a
internalError ([Char] -> [Char] -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"not found: %s.YY.bc" [Char]
basename)
      best :: [Char] -> Bool
best [Char]
f  = [Char]
basename [Char] -> [Char] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` [Char]
f Bool -> Bool -> Bool
&& [Char] -> [Char]
takeExtension [Char]
f [Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
".bc"
#if MIN_VERSION_nvvm(0,10,0)
      base :: [Char]
base    = [Char]
nvvmDeviceLibraryPath
#else
      base    = cudaInstallPath </> "nvvm" </> "libdevice"
#endif

  --
  [[Char]]
files <- IO [[Char]] -> Q [[Char]]
forall a. IO a -> Q a
TH.runIO (IO [[Char]] -> Q [[Char]]) -> IO [[Char]] -> Q [[Char]]
forall a b. (a -> b) -> a -> b
$ [Char] -> IO [[Char]]
getDirectoryContents [Char]
base
  --
  let name :: [Char]
name  = [Char] -> Maybe [Char] -> [Char]
forall a. a -> Maybe a -> a
fromMaybe [Char]
forall a. a
err (Maybe [Char] -> [Char])
-> ([[Char]] -> Maybe [Char]) -> [[Char]] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Char]] -> Maybe [Char]
forall a. [a] -> Maybe a
listToMaybe ([[Char]] -> Maybe [Char])
-> ([[Char]] -> [[Char]]) -> [[Char]] -> Maybe [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> [Char] -> Ordering) -> [[Char]] -> [[Char]]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (([Char] -> [Char] -> Ordering) -> [Char] -> [Char] -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip [Char] -> [Char] -> Ordering
forall a. Ord a => a -> a -> Ordering
compare) ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ ([Char] -> Bool) -> [[Char]] -> [[Char]]
forall a. (a -> Bool) -> [a] -> [a]
filter [Char] -> Bool
best [[Char]]
files
      path :: [Char]
path  = [Char]
base [Char] -> [Char] -> [Char]
</> [Char]
name
  --
  Q Exp -> Q (TExp (ShortByteString, ByteString))
forall a. Q Exp -> Q (TExp a)
TH.unsafeTExpCoerce (Q Exp -> Q (TExp (ShortByteString, ByteString)))
-> Q Exp -> Q (TExp (ShortByteString, ByteString))
forall a b. (a -> b) -> a -> b
$ [Q Exp] -> Q Exp
TH.tupE [ Q (TExp ShortByteString) -> Q Exp
forall a. Q (TExp a) -> Q Exp
TH.unTypeQ (ShortByteString -> Q (TExp ShortByteString)
BS.liftSBS ([Char] -> ShortByteString
S8.pack [Char]
name))
                                , [Char] -> Q Exp
embedFile [Char]
path ]