-- | Code generation for 'SegScan'.  Dispatches to either a
-- single-pass or two-pass implementation, depending on the nature of
-- the scan and the chosen abckend.
module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where

import Control.Monad
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass qualified as SinglePass
import Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass qualified as TwoPass
import Futhark.IR.GPUMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScanOps :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScanOps :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScanOps [SegBinOp GPUMem]
ops =
  SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
      segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda GPUMem]
lams = forall a b. (a -> b) -> [a] -> [b]
map forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    lam' :: Lambda GPUMem
lam' =
      Lambda
        { lambdaParams :: [LParam GPUMem]
lambdaParams = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Lambda rep -> [LParam rep]
xParams [Lambda GPUMem]
lams forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Lambda rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Lambda rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
          lambdaBody :: Body GPUMem
lambdaBody =
            forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
              ()
              (forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams))
              (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams)
        }

bodyHas :: (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas :: (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
f = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp GPUMem -> Bool
f' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms
  where
    f' :: Exp GPUMem -> Bool
f' Exp GPUMem
e
      | Exp GPUMem -> Bool
f Exp GPUMem
e = Bool
True
      | Bool
otherwise = forall a. Maybe a -> Bool
isNothing forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM Walker GPUMem Maybe
walker Exp GPUMem
e
    walker :: Walker GPUMem Maybe
walker =
      forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker
        { walkOnBody :: Scope GPUMem -> Body GPUMem -> Maybe ()
walkOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
f
        }

canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scan_ops =
  if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
ok [SegBinOp GPUMem]
scan_ops
    then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScanOps [SegBinOp GPUMem]
scan_ops
    else forall a. Maybe a
Nothing
  where
    ok :: SegBinOp GPUMem -> Bool
ok SegBinOp GPUMem
op =
      forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType (forall rep. Lambda rep -> [Type]
lambdaReturnType (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op))
        Bool -> Bool -> Bool
&& Bool -> Bool
not ((Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas forall {rep}. Exp rep -> Bool
isAssert (forall rep. Lambda rep -> Body rep
lambdaBody (forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)))
    isAssert :: Exp rep -> Bool
isAssert (BasicOp Assert {}) = Bool
True
    isAssert Exp rep
_ = Bool
False

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scan_ops KernelBody GPUMem
map_kbody =
  forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
0 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) forall a b. (a -> b) -> a -> b
$ do
    forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" forall a. Maybe a
Nothing
    Target
target <- HostEnv -> Target
hostTarget forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv

    case (Target -> Bool
targetSupportsSinglePass Target
target, [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scan_ops) of
      (Bool
True, Just SegBinOp GPUMem
scan_ops') ->
        Pat LParamMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan_ops' KernelBody GPUMem
map_kbody
      (Bool, Maybe (SegBinOp GPUMem))
_ ->
        Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scan_ops KernelBody GPUMem
map_kbody
    forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" forall a. Maybe a
Nothing
  where
    n :: TPrimExp Int64 VName
n = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    targetSupportsSinglePass :: Target -> Bool
targetSupportsSinglePass Target
CUDA = Bool
True
    targetSupportsSinglePass Target
HIP = Bool
True
    targetSupportsSinglePass Target
_ = Bool
False