-- | Code generation for 'SegScan'.  Dispatches to either a
-- single-pass or two-pass implementation, depending on the nature of
-- the scan and the chosen abckend.
module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where

import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import qualified Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass as SinglePass
import qualified Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass as TwoPass
import Futhark.IR.GPUMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops =
  SegBinOp :: forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp GPUMem -> Commutativity)
-> [SegBinOp GPUMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
      segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = (SegBinOp GPUMem -> [SubExp]) -> [SegBinOp GPUMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = Shape
forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda GPUMem]
lams = (SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    lam' :: Lambda GPUMem
lam' =
      Lambda :: forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
        { lambdaParams :: [LParam GPUMem]
lambdaParams = (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
xParams [Lambda GPUMem]
lams [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = (Lambda GPUMem -> [Type]) -> [Lambda GPUMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
          lambdaBody :: Body GPUMem
lambdaBody =
            BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
              ()
              ([Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ((Lambda GPUMem -> Stms GPUMem) -> [Lambda GPUMem] -> [Stms GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem)
-> (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams))
              ((Lambda GPUMem -> Result) -> [Lambda GPUMem] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Body GPUMem -> Result
forall rep. Body rep -> Result
bodyResult (Body GPUMem -> Result)
-> (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams)
        }

canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
ops
  | (SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
forall rep. SegBinOp rep -> Bool
ok [SegBinOp GPUMem]
ops =
      SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a. a -> Maybe a
Just (SegBinOp GPUMem -> Maybe (SegBinOp GPUMem))
-> SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops
  | Bool
otherwise =
      Maybe (SegBinOp GPUMem)
forall a. Maybe a
Nothing
  where
    ok :: SegBinOp rep -> Bool
ok SegBinOp rep
op =
      SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (SegBinOp rep -> Lambda rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op))

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
  Target
target <- HostEnv -> Target
hostTarget (HostEnv -> Target)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
  case Target
target of
    Target
CUDA
      | Just SegBinOp GPUMem
scan' <- [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scans ->
          Pat LParamMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan' KernelBody GPUMem
kbody
    Target
_ -> Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
  where
    n :: TPrimExp Int64 VName
n = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
forall a. ToExp a => a -> TPrimExp Int64 VName
toInt64Exp ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space