{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | "Sinking" is conceptually the opposite of hoisting.  The idea is
-- to take code that looks like this:
--
-- @
-- x = xs[i]
-- y = ys[i]
-- if x != 0 then {
--   y
-- } else {
--   0
-- }
-- @
--
-- and turn it into
--
-- @
-- x = xs[i]
-- if x != 0 then {
--   y = ys[i]
--   y
-- } else {
--   0
-- }
-- @
--
-- The idea is to delay loads from memory until (if) they are actually
-- needed.  Code patterns like the above is particularly common in
-- code that makes use of pattern matching on sum types.
--
-- We are currently quite conservative about when we do this.  In
-- particular, if any consumption is going on in a body, we don't do
-- anything.  This is far too conservative.  Also, we are careful
-- never to duplicate work.
--
-- This pass redundantly computes free-variable information a lot.  If
-- you ever see this pass as being a compilation speed bottleneck,
-- start by caching that a bit.
--
-- This pass is defined on post-SOACS representations.  This is not
-- because we do anything GPU-specific here, but simply because more
-- explicit indexing is going on after SOACs are gone.
module Futhark.Optimise.Sink (sinkGPU, sinkMC) where

import Control.Monad.State
import Data.Bifunctor
import Data.List (foldl')
import qualified Data.Map as M
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.IR.MC
import Futhark.Pass

type SymbolTable rep = ST.SymbolTable rep

type Sinking rep = M.Map VName (Stm rep)

type Sunk = Names

type Sinker rep a = SymbolTable rep -> Sinking rep -> a -> (a, Sunk)

type Constraints rep =
  ( ASTRep rep,
    Aliased rep,
    ST.IndexOp (Op rep)
  )

-- | Given a statement, compute how often each of its free variables
-- are used.  Not accurate: what we care about are only 1, and greater
-- than 1.
multiplicity :: Constraints rep => Stm rep -> M.Map VName Int
multiplicity :: Stm rep -> Map VName Int
multiplicity Stm rep
stm =
  case Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm of
    If SubExp
cond BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
_ ->
      SubExp -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free SubExp
cond Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT rep -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT rep
tbranch Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT rep -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT rep
fbranch Int
1
    Op {} -> Stm rep -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm rep
stm Int
2
    DoLoop {} -> Stm rep -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm rep
stm Int
2
    Exp rep
_ -> Stm rep -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm rep
stm Int
1
  where
    free :: a -> a -> Map VName a
free a
x a
k = [(VName, a)] -> Map VName a
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, a)] -> Map VName a) -> [(VName, a)] -> Map VName a
forall a b. (a -> b) -> a -> b
$ [VName] -> [a] -> [(VName, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x) ([a] -> [(VName, a)]) -> [a] -> [(VName, a)]
forall a b. (a -> b) -> a -> b
$ a -> [a]
forall a. a -> [a]
repeat a
k
    comb :: Map VName Int -> Map VName Int -> Map VName Int
comb = (Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)

optimiseBranch ::
  Constraints rep =>
  Sinker rep (Op rep) ->
  Sinker rep (Body rep)
optimiseBranch :: Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (Body BodyDec rep
dec Stms rep
stms Result
res) =
  let (Stms rep
stms', Names
stms_sunk) = Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking' (Stms rep
sunk_stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
stms) (Names -> (Stms rep, Names)) -> Names -> (Stms rep, Names)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
   in ( BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec rep
dec Stms rep
stms' Result
res,
        Names
sunk Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
stms_sunk
      )
  where
    free_in_stms :: Names
free_in_stms = Stms rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stms rep
stms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
    (Sinking rep
sinking_here, Sinking rep
sinking') = (VName -> Stm rep -> Bool)
-> Sinking rep -> (Sinking rep, Sinking rep)
forall k a. (k -> a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partitionWithKey VName -> Stm rep -> Bool
sunkHere Sinking rep
sinking
    sunk_stms :: Stms rep
sunk_stms = [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Stms rep) -> [Stm rep] -> Stms rep
forall a b. (a -> b) -> a -> b
$ Sinking rep -> [Stm rep]
forall k a. Map k a -> [a]
M.elems Sinking rep
sinking_here
    sunkHere :: VName -> Stm rep -> Bool
sunkHere VName
v Stm rep
stm =
      VName
v VName -> Names -> Bool
`nameIn` Names
free_in_stms
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable) (Names -> [VName]
namesToList (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm))
    sunk :: Names
sunk = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm rep -> [VName]) -> Stms rep -> [VName]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (PatT (LetDec rep) -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT (LetDec rep) -> [VName])
-> (Stm rep -> PatT (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> PatT (LetDec rep)
forall rep. Stm rep -> Pat rep
stmPat) Stms rep
sunk_stms

optimiseStms ::
  Constraints rep =>
  Sinker rep (Op rep) ->
  SymbolTable rep ->
  Sinking rep ->
  Stms rep ->
  Names ->
  (Stms rep, Sunk)
optimiseStms :: Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
init_vtable Sinking rep
init_sinking Stms rep
all_stms Names
free_in_res =
  let ([Stm rep]
all_stms', Names
sunk) =
        SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Names)
optimiseStms' SymbolTable rep
init_vtable Sinking rep
init_sinking ([Stm rep] -> ([Stm rep], Names))
-> [Stm rep] -> ([Stm rep], Names)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
   in ([Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
all_stms', Names
sunk)
  where
    multiplicities :: Map VName Int
multiplicities =
      (Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
        ((Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
        ([(VName, Int)] -> Map VName Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList Names
free_in_res) (Int -> [Int]
forall a. a -> [a]
repeat Int
1)))
        ((Stm rep -> Map VName Int) -> [Stm rep] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Map VName Int
forall rep. Constraints rep => Stm rep -> Map VName Int
multiplicity ([Stm rep] -> [Map VName Int]) -> [Stm rep] -> [Map VName Int]
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms)

    optimiseStms' :: SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Names)
optimiseStms' SymbolTable rep
_ Sinking rep
_ [] = ([], Names
forall a. Monoid a => a
mempty)
    optimiseStms' SymbolTable rep
vtable Sinking rep
sinking (Stm rep
stm : [Stm rep]
stms)
      | BasicOp Index {} <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        [PatElemT (LetDec rep)
pe] <- PatT (LetDec rep) -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems (Stm rep -> PatT (LetDec rep)
forall rep. Stm rep -> Pat rep
stmPat Stm rep
stm),
        TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec rep) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT (LetDec rep)
pe,
        Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) (Maybe Int -> Bool) -> Maybe Int -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) Map VName Int
multiplicities =
        let ([Stm rep]
stms', Names
sunk) =
              SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Names)
optimiseStms' SymbolTable rep
vtable' (VName -> Stm rep -> Sinking rep -> Sinking rep
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) Stm rep
stm Sinking rep
sinking) [Stm rep]
stms
         in if PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> Names -> Bool
`nameIn` Names
sunk
              then ([Stm rep]
stms', Names
sunk)
              else (Stm rep
stm Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms', Names
sunk)
      | If SubExp
cond BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
ret <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        let (BodyT rep
tbranch', Names
tsunk) = Sinker rep (Op rep) -> Sinker rep (BodyT rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking BodyT rep
tbranch
            (BodyT rep
fbranch', Names
fsunk) = Sinker rep (Op rep) -> Sinker rep (BodyT rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking BodyT rep
fbranch
            ([Stm rep]
stms', Names
sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Names)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
         in ( Stm rep
stm {stmExp :: ExpT rep
stmExp = SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond BodyT rep
tbranch' BodyT rep
fbranch' IfDec (BranchType rep)
ret} Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms',
              Names
tsunk Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
fsunk Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
sunk
            )
      | Op Op rep
op <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
        let (Op rep
op', Names
op_sunk) = Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Op rep
op
            ([Stm rep]
stms', Names
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Names)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
         in ( Stm rep
stm {stmExp :: ExpT rep
stmExp = Op rep -> ExpT rep
forall rep. Op rep -> ExpT rep
Op Op rep
op'} Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms',
              Names
stms_sunk Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
op_sunk
            )
      | Bool
otherwise =
        let ([Stm rep]
stms', Names
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Names)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
            (ExpT rep
e', Names
stm_sunk) = State Names (ExpT rep) -> Names -> (ExpT rep, Names)
forall s a. State s a -> s -> (a, s)
runState (Mapper rep rep (StateT Names Identity)
-> ExpT rep -> State Names (ExpT rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (StateT Names Identity)
mapper (Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm)) Names
forall a. Monoid a => a
mempty
         in ( Stm rep
stm {stmExp :: ExpT rep
stmExp = ExpT rep
e'} Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms',
              Names
stm_sunk Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
stms_sunk
            )
      where
        vtable' :: SymbolTable rep
vtable' = Stm rep -> SymbolTable rep -> SymbolTable rep
forall rep.
(ASTRep rep, IndexOp (Op rep), Aliased rep) =>
Stm rep -> SymbolTable rep -> SymbolTable rep
ST.insertStm Stm rep
stm SymbolTable rep
vtable
        mapper :: Mapper rep rep (StateT Names Identity)
mapper =
          Mapper rep rep (StateT Names Identity)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
            { mapOnBody :: Scope rep -> BodyT rep -> StateT Names Identity (BodyT rep)
mapOnBody = \Scope rep
scope BodyT rep
body -> do
                let (BodyT rep
body', Names
sunk) =
                      Sinker rep (Op rep) -> Sinker rep (BodyT rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody
                        Sinker rep (Op rep)
onOp
                        (Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope SymbolTable rep -> SymbolTable rep -> SymbolTable rep
forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable)
                        Sinking rep
sinking
                        BodyT rep
body
                (Names -> Names) -> StateT Names Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
sunk)
                BodyT rep -> StateT Names Identity (BodyT rep)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT rep
body'
            }

optimiseBody ::
  Constraints rep =>
  Sinker rep (Op rep) ->
  Sinker rep (Body rep)
optimiseBody :: Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (Body BodyDec rep
attr Stms rep
stms Result
res) =
  let (Stms rep
stms', Names
sunk) = Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Stms rep
stms (Names -> (Stms rep, Names)) -> Names -> (Stms rep, Names)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
   in (BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec rep
attr Stms rep
stms' Result
res, Names
sunk)

optimiseKernelBody ::
  Constraints rep =>
  Sinker rep (Op rep) ->
  Sinker rep (KernelBody rep)
optimiseKernelBody :: Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
optimiseKernelBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (KernelBody BodyDec rep
attr Stms rep
stms [KernelResult]
res) =
  let (Stms rep
stms', Names
sunk) = Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Stms rep
stms (Names -> (Stms rep, Names)) -> Names -> (Stms rep, Names)
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res
   in (BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
attr Stms rep
stms' [KernelResult]
res, Names
sunk)

optimiseSegOp ::
  Constraints rep =>
  Sinker rep (Op rep) ->
  Sinker rep (SegOp lvl rep)
optimiseSegOp :: Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking SegOp lvl rep
op =
  let scope :: Scope rep
scope = SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace -> Scope rep) -> SegSpace -> Scope rep
forall a b. (a -> b) -> a -> b
$ SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op
   in State Names (SegOp lvl rep) -> Names -> (SegOp lvl rep, Names)
forall s a. State s a -> s -> (a, s)
runState (SegOpMapper lvl rep rep (StateT Names Identity)
-> SegOp lvl rep -> State Names (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM (Scope rep -> SegOpMapper lvl rep rep (StateT Names Identity)
opMapper Scope rep
scope) SegOp lvl rep
op) Names
forall a. Monoid a => a
mempty
  where
    opMapper :: Scope rep -> SegOpMapper lvl rep rep (StateT Names Identity)
opMapper Scope rep
scope =
      SegOpMapper lvl Any Any (StateT Names Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda rep -> StateT Names Identity (Lambda rep)
mapOnSegOpLambda = \Lambda rep
lam -> do
            let (Body rep
body, Names
sunk) =
                  Sinker rep (Op rep) -> Sinker rep (Body rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
op_vtable Sinking rep
sinking (Body rep -> (Body rep, Names)) -> Body rep -> (Body rep, Names)
forall a b. (a -> b) -> a -> b
$
                    Lambda rep -> Body rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
            (Names -> Names) -> StateT Names Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
sunk)
            Lambda rep -> StateT Names Identity (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body},
          mapOnSegOpBody :: KernelBody rep -> StateT Names Identity (KernelBody rep)
mapOnSegOpBody = \KernelBody rep
body -> do
            let (KernelBody rep
body', Names
sunk) =
                  Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
optimiseKernelBody Sinker rep (Op rep)
onOp SymbolTable rep
op_vtable Sinking rep
sinking KernelBody rep
body
            (Names -> Names) -> StateT Names Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
sunk)
            KernelBody rep -> StateT Names Identity (KernelBody rep)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody rep
body'
        }
      where
        op_vtable :: SymbolTable rep
op_vtable = Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope SymbolTable rep -> SymbolTable rep -> SymbolTable rep
forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable

type SinkRep rep = Aliases rep

sink ::
  ( ASTRep rep,
    CanBeAliased (Op rep),
    ST.IndexOp (OpWithAliases (Op rep))
  ) =>
  Sinker (SinkRep rep) (Op (SinkRep rep)) ->
  Pass rep rep
sink :: Sinker (SinkRep rep) (Op (SinkRep rep)) -> Pass rep rep
sink Sinker (SinkRep rep) (Op (SinkRep rep))
onOp =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"sink" String
"move memory loads closer to their uses" ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$
    (Prog (SinkRep rep) -> Prog rep)
-> PassM (Prog (SinkRep rep)) -> PassM (Prog rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (SinkRep rep) -> Prog rep
forall rep. CanBeAliased (Op rep) => Prog (Aliases rep) -> Prog rep
removeProgAliases
      (PassM (Prog (SinkRep rep)) -> PassM (Prog rep))
-> (Prog rep -> PassM (Prog (SinkRep rep)))
-> Prog rep
-> PassM (Prog rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms (SinkRep rep) -> PassM (Stms (SinkRep rep)))
-> (Stms (SinkRep rep)
    -> FunDef (SinkRep rep) -> PassM (FunDef (SinkRep rep)))
-> Prog (SinkRep rep)
-> PassM (Prog (SinkRep rep))
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (SinkRep rep) -> PassM (Stms (SinkRep rep))
onConsts Stms (SinkRep rep)
-> FunDef (SinkRep rep) -> PassM (FunDef (SinkRep rep))
onFun
      (Prog (SinkRep rep) -> PassM (Prog (SinkRep rep)))
-> (Prog rep -> Prog (SinkRep rep))
-> Prog rep
-> PassM (Prog (SinkRep rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog rep -> Prog (SinkRep rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
Prog rep -> Prog (Aliases rep)
Alias.aliasAnalysis
  where
    onFun :: Stms (SinkRep rep)
-> FunDef (SinkRep rep) -> PassM (FunDef (SinkRep rep))
onFun Stms (SinkRep rep)
_ FunDef (SinkRep rep)
fd = do
      let vtable :: SymbolTable (SinkRep rep)
vtable = [FParam (SinkRep rep)]
-> SymbolTable (SinkRep rep) -> SymbolTable (SinkRep rep)
forall rep.
ASTRep rep =>
[FParam rep] -> SymbolTable rep -> SymbolTable rep
ST.insertFParams (FunDef (SinkRep rep) -> [FParam (SinkRep rep)]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef (SinkRep rep)
fd) SymbolTable (SinkRep rep)
forall a. Monoid a => a
mempty
          (Body (SinkRep rep)
body, Names
_) = Sinker (SinkRep rep) (Op (SinkRep rep))
-> Sinker (SinkRep rep) (Body (SinkRep rep))
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker (SinkRep rep) (Op (SinkRep rep))
onOp SymbolTable (SinkRep rep)
vtable Map VName (Stm (SinkRep rep))
forall a. Monoid a => a
mempty (Body (SinkRep rep) -> (Body (SinkRep rep), Names))
-> Body (SinkRep rep) -> (Body (SinkRep rep), Names)
forall a b. (a -> b) -> a -> b
$ FunDef (SinkRep rep) -> Body (SinkRep rep)
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef (SinkRep rep)
fd
      FunDef (SinkRep rep) -> PassM (FunDef (SinkRep rep))
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef (SinkRep rep)
fd {funDefBody :: Body (SinkRep rep)
funDefBody = Body (SinkRep rep)
body}

    onConsts :: Stms (SinkRep rep) -> PassM (Stms (SinkRep rep))
onConsts Stms (SinkRep rep)
consts =
      Stms (SinkRep rep) -> PassM (Stms (SinkRep rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms (SinkRep rep) -> PassM (Stms (SinkRep rep)))
-> Stms (SinkRep rep) -> PassM (Stms (SinkRep rep))
forall a b. (a -> b) -> a -> b
$
        (Stms (SinkRep rep), Names) -> Stms (SinkRep rep)
forall a b. (a, b) -> a
fst ((Stms (SinkRep rep), Names) -> Stms (SinkRep rep))
-> (Stms (SinkRep rep), Names) -> Stms (SinkRep rep)
forall a b. (a -> b) -> a -> b
$
          Sinker (SinkRep rep) (Op (SinkRep rep))
-> SymbolTable (SinkRep rep)
-> Map VName (Stm (SinkRep rep))
-> Stms (SinkRep rep)
-> Names
-> (Stms (SinkRep rep), Names)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Names
-> (Stms rep, Names)
optimiseStms Sinker (SinkRep rep) (Op (SinkRep rep))
onOp SymbolTable (SinkRep rep)
forall a. Monoid a => a
mempty Map VName (Stm (SinkRep rep))
forall a. Monoid a => a
mempty Stms (SinkRep rep)
consts (Names -> (Stms (SinkRep rep), Names))
-> Names -> (Stms (SinkRep rep), Names)
forall a b. (a -> b) -> a -> b
$
            [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo (SinkRep rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (SinkRep rep)) -> [VName])
-> Map VName (NameInfo (SinkRep rep)) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (SinkRep rep) -> Map VName (NameInfo (SinkRep rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (SinkRep rep)
consts

-- | Sinking in GPU kernels.
sinkGPU :: Pass GPU GPU
sinkGPU :: Pass GPU GPU
sinkGPU = Sinker (SinkRep GPU) (Op (SinkRep GPU)) -> Pass GPU GPU
forall rep.
(ASTRep rep, CanBeAliased (Op rep),
 IndexOp (OpWithAliases (Op rep))) =>
Sinker (SinkRep rep) (Op (SinkRep rep)) -> Pass rep rep
sink Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp
  where
    onHostOp :: Sinker (SinkRep GPU) (Op (SinkRep GPU))
    onHostOp :: Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking (SegOp op) =
      (SegOp SegLevel (SinkRep GPU)
 -> HostOp (SinkRep GPU) (SOAC (SinkRep GPU)))
-> (SegOp SegLevel (SinkRep GPU), Names)
-> (HostOp (SinkRep GPU) (SOAC (SinkRep GPU)), Names)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SegOp SegLevel (SinkRep GPU)
-> HostOp (SinkRep GPU) (SOAC (SinkRep GPU))
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp ((SegOp SegLevel (SinkRep GPU), Names)
 -> (HostOp (SinkRep GPU) (SOAC (SinkRep GPU)), Names))
-> (SegOp SegLevel (SinkRep GPU), Names)
-> (HostOp (SinkRep GPU) (SOAC (SinkRep GPU)), Names)
forall a b. (a -> b) -> a -> b
$ Sinker (SinkRep GPU) (Op (SinkRep GPU))
-> Sinker (SinkRep GPU) (SegOp SegLevel (SinkRep GPU))
forall rep lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking SegOp SegLevel (SinkRep GPU)
op
    onHostOp SymbolTable (SinkRep GPU)
_ Sinking (SinkRep GPU)
_ Op (SinkRep GPU)
op = (Op (SinkRep GPU)
op, Names
forall a. Monoid a => a
mempty)

-- | Sinking for multicore.
sinkMC :: Pass MC MC
sinkMC :: Pass MC MC
sinkMC = Sinker (SinkRep MC) (Op (SinkRep MC)) -> Pass MC MC
forall rep.
(ASTRep rep, CanBeAliased (Op rep),
 IndexOp (OpWithAliases (Op rep))) =>
Sinker (SinkRep rep) (Op (SinkRep rep)) -> Pass rep rep
sink Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp
  where
    onHostOp :: Sinker (SinkRep MC) (Op (SinkRep MC))
    onHostOp :: Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking (ParOp par_op op) =
      let (Maybe (SegOp () (SinkRep MC))
par_op', Names
par_sunk) =
            (Maybe (SegOp () (SinkRep MC)), Names)
-> (SegOp () (SinkRep MC)
    -> (Maybe (SegOp () (SinkRep MC)), Names))
-> Maybe (SegOp () (SinkRep MC))
-> (Maybe (SegOp () (SinkRep MC)), Names)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
              (Maybe (SegOp () (SinkRep MC))
forall a. Maybe a
Nothing, Names
forall a. Monoid a => a
mempty)
              ((SegOp () (SinkRep MC) -> Maybe (SegOp () (SinkRep MC)))
-> (SegOp () (SinkRep MC), Names)
-> (Maybe (SegOp () (SinkRep MC)), Names)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SegOp () (SinkRep MC) -> Maybe (SegOp () (SinkRep MC))
forall a. a -> Maybe a
Just ((SegOp () (SinkRep MC), Names)
 -> (Maybe (SegOp () (SinkRep MC)), Names))
-> (SegOp () (SinkRep MC) -> (SegOp () (SinkRep MC), Names))
-> SegOp () (SinkRep MC)
-> (Maybe (SegOp () (SinkRep MC)), Names)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sinker (SinkRep MC) (Op (SinkRep MC))
-> Sinker (SinkRep MC) (SegOp () (SinkRep MC))
forall rep lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking)
              Maybe (SegOp () (SinkRep MC))
par_op
          (SegOp () (SinkRep MC)
op', Names
sunk) = Sinker (SinkRep MC) (Op (SinkRep MC))
-> Sinker (SinkRep MC) (SegOp () (SinkRep MC))
forall rep lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking SegOp () (SinkRep MC)
op
       in (Maybe (SegOp () (SinkRep MC))
-> SegOp () (SinkRep MC) -> MCOp (SinkRep MC) (SOAC (SinkRep MC))
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () (SinkRep MC))
par_op' SegOp () (SinkRep MC)
op', Names
par_sunk Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
sunk)
    onHostOp SymbolTable (SinkRep MC)
_ Sinking (SinkRep MC)
_ Op (SinkRep MC)
op = (Op (SinkRep MC)
op, Names
forall a. Monoid a => a
mempty)