{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | "Sinking" is conceptually the opposite of hoisting.  The idea is
-- to take code that looks like this:
--
-- @
-- x = xs[i]
-- y = ys[i]
-- if x != 0 then {
--   y
-- } else {
--   0
-- }
-- @
--
-- and turn it into
--
-- @
-- x = xs[i]
-- if x != 0 then {
--   y = ys[i]
--   y
-- } else {
--   0
-- }
-- @
--
-- The idea is to delay loads from memory until (if) they are actually
-- needed.  Code patterns like the above is particularly common in
-- code that makes use of pattern matching on sum types.
--
-- We are currently quite conservative about when we do this.  In
-- particular, if any consumption is going on in a body, we don't do
-- anything.  This is far too conservative.  Also, we are careful
-- never to duplicate work.
--
-- This pass redundantly computes free-variable information a lot.  If
-- you ever see this pass as being a compilation speed bottleneck,
-- start by caching that a bit.
--
-- This pass is defined on the Kernels representation.  This is not
-- because we do anything kernel-specific here, but simply because
-- more explicit indexing is going on after SOACs are gone.

module Futhark.Optimise.Sink (sink) where

import Control.Monad.State
import Data.List (foldl')
import qualified Data.Map as M
import qualified Data.Set as S

import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.Pass

type SinkLore = Aliases Kernels
type SymbolTable = ST.SymbolTable SinkLore
type Sinking = M.Map VName (Stm SinkLore)
type Sunk = S.Set VName

-- | Given a statement, compute how often each of its free variables
-- are used.  Not accurate: what we care about are only 1, and greater
-- than 1.
multiplicity :: Stm SinkLore -> M.Map VName Int
multiplicity :: Stm SinkLore -> Map VName Int
multiplicity Stm SinkLore
stm =
  case Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm of
    If SubExp
cond BodyT SinkLore
tbranch BodyT SinkLore
fbranch IfDec (BranchType SinkLore)
_ ->
      SubExp -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free SubExp
cond Int
1 Map VName Int -> Map VName Int -> Map VName Int
forall a. Semigroup a => a -> a -> a
<> (Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (BodyT SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT SinkLore
tbranch Int
1) (BodyT SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT SinkLore
fbranch Int
1)
    Op{} -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
2
    DoLoop{} -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
2
    Exp SinkLore
_ -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
1
  where free :: a -> a -> Map VName a
free a
x a
k = [(VName, a)] -> Map VName a
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, a)] -> Map VName a) -> [(VName, a)] -> Map VName a
forall a b. (a -> b) -> a -> b
$ [VName] -> [a] -> [(VName, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x) ([a] -> [(VName, a)]) -> [a] -> [(VName, a)]
forall a b. (a -> b) -> a -> b
$ a -> [a]
forall a. a -> [a]
repeat a
k

optimiseBranch :: SymbolTable -> Sinking -> Body SinkLore
               -> (Body SinkLore, Sunk)
optimiseBranch :: SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking (Body BodyDec SinkLore
dec Stms SinkLore
stms Result
res) =
  let (Stms SinkLore
stms', Sunk
stms_sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking' Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
  in (BodyDec SinkLore -> Stms SinkLore -> Result -> BodyT SinkLore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec SinkLore
dec (Stms SinkLore
sunk_stms Stms SinkLore -> Stms SinkLore -> Stms SinkLore
forall a. Semigroup a => a -> a -> a
<> Stms SinkLore
stms') Result
res,
      Sunk
sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk)
  where free_in_stms :: Names
free_in_stms = Stms SinkLore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms SinkLore
stms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
        (Sinking
sinking_here, Sinking
sinking') = (VName -> Stm SinkLore -> Bool) -> Sinking -> (Sinking, Sinking)
forall k a. (k -> a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partitionWithKey VName -> Stm SinkLore -> Bool
sunkHere Sinking
sinking
        sunk_stms :: Stms SinkLore
sunk_stms = [Stm SinkLore] -> Stms SinkLore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm SinkLore] -> Stms SinkLore)
-> [Stm SinkLore] -> Stms SinkLore
forall a b. (a -> b) -> a -> b
$ Sinking -> [Stm SinkLore]
forall k a. Map k a -> [a]
M.elems Sinking
sinking_here
        sunkHere :: VName -> Stm SinkLore -> Bool
sunkHere VName
v Stm SinkLore
stm =
          VName
v VName -> Names -> Bool
`nameIn` Names
free_in_stms Bool -> Bool -> Bool
&&
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.available` SymbolTable
vtable) (Names -> [VName]
namesToList (Stm SinkLore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SinkLore
stm))
        sunk :: Sunk
sunk = [VName] -> Sunk
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Sunk) -> [VName] -> Sunk
forall a b. (a -> b) -> a -> b
$ (Stm SinkLore -> [VName]) -> Stms SinkLore -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT (VarAliases, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (VarAliases, Type) -> [VName])
-> (Stm SinkLore -> PatternT (VarAliases, Type))
-> Stm SinkLore
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SinkLore -> PatternT (VarAliases, Type)
forall lore. Stm lore -> Pattern lore
stmPattern) Stms SinkLore
sunk_stms

optimiseStms :: SymbolTable -> Sinking -> Stms SinkLore -> Names
             -> (Stms SinkLore, Sunk)
optimiseStms :: SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
init_vtable Sinking
init_sinking Stms SinkLore
all_stms Names
free_in_res =
  let ([Stm SinkLore]
all_stms', Sunk
sunk) =
        SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
init_vtable Sinking
init_sinking ([Stm SinkLore] -> ([Stm SinkLore], Sunk))
-> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> [Stm SinkLore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SinkLore
all_stms
  in ([Stm SinkLore] -> Stms SinkLore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm SinkLore]
all_stms', Sunk
sunk)
  where
    multiplicities :: Map VName Int
multiplicities = (Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
                     ([(VName, Int)] -> Map VName Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList Names
free_in_res) [Int
1..]))
                     ((Stm SinkLore -> Map VName Int)
-> [Stm SinkLore] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map Stm SinkLore -> Map VName Int
multiplicity ([Stm SinkLore] -> [Map VName Int])
-> [Stm SinkLore] -> [Map VName Int]
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> [Stm SinkLore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SinkLore
all_stms)

    optimiseStms' :: SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
_ Sinking
_ [] = ([], Sunk
forall a. Monoid a => a
mempty)

    optimiseStms' SymbolTable
vtable Sinking
sinking (Stm SinkLore
stm : [Stm SinkLore]
stms)
      | BasicOp Index{} <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm,
        [PatElemT (VarAliases, Type)
pe] <- PatternT (VarAliases, Type) -> [PatElemT (VarAliases, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements (Stm SinkLore -> Pattern SinkLore
forall lore. Stm lore -> Pattern lore
stmPattern Stm SinkLore
stm),
        Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (VarAliases, Type) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (VarAliases, Type)
pe,
        Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
1) (Maybe Int -> Bool) -> Maybe Int -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe) Map VName Int
multiplicities =
          let ([Stm SinkLore]
stms', Sunk
sunk) =
                SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' (VName -> Stm SinkLore -> Sinking -> Sinking
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe) Stm SinkLore
stm Sinking
sinking) [Stm SinkLore]
stms
          in if PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe VName -> Sunk -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Sunk
sunk
             then ([Stm SinkLore]
stms', Sunk
sunk)
             else (Stm SinkLore
stm Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms', Sunk
sunk)

      | If SubExp
cond BodyT SinkLore
tbranch BodyT SinkLore
fbranch IfDec (BranchType SinkLore)
ret <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm =
          let (BodyT SinkLore
tbranch', Sunk
tsunk) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking BodyT SinkLore
tbranch
              (BodyT SinkLore
fbranch', Sunk
fsunk) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking BodyT SinkLore
fbranch
              ([Stm SinkLore]
stms', Sunk
sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
          in (Stm SinkLore
stm { stmExp :: Exp SinkLore
stmExp = SubExp
-> BodyT SinkLore
-> BodyT SinkLore
-> IfDec (BranchType SinkLore)
-> Exp SinkLore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT SinkLore
tbranch' BodyT SinkLore
fbranch' IfDec (BranchType SinkLore)
ret } Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
              Sunk
tsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
fsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)

      | Op (SegOp op) <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm =
          let scope :: Scope SinkLore
scope = SegSpace -> Scope SinkLore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope SinkLore) -> SegSpace -> Scope SinkLore
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel SinkLore -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel SinkLore
op
              ([Stm SinkLore]
stms', Sunk
stms_sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
              (SegOp SegLevel SinkLore
op', Sunk
op_sunk) = State Sunk (SegOp SegLevel SinkLore)
-> Sunk -> (SegOp SegLevel SinkLore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
-> SegOp SegLevel SinkLore -> State Sunk (SegOp SegLevel SinkLore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM (Scope SinkLore
-> SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
opMapper Scope SinkLore
scope) SegOp SegLevel SinkLore
op) Sunk
forall a. Monoid a => a
mempty
          in (Stm SinkLore
stm { stmExp :: Exp SinkLore
stmExp = Op SinkLore -> Exp SinkLore
forall lore. Op lore -> ExpT lore
Op (SegOp SegLevel SinkLore -> HostOp SinkLore (SOAC SinkLore)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel SinkLore
op') } Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
              Sunk
stms_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
op_sunk)

      | Bool
otherwise =
          let ([Stm SinkLore]
stms', Sunk
stms_sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
              (Exp SinkLore
e', Sunk
stm_sunk) = State Sunk (Exp SinkLore) -> Sunk -> (Exp SinkLore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (Mapper SinkLore SinkLore (StateT Sunk Identity)
-> Exp SinkLore -> State Sunk (Exp SinkLore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SinkLore SinkLore (StateT Sunk Identity)
mapper (Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm)) Sunk
forall a. Monoid a => a
mempty
          in (Stm SinkLore
stm { stmExp :: Exp SinkLore
stmExp = Exp SinkLore
e' } Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
              Sunk
stm_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk)

      where vtable' :: SymbolTable
vtable' = Stm SinkLore -> SymbolTable -> SymbolTable
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm Stm SinkLore
stm SymbolTable
vtable
            mapper :: Mapper SinkLore SinkLore (StateT Sunk Identity)
mapper =
              Mapper SinkLore SinkLore (StateT Sunk Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
              { mapOnBody :: Scope SinkLore
-> BodyT SinkLore -> StateT Sunk Identity (BodyT SinkLore)
mapOnBody = \Scope SinkLore
scope BodyT SinkLore
body -> do
                  let (BodyT SinkLore
body', Sunk
sunk) =
                        SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody (Scope SinkLore -> SymbolTable
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope SinkLore
scope SymbolTable -> SymbolTable -> SymbolTable
forall a. Semigroup a => a -> a -> a
<> SymbolTable
vtable) Sinking
sinking BodyT SinkLore
body
                  (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<>Sunk
sunk)
                  BodyT SinkLore -> StateT Sunk Identity (BodyT SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT SinkLore
body'
              }

            opMapper :: Scope SinkLore
-> SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
opMapper Scope SinkLore
scope =
              SegOpMapper SegLevel Any Any (StateT Sunk Identity)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
              { mapOnSegOpLambda :: Lambda SinkLore -> StateT Sunk Identity (Lambda SinkLore)
mapOnSegOpLambda = \Lambda SinkLore
lam -> do
                  let (BodyT SinkLore
body, Sunk
sunk) =
                        SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
op_vtable Sinking
sinking (BodyT SinkLore -> (BodyT SinkLore, Sunk))
-> BodyT SinkLore -> (BodyT SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$
                        Lambda SinkLore -> BodyT SinkLore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SinkLore
lam
                  (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<>Sunk
sunk)
                  Lambda SinkLore -> StateT Sunk Identity (Lambda SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda SinkLore
lam { lambdaBody :: BodyT SinkLore
lambdaBody = BodyT SinkLore
body }

              , mapOnSegOpBody :: KernelBody SinkLore -> StateT Sunk Identity (KernelBody SinkLore)
mapOnSegOpBody = \KernelBody SinkLore
body -> do
                  let (KernelBody SinkLore
body', Sunk
sunk) =
                        SymbolTable
-> Sinking -> KernelBody SinkLore -> (KernelBody SinkLore, Sunk)
optimiseKernelBody SymbolTable
op_vtable Sinking
sinking KernelBody SinkLore
body
                  (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<>Sunk
sunk)
                  KernelBody SinkLore -> StateT Sunk Identity (KernelBody SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody SinkLore
body'
              }
              where op_vtable :: SymbolTable
op_vtable = Scope SinkLore -> SymbolTable
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope SinkLore
scope SymbolTable -> SymbolTable -> SymbolTable
forall a. Semigroup a => a -> a -> a
<> SymbolTable
vtable

optimiseBody :: SymbolTable -> Sinking -> Body SinkLore
             -> (Body SinkLore, Sunk)
optimiseBody :: SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
vtable Sinking
sinking (Body BodyDec SinkLore
dec Stms SinkLore
stms Result
res) =
  let (Stms SinkLore
stms', Sunk
sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
  in (BodyDec SinkLore -> Stms SinkLore -> Result -> BodyT SinkLore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec SinkLore
dec Stms SinkLore
stms' Result
res, Sunk
sunk)

optimiseKernelBody :: SymbolTable -> Sinking -> KernelBody SinkLore
                   -> (KernelBody SinkLore, Sunk)
optimiseKernelBody :: SymbolTable
-> Sinking -> KernelBody SinkLore -> (KernelBody SinkLore, Sunk)
optimiseKernelBody SymbolTable
vtable Sinking
sinking (KernelBody BodyDec SinkLore
dec Stms SinkLore
stms [KernelResult]
res) =
  let (Stms SinkLore
stms', Sunk
sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res
  in (BodyDec SinkLore
-> Stms SinkLore -> [KernelResult] -> KernelBody SinkLore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec SinkLore
dec Stms SinkLore
stms' [KernelResult]
res, Sunk
sunk)

-- | The pass definition.
sink :: Pass Kernels Kernels
sink :: Pass Kernels Kernels
sink = String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"sink" String
"move memory loads closer to their uses" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
       (Prog SinkLore -> Prog Kernels)
-> PassM (Prog SinkLore) -> PassM (Prog Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog SinkLore -> Prog Kernels
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases (PassM (Prog SinkLore) -> PassM (Prog Kernels))
-> (Prog Kernels -> PassM (Prog SinkLore))
-> Prog Kernels
-> PassM (Prog Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
       (Stms SinkLore -> PassM (Stms SinkLore))
-> (Stms SinkLore -> FunDef SinkLore -> PassM (FunDef SinkLore))
-> Prog SinkLore
-> PassM (Prog SinkLore)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms SinkLore -> PassM (Stms SinkLore)
forall (f :: * -> *).
Applicative f =>
Stms SinkLore -> f (Stms SinkLore)
onConsts Stms SinkLore -> FunDef SinkLore -> PassM (FunDef SinkLore)
forall (m :: * -> *) p.
Monad m =>
p -> FunDef SinkLore -> m (FunDef SinkLore)
onFun (Prog SinkLore -> PassM (Prog SinkLore))
-> (Prog Kernels -> Prog SinkLore)
-> Prog Kernels
-> PassM (Prog SinkLore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
       Prog Kernels -> Prog SinkLore
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
Alias.aliasAnalysis
  where onFun :: p -> FunDef SinkLore -> m (FunDef SinkLore)
onFun p
_ FunDef SinkLore
fd = do
          let vtable :: SymbolTable
vtable = [FParam SinkLore] -> SymbolTable -> SymbolTable
forall lore.
ASTLore lore =>
[FParam lore] -> SymbolTable lore -> SymbolTable lore
ST.insertFParams (FunDef SinkLore -> [FParam SinkLore]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SinkLore
fd) SymbolTable
forall a. Monoid a => a
mempty
              (BodyT SinkLore
body, Sunk
_) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
vtable Sinking
forall a. Monoid a => a
mempty (BodyT SinkLore -> (BodyT SinkLore, Sunk))
-> BodyT SinkLore -> (BodyT SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ FunDef SinkLore -> BodyT SinkLore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SinkLore
fd
          FunDef SinkLore -> m (FunDef SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef SinkLore
fd { funDefBody :: BodyT SinkLore
funDefBody = BodyT SinkLore
body }

        onConsts :: Stms SinkLore -> f (Stms SinkLore)
onConsts Stms SinkLore
consts =
          Stms SinkLore -> f (Stms SinkLore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SinkLore -> f (Stms SinkLore))
-> Stms SinkLore -> f (Stms SinkLore)
forall a b. (a -> b) -> a -> b
$ (Stms SinkLore, Sunk) -> Stms SinkLore
forall a b. (a, b) -> a
fst ((Stms SinkLore, Sunk) -> Stms SinkLore)
-> (Stms SinkLore, Sunk) -> Stms SinkLore
forall a b. (a -> b) -> a -> b
$ SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
forall a. Monoid a => a
mempty Sinking
forall a. Monoid a => a
mempty Stms SinkLore
consts (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$
          [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Scope SinkLore -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope SinkLore -> [VName]) -> Scope SinkLore -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> Scope SinkLore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms SinkLore
consts