{-# LANGUAGE CPP #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Array.Accelerate.LLVM.PTX.Target (
module Data.Array.Accelerate.LLVM.Target,
module Data.Array.Accelerate.LLVM.PTX.Target,
) where
import LLVM.AST.AddrSpace
import LLVM.AST.DataLayout
import LLVM.Target hiding ( Target )
import qualified LLVM.Target as LLVM
import qualified LLVM.Relocation as R
import qualified LLVM.CodeModel as CM
import qualified LLVM.CodeGenOpt as CGO
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.LLVM.Extra
import Data.Array.Accelerate.LLVM.Target
import Data.Array.Accelerate.LLVM.PTX.Array.Table ( MemoryTable )
import Data.Array.Accelerate.LLVM.PTX.Context ( Context, deviceProperties )
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream.Reservoir ( Reservoir )
import Data.Array.Accelerate.LLVM.PTX.Link.Cache ( KernelTable )
import Foreign.CUDA.Analysis.Device
import Data.ByteString ( ByteString )
import Data.ByteString.Short ( ShortByteString )
import Data.String
import Debug.Trace
import System.IO.Unsafe
import Text.Printf
import qualified Data.Map as Map
import qualified Data.Set as Set
data PTX = PTX {
PTX -> Context
ptxContext :: {-# UNPACK #-} !Context
, PTX -> MemoryTable
ptxMemoryTable :: {-# UNPACK #-} !MemoryTable
, PTX -> KernelTable
ptxKernelTable :: {-# UNPACK #-} !KernelTable
, PTX -> Reservoir
ptxStreamReservoir :: {-# UNPACK #-} !Reservoir
}
instance Target PTX where
targetTriple :: Maybe ShortByteString
targetTriple = ShortByteString -> Maybe ShortByteString
forall a. a -> Maybe a
Just ShortByteString
HasCallStack => ShortByteString
ptxTargetTriple
#if ACCELERATE_USE_NVVM
targetDataLayout = Nothing
#else
targetDataLayout :: Maybe DataLayout
targetDataLayout = DataLayout -> Maybe DataLayout
forall a. a -> Maybe a
Just DataLayout
ptxDataLayout
#endif
ptxDeviceProperties :: PTX -> DeviceProperties
ptxDeviceProperties :: PTX -> DeviceProperties
ptxDeviceProperties = Context -> DeviceProperties
deviceProperties (Context -> DeviceProperties)
-> (PTX -> Context) -> PTX -> DeviceProperties
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PTX -> Context
ptxContext
ptxDataLayout :: DataLayout
ptxDataLayout :: DataLayout
ptxDataLayout = DataLayout :: Endianness
-> Maybe Mangling
-> Maybe Word32
-> Map AddrSpace (Word32, AlignmentInfo)
-> Map (AlignType, Word32) AlignmentInfo
-> AlignmentInfo
-> Maybe (Set Word32)
-> DataLayout
DataLayout
{ endianness :: Endianness
endianness = Endianness
LittleEndian
, mangling :: Maybe Mangling
mangling = Maybe Mangling
forall a. Maybe a
Nothing
, aggregateLayout :: AlignmentInfo
aggregateLayout = Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
0 Word32
64
, stackAlignment :: Maybe Word32
stackAlignment = Maybe Word32
forall a. Maybe a
Nothing
, pointerLayouts :: Map AddrSpace (Word32, AlignmentInfo)
pointerLayouts = [(AddrSpace, (Word32, AlignmentInfo))]
-> Map AddrSpace (Word32, AlignmentInfo)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList
[ (Word32 -> AddrSpace
AddrSpace Word32
0, (Word32
wordSize, Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
wordSize Word32
wordSize)) ]
, typeLayouts :: Map (AlignType, Word32) AlignmentInfo
typeLayouts = [((AlignType, Word32), AlignmentInfo)]
-> Map (AlignType, Word32) AlignmentInfo
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([((AlignType, Word32), AlignmentInfo)]
-> Map (AlignType, Word32) AlignmentInfo)
-> [((AlignType, Word32), AlignmentInfo)]
-> Map (AlignType, Word32) AlignmentInfo
forall a b. (a -> b) -> a -> b
$
[ ((AlignType
IntegerAlign, Word32
1), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
8 Word32
8) ] [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
forall a. [a] -> [a] -> [a]
++
[ ((AlignType
IntegerAlign, Word32
i), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
i Word32
i) | Word32
i <- [Word32
8,Word32
16,Word32
32,Word32
64]] [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
forall a. [a] -> [a] -> [a]
++
[ ((AlignType
VectorAlign, Word32
v), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
v Word32
v) | Word32
v <- [Word32
16,Word32
32,Word32
64,Word32
128]] [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
-> [((AlignType, Word32), AlignmentInfo)]
forall a. [a] -> [a] -> [a]
++
[ ((AlignType
FloatAlign, Word32
f), Word32 -> Word32 -> AlignmentInfo
AlignmentInfo Word32
f Word32
f) | Word32
f <- [Word32
32,Word32
64] ]
, nativeSizes :: Maybe (Set Word32)
nativeSizes = Set Word32 -> Maybe (Set Word32)
forall a. a -> Maybe a
Just (Set Word32 -> Maybe (Set Word32))
-> Set Word32 -> Maybe (Set Word32)
forall a b. (a -> b) -> a -> b
$ [Word32] -> Set Word32
forall a. Ord a => [a] -> Set a
Set.fromList [ Word32
16,Word32
32,Word32
64 ]
}
where
wordSize :: Word32
wordSize = Int -> Word32
forall a. (HasCallStack, Bits a) => a -> Word32
bitSize (Int
forall a. HasCallStack => a
undefined :: Int)
ptxTargetTriple :: HasCallStack => ShortByteString
ptxTargetTriple :: ShortByteString
ptxTargetTriple =
case Int -> Word32
forall a. (HasCallStack, Bits a) => a -> Word32
bitSize (Int
forall a. HasCallStack => a
undefined::Int) of
Word32
32 -> ShortByteString
"nvptx-nvidia-cuda"
Word32
64 -> ShortByteString
"nvptx64-nvidia-cuda"
Word32
_ -> String -> ShortByteString
forall a. HasCallStack => String -> a
internalError String
"I don't know what architecture I am"
withPTXTargetMachine
:: HasCallStack
=> DeviceProperties
-> (TargetMachine -> IO a)
-> IO a
withPTXTargetMachine :: DeviceProperties -> (TargetMachine -> IO a) -> IO a
withPTXTargetMachine DeviceProperties
dev TargetMachine -> IO a
go =
let (ByteString
sm, ByteString
isa) = Compute -> (ByteString, ByteString)
ptxTargetVersion (DeviceProperties -> Compute
computeCapability DeviceProperties
dev)
in
(TargetOptions -> IO a) -> IO a
forall a. (TargetOptions -> IO a) -> IO a
withTargetOptions ((TargetOptions -> IO a) -> IO a)
-> (TargetOptions -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \TargetOptions
options -> do
Target
-> ShortByteString
-> ByteString
-> Map CPUFeature Bool
-> TargetOptions
-> Model
-> Model
-> Level
-> (TargetMachine -> IO a)
-> IO a
forall a.
Target
-> ShortByteString
-> ByteString
-> Map CPUFeature Bool
-> TargetOptions
-> Model
-> Model
-> Level
-> (TargetMachine -> IO a)
-> IO a
withTargetMachine
Target
ptxTarget
ShortByteString
HasCallStack => ShortByteString
ptxTargetTriple
ByteString
sm
(CPUFeature -> Bool -> Map CPUFeature Bool
forall k a. k -> a -> Map k a
Map.singleton (ByteString -> CPUFeature
CPUFeature ByteString
isa) Bool
True)
TargetOptions
options
Model
R.Default
Model
CM.Default
Level
CGO.Default
TargetMachine -> IO a
go
ptxTargetVersion :: Compute -> (ByteString, ByteString)
ptxTargetVersion :: Compute -> (ByteString, ByteString)
ptxTargetVersion compute :: Compute
compute@(Compute Int
m Int
n)
#if MIN_VERSION_llvm_hs(8,0,0)
| Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
7 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
5 = (ByteString
"sm_75", ByteString
"ptx63")
#endif
#if MIN_VERSION_llvm_hs(7,0,0)
| Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
7 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2 = (ByteString
"sm_72", ByteString
"ptx61")
#endif
#if MIN_VERSION_llvm_hs(6,0,0)
| Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
7 = (ByteString
"sm_70", ByteString
"ptx60")
#endif
| Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
6 = (ByteString
"sm_62", ByteString
"ptx50")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
6 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = (ByteString
"sm_62", ByteString
"ptx50")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
6 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = (ByteString
"sm_61", ByteString
"ptx50")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
6 = (ByteString
"sm_60", ByteString
"ptx50")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 = (ByteString
"sm_53", ByteString
"ptx42")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = (ByteString
"sm_52", ByteString
"ptx41")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5 = (ByteString
"sm_50", ByteString
"ptx40")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
7 = (ByteString
"sm_37", ByteString
"ptx41")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
5 = (ByteString
"sm_35", ByteString
"ptx40")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = (ByteString
"sm_32", ByteString
"ptx40")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 = (ByteString
"sm_30", ByteString
"ptx40")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = (ByteString
"sm_21", ByteString
"ptx40")
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = (ByteString
"sm_20", ByteString
"ptx40")
| Bool
otherwise
= String -> (ByteString, ByteString) -> (ByteString, ByteString)
forall a. String -> a -> a
trace String
warning (String -> ByteString
forall a. IsString a => String -> a
fromString (String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"sm_%d%d" Int
m Int
n), ByteString
"ptx40")
where
warning :: String
warning = [String] -> String
unlines [ String
"*** Warning: Unhandled CUDA device compute capability: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Compute -> String
forall a. Show a => a -> String
show Compute
compute
, String
"*** Please submit a bug report at https://github.com/AccelerateHS/accelerate/issues" ]
{-# NOINLINE ptxTarget #-}
ptxTarget :: LLVM.Target
ptxTarget :: Target
ptxTarget = IO Target -> Target
forall a. IO a -> a
unsafePerformIO (IO Target -> Target) -> IO Target -> Target
forall a b. (a -> b) -> a -> b
$ do
IO ()
initializeAllTargets
(Target, ShortByteString) -> Target
forall a b. (a, b) -> a
fst ((Target, ShortByteString) -> Target)
-> IO (Target, ShortByteString) -> IO Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Maybe ShortByteString
-> ShortByteString -> IO (Target, ShortByteString)
lookupTarget Maybe ShortByteString
forall a. Maybe a
Nothing ShortByteString
HasCallStack => ShortByteString
ptxTargetTriple