{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | "Sinking" is conceptually the opposite of hoisting.  The idea is
-- to take code that looks like this:
--
-- @
-- x = xs[i]
-- y = ys[i]
-- if x != 0 then {
--   y
-- } else {
--   0
-- }
-- @
--
-- and turn it into
--
-- @
-- x = xs[i]
-- if x != 0 then {
--   y = ys[i]
--   y
-- } else {
--   0
-- }
-- @
--
-- The idea is to delay loads from memory until (if) they are actually
-- needed.  Code patterns like the above is particularly common in
-- code that makes use of pattern matching on sum types.
--
-- We are currently quite conservative about when we do this.  In
-- particular, if any consumption is going on in a body, we don't do
-- anything.  This is far too conservative.  Also, we are careful
-- never to duplicate work.
--
-- This pass redundantly computes free-variable information a lot.  If
-- you ever see this pass as being a compilation speed bottleneck,
-- start by caching that a bit.
--
-- This pass is defined on the Kernels representation.  This is not
-- because we do anything kernel-specific here, but simply because
-- more explicit indexing is going on after SOACs are gone.
module Futhark.Optimise.Sink (sink) where

import Control.Monad.State
import Data.List (foldl')
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.Pass

type SinkLore = Aliases Kernels

type SymbolTable = ST.SymbolTable SinkLore

type Sinking = M.Map VName (Stm SinkLore)

type Sunk = S.Set VName

-- | Given a statement, compute how often each of its free variables
-- are used.  Not accurate: what we care about are only 1, and greater
-- than 1.
multiplicity :: Stm SinkLore -> M.Map VName Int
multiplicity :: Stm SinkLore -> Map VName Int
multiplicity Stm SinkLore
stm =
  case Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm of
    If SubExp
cond BodyT SinkLore
tbranch BodyT SinkLore
fbranch IfDec (BranchType SinkLore)
_ ->
      SubExp -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free SubExp
cond Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT SinkLore
tbranch Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free BodyT SinkLore
fbranch Int
1
    Op {} -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
2
    DoLoop {} -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
2
    Exp SinkLore
_ -> Stm SinkLore -> Int -> Map VName Int
forall a a. FreeIn a => a -> a -> Map VName a
free Stm SinkLore
stm Int
1
  where
    free :: a -> a -> Map VName a
free a
x a
k = [(VName, a)] -> Map VName a
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, a)] -> Map VName a) -> [(VName, a)] -> Map VName a
forall a b. (a -> b) -> a -> b
$ [VName] -> [a] -> [(VName, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x) ([a] -> [(VName, a)]) -> [a] -> [(VName, a)]
forall a b. (a -> b) -> a -> b
$ a -> [a]
forall a. a -> [a]
repeat a
k
    comb :: Map VName Int -> Map VName Int -> Map VName Int
comb = (Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)

optimiseBranch ::
  SymbolTable ->
  Sinking ->
  Body SinkLore ->
  (Body SinkLore, Sunk)
optimiseBranch :: SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking (Body BodyDec SinkLore
dec Stms SinkLore
stms Result
res) =
  let (Stms SinkLore
stms', Sunk
stms_sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking' Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
   in ( BodyDec SinkLore -> Stms SinkLore -> Result -> BodyT SinkLore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec SinkLore
dec (Stms SinkLore
sunk_stms Stms SinkLore -> Stms SinkLore -> Stms SinkLore
forall a. Semigroup a => a -> a -> a
<> Stms SinkLore
stms') Result
res,
        Sunk
sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
      )
  where
    free_in_stms :: Names
free_in_stms = Stms SinkLore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms SinkLore
stms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
    (Sinking
sinking_here, Sinking
sinking') = (VName -> Stm SinkLore -> Bool) -> Sinking -> (Sinking, Sinking)
forall k a. (k -> a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partitionWithKey VName -> Stm SinkLore -> Bool
sunkHere Sinking
sinking
    sunk_stms :: Stms SinkLore
sunk_stms = [Stm SinkLore] -> Stms SinkLore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm SinkLore] -> Stms SinkLore)
-> [Stm SinkLore] -> Stms SinkLore
forall a b. (a -> b) -> a -> b
$ Sinking -> [Stm SinkLore]
forall k a. Map k a -> [a]
M.elems Sinking
sinking_here
    sunkHere :: VName -> Stm SinkLore -> Bool
sunkHere VName
v Stm SinkLore
stm =
      VName
v VName -> Names -> Bool
`nameIn` Names
free_in_stms
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.available` SymbolTable
vtable) (Names -> [VName]
namesToList (Stm SinkLore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm SinkLore
stm))
    sunk :: Sunk
sunk = [VName] -> Sunk
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Sunk) -> [VName] -> Sunk
forall a b. (a -> b) -> a -> b
$ (Stm SinkLore -> [VName]) -> Stms SinkLore -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT (VarAliases, Type) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (VarAliases, Type) -> [VName])
-> (Stm SinkLore -> PatternT (VarAliases, Type))
-> Stm SinkLore
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SinkLore -> PatternT (VarAliases, Type)
forall lore. Stm lore -> Pattern lore
stmPattern) Stms SinkLore
sunk_stms

optimiseStms ::
  SymbolTable ->
  Sinking ->
  Stms SinkLore ->
  Names ->
  (Stms SinkLore, Sunk)
optimiseStms :: SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
init_vtable Sinking
init_sinking Stms SinkLore
all_stms Names
free_in_res =
  let ([Stm SinkLore]
all_stms', Sunk
sunk) =
        SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
init_vtable Sinking
init_sinking ([Stm SinkLore] -> ([Stm SinkLore], Sunk))
-> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> [Stm SinkLore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SinkLore
all_stms
   in ([Stm SinkLore] -> Stms SinkLore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm SinkLore]
all_stms', Sunk
sunk)
  where
    multiplicities :: Map VName Int
multiplicities =
      (Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
        ((Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
        ([(VName, Int)] -> Map VName Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList Names
free_in_res) (Int -> [Int]
forall a. a -> [a]
repeat Int
1)))
        ((Stm SinkLore -> Map VName Int)
-> [Stm SinkLore] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map Stm SinkLore -> Map VName Int
multiplicity ([Stm SinkLore] -> [Map VName Int])
-> [Stm SinkLore] -> [Map VName Int]
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> [Stm SinkLore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SinkLore
all_stms)

    optimiseStms' :: SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
_ Sinking
_ [] = ([], Sunk
forall a. Monoid a => a
mempty)
    optimiseStms' SymbolTable
vtable Sinking
sinking (Stm SinkLore
stm : [Stm SinkLore]
stms)
      | BasicOp Index {} <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm,
        [PatElemT (VarAliases, Type)
pe] <- PatternT (VarAliases, Type) -> [PatElemT (VarAliases, Type)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements (Stm SinkLore -> Pattern SinkLore
forall lore. Stm lore -> Pattern lore
stmPattern Stm SinkLore
stm),
        Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (VarAliases, Type) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (VarAliases, Type)
pe,
        Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) (Maybe Int -> Bool) -> Maybe Int -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe) Map VName Int
multiplicities =
        let ([Stm SinkLore]
stms', Sunk
sunk) =
              SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' (VName -> Stm SinkLore -> Sinking -> Sinking
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe) Stm SinkLore
stm Sinking
sinking) [Stm SinkLore]
stms
         in if PatElemT (VarAliases, Type) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (VarAliases, Type)
pe VName -> Sunk -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Sunk
sunk
              then ([Stm SinkLore]
stms', Sunk
sunk)
              else (Stm SinkLore
stm Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms', Sunk
sunk)
      | If SubExp
cond BodyT SinkLore
tbranch BodyT SinkLore
fbranch IfDec (BranchType SinkLore)
ret <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm =
        let (BodyT SinkLore
tbranch', Sunk
tsunk) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking BodyT SinkLore
tbranch
            (BodyT SinkLore
fbranch', Sunk
fsunk) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBranch SymbolTable
vtable Sinking
sinking BodyT SinkLore
fbranch
            ([Stm SinkLore]
stms', Sunk
sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
         in ( Stm SinkLore
stm {stmExp :: Exp SinkLore
stmExp = SubExp
-> BodyT SinkLore
-> BodyT SinkLore
-> IfDec (BranchType SinkLore)
-> Exp SinkLore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT SinkLore
tbranch' BodyT SinkLore
fbranch' IfDec (BranchType SinkLore)
ret} Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
              Sunk
tsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
fsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk
            )
      | Op (SegOp op) <- Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm =
        let scope :: Scope SinkLore
scope = SegSpace -> Scope SinkLore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope SinkLore) -> SegSpace -> Scope SinkLore
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel SinkLore -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel SinkLore
op
            ([Stm SinkLore]
stms', Sunk
stms_sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
            (SegOp SegLevel SinkLore
op', Sunk
op_sunk) = State Sunk (SegOp SegLevel SinkLore)
-> Sunk -> (SegOp SegLevel SinkLore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
-> SegOp SegLevel SinkLore -> State Sunk (SegOp SegLevel SinkLore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM (Scope SinkLore
-> SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
opMapper Scope SinkLore
scope) SegOp SegLevel SinkLore
op) Sunk
forall a. Monoid a => a
mempty
         in ( Stm SinkLore
stm {stmExp :: Exp SinkLore
stmExp = Op SinkLore -> Exp SinkLore
forall lore. Op lore -> ExpT lore
Op (SegOp SegLevel SinkLore -> HostOp SinkLore (SOAC SinkLore)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel SinkLore
op')} Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
              Sunk
stms_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
op_sunk
            )
      | Bool
otherwise =
        let ([Stm SinkLore]
stms', Sunk
stms_sunk) = SymbolTable -> Sinking -> [Stm SinkLore] -> ([Stm SinkLore], Sunk)
optimiseStms' SymbolTable
vtable' Sinking
sinking [Stm SinkLore]
stms
            (Exp SinkLore
e', Sunk
stm_sunk) = State Sunk (Exp SinkLore) -> Sunk -> (Exp SinkLore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (Mapper SinkLore SinkLore (StateT Sunk Identity)
-> Exp SinkLore -> State Sunk (Exp SinkLore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SinkLore SinkLore (StateT Sunk Identity)
mapper (Stm SinkLore -> Exp SinkLore
forall lore. Stm lore -> Exp lore
stmExp Stm SinkLore
stm)) Sunk
forall a. Monoid a => a
mempty
         in ( Stm SinkLore
stm {stmExp :: Exp SinkLore
stmExp = Exp SinkLore
e'} Stm SinkLore -> [Stm SinkLore] -> [Stm SinkLore]
forall a. a -> [a] -> [a]
: [Stm SinkLore]
stms',
              Sunk
stm_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
            )
      where
        vtable' :: SymbolTable
vtable' = Stm SinkLore -> SymbolTable -> SymbolTable
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm Stm SinkLore
stm SymbolTable
vtable
        mapper :: Mapper SinkLore SinkLore (StateT Sunk Identity)
mapper =
          Mapper SinkLore SinkLore (StateT Sunk Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
            { mapOnBody :: Scope SinkLore
-> BodyT SinkLore -> StateT Sunk Identity (BodyT SinkLore)
mapOnBody = \Scope SinkLore
scope BodyT SinkLore
body -> do
                let (BodyT SinkLore
body', Sunk
sunk) =
                      SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody (Scope SinkLore -> SymbolTable
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope SinkLore
scope SymbolTable -> SymbolTable -> SymbolTable
forall a. Semigroup a => a -> a -> a
<> SymbolTable
vtable) Sinking
sinking BodyT SinkLore
body
                (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
                BodyT SinkLore -> StateT Sunk Identity (BodyT SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT SinkLore
body'
            }

        opMapper :: Scope SinkLore
-> SegOpMapper SegLevel SinkLore SinkLore (StateT Sunk Identity)
opMapper Scope SinkLore
scope =
          SegOpMapper SegLevel Any Any (StateT Sunk Identity)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
            { mapOnSegOpLambda :: Lambda SinkLore -> StateT Sunk Identity (Lambda SinkLore)
mapOnSegOpLambda = \Lambda SinkLore
lam -> do
                let (BodyT SinkLore
body, Sunk
sunk) =
                      SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
op_vtable Sinking
sinking (BodyT SinkLore -> (BodyT SinkLore, Sunk))
-> BodyT SinkLore -> (BodyT SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$
                        Lambda SinkLore -> BodyT SinkLore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SinkLore
lam
                (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
                Lambda SinkLore -> StateT Sunk Identity (Lambda SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda SinkLore
lam {lambdaBody :: BodyT SinkLore
lambdaBody = BodyT SinkLore
body},
              mapOnSegOpBody :: KernelBody SinkLore -> StateT Sunk Identity (KernelBody SinkLore)
mapOnSegOpBody = \KernelBody SinkLore
body -> do
                let (KernelBody SinkLore
body', Sunk
sunk) =
                      SymbolTable
-> Sinking -> KernelBody SinkLore -> (KernelBody SinkLore, Sunk)
optimiseKernelBody SymbolTable
op_vtable Sinking
sinking KernelBody SinkLore
body
                (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
                KernelBody SinkLore -> StateT Sunk Identity (KernelBody SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody SinkLore
body'
            }
          where
            op_vtable :: SymbolTable
op_vtable = Scope SinkLore -> SymbolTable
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope SinkLore
scope SymbolTable -> SymbolTable -> SymbolTable
forall a. Semigroup a => a -> a -> a
<> SymbolTable
vtable

optimiseBody ::
  SymbolTable ->
  Sinking ->
  Body SinkLore ->
  (Body SinkLore, Sunk)
optimiseBody :: SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
vtable Sinking
sinking (Body BodyDec SinkLore
dec Stms SinkLore
stms Result
res) =
  let (Stms SinkLore
stms', Sunk
sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
   in (BodyDec SinkLore -> Stms SinkLore -> Result -> BodyT SinkLore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec SinkLore
dec Stms SinkLore
stms' Result
res, Sunk
sunk)

optimiseKernelBody ::
  SymbolTable ->
  Sinking ->
  KernelBody SinkLore ->
  (KernelBody SinkLore, Sunk)
optimiseKernelBody :: SymbolTable
-> Sinking -> KernelBody SinkLore -> (KernelBody SinkLore, Sunk)
optimiseKernelBody SymbolTable
vtable Sinking
sinking (KernelBody BodyDec SinkLore
dec Stms SinkLore
stms [KernelResult]
res) =
  let (Stms SinkLore
stms', Sunk
sunk) = SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
vtable Sinking
sinking Stms SinkLore
stms (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res
   in (BodyDec SinkLore
-> Stms SinkLore -> [KernelResult] -> KernelBody SinkLore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec SinkLore
dec Stms SinkLore
stms' [KernelResult]
res, Sunk
sunk)

-- | The pass definition.
sink :: Pass Kernels Kernels
sink :: Pass Kernels Kernels
sink =
  String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"sink" String
"move memory loads closer to their uses" ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
    (Prog SinkLore -> Prog Kernels)
-> PassM (Prog SinkLore) -> PassM (Prog Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog SinkLore -> Prog Kernels
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases
      (PassM (Prog SinkLore) -> PassM (Prog Kernels))
-> (Prog Kernels -> PassM (Prog SinkLore))
-> Prog Kernels
-> PassM (Prog Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms SinkLore -> PassM (Stms SinkLore))
-> (Stms SinkLore -> FunDef SinkLore -> PassM (FunDef SinkLore))
-> Prog SinkLore
-> PassM (Prog SinkLore)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms SinkLore -> PassM (Stms SinkLore)
forall (f :: * -> *).
Applicative f =>
Stms SinkLore -> f (Stms SinkLore)
onConsts Stms SinkLore -> FunDef SinkLore -> PassM (FunDef SinkLore)
forall (m :: * -> *) p.
Monad m =>
p -> FunDef SinkLore -> m (FunDef SinkLore)
onFun
      (Prog SinkLore -> PassM (Prog SinkLore))
-> (Prog Kernels -> Prog SinkLore)
-> Prog Kernels
-> PassM (Prog SinkLore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog Kernels -> Prog SinkLore
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
Alias.aliasAnalysis
  where
    onFun :: p -> FunDef SinkLore -> m (FunDef SinkLore)
onFun p
_ FunDef SinkLore
fd = do
      let vtable :: SymbolTable
vtable = [FParam SinkLore] -> SymbolTable -> SymbolTable
forall lore.
ASTLore lore =>
[FParam lore] -> SymbolTable lore -> SymbolTable lore
ST.insertFParams (FunDef SinkLore -> [FParam SinkLore]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SinkLore
fd) SymbolTable
forall a. Monoid a => a
mempty
          (BodyT SinkLore
body, Sunk
_) = SymbolTable -> Sinking -> BodyT SinkLore -> (BodyT SinkLore, Sunk)
optimiseBody SymbolTable
vtable Sinking
forall a. Monoid a => a
mempty (BodyT SinkLore -> (BodyT SinkLore, Sunk))
-> BodyT SinkLore -> (BodyT SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$ FunDef SinkLore -> BodyT SinkLore
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SinkLore
fd
      FunDef SinkLore -> m (FunDef SinkLore)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef SinkLore
fd {funDefBody :: BodyT SinkLore
funDefBody = BodyT SinkLore
body}

    onConsts :: Stms SinkLore -> f (Stms SinkLore)
onConsts Stms SinkLore
consts =
      Stms SinkLore -> f (Stms SinkLore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SinkLore -> f (Stms SinkLore))
-> Stms SinkLore -> f (Stms SinkLore)
forall a b. (a -> b) -> a -> b
$
        (Stms SinkLore, Sunk) -> Stms SinkLore
forall a b. (a, b) -> a
fst ((Stms SinkLore, Sunk) -> Stms SinkLore)
-> (Stms SinkLore, Sunk) -> Stms SinkLore
forall a b. (a -> b) -> a -> b
$
          SymbolTable
-> Sinking -> Stms SinkLore -> Names -> (Stms SinkLore, Sunk)
optimiseStms SymbolTable
forall a. Monoid a => a
mempty Sinking
forall a. Monoid a => a
mempty Stms SinkLore
consts (Names -> (Stms SinkLore, Sunk)) -> Names -> (Stms SinkLore, Sunk)
forall a b. (a -> b) -> a -> b
$
            [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Scope SinkLore -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope SinkLore -> [VName]) -> Scope SinkLore -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms SinkLore -> Scope SinkLore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms SinkLore
consts