{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.BlockedKernel
  ( DistRep,
    MkSegLevel,
    ThreadRecommendation (..),
    segRed,
    nonSegRed,
    segScan,
    segHist,
    segMap,
    mapKernel,
    KernelInput (..),
    readKernelInput,
    mkSegSpace,
    dummyDim,
  )
where

import Control.Monad
import Control.Monad.Writer
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)

-- | Constraints pertinent to performing distribution/flattening.
type DistRep rep =
  ( Buildable rep,
    HasSegOp rep,
    BuilderOps rep,
    LetDec rep ~ Type,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    CanBeAliased (Op rep)
  )

data ThreadRecommendation = ManyThreads | NoRecommendation SegVirt

type MkSegLevel rep m =
  [SubExp] -> String -> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep)

mkSegSpace :: MonadFreshNames m => [(VName, SubExp)] -> m SegSpace
mkSegSpace :: [(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
dims = VName -> [(VName, SubExp)] -> SegSpace
SegSpace (VName -> [(VName, SubExp)] -> SegSpace)
-> m VName -> m ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid" m ([(VName, SubExp)] -> SegSpace)
-> m [(VName, SubExp)] -> m SegSpace
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(VName, SubExp)] -> m [(VName, SubExp)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, SubExp)]
dims

prepareRedOrScan ::
  (MonadBuilder m, DistRep (Rep m)) =>
  Certs ->
  SubExp ->
  Lambda (Rep m) ->
  [VName] ->
  [(VName, SubExp)] ->
  [KernelInput] ->
  m (SegSpace, KernelBody (Rep m))
prepareRedOrScan :: Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda (Rep m)
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = do
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
w)]
  KernelBody (Rep m)
kbody <- (([KernelResult], Stms (Rep m)) -> KernelBody (Rep m))
-> m ([KernelResult], Stms (Rep m)) -> m (KernelBody (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms (Rep m) -> KernelBody (Rep m))
-> ([KernelResult], Stms (Rep m)) -> KernelBody (Rep m)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms (Rep m) -> [KernelResult] -> KernelBody (Rep m))
-> [KernelResult] -> Stms (Rep m) -> KernelBody (Rep m)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec (Rep m)
-> Stms (Rep m) -> [KernelResult] -> KernelBody (Rep m)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ()))) (m ([KernelResult], Stms (Rep m)) -> m (KernelBody (Rep m)))
-> m ([KernelResult], Stms (Rep m)) -> m (KernelBody (Rep m))
forall a b. (a -> b) -> a -> b
$
    Builder (Rep m) [KernelResult] -> m ([KernelResult], Stms (Rep m))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder (Rep m) [KernelResult]
 -> m ([KernelResult], Stms (Rep m)))
-> Builder (Rep m) [KernelResult]
-> m ([KernelResult], Stms (Rep m))
forall a b. (a -> b) -> a -> b
$
      Scope (Rep m)
-> Builder (Rep m) [KernelResult] -> Builder (Rep m) [KernelResult]
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope (Rep m)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (Builder (Rep m) [KernelResult] -> Builder (Rep m) [KernelResult])
-> Builder (Rep m) [KernelResult] -> Builder (Rep m) [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
        (KernelInput -> BuilderT (Rep m) (State VNameSource) ())
-> [KernelInput] -> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
        Certs
-> BuilderT (Rep m) (State VNameSource) ()
-> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT (Rep m) (State VNameSource) ()
 -> BuilderT (Rep m) (State VNameSource) ())
-> ([KernelInput] -> BuilderT (Rep m) (State VNameSource) ())
-> [KernelInput]
-> BuilderT (Rep m) (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (KernelInput -> BuilderT (Rep m) (State VNameSource) ())
-> [KernelInput] -> BuilderT (Rep m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BuilderT (Rep m) (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput ([KernelInput] -> BuilderT (Rep m) (State VNameSource) ())
-> [KernelInput] -> BuilderT (Rep m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ do
          (Param Type
p, VName
arr) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Rep m) -> [LParam (Rep m)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
map_lam) [VName]
arrs
          KernelInput -> [KernelInput]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
arr [VName -> SubExp
Var VName
gtid]
        Result
res <- Body (Rep (BuilderT (Rep m) (State VNameSource)))
-> BuilderT (Rep m) (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Lambda (Rep m) -> Body (Rep m)
forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
map_lam)
        Result
-> (SubExpRes -> BuilderT (Rep m) (State VNameSource) KernelResult)
-> Builder (Rep m) [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res ((SubExpRes -> BuilderT (Rep m) (State VNameSource) KernelResult)
 -> Builder (Rep m) [KernelResult])
-> (SubExpRes -> BuilderT (Rep m) (State VNameSource) KernelResult)
-> Builder (Rep m) [KernelResult]
forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
res_cs SubExp
se) -> KernelResult -> BuilderT (Rep m) (State VNameSource) KernelResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> BuilderT (Rep m) (State VNameSource) KernelResult)
-> KernelResult
-> BuilderT (Rep m) (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
res_cs SubExp
se

  (SegSpace, KernelBody (Rep m)) -> m (SegSpace, KernelBody (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space, KernelBody (Rep m)
kbody)

segRed ::
  (MonadFreshNames m, DistRep rep, HasScope rep m) =>
  SegOpLevel rep ->
  Pat (LetDec rep) ->
  Certs ->
  SubExp -> -- segment size
  [SegBinOp rep] ->
  Lambda rep ->
  [VName] ->
  [(VName, SubExp)] -> -- ispace = pair of (gtid, size) for the maps on "top" of this reduction
  [KernelInput] -> -- inps = inputs that can be looked up by using the gtids from ispace
  m (Stms rep)
segRed :: SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat (LetDec rep)
pat Certs
cs SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = Builder rep () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> m (Stms rep)) -> Builder rep () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
kspace, KernelBody rep
kbody) <- Certs
-> SubExp
-> Lambda (Rep (BuilderT rep (State VNameSource)))
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT
     rep
     (State VNameSource)
     (SegSpace, KernelBody (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda rep
Lambda (Rep (BuilderT rep (State VNameSource)))
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
  Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
pat (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$
    Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$
      SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$
        SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
kspace [SegBinOp rep]
ops (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody

segScan ::
  (MonadFreshNames m, DistRep rep, HasScope rep m) =>
  SegOpLevel rep ->
  Pat (LetDec rep) ->
  Certs ->
  SubExp -> -- segment size
  [SegBinOp rep] ->
  Lambda rep ->
  [VName] ->
  [(VName, SubExp)] -> -- ispace = pair of (gtid, size) for the maps on "top" of this scan
  [KernelInput] -> -- inps = inputs that can be looked up by using the gtids from ispace
  m (Stms rep)
segScan :: SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel rep
lvl Pat (LetDec rep)
pat Certs
cs SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = Builder rep () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> m (Stms rep)) -> Builder rep () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
kspace, KernelBody rep
kbody) <- Certs
-> SubExp
-> Lambda (Rep (BuilderT rep (State VNameSource)))
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT
     rep
     (State VNameSource)
     (SegSpace, KernelBody (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda rep
Lambda (Rep (BuilderT rep (State VNameSource)))
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
  Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
pat (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$
    Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$
      SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$
        SegOpLevel rep
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
kspace [SegBinOp rep]
ops (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody

segMap ::
  (MonadFreshNames m, DistRep rep, HasScope rep m) =>
  SegOpLevel rep ->
  Pat (LetDec rep) ->
  SubExp -> -- segment size
  Lambda rep ->
  [VName] ->
  [(VName, SubExp)] -> -- ispace = pair of (gtid, size) for the maps on "top" of this map
  [KernelInput] -> -- inps = inputs that can be looked up by using the gtids from ispace
  m (Stms rep)
segMap :: SegOpLevel rep
-> Pat (LetDec rep)
-> SubExp
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segMap SegOpLevel rep
lvl Pat (LetDec rep)
pat SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = Builder rep () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> m (Stms rep)) -> Builder rep () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
kspace, KernelBody rep
kbody) <- Certs
-> SubExp
-> Lambda (Rep (BuilderT rep (State VNameSource)))
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT
     rep
     (State VNameSource)
     (SegSpace, KernelBody (Rep (BuilderT rep (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
forall a. Monoid a => a
mempty SubExp
w Lambda rep
Lambda (Rep (BuilderT rep (State VNameSource)))
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
  Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
pat (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$
    Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$
      SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$
        SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
kspace (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody

dummyDim ::
  (MonadFreshNames m, MonadBuilder m) =>
  Pat Type ->
  m (Pat Type, [(VName, SubExp)], m ())
dummyDim :: Pat Type -> m (Pat Type, [(VName, SubExp)], m ())
dummyDim Pat Type
pat = do
  -- We add a unit-size segment on top to ensure that the result
  -- of the SegRed is an array, which we then immediately index.
  -- This is useful in the case that the value is used on the
  -- device afterwards, as this may save an expensive
  -- host-device copy (scalars are kept on the host, but arrays
  -- may be on the device).
  let addDummyDim :: Type -> Type
addDummyDim Type
t = Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
  Pat Type
pat' <- (Type -> Type) -> Pat Type -> Pat Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
addDummyDim (Pat Type -> Pat Type) -> m (Pat Type) -> m (Pat Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> m (Pat Type)
forall dec (m :: * -> *).
(Rename dec, MonadFreshNames m) =>
Pat dec -> m (Pat dec)
renamePat Pat Type
pat
  VName
dummy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"dummy"
  let ispace :: [(VName, SubExp)]
ispace = [(VName
dummy, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]

  (Pat Type, [(VName, SubExp)], m ())
-> m (Pat Type, [(VName, SubExp)], m ())
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Pat Type
pat',
      [(VName, SubExp)]
ispace,
      [(VName, VName)] -> ((VName, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat') (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat)) (((VName, VName) -> m ()) -> m ())
-> ((VName, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) -> do
        Type
from_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
from
        [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
to] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
from (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
from_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
    )

nonSegRed ::
  (MonadFreshNames m, DistRep rep, HasScope rep m) =>
  SegOpLevel rep ->
  Pat Type ->
  SubExp ->
  [SegBinOp rep] ->
  Lambda rep ->
  [VName] ->
  m (Stms rep)
nonSegRed :: SegOpLevel rep
-> Pat Type
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel rep
lvl Pat Type
pat SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs = Builder rep () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> m (Stms rep)) -> Builder rep () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
  (Pat Type
pat', [(VName, SubExp)]
ispace, Builder rep ()
read_dummy) <- Pat Type
-> BuilderT
     rep
     (State VNameSource)
     (Pat Type, [(VName, SubExp)], Builder rep ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBuilder m) =>
Pat Type -> m (Pat Type, [(VName, SubExp)], m ())
dummyDim Pat Type
pat
  Stms rep -> Builder rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms rep -> Builder rep ())
-> BuilderT rep (State VNameSource) (Stms rep) -> Builder rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT rep (State VNameSource) (Stms rep)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat Type
Pat (LetDec rep)
pat' Certs
forall a. Monoid a => a
mempty SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace []
  Builder rep ()
read_dummy

segHist ::
  (DistRep rep, MonadFreshNames m, HasScope rep m) =>
  SegOpLevel rep ->
  Pat Type ->
  SubExp ->
  -- | Segment indexes and sizes.
  [(VName, SubExp)] ->
  [KernelInput] ->
  [HistOp rep] ->
  Lambda rep ->
  [VName] ->
  m (Stms rep)
segHist :: SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel rep
lvl Pat Type
pat SubExp
arr_w [(VName, SubExp)]
ispace [KernelInput]
inps [HistOp rep]
ops Lambda rep
lam [VName]
arrs = Builder rep () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> m (Stms rep)) -> Builder rep () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
  VName
gtid <- String -> BuilderT rep (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  SegSpace
space <- [(VName, SubExp)] -> BuilderT rep (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> BuilderT rep (State VNameSource) SegSpace)
-> [(VName, SubExp)] -> BuilderT rep (State VNameSource) SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
arr_w)]

  KernelBody rep
kbody <- (([KernelResult], Stms rep) -> KernelBody rep)
-> BuilderT rep (State VNameSource) ([KernelResult], Stms rep)
-> BuilderT rep (State VNameSource) (KernelBody rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms rep -> KernelBody rep)
-> ([KernelResult], Stms rep) -> KernelBody rep
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms rep -> [KernelResult] -> KernelBody rep)
-> [KernelResult] -> Stms rep -> KernelBody rep
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Stms rep -> [KernelResult] -> KernelBody rep)
 -> [KernelResult] -> Stms rep -> KernelBody rep)
-> (Stms rep -> [KernelResult] -> KernelBody rep)
-> [KernelResult]
-> Stms rep
-> KernelBody rep
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ())) (BuilderT rep (State VNameSource) ([KernelResult], Stms rep)
 -> BuilderT rep (State VNameSource) (KernelBody rep))
-> BuilderT rep (State VNameSource) ([KernelResult], Stms rep)
-> BuilderT rep (State VNameSource) (KernelBody rep)
forall a b. (a -> b) -> a -> b
$
    Builder rep [KernelResult]
-> BuilderT rep (State VNameSource) ([KernelResult], Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep [KernelResult]
 -> BuilderT rep (State VNameSource) ([KernelResult], Stms rep))
-> Builder rep [KernelResult]
-> BuilderT rep (State VNameSource) ([KernelResult], Stms rep)
forall a b. (a -> b) -> a -> b
$
      Scope rep
-> Builder rep [KernelResult] -> Builder rep [KernelResult]
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (Builder rep [KernelResult] -> Builder rep [KernelResult])
-> Builder rep [KernelResult] -> Builder rep [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
        (KernelInput -> Builder rep ()) -> [KernelInput] -> Builder rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> Builder rep ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
        [(Param Type, VName)]
-> ((Param Type, VName) -> Builder rep ()) -> Builder rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [VName]
arrs) (((Param Type, VName) -> Builder rep ()) -> Builder rep ())
-> ((Param Type, VName) -> Builder rep ()) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
          Type
arr_t <- VName -> BuilderT rep (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
          [VName]
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid]
        Result
res <- Body (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
        Result
-> (SubExpRes -> BuilderT rep (State VNameSource) KernelResult)
-> Builder rep [KernelResult]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res ((SubExpRes -> BuilderT rep (State VNameSource) KernelResult)
 -> Builder rep [KernelResult])
-> (SubExpRes -> BuilderT rep (State VNameSource) KernelResult)
-> Builder rep [KernelResult]
forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
cs SubExp
se) ->
          KernelResult -> BuilderT rep (State VNameSource) KernelResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> BuilderT rep (State VNameSource) KernelResult)
-> KernelResult -> BuilderT rep (State VNameSource) KernelResult
forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se

  Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep (BuilderT rep (State VNameSource))))
pat (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel rep) rep -> Op rep
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp (SegOp (SegOpLevel rep) rep -> Op rep)
-> SegOp (SegOpLevel rep) rep -> Op rep
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) KernelBody rep
kbody

mapKernelSkeleton ::
  (DistRep rep, HasScope rep m, MonadFreshNames m) =>
  [(VName, SubExp)] ->
  [KernelInput] ->
  m (SegSpace, Stms rep)
mapKernelSkeleton :: [(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms rep)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs = do
  Stms rep
read_input_stms <- Builder rep [()] -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep [()] -> m (Stms rep))
-> Builder rep [()] -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BuilderT rep (State VNameSource) ())
-> [KernelInput] -> Builder rep [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BuilderT rep (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs

  SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
  (SegSpace, Stms rep) -> m (SegSpace, Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space, Stms rep
read_input_stms)

mapKernel ::
  (DistRep rep, HasScope rep m, MonadFreshNames m) =>
  MkSegLevel rep m ->
  [(VName, SubExp)] ->
  [KernelInput] ->
  [Type] ->
  KernelBody rep ->
  m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel :: MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel MkSegLevel rep m
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
inputs [Type]
rts (KernelBody () Stms rep
kstms [KernelResult]
krets) = BuilderT rep m (SegOp (SegOpLevel rep) rep)
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT rep m (SegOp (SegOpLevel rep) rep)
 -> m (SegOp (SegOpLevel rep) rep, Stms rep))
-> BuilderT rep m (SegOp (SegOpLevel rep) rep)
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ do
  (SegSpace
space, Stms rep
read_input_stms) <- [(VName, SubExp)]
-> [KernelInput] -> BuilderT rep m (SegSpace, Stms rep)
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms rep)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs

  let kbody' :: KernelBody rep
kbody' = BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () (Stms rep
read_input_stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
kstms) [KernelResult]
krets

  -- If the kernel creates arrays (meaning it will require memory
  -- expansion), we want to truncate the amount of threads.
  -- Otherwise, have at it!  This is a bit of a hack - in principle,
  -- we should make this decision later, when we have a clearer idea
  -- of what is happening inside the kernel.
  let r :: ThreadRecommendation
r = if (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType [Type]
rts then ThreadRecommendation
ManyThreads else SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegVirt

  SegOpLevel rep
lvl <- MkSegLevel rep m
mk_lvl (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segmap" ThreadRecommendation
r

  SegOp (SegOpLevel rep) rep
-> BuilderT rep m (SegOp (SegOpLevel rep) rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp (SegOpLevel rep) rep
 -> BuilderT rep m (SegOp (SegOpLevel rep) rep))
-> SegOp (SegOpLevel rep) rep
-> BuilderT rep m (SegOp (SegOpLevel rep) rep)
forall a b. (a -> b) -> a -> b
$ SegOpLevel rep
-> SegSpace
-> [Type]
-> KernelBody rep
-> SegOp (SegOpLevel rep) rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
rts KernelBody rep
kbody'

data KernelInput = KernelInput
  { KernelInput -> VName
kernelInputName :: VName,
    KernelInput -> Type
kernelInputType :: Type,
    KernelInput -> VName
kernelInputArray :: VName,
    KernelInput -> [SubExp]
kernelInputIndices :: [SubExp]
  }
  deriving (Int -> KernelInput -> ShowS
[KernelInput] -> ShowS
KernelInput -> String
(Int -> KernelInput -> ShowS)
-> (KernelInput -> String)
-> ([KernelInput] -> ShowS)
-> Show KernelInput
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelInput] -> ShowS
$cshowList :: [KernelInput] -> ShowS
show :: KernelInput -> String
$cshow :: KernelInput -> String
showsPrec :: Int -> KernelInput -> ShowS
$cshowsPrec :: Int -> KernelInput -> ShowS
Show)

readKernelInput ::
  (DistRep (Rep m), MonadBuilder m) =>
  KernelInput ->
  m ()
readKernelInput :: KernelInput -> m ()
readKernelInput KernelInput
inp = do
  let pe :: PatElem Type
pe = VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem (KernelInput -> VName
kernelInputName KernelInput
inp) (Type -> PatElem Type) -> Type -> PatElem Type
forall a b. (a -> b) -> a -> b
$ KernelInput -> Type
kernelInputType KernelInput
inp
  Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (Exp (Rep m) -> m ())
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m ()) -> BasicOp -> m ()
forall a b. (a -> b) -> a -> b
$
    case KernelInput -> Type
kernelInputType KernelInput
inp of
      Acc {} ->
        SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
      Type
_ ->
        VName -> Slice SubExp -> BasicOp
Index (KernelInput -> VName
kernelInputArray KernelInput
inp) (Slice SubExp -> BasicOp)
-> ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp]
-> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$
          (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp)
            [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (KernelInput -> Type
kernelInputType KernelInput
inp))