{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice -- Copyright : [2014..2017] Trevor L. McDonell -- [2014..2014] Vinod Grover (NVIDIA Corporation) -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice ( nvvmReflect, libdevice, ) where -- llvm-hs import LLVM.Context import LLVM.Module as LLVM import LLVM.AST as AST ( Module(..), Definition(..) ) import LLVM.AST.Attribute import LLVM.AST.Global as G import qualified LLVM.AST.Name as AST -- accelerate import LLVM.AST.Type.Name ( Label(..) ) import LLVM.AST.Type.Representation import Data.Array.Accelerate.Error import Data.Array.Accelerate.LLVM.CodeGen.Base import Data.Array.Accelerate.LLVM.CodeGen.Downcast import Data.Array.Accelerate.LLVM.CodeGen.Intrinsic import Data.Array.Accelerate.LLVM.PTX.Target -- cuda import Foreign.CUDA.Analysis -- standard library import Control.Monad.Except import Data.ByteString ( ByteString ) import Data.HashMap.Strict ( HashMap ) import Data.List import Data.Maybe import System.Directory import System.FilePath import System.IO.Unsafe import Text.Printf import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 import qualified Data.HashMap.Strict as HashMap -- NVVM Reflect -- ------------ class NVVMReflect a where nvvmReflect :: a instance NVVMReflect AST.Module where nvvmReflect = nvvmReflectPass_mdl instance NVVMReflect (String, ByteString) where nvvmReflect = nvvmReflectPass_bc -- This is a hacky module that can be linked against in order to provide the -- same functionality as running the NVVMReflect pass. -- -- Note: [NVVM Reflect Pass] -- -- To accommodate various math-related compiler flags that can affect code -- generation of libdevice code, the library code depends on a special LLVM IR -- pass (NVVMReflect) to handle conditional compilation within LLVM IR. This -- pass looks for calls to the @__nvvm_reflect function and replaces them with -- constants based on the defined reflection parameters. -- -- libdevice currently uses the following reflection parameters to control code -- generation: -- -- * __CUDA_FTZ={0,1} fast math that flushes denormals to zero -- -- Since this is currently the only reflection parameter supported, and that we -- prefer correct results over pure speed, we do not flush denormals to zero. If -- the list of supported parameters ever changes, we may need to re-evaluate -- this implementation. -- nvvmReflectPass_mdl :: AST.Module nvvmReflectPass_mdl = AST.Module { moduleName = "nvvm-reflect" , moduleSourceFileName = [] , moduleDataLayout = targetDataLayout (undefined::PTX) , moduleTargetTriple = targetTriple (undefined::PTX) , moduleDefinitions = [GlobalDefinition $ functionDefaults { name = AST.Name "__nvvm_reflect" , returnType = downcast (integralType :: IntegralType Int32) , parameters = ( [ptrParameter scalarType (UnName 0 :: Name (Ptr Int8))], False ) , G.functionAttributes = map Right [NoUnwind, ReadNone, AlwaysInline] , basicBlocks = [] }] } {-# NOINLINE nvvmReflectPass_bc #-} nvvmReflectPass_bc :: (String, ByteString) nvvmReflectPass_bc = (name,) . unsafePerformIO $ do withContext $ \ctx -> do runError $ withModuleFromAST ctx nvvmReflectPass_mdl (return . B8.pack <=< moduleLLVMAssembly) where name = "__nvvm_reflect" runError = either ($internalError "nvvmReflectPass") return <=< runExceptT -- libdevice -- --------- -- Compatible version of libdevice for a given compute capability should be -- listed here: -- -- https://github.com/llvm-mirror/llvm/blob/master/lib/Target/NVPTX/NVPTX.td#L72 -- class Libdevice a where libdevice :: Compute -> a instance Libdevice AST.Module where libdevice (Compute n m) = case (n,m) of (2,_) -> libdevice_20_mdl -- 2.0, 2.1 (3,x) | x < 5 -> libdevice_30_mdl -- 3.0, 3.2 | otherwise -> libdevice_35_mdl -- 3.5, 3.7 (5,_) -> libdevice_50_mdl -- 5.x (6,_) -> libdevice_50_mdl -- 6.x _ -> $internalError "libdevice" "no binary for this architecture" instance Libdevice (String, ByteString) where libdevice (Compute n m) = case (n,m) of (2,_) -> libdevice_20_bc -- 2.0, 2.1 (3,x) | x < 5 -> libdevice_30_bc -- 3.0, 3.2 | otherwise -> libdevice_35_bc -- 3.5, 3.7 (5,_) -> libdevice_50_bc -- 5.x (6,_) -> libdevice_50_bc -- 6.x _ -> $internalError "libdevice" "no binary for this architecture" -- Load the libdevice bitcode files as an LLVM AST module. The top-level -- unsafePerformIO ensures that the data is only read from disk once per program -- execution. -- {-# NOINLINE libdevice_20_mdl #-} {-# NOINLINE libdevice_30_mdl #-} {-# NOINLINE libdevice_35_mdl #-} {-# NOINLINE libdevice_50_mdl #-} libdevice_20_mdl, libdevice_30_mdl, libdevice_35_mdl, libdevice_50_mdl :: AST.Module libdevice_20_mdl = unsafePerformIO $ libdeviceModule (Compute 2 0) libdevice_30_mdl = unsafePerformIO $ libdeviceModule (Compute 3 0) libdevice_35_mdl = unsafePerformIO $ libdeviceModule (Compute 3 5) libdevice_50_mdl = unsafePerformIO $ libdeviceModule (Compute 5 0) -- Load the libdevice bitcode files as raw binary data. The top-level -- unsafePerformIO ensures that the data is read only once per program -- execution. -- {-# NOINLINE libdevice_20_bc #-} {-# NOINLINE libdevice_30_bc #-} {-# NOINLINE libdevice_35_bc #-} {-# NOINLINE libdevice_50_bc #-} libdevice_20_bc, libdevice_30_bc, libdevice_35_bc, libdevice_50_bc :: (String,ByteString) libdevice_20_bc = unsafePerformIO $ libdeviceBitcode (Compute 2 0) libdevice_30_bc = unsafePerformIO $ libdeviceBitcode (Compute 3 0) libdevice_35_bc = unsafePerformIO $ libdeviceBitcode (Compute 3 5) libdevice_50_bc = unsafePerformIO $ libdeviceBitcode (Compute 5 0) -- Load the libdevice bitcode file for the given compute architecture, and raise -- it to a Haskell AST that can be kept for future use. The name of the bitcode -- files follows: -- -- libdevice.compute_XX.YY.bc -- -- Where XX represents the compute capability, and YY represents a version(?) We -- search the libdevice PATH for all files of the appropriate compute capability -- and load the most recent. -- libdeviceModule :: Compute -> IO AST.Module libdeviceModule arch = do let bc :: (String, ByteString) bc = libdevice arch -- TLM: we have called 'withContext' again here, although the LLVM state -- already carries a version of the context. We do this so that we can -- fully apply this function that can be lifted out to a CAF and only -- executed once per program execution. -- withContext $ \ctx -> either ($internalError "libdeviceModule") id `fmap` runExceptT (withModuleFromBitcode ctx bc moduleAST) -- Load the libdevice bitcode file for the given compute architecture. The name -- of the bitcode files follows the format: -- -- libdevice.compute_XX.YY.bc -- -- Where XX represents the compute capability, and YY represents a version(?) We -- search the libdevice PATH for all files of the appropriate compute capability -- and load the "most recent" (by sort order). -- libdeviceBitcode :: Compute -> IO (String, ByteString) libdeviceBitcode (Compute m n) = do let arch = printf "libdevice.compute_%d%d" m n err = $internalError "libdevice" (printf "not found: %s.YY.bc" arch) best f = arch `isPrefixOf` f && takeExtension f == ".bc" path <- libdevicePath files <- getDirectoryContents path name <- maybe err return . listToMaybe . sortBy (flip compare) $ filter best files bc <- B.readFile (path name) return (name, bc) -- Determine the location of the libdevice bitcode libraries. We search for the -- location of the 'nvcc' executable in the PATH. From that, we assume the -- location of the libdevice bitcode files. -- libdevicePath :: IO FilePath libdevicePath = do nvcc <- fromMaybe (error "could not find 'nvcc' in PATH") `fmap` findExecutable "nvcc" let ccvn = reverse (splitPath nvcc) dir = "libdevice" : "nvvm" : drop 2 ccvn return (joinPath (reverse dir)) instance Intrinsic PTX where intrinsicForTarget _ = libdeviceIndex -- The list of functions implemented by libdevice. These are all more-or-less -- named consistently based on the standard mathematical functions they -- implement, with the "__nv_" prefix stripped. -- libdeviceIndex :: HashMap String Label libdeviceIndex = let nv base = (base, Label $ "__nv_" ++ base) in HashMap.fromList $ map nv [ "abs" , "acos" , "acosf" , "acosh" , "acoshf" , "asin" , "asinf" , "asinh" , "asinhf" , "atan" , "atan2" , "atan2f" , "atanf" , "atanh" , "atanhf" , "brev" , "brevll" , "byte_perm" , "cbrt" , "cbrtf" , "ceil" , "ceilf" , "clz" , "clzll" , "copysign" , "copysignf" , "cos" , "cosf" , "cosh" , "coshf" , "cospi" , "cospif" , "dadd_rd" , "dadd_rn" , "dadd_ru" , "dadd_rz" , "ddiv_rd" , "ddiv_rn" , "ddiv_ru" , "ddiv_rz" , "dmul_rd" , "dmul_rn" , "dmul_ru" , "dmul_rz" , "double2float_rd" , "double2float_rn" , "double2float_ru" , "double2float_rz" , "double2hiint" , "double2int_rd" , "double2int_rn" , "double2int_ru" , "double2int_rz" , "double2ll_rd" , "double2ll_rn" , "double2ll_ru" , "double2ll_rz" , "double2loint" , "double2uint_rd" , "double2uint_rn" , "double2uint_ru" , "double2uint_rz" , "double2ull_rd" , "double2ull_rn" , "double2ull_ru" , "double2ull_rz" , "double_as_longlong" , "drcp_rd" , "drcp_rn" , "drcp_ru" , "drcp_rz" , "dsqrt_rd" , "dsqrt_rn" , "dsqrt_ru" , "dsqrt_rz" , "erf" , "erfc" , "erfcf" , "erfcinv" , "erfcinvf" , "erfcx" , "erfcxf" , "erff" , "erfinv" , "erfinvf" , "exp" , "exp10" , "exp10f" , "exp2" , "exp2f" , "expf" , "expm1" , "expm1f" , "fabs" , "fabsf" , "fadd_rd" , "fadd_rn" , "fadd_ru" , "fadd_rz" , "fast_cosf" , "fast_exp10f" , "fast_expf" , "fast_fdividef" , "fast_log10f" , "fast_log2f" , "fast_logf" , "fast_powf" , "fast_sincosf" , "fast_sinf" , "fast_tanf" , "fdim" , "fdimf" , "fdiv_rd" , "fdiv_rn" , "fdiv_ru" , "fdiv_rz" , "ffs" , "ffsll" , "finitef" , "float2half_rn" , "float2int_rd" , "float2int_rn" , "float2int_ru" , "float2int_rz" , "float2ll_rd" , "float2ll_rn" , "float2ll_ru" , "float2ll_rz" , "float2uint_rd" , "float2uint_rn" , "float2uint_ru" , "float2uint_rz" , "float2ull_rd" , "float2ull_rn" , "float2ull_ru" , "float2ull_rz" , "float_as_int" , "floor" , "floorf" , "fma" , "fma_rd" , "fma_rn" , "fma_ru" , "fma_rz" , "fmaf" , "fmaf_rd" , "fmaf_rn" , "fmaf_ru" , "fmaf_rz" , "fmax" , "fmaxf" , "fmin" , "fminf" , "fmod" , "fmodf" , "fmul_rd" , "fmul_rn" , "fmul_ru" , "fmul_rz" , "frcp_rd" , "frcp_rn" , "frcp_ru" , "frcp_rz" , "frexp" , "frexpf" , "frsqrt_rn" , "fsqrt_rd" , "fsqrt_rn" , "fsqrt_ru" , "fsqrt_rz" , "fsub_rd" , "fsub_rn" , "fsub_ru" , "fsub_rz" , "hadd" , "half2float" , "hiloint2double" , "hypot" , "hypotf" , "ilogb" , "ilogbf" , "int2double_rn" , "int2float_rd" , "int2float_rn" , "int2float_ru" , "int2float_rz" , "int_as_float" , "isfinited" , "isinfd" , "isinff" , "isnand" , "isnanf" , "j0" , "j0f" , "j1" , "j1f" , "jn" , "jnf" , "ldexp" , "ldexpf" , "lgamma" , "lgammaf" , "ll2double_rd" , "ll2double_rn" , "ll2double_ru" , "ll2double_rz" , "ll2float_rd" , "ll2float_rn" , "ll2float_ru" , "ll2float_rz" , "llabs" , "llmax" , "llmin" , "llrint" , "llrintf" , "llround" , "llroundf" , "log" , "log10" , "log10f" , "log1p" , "log1pf" , "log2" , "log2f" , "logb" , "logbf" , "logf" , "longlong_as_double" , "max" , "min" , "modf" , "modff" , "mul24" , "mul64hi" , "mulhi" , "nan" , "nanf" , "nearbyint" , "nearbyintf" , "nextafter" , "nextafterf" , "normcdf" , "normcdff" , "normcdfinv" , "normcdfinvf" , "popc" , "popcll" , "pow" , "powf" , "powi" , "powif" , "rcbrt" , "rcbrtf" , "remainder" , "remainderf" , "remquo" , "remquof" , "rhadd" , "rint" , "rintf" , "round" , "roundf" , "rsqrt" , "rsqrtf" , "sad" , "saturatef" , "scalbn" , "scalbnf" , "signbitd" , "signbitf" , "sin" , "sincos" , "sincosf" , "sincospi" , "sincospif" , "sinf" , "sinh" , "sinhf" , "sinpi" , "sinpif" , "sqrt" , "sqrtf" , "tan" , "tanf" , "tanh" , "tanhf" , "tgamma" , "tgammaf" , "trunc" , "truncf" , "uhadd" , "uint2double_rn" , "uint2float_rd" , "uint2float_rn" , "uint2float_ru" , "uint2float_rz" , "ull2double_rd" , "ull2double_rn" , "ull2double_ru" , "ull2double_rz" , "ull2float_rd" , "ull2float_rn" , "ull2float_ru" , "ull2float_rz" , "ullmax" , "ullmin" , "umax" , "umin" , "umul24" , "umul64hi" , "umulhi" , "urhadd" , "usad" , "y0" , "y0f" , "y1" , "y1f" , "yn" , "ynf" ]