-- | Code generation for 'SegScan'.  Dispatches to either a
-- single-pass or two-pass implementation, depending on the nature of
-- the scan and the chosen abckend.
module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where

import Control.Monad
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass qualified as SinglePass
import Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass qualified as TwoPass
import Futhark.IR.GPUMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops =
  SegBinOp
    { segBinOpComm :: Commutativity
segBinOpComm = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((SegBinOp GPUMem -> Commutativity)
-> [SegBinOp GPUMem] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Commutativity
forall rep. SegBinOp rep -> Commutativity
segBinOpComm [SegBinOp GPUMem]
ops),
      segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam',
      segBinOpNeutral :: [SubExp]
segBinOpNeutral = (SegBinOp GPUMem -> [SubExp]) -> [SegBinOp GPUMem] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral [SegBinOp GPUMem]
ops,
      segBinOpShape :: Shape
segBinOpShape = Shape
forall a. Monoid a => a
mempty -- Assumed
    }
  where
    lams :: [Lambda GPUMem]
lams = (SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
ops
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
    lam' :: Lambda GPUMem
lam' =
      Lambda
        { lambdaParams :: [LParam GPUMem]
lambdaParams = (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [LParam GPUMem]
Lambda GPUMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
xParams [Lambda GPUMem]
lams [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ (Lambda GPUMem -> [Param LParamMem])
-> [Lambda GPUMem] -> [Param LParamMem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [LParam GPUMem]
Lambda GPUMem -> [Param LParamMem]
forall rep. Lambda rep -> [LParam rep]
yParams [Lambda GPUMem]
lams,
          lambdaReturnType :: [Type]
lambdaReturnType = (Lambda GPUMem -> [Type]) -> [Lambda GPUMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType [Lambda GPUMem]
lams,
          lambdaBody :: Body GPUMem
lambdaBody =
            BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
              ()
              ([Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ((Lambda GPUMem -> Stms GPUMem) -> [Lambda GPUMem] -> [Stms GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem)
-> (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams))
              ((Lambda GPUMem -> Result) -> [Lambda GPUMem] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Body GPUMem -> Result
forall rep. Body rep -> Result
bodyResult (Body GPUMem -> Result)
-> (Lambda GPUMem -> Body GPUMem) -> Lambda GPUMem -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda GPUMem]
lams)
        }

bodyHas :: (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas :: (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
f = (Stm GPUMem -> Bool) -> Stms GPUMem -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Exp GPUMem -> Bool
f' (Exp GPUMem -> Bool)
-> (Stm GPUMem -> Exp GPUMem) -> Stm GPUMem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp) (Stms GPUMem -> Bool)
-> (Body GPUMem -> Stms GPUMem) -> Body GPUMem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms
  where
    f' :: Exp GPUMem -> Bool
f' Exp GPUMem
e
      | Exp GPUMem -> Bool
f Exp GPUMem
e = Bool
True
      | Bool
otherwise = Maybe () -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe () -> Bool) -> Maybe () -> Bool
forall a b. (a -> b) -> a -> b
$ Walker GPUMem Maybe -> Exp GPUMem -> Maybe ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM Walker GPUMem Maybe
walker Exp GPUMem
e
    walker :: Walker GPUMem Maybe
walker =
      Walker GPUMem Maybe
forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker
        { walkOnBody :: Scope GPUMem -> Body GPUMem -> Maybe ()
walkOnBody = (Body GPUMem -> Maybe ())
-> Scope GPUMem -> Body GPUMem -> Maybe ()
forall a b. a -> b -> a
const ((Body GPUMem -> Maybe ())
 -> Scope GPUMem -> Body GPUMem -> Maybe ())
-> (Body GPUMem -> Maybe ())
-> Scope GPUMem
-> Body GPUMem
-> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ())
-> (Body GPUMem -> Bool) -> Body GPUMem -> Maybe ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Bool
not (Bool -> Bool) -> (Body GPUMem -> Bool) -> Body GPUMem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
f
        }

canBeSinglePass :: [SegBinOp GPUMem] -> KernelBody GPUMem -> Maybe (SegBinOp GPUMem)
canBeSinglePass :: [SegBinOp GPUMem] -> KernelBody GPUMem -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
ops KernelBody GPUMem
kbody
  | (SegBinOp GPUMem -> Bool) -> [SegBinOp GPUMem] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SegBinOp GPUMem -> Bool
ok [SegBinOp GPUMem]
ops,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
forall {rep}. Exp rep -> Bool
freshArray (BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) []) =
      SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a. a -> Maybe a
Just (SegBinOp GPUMem -> Maybe (SegBinOp GPUMem))
-> SegBinOp GPUMem -> Maybe (SegBinOp GPUMem)
forall a b. (a -> b) -> a -> b
$ [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans [SegBinOp GPUMem]
ops
  | Bool
otherwise =
      Maybe (SegBinOp GPUMem)
forall a. Maybe a
Nothing
  where
    ok :: SegBinOp GPUMem -> Bool
ok SegBinOp GPUMem
op =
      SegBinOp GPUMem -> Shape
forall rep. SegBinOp rep -> Shape
segBinOpShape SegBinOp GPUMem
op Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty
        Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Lambda GPUMem -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op))
        Bool -> Bool -> Bool
&& Bool -> Bool
not ((Exp GPUMem -> Bool) -> Body GPUMem -> Bool
bodyHas Exp GPUMem -> Bool
forall {rep}. Exp rep -> Bool
isAssert (Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)))
    isAssert :: Exp rep -> Bool
isAssert (BasicOp Assert {}) = Bool
True
    isAssert Exp rep
_ = Bool
False
    -- XXX: Currently single pass scans cannot handle construction of
    -- arrays in the kernel body (#2013), because of insufficient
    -- memory expansion.  This can in principle be fixed.
    freshArray :: Exp rep -> Bool
freshArray (BasicOp Manifest {}) = Bool
True
    freshArray (BasicOp Iota {}) = Bool
True
    freshArray (BasicOp Replicate {}) = Bool
True
    freshArray (BasicOp Scratch {}) = Bool
True
    freshArray (BasicOp Concat {}) = Bool
True
    freshArray (BasicOp ArrayLit {}) = Bool
True
    freshArray Exp rep
_ = Bool
False

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
compileSegScan ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  [SegBinOp GPUMem] ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegScan :: Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody = TExp Bool -> CallKernelGen () -> CallKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegScan" Maybe Exp
forall a. Maybe a
Nothing
  Target
target <- HostEnv -> Target
hostTarget (HostEnv -> Target)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp Target
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
  case Target
target of
    Target
CUDA
      | Just SegBinOp GPUMem
scan' <- [SegBinOp GPUMem] -> KernelBody GPUMem -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody ->
          Pat LParamMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan' KernelBody GPUMem
kbody
    Target
HIP
      | Just SegBinOp GPUMem
scan' <- [SegBinOp GPUMem] -> KernelBody GPUMem -> Maybe (SegBinOp GPUMem)
canBeSinglePass [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody ->
          Pat LParamMem
-> SegLevel
-> SegSpace
-> SegBinOp GPUMem
-> KernelBody GPUMem
-> CallKernelGen ()
SinglePass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space SegBinOp GPUMem
scan' KernelBody GPUMem
kbody
    Target
_ -> Pat LParamMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
TwoPass.compileSegScan Pat LParamMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing
  where
    n :: TPrimExp Int64 VName
n = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space