module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import qualified Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass as SinglePass
import qualified Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass as TwoPass
import Futhark.IR.GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops =
SegBinOp :: forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp
{ segBinOpComm :: Commutativity
segBinOpComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp GPUMem -> Commutativity)
-> [SegBinOp GPUMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
segBinOpNeutral :: [SubExp]
segBinOpNeutral = (SegBinOp GPUMem -> [SubExp]) -> [SegBinOp GPUMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
segBinOpShape :: Shape
segBinOpShape = Shape
forall a. Monoid a => a
mempty
}
where
lams :: [Lambda GPUMem]
lams = (SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
xParams :: LambdaT rep -> [Param (LParamInfo rep)]
xParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT rep
lam)
yParams :: LambdaT rep -> [Param (LParamInfo rep)]
yParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT rep
lam)
lam' :: Lambda GPUMem
lam' =
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam GPUMem]
lambdaParams = (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Param LParamMem]
forall rep. LambdaT rep -> [LParam rep]
xParams [Lambda GPUMem]
lams [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Param LParamMem]
forall rep. LambdaT rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
lambdaReturnType :: [Type]
lambdaReturnType = (Lambda GPUMem -> [Type]) -> [Lambda GPUMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
lambdaBody :: BodyT GPUMem
lambdaBody =
BodyDec GPUMem -> Stms GPUMem -> Result -> BodyT GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body
()
([Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ((Lambda GPUMem -> Stms GPUMem) -> [Lambda GPUMem] -> [Stms GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT GPUMem -> Stms GPUMem)
-> (Lambda GPUMem -> BodyT GPUMem) -> Lambda GPUMem -> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda GPUMem]
lams))
((Lambda GPUMem -> Result) -> [Lambda GPUMem] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT GPUMem -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT GPUMem -> Result)
-> (Lambda GPUMem -> BodyT GPUMem) -> Lambda GPUMem -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> BodyT GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda GPUMem]
lams)
}
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
ops
| (SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
forall rep. SegBinOp rep -> Bool
ok [SegBinOp GPUMem]
ops =
SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a. a -> Maybe a
Just (SegBinOp GPUMem -> Maybe (SegBinOp GPUMem))
-> SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops
| Bool
otherwise =
Maybe (SegBinOp GPUMem)
forall a. Maybe a
Nothing
where
ok :: SegBinOp rep -> Bool
ok SegBinOp rep
op =
SegBinOp rep -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (SegBinOp rep -> LambdaT rep
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op))
compileSegScan ::
Pat GPUMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 ExpLeaf
0 TPrimExp Int64 ExpLeaf -> TPrimExp Int64 ExpLeaf -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 ExpLeaf
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
Target
target <- HostEnv -> Target
hostTarget (HostEnv -> Target)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
case Target
target of
Target
CUDA
| Just SegBinOp GPUMem
scan' <- [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scans ->
Pat GPUMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan' KernelBody GPUMem
kbody
Target
_ -> Pat GPUMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat GPUMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
where
n :: TPrimExp Int64 ExpLeaf
n = [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf)
-> [TPrimExp Int64 ExpLeaf] -> TPrimExp Int64 ExpLeaf
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 ExpLeaf)
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 ExpLeaf
forall a. ToExp a => a -> TPrimExp Int64 ExpLeaf
toInt64Exp ([SubExp] -> [TPrimExp Int64 ExpLeaf])
-> [SubExp] -> [TPrimExp Int64 ExpLeaf]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space