{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE ViewPatterns      #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice (

  withLibdeviceNVVM,
  withLibdeviceNVPTX,

) where

-- llvm-hs
import LLVM.Context
import qualified LLVM.Module                                        as LLVM

import LLVM.AST                                                     as AST
import LLVM.AST.Global                                              as G
import LLVM.AST.Linkage

-- accelerate
import Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice.Load
import qualified Data.Array.Accelerate.LLVM.PTX.Debug               as Debug

-- cuda
import Foreign.CUDA.Analysis

-- standard library
import Control.Monad
import Data.ByteString                                              ( ByteString )
import Data.ByteString.Short.Char8                                  ( ShortByteString )
import Data.HashSet                                                 ( HashSet )
import Data.List
import Data.Maybe
import Text.Printf
import qualified Data.ByteString.Short.Char8                        as S8
import qualified Data.ByteString.Short.Extra                        as BS
import qualified Data.HashSet                                       as Set


-- | Lower an LLVM AST to C++ objects and link it against the libdevice module,
-- iff any libdevice functions are referenced from the base module.
--
-- Note: [Linking with libdevice]
--
-- The CUDA toolkit comes with an LLVM bitcode library called 'libdevice' that
-- implements many common mathematical functions. The library can be used as a
-- high performance math library for targets of the LLVM NVPTX backend, such as
-- this one. To link a module 'foo' with libdevice, the following compilation
-- pipeline is recommended:
--
--   1. Save all external functions in module 'foo'
--
--   2. Link module 'foo' with the appropriate 'libdevice_compute_XX.YY.bc'
--
--   3. Internalise all functions not in the list from (1)
--
--   4. Eliminate all unused internal functions
--
--   5. Run the NVVMReflect pass (see note: [NVVM Reflect Pass])
--
--   6. Run the standard optimisation pipeline
--
withLibdeviceNVPTX
    :: DeviceProperties
    -> Context
    -> Module
    -> (LLVM.Module -> IO a)
    -> IO a
withLibdeviceNVPTX :: DeviceProperties -> Context -> Module -> (Module -> IO a) -> IO a
withLibdeviceNVPTX DeviceProperties
dev Context
ctx Module
ast Module -> IO a
next =
  case HashSet ShortByteString -> Bool
forall a. HashSet a -> Bool
Set.null HashSet ShortByteString
externs of
    Bool
True        -> Context -> Module -> (Module -> IO a) -> IO a
forall a. Context -> Module -> (Module -> IO a) -> IO a
LLVM.withModuleFromAST Context
ctx Module
ast Module -> IO a
next
    Bool
False       ->
      Context -> Module -> (Module -> IO a) -> IO a
forall a. Context -> Module -> (Module -> IO a) -> IO a
LLVM.withModuleFromAST Context
ctx Module
ast                          ((Module -> IO a) -> IO a) -> (Module -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Module
mdl  ->
      Context -> Module -> (Module -> IO a) -> IO a
forall a. Context -> Module -> (Module -> IO a) -> IO a
LLVM.withModuleFromAST Context
ctx Module
forall a. NVVMReflect a => a
nvvmReflect                  ((Module -> IO a) -> IO a) -> (Module -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Module
refl ->
      Context -> Module -> (Module -> IO a) -> IO a
forall a. Context -> Module -> (Module -> IO a) -> IO a
LLVM.withModuleFromAST Context
ctx (HashSet ShortByteString -> Module -> Module
internalise HashSet ShortByteString
externs Module
libdev) ((Module -> IO a) -> IO a) -> (Module -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Module
libd -> do
        Module -> Module -> IO ()
LLVM.linkModules Module
mdl Module
refl
        Module -> Module -> IO ()
LLVM.linkModules Module
mdl Module
libd
        Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_cc String
msg
        Module -> IO a
next Module
mdl
  where
    -- Replace the target triple and datalayout from the libdevice.bc module
    -- with those of the generated code. This avoids warnings such as "linking
    -- two modules of different target triples..."
    libdev :: Module
libdev      = (Compute -> Module
forall a. Libdevice a => Compute -> a
libdevice Compute
arch) { moduleTargetTriple :: Maybe ShortByteString
moduleTargetTriple = Module -> Maybe ShortByteString
moduleTargetTriple Module
ast
                                   , moduleDataLayout :: Maybe DataLayout
moduleDataLayout   = Module -> Maybe DataLayout
moduleDataLayout Module
ast
                                   }
    externs :: HashSet ShortByteString
externs     = Module -> HashSet ShortByteString
analyse Module
ast
    arch :: Compute
arch        = DeviceProperties -> Compute
computeCapability DeviceProperties
dev

    msg :: String
msg         = String -> String -> String
forall r. PrintfType r => String -> r
printf String
"cc: linking with libdevice: %s"
                (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", "
                ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (ShortByteString -> String) -> [ShortByteString] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ShortByteString -> String
S8.unpack
                ([ShortByteString] -> [String]) -> [ShortByteString] -> [String]
forall a b. (a -> b) -> a -> b
$ HashSet ShortByteString -> [ShortByteString]
forall a. HashSet a -> [a]
Set.toList HashSet ShortByteString
externs


-- | Lower an LLVM AST to C++ objects and prepare it for linking against
-- libdevice using the nvvm bindings, iff any libdevice functions are referenced
-- from the base module.
--
-- Rather than internalise and strip any unused functions ourselves, allow the
-- nvvm library to do so when linking the two modules together.
--
-- TLM: This really should work with the above method, however for some reason
-- we get a "CUDA Exception: function named symbol not found" error, even though
-- the function is clearly visible in the generated code. hmm...
--
withLibdeviceNVVM
    :: DeviceProperties
    -> Context
    -> Module
    -> ([(ShortByteString, ByteString)] -> LLVM.Module -> IO a)
    -> IO a
withLibdeviceNVVM :: DeviceProperties
-> Context
-> Module
-> ([(ShortByteString, ByteString)] -> Module -> IO a)
-> IO a
withLibdeviceNVVM DeviceProperties
dev Context
ctx Module
ast [(ShortByteString, ByteString)] -> Module -> IO a
next =
  Context -> Module -> (Module -> IO a) -> IO a
forall a. Context -> Module -> (Module -> IO a) -> IO a
LLVM.withModuleFromAST Context
ctx Module
ast ((Module -> IO a) -> IO a) -> (Module -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Module
mdl -> do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
withlib (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_cc String
msg
    [(ShortByteString, ByteString)] -> Module -> IO a
next [(ShortByteString, ByteString)]
lib Module
mdl
  where
    externs :: HashSet ShortByteString
externs             = Module -> HashSet ShortByteString
analyse Module
ast
    withlib :: Bool
withlib             = Bool -> Bool
not (HashSet ShortByteString -> Bool
forall a. HashSet a -> Bool
Set.null HashSet ShortByteString
externs)
    lib :: [(ShortByteString, ByteString)]
lib | Bool
withlib       = [ (ShortByteString, ByteString)
forall a. NVVMReflect a => a
nvvmReflect, Compute -> (ShortByteString, ByteString)
forall a. Libdevice a => Compute -> a
libdevice Compute
arch ]
        | Bool
otherwise     = []

    arch :: Compute
arch        = DeviceProperties -> Compute
computeCapability DeviceProperties
dev

    msg :: String
msg         = String -> String -> String
forall r. PrintfType r => String -> r
printf String
"cc: linking with libdevice: %s"
                (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", "
                ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (ShortByteString -> String) -> [ShortByteString] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ShortByteString -> String
S8.unpack
                ([ShortByteString] -> [String]) -> [ShortByteString] -> [String]
forall a b. (a -> b) -> a -> b
$ HashSet ShortByteString -> [ShortByteString]
forall a. HashSet a -> [a]
Set.toList HashSet ShortByteString
externs


-- | Analyse the LLVM AST module and determine if any of the external
-- declarations are intrinsics implemented by libdevice. The set of such
-- functions is returned, and will be used when determining which functions from
-- libdevice to internalise.
--
analyse :: Module -> HashSet ShortByteString
analyse :: Module -> HashSet ShortByteString
analyse Module{[Definition]
Maybe ShortByteString
Maybe DataLayout
ShortByteString
moduleName :: Module -> ShortByteString
moduleSourceFileName :: Module -> ShortByteString
moduleDefinitions :: Module -> [Definition]
moduleDefinitions :: [Definition]
moduleTargetTriple :: Maybe ShortByteString
moduleDataLayout :: Maybe DataLayout
moduleSourceFileName :: ShortByteString
moduleName :: ShortByteString
moduleDataLayout :: Module -> Maybe DataLayout
moduleTargetTriple :: Module -> Maybe ShortByteString
..} =
  let intrinsic :: Definition -> Maybe ShortByteString
intrinsic (GlobalDefinition Function{[Either GroupID FunctionAttribute]
[(ShortByteString, MDRef MDNode)]
[BasicBlock]
[ParameterAttribute]
Maybe ShortByteString
Maybe Constant
Maybe StorageClass
Word32
([Parameter], Bool)
Type
CallingConvention
Linkage
Name
Visibility
name :: Global -> Name
linkage :: Global -> Linkage
visibility :: Global -> Visibility
dllStorageClass :: Global -> Maybe StorageClass
section :: Global -> Maybe ShortByteString
comdat :: Global -> Maybe ShortByteString
alignment :: Global -> Word32
metadata :: Global -> [(ShortByteString, MDRef MDNode)]
callingConvention :: Global -> CallingConvention
returnAttributes :: Global -> [ParameterAttribute]
returnType :: Global -> Type
parameters :: Global -> ([Parameter], Bool)
functionAttributes :: Global -> [Either GroupID FunctionAttribute]
garbageCollectorName :: Global -> Maybe ShortByteString
prefix :: Global -> Maybe Constant
basicBlocks :: Global -> [BasicBlock]
personalityFunction :: Global -> Maybe Constant
metadata :: [(ShortByteString, MDRef MDNode)]
personalityFunction :: Maybe Constant
basicBlocks :: [BasicBlock]
prefix :: Maybe Constant
garbageCollectorName :: Maybe ShortByteString
alignment :: Word32
comdat :: Maybe ShortByteString
section :: Maybe ShortByteString
functionAttributes :: [Either GroupID FunctionAttribute]
parameters :: ([Parameter], Bool)
name :: Name
returnType :: Type
returnAttributes :: [ParameterAttribute]
callingConvention :: CallingConvention
dllStorageClass :: Maybe StorageClass
visibility :: Visibility
linkage :: Linkage
..})
        | [BasicBlock] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [BasicBlock]
basicBlocks
        , Name ShortByteString
n        <- Name
name
        , ShortByteString
"__nv_"       <- Int -> ShortByteString -> ShortByteString
BS.take Int
5 ShortByteString
n
        = ShortByteString -> Maybe ShortByteString
forall a. a -> Maybe a
Just ShortByteString
n

      intrinsic Definition
_
        = Maybe ShortByteString
forall a. Maybe a
Nothing
  in
  [ShortByteString] -> HashSet ShortByteString
forall a. (Eq a, Hashable a) => [a] -> HashSet a
Set.fromList ((Definition -> Maybe ShortByteString)
-> [Definition] -> [ShortByteString]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Definition -> Maybe ShortByteString
intrinsic [Definition]
moduleDefinitions)


-- | Mark all definitions in the module as internal linkage. This means that
-- unused definitions can be removed as dead code. Be careful to leave any
-- declarations as external.
--
internalise :: HashSet ShortByteString -> Module -> Module
internalise :: HashSet ShortByteString -> Module -> Module
internalise HashSet ShortByteString
externals Module{[Definition]
Maybe ShortByteString
Maybe DataLayout
ShortByteString
moduleDefinitions :: [Definition]
moduleTargetTriple :: Maybe ShortByteString
moduleDataLayout :: Maybe DataLayout
moduleSourceFileName :: ShortByteString
moduleName :: ShortByteString
moduleName :: Module -> ShortByteString
moduleSourceFileName :: Module -> ShortByteString
moduleDefinitions :: Module -> [Definition]
moduleDataLayout :: Module -> Maybe DataLayout
moduleTargetTriple :: Module -> Maybe ShortByteString
..} =
  let internal :: Definition -> Definition
internal (GlobalDefinition Function{[Either GroupID FunctionAttribute]
[(ShortByteString, MDRef MDNode)]
[BasicBlock]
[ParameterAttribute]
Maybe ShortByteString
Maybe Constant
Maybe StorageClass
Word32
([Parameter], Bool)
Type
CallingConvention
Linkage
Name
Visibility
metadata :: [(ShortByteString, MDRef MDNode)]
personalityFunction :: Maybe Constant
basicBlocks :: [BasicBlock]
prefix :: Maybe Constant
garbageCollectorName :: Maybe ShortByteString
alignment :: Word32
comdat :: Maybe ShortByteString
section :: Maybe ShortByteString
functionAttributes :: [Either GroupID FunctionAttribute]
parameters :: ([Parameter], Bool)
name :: Name
returnType :: Type
returnAttributes :: [ParameterAttribute]
callingConvention :: CallingConvention
dllStorageClass :: Maybe StorageClass
visibility :: Visibility
linkage :: Linkage
name :: Global -> Name
linkage :: Global -> Linkage
visibility :: Global -> Visibility
dllStorageClass :: Global -> Maybe StorageClass
section :: Global -> Maybe ShortByteString
comdat :: Global -> Maybe ShortByteString
alignment :: Global -> Word32
metadata :: Global -> [(ShortByteString, MDRef MDNode)]
callingConvention :: Global -> CallingConvention
returnAttributes :: Global -> [ParameterAttribute]
returnType :: Global -> Type
parameters :: Global -> ([Parameter], Bool)
functionAttributes :: Global -> [Either GroupID FunctionAttribute]
garbageCollectorName :: Global -> Maybe ShortByteString
prefix :: Global -> Maybe Constant
basicBlocks :: Global -> [BasicBlock]
personalityFunction :: Global -> Maybe Constant
..})
        | Name ShortByteString
n <- Name
name
        , Bool -> Bool
not (ShortByteString -> HashSet ShortByteString -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
Set.member ShortByteString
n HashSet ShortByteString
externals)          -- we don't call this function directly; and
        , Bool -> Bool
not ([BasicBlock] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [BasicBlock]
basicBlocks)                -- it is not an external declaration
        = Global -> Definition
GlobalDefinition Function :: Linkage
-> Visibility
-> Maybe StorageClass
-> CallingConvention
-> [ParameterAttribute]
-> Type
-> Name
-> ([Parameter], Bool)
-> [Either GroupID FunctionAttribute]
-> Maybe ShortByteString
-> Maybe ShortByteString
-> Word32
-> Maybe ShortByteString
-> Maybe Constant
-> [BasicBlock]
-> Maybe Constant
-> [(ShortByteString, MDRef MDNode)]
-> Global
Function { linkage :: Linkage
linkage=Linkage
Internal, [Either GroupID FunctionAttribute]
[(ShortByteString, MDRef MDNode)]
[BasicBlock]
[ParameterAttribute]
Maybe ShortByteString
Maybe Constant
Maybe StorageClass
Word32
([Parameter], Bool)
Type
CallingConvention
Name
Visibility
metadata :: [(ShortByteString, MDRef MDNode)]
personalityFunction :: Maybe Constant
basicBlocks :: [BasicBlock]
prefix :: Maybe Constant
garbageCollectorName :: Maybe ShortByteString
alignment :: Word32
comdat :: Maybe ShortByteString
section :: Maybe ShortByteString
functionAttributes :: [Either GroupID FunctionAttribute]
parameters :: ([Parameter], Bool)
name :: Name
returnType :: Type
returnAttributes :: [ParameterAttribute]
callingConvention :: CallingConvention
dllStorageClass :: Maybe StorageClass
visibility :: Visibility
name :: Name
visibility :: Visibility
dllStorageClass :: Maybe StorageClass
section :: Maybe ShortByteString
comdat :: Maybe ShortByteString
alignment :: Word32
metadata :: [(ShortByteString, MDRef MDNode)]
callingConvention :: CallingConvention
returnAttributes :: [ParameterAttribute]
returnType :: Type
parameters :: ([Parameter], Bool)
functionAttributes :: [Either GroupID FunctionAttribute]
garbageCollectorName :: Maybe ShortByteString
prefix :: Maybe Constant
basicBlocks :: [BasicBlock]
personalityFunction :: Maybe Constant
.. }

      internal Definition
x
        = Definition
x
  in
  Module :: ShortByteString
-> ShortByteString
-> Maybe DataLayout
-> Maybe ShortByteString
-> [Definition]
-> Module
Module { moduleDefinitions :: [Definition]
moduleDefinitions = (Definition -> Definition) -> [Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> Definition
internal [Definition]
moduleDefinitions, Maybe ShortByteString
Maybe DataLayout
ShortByteString
moduleTargetTriple :: Maybe ShortByteString
moduleDataLayout :: Maybe DataLayout
moduleSourceFileName :: ShortByteString
moduleName :: ShortByteString
moduleName :: ShortByteString
moduleSourceFileName :: ShortByteString
moduleDataLayout :: Maybe DataLayout
moduleTargetTriple :: Maybe ShortByteString
.. }