-- | Code generation for 'SegScan'.  Dispatches to either a
-- single-pass or two-pass implementation, depending on the nature of
-- the scan and the chosen abckend.
module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where

import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import qualified Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass as SinglePass
import qualified Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass as TwoPass
import Futhark.IR.GPUMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops =
  SegBinOp :: forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp GPUMem -> Commutativity)
-> [SegBinOp GPUMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
      segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = (SegBinOp GPUMem -> [SubExp]) -> [SegBinOp GPUMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = Shape
forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda GPUMem]
lams = (SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
    xParams :: LambdaT rep -> [Param (LParamInfo rep)]
xParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT rep
lam)
    yParams :: LambdaT rep -> [Param (LParamInfo rep)]
yParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT rep
lam)
    lam' :: Lambda GPUMem
lam' =
      Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
        { lambdaParams :: [LParam GPUMem]
lambdaParams = (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Param LParamMem]
forall rep. LambdaT rep -> [LParam rep]
xParams [Lambda GPUMem]
lams [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Param LParamMem]
forall rep. LambdaT rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = (Lambda GPUMem -> [Type]) -> [Lambda GPUMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
          lambdaBody :: BodyT GPUMem
lambdaBody =
            BodyDec GPUMem -> Stms GPUMem -> Result -> BodyT GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body
              ()
              ([Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ((Lambda GPUMem -> Stms GPUMem) -> [Lambda GPUMem] -> [Stms GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem)
-> (Lambda GPUMem -> BodyT GPUMem) -> Lambda GPUMem -> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda GPUMem]
lams))
              ((Lambda GPUMem -> Result) -> [Lambda GPUMem] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT GPUMem -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT GPUMem -> Result)
-> (Lambda GPUMem -> BodyT GPUMem) -> Lambda GPUMem -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda GPUMem]
lams)
        }

canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
ops
  | (SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
forall rep. SegBinOp rep -> Bool
ok [SegBinOp GPUMem]
ops =
    SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a. a -> Maybe a
Just (SegBinOp GPUMem -> Maybe (SegBinOp GPUMem))
-> SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops
  | Bool
otherwise =
    Maybe (SegBinOp GPUMem)
forall a. Maybe a
Nothing
  where
    ok :: SegBinOp rep -> Bool
ok SegBinOp rep
op =
      SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (SegBinOp rep -> LambdaT rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op))

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat GPUMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 ExpLeaf
0 TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
  Target
target <- HostEnv -> Target
hostTarget (HostEnv -> Target)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
  case Target
target of
    Target
CUDA
      | Just SegBinOp GPUMem
scan' <- [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scans ->
        Pat GPUMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan' KernelBody GPUMem
kbody
    Target
_ -> Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
  where
    n :: TPrimExp Int64 ExpLeaf
n = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp ([SubExp] -> [TPrimExp Int64 ExpLeaf])
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space