-- | Imperative code with an OpenCL component.
--
-- Apart from ordinary imperative code, this also carries around an
-- OpenCL program as a string, as well as a list of kernels defined by
-- the OpenCL program.
--
-- The imperative code has been augmented with a 'LaunchKernel'
-- operation that allows one to execute an OpenCL kernel.
module Futhark.CodeGen.ImpCode.OpenCL
       ( Program (..)
       , Function
       , FunctionT (Function)
       , Code
       , KernelName
       , KernelArg (..)
       , OpenCL (..)
       , KernelSafety(..)
       , numFailureParams
       , KernelTarget (..)
       , FailureMsg(..)
       , module Futhark.CodeGen.ImpCode
       , module Futhark.IR.Kernels.Sizes
       )
       where

import qualified Data.Map as M

import Futhark.CodeGen.ImpCode hiding (Function, Code)
import Futhark.IR.Kernels.Sizes
import qualified Futhark.CodeGen.ImpCode as Imp

import Futhark.Util.Pretty

-- | An program calling OpenCL kernels.
data Program = Program { Program -> String
openClProgram :: String
                       , Program -> String
openClPrelude :: String
                         -- ^ Must be prepended to the program.
                       , Program -> Map KernelName KernelSafety
openClKernelNames :: M.Map KernelName KernelSafety
                       , Program -> [PrimType]
openClUsedTypes :: [PrimType]
                         -- ^ So we can detect whether the device is capable.
                       , Program -> Map KernelName SizeClass
openClSizes :: M.Map Name SizeClass
                         -- ^ Runtime-configurable constants.
                       , Program -> [FailureMsg]
openClFailures :: [FailureMsg]
                         -- ^ Assertion failure error messages.
                       , Program -> Definitions OpenCL
hostDefinitions :: Definitions OpenCL
                       }

-- | Something that can go wrong in a kernel.  Part of the machinery
-- for reporting error messages from within kernels.
data FailureMsg = FailureMsg { FailureMsg -> ErrorMsg Exp
failureError :: ErrorMsg Exp
                             , FailureMsg -> String
failureBacktrace :: String
                             }

-- | A function calling OpenCL kernels.
type Function = Imp.Function OpenCL

-- | A piece of code calling OpenCL.
type Code = Imp.Code OpenCL

-- | The name of a kernel.
type KernelName = Name

-- | An argument to be passed to a kernel.
data KernelArg = ValueKArg Exp PrimType
                 -- ^ Pass the value of this scalar expression as argument.
               | MemKArg VName
                 -- ^ Pass this pointer as argument.
               | SharedMemoryKArg (Count Bytes Exp)
                 -- ^ Create this much local memory per workgroup.
               deriving (Int -> KernelArg -> ShowS
[KernelArg] -> ShowS
KernelArg -> String
(Int -> KernelArg -> ShowS)
-> (KernelArg -> String)
-> ([KernelArg] -> ShowS)
-> Show KernelArg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelArg] -> ShowS
$cshowList :: [KernelArg] -> ShowS
show :: KernelArg -> String
$cshow :: KernelArg -> String
showsPrec :: Int -> KernelArg -> ShowS
$cshowsPrec :: Int -> KernelArg -> ShowS
Show)

-- | Whether a kernel can potentially fail (because it contains bounds
-- checks and such).
data MayFail = MayFail | CannotFail
             deriving (Int -> MayFail -> ShowS
[MayFail] -> ShowS
MayFail -> String
(Int -> MayFail -> ShowS)
-> (MayFail -> String) -> ([MayFail] -> ShowS) -> Show MayFail
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MayFail] -> ShowS
$cshowList :: [MayFail] -> ShowS
show :: MayFail -> String
$cshow :: MayFail -> String
showsPrec :: Int -> MayFail -> ShowS
$cshowsPrec :: Int -> MayFail -> ShowS
Show)

-- | Information about bounds checks and how sensitive it is to
-- errors.  Ordered by least demanding to most.
data KernelSafety
  = SafetyNone
    -- ^ Does not need to know if we are in a failing state, and also
    -- cannot fail.
  | SafetyCheap
    -- ^ Needs to be told if there's a global failure, and that's it,
    -- and cannot fail.
  | SafetyFull
    -- ^ Needs all parameters, may fail itself.
    deriving (KernelSafety -> KernelSafety -> Bool
(KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool) -> Eq KernelSafety
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSafety -> KernelSafety -> Bool
$c/= :: KernelSafety -> KernelSafety -> Bool
== :: KernelSafety -> KernelSafety -> Bool
$c== :: KernelSafety -> KernelSafety -> Bool
Eq, Eq KernelSafety
Eq KernelSafety
-> (KernelSafety -> KernelSafety -> Ordering)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> KernelSafety)
-> (KernelSafety -> KernelSafety -> KernelSafety)
-> Ord KernelSafety
KernelSafety -> KernelSafety -> Bool
KernelSafety -> KernelSafety -> Ordering
KernelSafety -> KernelSafety -> KernelSafety
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSafety -> KernelSafety -> KernelSafety
$cmin :: KernelSafety -> KernelSafety -> KernelSafety
max :: KernelSafety -> KernelSafety -> KernelSafety
$cmax :: KernelSafety -> KernelSafety -> KernelSafety
>= :: KernelSafety -> KernelSafety -> Bool
$c>= :: KernelSafety -> KernelSafety -> Bool
> :: KernelSafety -> KernelSafety -> Bool
$c> :: KernelSafety -> KernelSafety -> Bool
<= :: KernelSafety -> KernelSafety -> Bool
$c<= :: KernelSafety -> KernelSafety -> Bool
< :: KernelSafety -> KernelSafety -> Bool
$c< :: KernelSafety -> KernelSafety -> Bool
compare :: KernelSafety -> KernelSafety -> Ordering
$ccompare :: KernelSafety -> KernelSafety -> Ordering
$cp1Ord :: Eq KernelSafety
Ord, Int -> KernelSafety -> ShowS
[KernelSafety] -> ShowS
KernelSafety -> String
(Int -> KernelSafety -> ShowS)
-> (KernelSafety -> String)
-> ([KernelSafety] -> ShowS)
-> Show KernelSafety
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSafety] -> ShowS
$cshowList :: [KernelSafety] -> ShowS
show :: KernelSafety -> String
$cshow :: KernelSafety -> String
showsPrec :: Int -> KernelSafety -> ShowS
$cshowsPrec :: Int -> KernelSafety -> ShowS
Show)

-- | How many leading failure arguments we must pass when launching a
-- kernel with these safety characteristics.
numFailureParams :: KernelSafety -> Int
numFailureParams :: KernelSafety -> Int
numFailureParams KernelSafety
SafetyNone = Int
0
numFailureParams KernelSafety
SafetyCheap = Int
1
numFailureParams KernelSafety
SafetyFull = Int
3

-- | Host-level OpenCL operation.
data OpenCL = LaunchKernel KernelSafety KernelName [KernelArg] [Exp] [Exp]
            | GetSize VName Name
            | CmpSizeLe VName Name Exp
            | GetSizeMax VName SizeClass
            deriving (Int -> OpenCL -> ShowS
[OpenCL] -> ShowS
OpenCL -> String
(Int -> OpenCL -> ShowS)
-> (OpenCL -> String) -> ([OpenCL] -> ShowS) -> Show OpenCL
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OpenCL] -> ShowS
$cshowList :: [OpenCL] -> ShowS
show :: OpenCL -> String
$cshow :: OpenCL -> String
showsPrec :: Int -> OpenCL -> ShowS
$cshowsPrec :: Int -> OpenCL -> ShowS
Show)

-- | The target platform when compiling imperative code to a 'Program'
data KernelTarget = TargetOpenCL
                  | TargetCUDA
                  deriving (KernelTarget -> KernelTarget -> Bool
(KernelTarget -> KernelTarget -> Bool)
-> (KernelTarget -> KernelTarget -> Bool) -> Eq KernelTarget
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelTarget -> KernelTarget -> Bool
$c/= :: KernelTarget -> KernelTarget -> Bool
== :: KernelTarget -> KernelTarget -> Bool
$c== :: KernelTarget -> KernelTarget -> Bool
Eq)

instance Pretty OpenCL where
  ppr :: OpenCL -> Doc
ppr = String -> Doc
text (String -> Doc) -> (OpenCL -> String) -> OpenCL -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenCL -> String
forall a. Show a => a -> String
show