{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Tries to turn a generalized reduction kernel into
--     a more specialized construct, for example:
--       (a) a map nest with a sequential redomap ripe for tiling
--       (b) a SegRed kernel followed by a smallish accumulation kernel.
--       (c) a histogram (for this we need to track the withAccs)
--   The idea is to identify the first accumulation and
--     to separate the initial kernels into two:
--     1. the code up to and including the accumulation,
--        which is optimized to turn the accumulation either
--        into a map-reduce composition or a histogram, and
--     2. the remaining code, which is recursively optimized.
--   Since this is mostly prototyping, when the accumulation
--     can be rewritten as a map-reduce, we sequentialize the
--     map-reduce, as to potentially enable tiling oportunities.
module Futhark.Optimise.GenRedOpt (optimiseGenRed) where

import Control.Monad.Reader
import Control.Monad.State
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Builder
import Futhark.IR.GPU
import Futhark.Optimise.TileLoops.Shared
import Futhark.Pass
import Futhark.Tools
import Futhark.Transform.Rename

type GenRedM = ReaderT (Scope GPU) (State VNameSource)

-- | The pass definition.
optimiseGenRed :: Pass GPU GPU
optimiseGenRed :: Pass GPU GPU
optimiseGenRed =
  String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"optimise generalized reductions" String
"Specializes generalized reductions into map-reductions or histograms" ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$
    (Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall (m :: * -> *).
MonadFreshNames m =>
Scope GPU -> Stms GPU -> m (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> m (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms =
      (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU))
-> (VNameSource -> (Stms GPU, VNameSource)) -> m (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        State VNameSource (Stms GPU)
-> VNameSource -> (Stms GPU, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms GPU)
 -> VNameSource -> (Stms GPU, VNameSource))
-> State VNameSource (Stms GPU)
-> VNameSource
-> (Stms GPU, VNameSource)
forall a b. (a -> b) -> a -> b
$
          ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms (Map VName (Lambda GPU, [SubExp])
forall k a. Map k a
M.empty, Map VName IxFun
forall k a. Map k a
M.empty) Stms GPU
stms) Scope GPU
scope

optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env (Body () Stms GPU
stms Result
res) =
  BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU -> Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms ReaderT (Scope GPU) (State VNameSource) (Result -> Body GPU)
-> ReaderT (Scope GPU) (State VNameSource) Result
-> GenRedM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope GPU) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

optimiseStms :: Env -> Stms GPU -> GenRedM (Stms GPU)
optimiseStms :: Env
-> Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
optimiseStms Env
env Stms GPU
stms =
  Scope GPU
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms) (ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
 -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU))
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    (Env
_, Stms GPU
stms') <- ((Env, Stms GPU)
 -> Stm GPU
 -> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> (Env, Stms GPU)
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
env, Stms GPU
forall a. Monoid a => a
mempty) ([Stm GPU]
 -> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU))
-> [Stm GPU]
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
    Stms GPU -> ReaderT (Scope GPU) (State VNameSource) (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms'
  where
    foldfun :: (Env, Stms GPU) -> Stm GPU -> GenRedM (Env, Stms GPU)
    foldfun :: (Env, Stms GPU)
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
foldfun (Env
e, Stms GPU
ss) Stm GPU
s = do
      (Env
e', Stms GPU
s') <- Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
e Stm GPU
s
      (Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
e', Stms GPU
ss Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s')

optimiseStm :: Env -> Stm GPU -> GenRedM (Env, Stms GPU)
optimiseStm :: Env
-> Stm GPU
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
optimiseStm Env
env stm :: Stm GPU
stm@(Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op (SegOp (SegMap SegThread {} _ _ _)))) = do
  Maybe (Stms GPU)
res_genred_opt <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
stm
  let stms' :: Stms GPU
stms' =
        case Maybe (Stms GPU)
res_genred_opt of
          Just Stms GPU
stms -> Stms GPU
stms
          Maybe (Stms GPU)
Nothing -> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stm
  (Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env, Stms GPU
stms')
optimiseStm Env
env (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Env
env' <- Env -> VName -> Exp GPU -> TileM Env
changeEnv Env
env ([VName] -> VName
forall a. [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat) Exp GPU
e
  Exp GPU
e' <- Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env') Exp GPU
e
  (Env, Stms GPU)
-> ReaderT (Scope GPU) (State VNameSource) (Env, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Env
env', Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> Stm GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e')
  where
    optimise :: Env -> Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
optimise Env
env' = Mapper GPU GPU (ReaderT (Scope GPU) (State VNameSource))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper {mapOnBody :: Scope GPU -> Body GPU -> GenRedM (Body GPU)
mapOnBody = \Scope GPU
scope -> Scope GPU -> GenRedM (Body GPU) -> GenRedM (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (GenRedM (Body GPU) -> GenRedM (Body GPU))
-> (Body GPU -> GenRedM (Body GPU))
-> Body GPU
-> GenRedM (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Body GPU -> GenRedM (Body GPU)
optimiseBody Env
env'}

------------------------

genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker = do
  Maybe (Stms GPU, Stm GPU)
res_tile <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env Stm GPU
ker
  case Maybe (Stms GPU, Stm GPU)
res_tile of
    Maybe (Stms GPU, Stm GPU)
Nothing -> do
      Maybe (Stms GPU, Stm GPU)
res_sgrd <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
env Stm GPU
ker
      Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_sgrd
    Maybe (Stms GPU, Stm GPU)
_ -> Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
res_tile
  where
    helperGenRed :: Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU))
helperGenRed Maybe (Stms GPU, Stm GPU)
Nothing = Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU)
forall a. Maybe a
Nothing
    helperGenRed (Just (Stms GPU
stms_before, Stm GPU
ker_snd)) = do
      Maybe (Stms GPU)
mb_stms_after <- Env -> Stm GPU -> GenRedM (Maybe (Stms GPU))
genRedOpts Env
env Stm GPU
ker_snd
      case Maybe (Stms GPU)
mb_stms_after of
        Just Stms GPU
stms_after -> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU)))
-> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Maybe (Stms GPU)
forall a. a -> Maybe a
Just (Stms GPU -> Maybe (Stms GPU)) -> Stms GPU -> Maybe (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms_after
        Maybe (Stms GPU)
Nothing -> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU)))
-> Maybe (Stms GPU) -> GenRedM (Maybe (Stms GPU))
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Maybe (Stms GPU)
forall a. a -> Maybe a
Just (Stms GPU -> Maybe (Stms GPU)) -> Stms GPU -> Maybe (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU
stms_before Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
ker_snd

se1 :: SubExp
se1 :: SubExp
se1 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2Tile2d Env
env kerstm :: Stm GPU
kerstm@(Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux (Op (SegOp (SegMap seg_thd seg_space kres_tps old_kbody))))
  | (SegThread Count NumGroups SubExp
_ Count GroupSize SubExp
seg_group_size SegVirt
_novirt) <- SegLevel
seg_thd,
    -- novirt == SegNoVirtFull || novirt == SegNoVirt,
    KernelBody () Stms GPU
kstms [KernelResult]
kres <- KernelBody GPU
old_kbody,
    Just ([VName]
css, [SubExp]
r_ses) <- [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres,
    [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
css,
    -- build the variance table, that records, for
    -- each variable name, the variables it depends on
    Map VName Names
initial_variance <- (NameInfo Any -> Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo Any -> Names
forall a. Monoid a => a
mempty (Map VName (NameInfo Any) -> Map VName Names)
-> Map VName (NameInfo Any) -> Map VName Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
seg_space,
    Map VName Names
variance <- Map VName Names -> Stms GPU -> Map VName Names
varianceInStms Map VName Names
initial_variance Stms GPU
kstms,
    -- check that the code fits the pattern having:
    -- some `code1`, followed by one accumulation, followed by some `code2`
    -- UpdateAcc VName [SubExp] [SubExp]
    (Stms GPU
code1, Just Stm GPU
accum_stmt, Stms GPU
code2) <- Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms,
    Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
_aux_acc (BasicOp (UpdateAcc VName
acc_nm [SubExp]
acc_inds [SubExp]
acc_vals)) <- Stm GPU
accum_stmt,
    [VName
pat_acc_nm] <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat_accum,
    -- check that the `acc_inds` are invariant to at least one
    -- parallel kernel dimensions, and return the innermost such one:
    Just (VName
invar_gid, Int
gid_ind) <- Names
-> SegSpace -> Map VName Names -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim Names
forall a. Monoid a => a
mempty SegSpace
seg_space Map VName Names
variance [SubExp]
acc_inds,
    [(VName, SubExp)]
gid_dims_new_0 <- ((VName, SubExp) -> Bool) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(VName, SubExp)
x -> VName
invar_gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst (VName, SubExp)
x) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
seg_space),
    -- reorder the variant dimensions such that inner(most) accum-indices
    -- correspond to inner(most) parallel dimensions, so that the babysitter
    -- does not introduce transpositions
    -- gid_dims_new <- gid_dims_new_0,
    [(VName, SubExp)]
gid_dims_new <- Map VName Names
-> [SubExp] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall b.
Map VName Names -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims Map VName Names
variance [SubExp]
acc_inds [(VName, SubExp)]
gid_dims_new_0,
    -- check that all global-memory accesses in `code1` on which
    --   `accum_stmt` depends on are invariant to at least one of
    --   the remaining parallel dimensions (i.e., excluding `invar_gid`)
    (Stm GPU -> Bool) -> [Stm GPU] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName
-> [(VName, SubExp)] -> Map VName Names -> VName -> Stm GPU -> Bool
isTileable VName
invar_gid [(VName, SubExp)]
gid_dims_new Map VName Names
variance VName
pat_acc_nm) (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
code1),
    -- need to establish a cost model for the stms that would now
    --   be redundantly executed by the two kernels. If any recurence
    --   is redundant than it is a no go. Otherwise we need to look at
    --   memory accesses: if more than two are re-executed, then we
    --   should abort.
    Cost
cost <- Map VName Names -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution Map VName Names
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms,
    Cost -> Cost -> Cost
maxCost Cost
cost (Int -> Cost
Small Int
2) Cost -> Cost -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Cost
Small Int
2 = do
      -- 1. create the first kernel
      Type
acc_tp <- VName -> ReaderT (Scope GPU) (State VNameSource) Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc_nm
      let inv_dim_len :: SubExp
inv_dim_len = SegSpace -> [SubExp]
segSpaceDims SegSpace
seg_space [SubExp] -> Int -> SubExp
forall a. [a] -> Int -> a
!! Int
gid_ind
          -- 1.1. get the accumulation operator
          ((Lambda GPU
redop0, [SubExp]
neutral), [Type]
el_tps) = Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp
      Lambda GPU
redop <- Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
redop0
      let red :: Reduce GPU
red =
            Reduce :: forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce
              { redComm :: Commutativity
redComm = Commutativity
Commutative,
                redLambda :: Lambda GPU
redLambda = Lambda GPU
redop,
                redNeutral :: [SubExp]
redNeutral = [SubExp]
neutral
              }
          -- 1.2. build the sequential map-reduce screma
          code1' :: Stms GPU
code1' =
            [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU) -> [Stm GPU] -> Stms GPU
forall a b. (a -> b) -> a -> b
$
              (Stm GPU -> Bool) -> [Stm GPU] -> [Stm GPU]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Map VName Names -> Stm GPU -> Bool
forall k rep. Ord k => k -> Map k Names -> Stm rep -> Bool
dependsOnAcc VName
pat_acc_nm Map VName Names
variance) ([Stm GPU] -> [Stm GPU]) -> [Stm GPU] -> [Stm GPU]
forall a b. (a -> b) -> a -> b
$
                Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
code1
      (Stms GPU
code1'', Stms GPU
code1_tr_host) <- Names
-> Map VName Names
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs (Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
kerstm) Map VName Names
variance VName
invar_gid Stms GPU
code1'
      let map_lam_body :: Body GPU
map_lam_body = Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
code1'' (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExpRes) -> [SubExp] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes ([VName] -> Certs
Certs [])) [SubExp]
acc_vals
          map_lam0 :: Lambda GPU
map_lam0 = [LParam GPU] -> Body GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
invar_gid (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)] Body GPU
map_lam_body [Type]
el_tps
      Lambda GPU
map_lam <- Lambda GPU -> ReaderT (Scope GPU) (State VNameSource) (Lambda GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPU
map_lam0
      (SubExp
k1_res, Stms GPU
ker1_stms) <- BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
 -> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
        VName
iota <- String
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
inv_dim_len (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
        let op_exp :: Exp GPU
op_exp = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SOAC GPU -> HostOp GPU (SOAC GPU)
forall rep op. op -> HostOp rep op
OtherOp (SubExp -> [VName] -> ScremaForm GPU -> SOAC GPU
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
inv_dim_len [VName
iota] ([Scan GPU] -> [Reduce GPU] -> Lambda GPU -> ScremaForm GPU
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [Reduce GPU
red] Lambda GPU
map_lam)))
        [VName]
res_redmap <- String
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res_mapred" Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
Exp GPU
op_exp
        String
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
pat_acc_nm String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_big_update") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc_nm [SubExp]
acc_inds ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
res_redmap)

      -- 1.3. build the kernel expression and rename it!
      VName
gid_flat_1 <- String -> ReaderT (Scope GPU) (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gid_flat"
      let space1 :: SegSpace
space1 = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
gid_flat_1 [(VName, SubExp)]
gid_dims_new

      (SubExp
grid_size, Stms GPU
host_stms1) <- Builder GPU SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU SubExp
 -> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU))
-> Builder GPU SubExp
-> ReaderT (Scope GPU) (State VNameSource) (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
        let grid_pexp :: TPrimExp Int64 VName
grid_pexp = (TPrimExp Int64 VName -> SubExp -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> [SubExp] -> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\TPrimExp Int64 VName
x SubExp
d -> TPrimExp Int64 VName
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 SubExp
d) (SubExp -> TPrimExp Int64 VName
pe64 SubExp
se1) ([SubExp] -> TPrimExp Int64 VName)
-> [SubExp] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
gid_dims_new
        SubExp
dim_prod <- String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"dim_prod" (Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
grid_pexp
        String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"grid_size" (Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> SubExp -> BuilderT GPU (State VNameSource) (Exp GPU)
forall (f :: * -> *) rep.
Applicative f =>
SubExp -> SubExp -> f (Exp rep)
ceilDiv SubExp
dim_prod (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
seg_group_size)
      let level1 :: SegLevel
level1 = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
grid_size) Count GroupSize SubExp
seg_group_size (SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims [])) -- novirt ?
          kbody1 :: KernelBody GPU
kbody1 = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
ker1_stms [ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify ([VName] -> Certs
Certs []) SubExp
k1_res]

      -- is it OK here to use the "aux" from the parrent kernel?
      Exp GPU
ker_exp <- Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
level1 SegSpace
space1 [Type
acc_tp] KernelBody GPU
kbody1))
      let ker1 :: Stm GPU
ker1 = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_accum StmAux (ExpDec GPU)
aux Exp GPU
ker_exp

      -- 2 build the second kernel
      let ker2_body :: KernelBody GPU
ker2_body = KernelBody GPU
old_kbody {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
code1 Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
code2}
      Exp GPU
ker2_exp <- Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp (Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU))
-> Exp GPU -> ReaderT (Scope GPU) (State VNameSource) (Exp GPU)
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
seg_thd SegSpace
seg_space [Type]
kres_tps KernelBody GPU
ker2_body))
      let ker2 :: Stm GPU
ker2 = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat_ker StmAux (ExpDec GPU)
aux Exp GPU
ker2_exp
      Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU)))
-> Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall a b. (a -> b) -> a -> b
$
        (Stms GPU, Stm GPU) -> Maybe (Stms GPU, Stm GPU)
forall a. a -> Maybe a
Just (Stms GPU
code1_tr_host Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
host_stms1 Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
ker1, Stm GPU
ker2)
  where
    isIndVarToParDim :: Map VName Names -> SubExp -> (VName, b) -> Bool
isIndVarToParDim Map VName Names
_ (Constant PrimValue
_) (VName, b)
_ = Bool
False
    isIndVarToParDim Map VName Names
variance (Var VName
acc_ind) (VName, b)
par_dim =
      VName
acc_ind VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== (VName, b) -> VName
forall a b. (a, b) -> a
fst (VName, b)
par_dim
        Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn ((VName, b) -> VName
forall a b. (a, b) -> a
fst (VName, b)
par_dim) (Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
acc_ind Map VName Names
variance)
    foldfunReorder :: Map VName Names
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder Map VName Names
variance ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims) SubExp
acc_ind =
      case ((VName, b) -> Bool) -> [(VName, b)] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (Map VName Names -> SubExp -> (VName, b) -> Bool
forall b. Map VName Names -> SubExp -> (VName, b) -> Bool
isIndVarToParDim Map VName Names
variance SubExp
acc_ind) [(VName, b)]
unused_dims of
        Maybe Int
Nothing -> ([(VName, b)]
unused_dims, [(VName, b)]
inner_dims)
        Just Int
i ->
          ( Int -> [(VName, b)] -> [(VName, b)]
forall a. Int -> [a] -> [a]
take Int
i [(VName, b)]
unused_dims [(VName, b)] -> [(VName, b)] -> [(VName, b)]
forall a. [a] -> [a] -> [a]
++ Int -> [(VName, b)] -> [(VName, b)]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [(VName, b)]
unused_dims,
            ([(VName, b)]
unused_dims [(VName, b)] -> Int -> (VName, b)
forall a. [a] -> Int -> a
!! Int
i) (VName, b) -> [(VName, b)] -> [(VName, b)]
forall a. a -> [a] -> [a]
: [(VName, b)]
inner_dims
          )
    reorderParDims :: Map VName Names -> [SubExp] -> [(VName, b)] -> [(VName, b)]
reorderParDims Map VName Names
variance [SubExp]
acc_inds [(VName, b)]
gid_dims_new_0 =
      let ([(VName, b)]
invar_dims, [(VName, b)]
inner_dims) =
            (([(VName, b)], [(VName, b)])
 -> SubExp -> ([(VName, b)], [(VName, b)]))
-> ([(VName, b)], [(VName, b)])
-> [SubExp]
-> ([(VName, b)], [(VName, b)])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              (Map VName Names
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
forall b.
Map VName Names
-> ([(VName, b)], [(VName, b)])
-> SubExp
-> ([(VName, b)], [(VName, b)])
foldfunReorder Map VName Names
variance)
              ([(VName, b)]
gid_dims_new_0, [])
              ([SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse [SubExp]
acc_inds)
       in [(VName, b)]
invar_dims [(VName, b)] -> [(VName, b)] -> [(VName, b)]
forall a. [a] -> [a] -> [a]
++ [(VName, b)]
inner_dims
    --
    ceilDiv :: SubExp -> SubExp -> f (Exp rep)
ceilDiv SubExp
x SubExp
y = Exp rep -> f (Exp rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> f (Exp rep)) -> Exp rep -> f (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) SubExp
x SubExp
y
    getAccLambda :: Type -> ((Lambda GPU, [SubExp]), [Type])
getAccLambda Type
acc_tp =
      case Type
acc_tp of
        (Acc VName
tp_id Shape
_shp [Type]
el_tps NoUniqueness
_) ->
          case VName
-> Map VName (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
tp_id (Env -> Map VName (Lambda GPU, [SubExp])
forall a b. (a, b) -> a
fst Env
env) of
            Just (Lambda GPU, [SubExp])
lam -> ((Lambda GPU, [SubExp])
lam, [Type]
el_tps)
            Maybe (Lambda GPU, [SubExp])
_ -> String -> ((Lambda GPU, [SubExp]), [Type])
forall a. HasCallStack => String -> a
error (String -> ((Lambda GPU, [SubExp]), [Type]))
-> String -> ((Lambda GPU, [SubExp]), [Type])
forall a b. (a -> b) -> a -> b
$ String
"Lookup in environment failed! " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
tp_id String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" env: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Map VName (Lambda GPU, [SubExp]) -> String
forall a. Pretty a => a -> String
pretty (Env -> Map VName (Lambda GPU, [SubExp])
forall a b. (a, b) -> a
fst Env
env)
        Type
_ -> String -> ((Lambda GPU, [SubExp]), [Type])
forall a. HasCallStack => String -> a
error String
"Illegal accumulator type!"
    -- is a subexp invariant to a gid of a parallel dimension?
    isSeInvar2 :: Map VName Names -> VName -> SubExp -> Bool
isSeInvar2 Map VName Names
variance VName
gid (Var VName
x) =
      let x_deps :: Names
x_deps = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
x Map VName Names
variance
       in VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Names -> Bool
nameIn VName
gid Names
x_deps)
    isSeInvar2 Map VName Names
_ VName
_ SubExp
_ = Bool
True
    -- is a DimIndex invar to a gid of a parallel dimension?
    isDimIdxInvar2 :: Map VName Names -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 Map VName Names
variance VName
gid (DimFix SubExp
d) =
      Map VName Names -> VName -> SubExp -> Bool
isSeInvar2 Map VName Names
variance VName
gid SubExp
d
    isDimIdxInvar2 Map VName Names
variance VName
gid (DimSlice SubExp
d1 SubExp
d2 SubExp
d3) =
      (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Map VName Names -> VName -> SubExp -> Bool
isSeInvar2 Map VName Names
variance VName
gid) [SubExp
d1, SubExp
d2, SubExp
d3]
    -- is an entire slice invariant to at least one gid of a parallel dimension
    isSliceInvar2 :: Map VName Names -> Slice SubExp -> t VName -> Bool
isSliceInvar2 Map VName Names
variance Slice SubExp
slc =
      (VName -> Bool) -> t VName -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\VName
gid -> (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Map VName Names -> VName -> DimIndex SubExp -> Bool
isDimIdxInvar2 Map VName Names
variance VName
gid) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc))
    -- are all statements that touch memory invariant to at least one parallel dimension?
    isTileable :: VName -> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool
    isTileable :: VName
-> [(VName, SubExp)] -> Map VName Names -> VName -> Stm GPU -> Bool
isTileable VName
seq_gid [(VName, SubExp)]
gid_dims Map VName Names
variance VName
acc_nm (Let (Pat [PatElem (LetDec GPU)
pel]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc)))
      | Names
acc_deps <- Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
acc_nm Map VName Names
variance,
        VName -> Names -> Bool
nameIn (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
PatElem (LetDec GPU)
pel) Names
acc_deps =
          let invar_par :: Bool
invar_par = Map VName Names -> Slice SubExp -> [VName] -> Bool
forall (t :: * -> *).
Foldable t =>
Map VName Names -> Slice SubExp -> t VName -> Bool
isSliceInvar2 Map VName Names
variance Slice SubExp
slc (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
gid_dims)
              invar_seq :: Bool
invar_seq = Map VName Names -> Slice SubExp -> [VName] -> Bool
forall (t :: * -> *).
Foldable t =>
Map VName Names -> Slice SubExp -> t VName -> Bool
isSliceInvar2 Map VName Names
variance Slice SubExp
slc [VName
seq_gid]
           in Bool
invar_par Bool -> Bool -> Bool
|| Bool
invar_seq
    -- this relies on the cost model, that currently accepts only
    -- global-memory reads, and for example rejects in-place updates
    -- or loops inside the code that is transformed in a redomap.
    isTileable VName
_ [(VName, SubExp)]
_ Map VName Names
_ VName
_ Stm GPU
_ = Bool
True
    -- does the to-be-reduced accumulator depends on this statement?
    dependsOnAcc :: k -> Map k Names -> Stm rep -> Bool
dependsOnAcc k
pat_acc_nm Map k Names
variance (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
_) =
      let acc_deps :: Names
acc_deps = Names -> k -> Map k Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty k
pat_acc_nm Map k Names
variance
       in (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
acc_deps) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
genRed2Tile2d Env
_ Stm GPU
_ =
  Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing

genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU))
genRed2SegRed Env
_ Stm GPU
_ =
  Maybe (Stms GPU, Stm GPU) -> GenRedM (Maybe (Stms GPU, Stm GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stms GPU, Stm GPU)
forall a. Maybe a
Nothing

transposeFVs ::
  Names ->
  VarianceTable ->
  VName ->
  Stms GPU ->
  GenRedM (Stms GPU, Stms GPU)
transposeFVs :: Names
-> Map VName Names
-> VName
-> Stms GPU
-> GenRedM (Stms GPU, Stms GPU)
transposeFVs Names
fvs Map VName Names
variance VName
gid Stms GPU
stms = do
  (Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
stms') <- ((Map VName ([Int], VName, Stms GPU), Stms GPU)
 -> Stm GPU
 -> ReaderT
      (Scope GPU)
      (State VNameSource)
      (Map VName ([Int], VName, Stms GPU), Stms GPU))
-> (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> [Stm GPU]
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (Map VName ([Int], VName, Stms GPU)
forall k a. Map k a
M.empty, Stms GPU
forall a. Monoid a => a
mempty) ([Stm GPU]
 -> ReaderT
      (Scope GPU)
      (State VNameSource)
      (Map VName ([Int], VName, Stms GPU), Stms GPU))
-> [Stm GPU]
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms
  let stms_host :: Stms GPU
stms_host = (([Int], VName, Stms GPU) -> Stms GPU -> Stms GPU)
-> Stms GPU -> Map VName ([Int], VName, Stms GPU) -> Stms GPU
forall a b k. (a -> b -> b) -> b -> Map k a -> b
M.foldr (\([Int]
_, VName
_, Stms GPU
s) Stms GPU
ss -> Stms GPU
ss Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s) Stms GPU
forall a. Monoid a => a
mempty Map VName ([Int], VName, Stms GPU)
tab
  (Stms GPU, Stms GPU) -> GenRedM (Stms GPU, Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Stms GPU
stms_host)
  where
    foldfun :: (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> Stm GPU
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
foldfun (Map VName ([Int], VName, Stms GPU)
tab, Stms GPU
all_stms) Stm GPU
stm = do
      (Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm') <- (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Stm GPU
stm)
      (Map VName ([Int], VName, Stms GPU), Stms GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stms GPU
all_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
stm')
    -- ToDo: currently handles only 2-dim arrays, please generalize
    transposeFV :: (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
transposeFV (Map VName ([Int], VName, Stms GPU)
tab, Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (Index VName
arr Slice SubExp
slc)))
      | [DimIndex SubExp]
dims <- Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc,
        (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isFixDim [DimIndex SubExp]
dims,
        VName -> Names -> Bool
nameIn VName
arr Names
fvs,
        [Int]
iis <- (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
L.findIndices DimIndex SubExp -> Bool
depOnGid [DimIndex SubExp]
dims,
        [Int
ii] <- [Int]
iis,
        -- generalize below: treat any rearange and add to tab if not there.
        Maybe ([Int], VName, Stms GPU)
Nothing <- VName
-> Map VName ([Int], VName, Stms GPU)
-> Maybe ([Int], VName, Stms GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Map VName ([Int], VName, Stms GPU)
tab,
        Int
ii Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1,
        [Int]
perm <- [Int
0 .. Int
ii Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
ii Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. [DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
ii] = do
          (VName
arr_tr, Stms GPU
stms_tr) <- BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
-> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
 -> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
-> ReaderT (Scope GPU) (State VNameSource) (VName, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
            VName
arr' <- String
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_trsp") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr -- Manifest [1,0] arr
            String
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_opaque") (Exp (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
 -> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName)
-> Exp
     (Rep (BuilderT GPU (ReaderT (Scope GPU) (State VNameSource))))
-> BuilderT GPU (ReaderT (Scope GPU) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ OpaqueOp -> SubExp -> BasicOp
Opaque OpaqueOp
OpaqueNil (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr'
          let tab' :: Map VName ([Int], VName, Stms GPU)
tab' = VName
-> ([Int], VName, Stms GPU)
-> Map VName ([Int], VName, Stms GPU)
-> Map VName ([Int], VName, Stms GPU)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
arr ([Int]
perm, VName
arr_tr, Stms GPU
stms_tr) Map VName ([Int], VName, Stms GPU)
tab
              slc' :: Slice SubExp
slc' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (Int -> DimIndex SubExp) -> [Int] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map ([DimIndex SubExp]
dims [DimIndex SubExp] -> Int -> DimIndex SubExp
forall a. [a] -> Int -> a
!!) [Int]
perm
              stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_tr Slice SubExp
slc'
          (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU)
tab', Stm GPU
stm')
      where
        isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
        isFixDim DimIndex d
_ = Bool
False
        depOnGid :: DimIndex SubExp -> Bool
depOnGid (DimFix (Var VName
nm)) =
          VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
nm Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
nm Map VName Names
variance)
        depOnGid DimIndex SubExp
_ = Bool
False
    transposeFV (Map VName ([Int], VName, Stms GPU), Stm GPU)
r = (Map VName ([Int], VName, Stms GPU), Stm GPU)
-> ReaderT
     (Scope GPU)
     (State VNameSource)
     (Map VName ([Int], VName, Stms GPU), Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName ([Int], VName, Stms GPU), Stm GPU)
r

-- | Tries to identify the following pattern:
--   code followed by some UpdateAcc-statement
--   followed by more code.
matchCodeAccumCode ::
  Stms GPU ->
  (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU)
matchCodeAccumCode Stms GPU
kstms =
  let ([Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU]
code2) =
        (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
 -> Stm GPU -> ([Stm GPU], Maybe (Stm GPU), [Stm GPU]))
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
-> [Stm GPU]
-> ([Stm GPU], Maybe (Stm GPU), [Stm GPU])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc Stm GPU
stmt ->
              case (([Stm GPU], Maybe (Stm GPU), [Stm GPU])
acc, Stm GPU
stmt) of
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) ->
                  ([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
stmt, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Maybe (Stm GPU)
Nothing, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [Stm GPU]
cd2)
                (([Stm GPU]
cd1, Just Stm GPU
strm, [Stm GPU]
cd2), Stm GPU
_) ->
                  ([Stm GPU]
cd1, Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just Stm GPU
strm, [Stm GPU]
cd2 [Stm GPU] -> [Stm GPU] -> [Stm GPU]
forall a. [a] -> [a] -> [a]
++ [Stm GPU
stmt])
          )
          ([], Maybe (Stm GPU)
forall a. Maybe a
Nothing, [])
          (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
   in ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code1, Maybe (Stm GPU)
screma, [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPU]
code2)

-- | Checks that there exist a parallel dimension (among @kids@),
--     to which all the indices (@acc_inds@) are invariant to.
--   It returns the innermost such parallel dimension, as a tuple
--     of the pardim gid ('VName') and its index ('Int') in the
--     parallel space.
isInvarToParDim ::
  Names ->
  SegSpace ->
  VarianceTable ->
  [SubExp] ->
  Maybe (VName, Int)
isInvarToParDim :: Names
-> SegSpace -> Map VName Names -> [SubExp] -> Maybe (VName, Int)
isInvarToParDim Names
branch_variant SegSpace
kspace Map VName Names
variance [SubExp]
acc_inds =
  let ker_gids :: [VName]
ker_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
kspace
      branch_invariant :: Bool
branch_invariant = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
branch_variant) [VName]
ker_gids
      allvar2 :: Names
allvar2 = [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
acc_inds [VName]
ker_gids
      last_invar_dim :: Maybe (VName, Int)
last_invar_dim =
        (Maybe (VName, Int) -> (VName, Int) -> Maybe (VName, Int))
-> Maybe (VName, Int) -> [(VName, Int)] -> Maybe (VName, Int)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names -> Maybe (VName, Int) -> (VName, Int) -> Maybe (VName, Int)
forall b.
Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2) Maybe (VName, Int)
forall a. Maybe a
Nothing ([(VName, Int)] -> Maybe (VName, Int))
-> [(VName, Int)] -> Maybe (VName, Int)
forall a b. (a -> b) -> a -> b
$
          [VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ker_gids [Int
0 .. [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ker_gids Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
   in if Bool
branch_invariant
        then Maybe (VName, Int)
last_invar_dim
        else Maybe (VName, Int)
forall a. Maybe a
Nothing
  where
    variant2 :: SubExp -> [VName] -> [VName]
variant2 (Var VName
ind) [VName]
kids =
      let variant_to :: Names
variant_to =
            Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
ind Map VName Names
variance
              Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (if VName
ind VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
kids then VName -> Names
oneName VName
ind else Names
forall a. Monoid a => a
mempty)
       in (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
variant_to) [VName]
kids
    variant2 SubExp
_ [VName]
_ = []
    allvariant2 :: [SubExp] -> [VName] -> Names
allvariant2 [SubExp]
ind_ses [VName]
kids =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> [VName]) -> [SubExp] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (SubExp -> [VName] -> [VName]
`variant2` [VName]
kids) [SubExp]
ind_ses
    lastNotIn :: Names -> Maybe (VName, b) -> (VName, b) -> Maybe (VName, b)
lastNotIn Names
allvar2 Maybe (VName, b)
acc (VName
kid, b
k) =
      if VName -> Names -> Bool
nameIn VName
kid Names
allvar2 then Maybe (VName, b)
acc else (VName, b) -> Maybe (VName, b)
forall a. a -> Maybe a
Just (VName
kid, b
k)

allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp])
allGoodReturns [KernelResult]
kres
  | (KernelResult -> Bool) -> [KernelResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all KernelResult -> Bool
goodReturn [KernelResult]
kres = do
      ([VName], [SubExp]) -> Maybe ([VName], [SubExp])
forall a. a -> Maybe a
Just (([VName], [SubExp]) -> Maybe ([VName], [SubExp]))
-> ([VName], [SubExp]) -> Maybe ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ (([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp]))
-> ([VName], [SubExp]) -> [KernelResult] -> ([VName], [SubExp])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([], []) [KernelResult]
kres
  where
    goodReturn :: KernelResult -> Bool
goodReturn (Returns ResultManifest
ResultMaySimplify Certs
_ SubExp
_) = Bool
True
    goodReturn KernelResult
_ = Bool
False
    addCertAndRes :: ([VName], [SubExp]) -> KernelResult -> ([VName], [SubExp])
addCertAndRes ([VName]
cs, [SubExp]
rs) (Returns ResultManifest
ResultMaySimplify Certs
c SubExp
r_se) =
      ([VName]
cs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Certs -> [VName]
unCerts Certs
c, [SubExp]
rs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
r_se])
    addCertAndRes ([VName], [SubExp])
_ KernelResult
_ =
      String -> ([VName], [SubExp])
forall a. HasCallStack => String -> a
error String
"Impossible case reached in GenRedOpt.hs, function allGoodReturns!"
allGoodReturns [KernelResult]
_ = Maybe ([VName], [SubExp])
forall a. Maybe a
Nothing

--------------------------
--- Cost Model Helpers ---
--------------------------

costRedundantExecution ::
  VarianceTable ->
  VName ->
  [SubExp] ->
  Stms GPU ->
  Cost
costRedundantExecution :: Map VName Names -> VName -> [SubExp] -> Stms GPU -> Cost
costRedundantExecution Map VName Names
variance VName
pat_acc_nm [SubExp]
r_ses Stms GPU
kstms =
  let acc_deps :: Names
acc_deps = Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
pat_acc_nm Map VName Names
variance
      vartab_cut_acc :: Map VName Names
vartab_cut_acc = Names -> Map VName Names -> Stms GPU -> Map VName Names
varianceInStmsWithout (VName -> Names
oneName VName
pat_acc_nm) Map VName Names
forall a. Monoid a => a
mempty Stms GPU
kstms
      res_deps :: Names
res_deps = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName Names -> VName -> Names
forall k a. (Ord k, Monoid a) => Map k a -> k -> a
findDeps Map VName Names
vartab_cut_acc) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
se2nm [SubExp]
r_ses
      common_deps :: Names
common_deps = Names -> Names -> Names
namesIntersection Names
res_deps Names
acc_deps
   in (Cost -> Stm GPU -> Cost) -> Cost -> Stms GPU -> Cost
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps) (Int -> Cost
Small Int
0) Stms GPU
kstms
  where
    se2nm :: SubExp -> Maybe VName
se2nm (Var VName
nm) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
nm
    se2nm SubExp
_ = Maybe VName
forall a. Maybe a
Nothing
    findDeps :: Map k a -> k -> a
findDeps Map k a
vartab k
nm = a -> k -> Map k a -> a
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault a
forall a. Monoid a => a
mempty k
nm Map k a
vartab
    addCostOfStmt :: Names -> Cost -> Stm GPU -> Cost
addCostOfStmt Names
common_deps Cost
cur_cost Stm GPU
stm =
      let pat_nms :: [VName]
pat_nms = Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
       in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
common_deps
            then Cost -> Cost -> Cost
addCosts Cost
cur_cost (Cost -> Cost) -> Cost -> Cost
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Cost
costRedundantStmt Stm GPU
stm
            else Cost
cur_cost
    varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable
    varianceInStmsWithout :: Names -> Map VName Names -> Stms GPU -> Map VName Names
varianceInStmsWithout Names
nms = (Map VName Names -> Stm GPU -> Map VName Names)
-> Map VName Names -> Stms GPU -> Map VName Names
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (Names -> Map VName Names -> Stm GPU -> Map VName Names
forall rep.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep),
 FreeIn (Op rep)) =>
Names -> Map VName Names -> Stm rep -> Map VName Names
varianceInStmWithout Names
nms)
    varianceInStmWithout :: Names -> Map VName Names -> Stm rep -> Map VName Names
varianceInStmWithout Names
cuts Map VName Names
vartab Stm rep
stm =
      let pat_nms :: [VName]
pat_nms = Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName]) -> Pat (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm
       in if Names -> Names -> Bool
namesIntersect ([VName] -> Names
namesFromList [VName]
pat_nms) Names
cuts
            then Map VName Names
vartab
            else (Map VName Names -> VName -> Map VName Names)
-> Map VName Names -> [VName] -> Map VName Names
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Map VName Names -> VName -> Map VName Names
add Map VName Names
vartab [VName]
pat_nms
      where
        add :: Map VName Names -> VName -> Map VName Names
add Map VName Names
variance' VName
v = VName -> Names -> Map VName Names -> Map VName Names
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance Map VName Names
variance'
        look :: Map VName Names -> VName -> Names
look Map VName Names
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> Map VName Names -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v Map VName Names
variance'
        binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName Names -> VName -> Names
look Map VName Names
vartab) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm)

data Cost = Small Int | Big | Break
  deriving (Cost -> Cost -> Bool
(Cost -> Cost -> Bool) -> (Cost -> Cost -> Bool) -> Eq Cost
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Cost -> Cost -> Bool
$c/= :: Cost -> Cost -> Bool
== :: Cost -> Cost -> Bool
$c== :: Cost -> Cost -> Bool
Eq)

addCosts :: Cost -> Cost -> Cost
addCosts :: Cost -> Cost -> Cost
addCosts Cost
Break Cost
_ = Cost
Break
addCosts Cost
_ Cost
Break = Cost
Break
addCosts Cost
Big Cost
_ = Cost
Big
addCosts Cost
_ Cost
Big = Cost
Big
addCosts (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)

maxCost :: Cost -> Cost -> Cost
maxCost :: Cost -> Cost -> Cost
maxCost (Small Int
c1) (Small Int
c2) = Int -> Cost
Small (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
c1 Int
c2)
maxCost Cost
c1 Cost
c2 = Cost -> Cost -> Cost
addCosts Cost
c1 Cost
c2

costBody :: Body GPU -> Cost
costBody :: Body GPU -> Cost
costBody Body GPU
bdy =
  (Cost -> Cost -> Cost) -> Cost -> [Cost] -> Cost
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Cost -> Cost -> Cost
addCosts (Int -> Cost
Small Int
0) ([Cost] -> Cost) -> [Cost] -> Cost
forall a b. (a -> b) -> a -> b
$
    (Stm GPU -> Cost) -> [Stm GPU] -> [Cost]
forall a b. (a -> b) -> [a] -> [b]
map Stm GPU -> Cost
costRedundantStmt ([Stm GPU] -> [Cost]) -> [Stm GPU] -> [Cost]
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPU -> [Stm GPU]) -> Stms GPU -> [Stm GPU]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
bdy

costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt :: Stm GPU -> Cost
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Op Op GPU
_)) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ DoLoop {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ Apply {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ WithAcc {}) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (If SubExp
_cond Body GPU
b_then Body GPU
b_else IfDec (BranchType GPU)
_)) =
  Cost -> Cost -> Cost
maxCost (Body GPU -> Cost
costBody Body GPU
b_then) (Body GPU -> Cost
costBody Body GPU
b_else)
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Array {}))) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (ArrayLit [SubExp]
_ Type
_))) = Int -> Cost
Small Int
1
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slc))) =
  if (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isFixDim (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc) then Int -> Cost
Small Int
1 else Int -> Cost
Small Int
0
  where
    isFixDim :: DimIndex d -> Bool
isFixDim DimFix {} = Bool
True
    isFixDim DimIndex d
_ = Bool
False
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatIndex {})) = Int -> Cost
Small Int
0
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Update {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp FlatUpdate {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Concat {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Copy {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Manifest {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp Replicate {})) = Cost
Big
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp UpdateAcc {})) = Cost
Break
costRedundantStmt (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) = Int -> Cost
Small Int
0