-- | Code generation for 'SegScan'.  Dispatches to either a
-- single-pass or two-pass implementation, depending on the nature of
-- the scan and the chosen abckend.
module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where

import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass qualified as SinglePass
import Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass qualified as TwoPass
import Futhark.IR.GPUMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops =
  SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
      segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda GPUMem]
lams = forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    lam' :: Lambda GPUMem
lam' =
      Lambda
        { lambdaParams :: [LParam GPUMem]
lambdaParams = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Lambda rep -> [LParam rep]
xParams [Lambda GPUMem]
lams forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Lambda rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
          lambdaBody :: Body GPUMem
lambdaBody =
            forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body
              ()
              (forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams))
              (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams)
        }

canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
ops
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {k} {rep :: k}. SegBinOp rep -> Bool
ok [SegBinOp GPUMem]
ops =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops
  | Bool
otherwise =
      forall a. Maybe a
Nothing
  where
    ok :: SegBinOp rep -> Bool
ok SegBinOp rep
op =
      forall {k} (rep :: k). SegBinOp rep -> Shape
segBinOpShape SegBinOp rep
op forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp rep
op))

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) forall a b. (a -> b) -> a -> b
$ do
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" forall a. Maybe a
Nothing
  Target
target <- HostEnv -> Target
hostTarget forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
  case Target
target of
    Target
CUDA
      | Just SegBinOp GPUMem
scan' <- [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scans ->
          Pat LParamMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan' KernelBody GPUMem
kbody
    Target
_ -> Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
  where
    n :: TPrimExp Int64 VName
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space