module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass qualified as SinglePass
import Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass qualified as TwoPass
import Futhark.IR.GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops =
SegBinOp
{ segBinOpComm :: Commutativity
segBinOpComm = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
segBinOpNeutral :: [SubExp]
segBinOpNeutral = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
segBinOpShape :: Shape
segBinOpShape = forall a. Monoid a => a
mempty
}
where
lams :: [Lambda GPUMem]
lams = forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
lam' :: Lambda GPUMem
lam' =
Lambda
{ lambdaParams :: [LParam GPUMem]
lambdaParams = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Lambda rep -> [LParam rep]
xParams [Lambda GPUMem]
lams forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Lambda rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
lambdaReturnType :: [Type]
lambdaReturnType = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
lambdaBody :: Body GPUMem
lambdaBody =
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body
()
(forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams))
(forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams)
}
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
ops
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {k} {rep :: k}. SegBinOp rep -> Bool
ok [SegBinOp GPUMem]
ops =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops
| Bool
otherwise =
forall a. Maybe a
Nothing
where
ok :: SegBinOp rep -> Bool
ok SegBinOp rep
op =
forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op))
compileSegScan ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) forall a b. (a -> b) -> a -> b
$ do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" forall a. Maybe a
Nothing
Target
target <- HostEnv -> Target
hostTarget forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
case Target
target of
Target
CUDA
| Just SegBinOp GPUMem
scan' <- [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scans ->
Pat LParamMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan' KernelBody GPUMem
kbody
Target
_ -> Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
where
n :: TPrimExp Int64 VName
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space