{-# LANGUAGE TypeFamilies #-}

-- | The simplification engine is only willing to hoist allocations
-- out of loops if the memory block resulting from the allocation is
-- dead at the end of the loop.  If it is not, we may cause data
-- hazards.
--
-- This pass tries to rewrite loops with memory parameters.
-- Specifically, it takes loops of this form:
--
-- @
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...) -- stores A_out
--   in {..., A_out_mem, ..., A_out, ...}
-- }
-- @
--
-- and turns them into
--
-- @
-- let A_in_mem = alloc(...)
-- let A_out_mem = alloc(...)
-- let A_in = copy A -- in A_in_mem
-- loop {..., A_in_mem, A_out_mem, ..., A=A_in, ...} ... do {
--   ...
--   in {..., A_out_mem, A_mem, ..., A_out, ...}
-- }
-- @
--
-- The result is essentially "pointer swapping" between the two memory
-- initial blocks @A_mem@ and @A_out_mem@.  The invariant is that the
-- array is always stored in the "first" memory block at the beginning
-- of the loop (and also in the final result).  We do need to add an
-- extra element to the pattern, however.  The initial copy of @A@
-- could be elided if @A@ is unique (thus @A_in_mem=A_mem@).  This is
-- because only then is it safe to use @A_mem@ to store loop results.
-- We don't currently do this.
--
-- Unfortunately, not all loops fit the pattern above.  In particular,
-- a nested loop that has been transformed as such does not!
-- Therefore we also have another double buffering strategy, that
-- turns
--
-- @
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...)
--   -- A in A_out_mem
--   in {..., A_out_mem, ..., A, ...}
-- }
-- @
--
-- into
--
-- @
-- let A_res_mem = alloc(...)
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...)
--   -- A in A_out_mem
--   let A' = copy A
--   -- A' in A_res_mem
--   in {..., A_res_mem, ..., A, ...}
-- }
-- @
--
-- The allocation of A_out_mem can then be hoisted out because it is
-- dead at the end of the loop.  This always works as long as
-- A_out_mem has a loop-invariant allocation size, but requires a copy
-- per iteration (and an initial one, elided above).
module Futhark.Optimise.DoubleBuffer (doubleBufferGPU, doubleBufferMC) where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor
import Data.List (find)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (arraySizeInBytesExp)
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Transform.Substitute
import Futhark.Util (mapAccumLM, maybeHead)

-- | The pass for GPU kernels.
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU = forall rep (inner :: * -> *).
Mem rep inner =>
OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp GPUMem
optimiseGPUOp

-- | The pass for multicore
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC = forall rep (inner :: * -> *).
Mem rep inner =>
OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp MCMem
optimiseMCOp

-- | The double buffering pass definition.
doubleBuffer :: Mem rep inner => OptimiseOp rep -> Pass rep rep
doubleBuffer :: forall rep (inner :: * -> *).
Mem rep inner =>
OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp rep
onOp =
  Pass
    { passName :: String
passName = String
"Double buffer",
      passDescription :: String
passDescription = String
"Perform double buffering for merge parameters of sequential loops.",
      passFunction :: Prog rep -> PassM (Prog rep)
passFunction = forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope rep -> Stms rep -> PassM (Stms rep)
optimise
    }
  where
    optimise :: Scope rep -> Stms rep -> PassM (Stms rep)
optimise Scope rep
scope Stms rep
stms = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
      let m :: ReaderT (Env rep) (State VNameSource) (Stms rep)
m =
            forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM forall a b. (a -> b) -> a -> b
$ forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope forall a b. (a -> b) -> a -> b
$ forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
       in forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep) (State VNameSource) (Stms rep)
m Env rep
env) VNameSource
src

    env :: Env rep
env = forall rep.
Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
Env forall a. Monoid a => a
mempty forall {f :: * -> *} {a} {b} {c} {d}.
(Applicative f, Monoid a) =>
b -> c -> d -> f (a, b, c, d)
doNotTouchLoop OptimiseOp rep
onOp
    doNotTouchLoop :: b -> c -> d -> f (a, b, c, d)
doNotTouchLoop b
pat c
merge d
body = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, b
pat, c
merge, d
body)

type OptimiseLoop rep =
  Pat (LetDec rep) ->
  [(FParam rep, SubExp)] ->
  Body rep ->
  DoubleBufferM
    rep
    ( Stms rep,
      Pat (LetDec rep),
      [(FParam rep, SubExp)],
      Body rep
    )

type OptimiseOp rep =
  Op rep -> DoubleBufferM rep (Op rep)

data Env rep = Env
  { forall rep. Env rep -> Scope rep
envScope :: Scope rep,
    forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop :: OptimiseLoop rep,
    forall rep. Env rep -> OptimiseOp rep
envOptimiseOp :: OptimiseOp rep
  }

newtype DoubleBufferM rep a = DoubleBufferM
  { forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM :: ReaderT (Env rep) (State VNameSource) a
  }
  deriving (forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
$c<$ :: forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
fmap :: forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
$cfmap :: forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
Functor, forall rep. Functor (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
$c<* :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
*> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c*> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
liftA2 :: forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
<*> :: forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
$c<*> :: forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
pure :: forall a. a -> DoubleBufferM rep a
$cpure :: forall rep a. a -> DoubleBufferM rep a
Applicative, forall rep. Applicative (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DoubleBufferM rep a
$creturn :: forall rep a. a -> DoubleBufferM rep a
>> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c>> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
>>= :: forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
$c>>= :: forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
Monad, MonadReader (Env rep), DoubleBufferM rep VNameSource
VNameSource -> DoubleBufferM rep ()
forall rep. Monad (DoubleBufferM rep)
forall rep. DoubleBufferM rep VNameSource
forall rep. VNameSource -> DoubleBufferM rep ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> DoubleBufferM rep ()
$cputNameSource :: forall rep. VNameSource -> DoubleBufferM rep ()
getNameSource :: DoubleBufferM rep VNameSource
$cgetNameSource :: forall rep. DoubleBufferM rep VNameSource
MonadFreshNames)

instance ASTRep rep => HasScope rep (DoubleBufferM rep) where
  askScope :: DoubleBufferM rep (Scope rep)
askScope = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep. Env rep -> Scope rep
envScope

instance ASTRep rep => LocalScope rep (DoubleBufferM rep) where
  localScope :: forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
localScope Scope rep
scope = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep
env -> Env rep
env {envScope :: Scope rep
envScope = forall rep. Env rep -> Scope rep
envScope Env rep
env forall a. Semigroup a => a -> a -> a
<> Scope rep
scope}

optimiseBody :: ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody :: forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body = do
  Stms rep
stms' <- forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms'}

optimiseStms :: ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms :: forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
optimiseStms (Stm rep
e : [Stm rep]
es) = do
  Stms rep
e_es <- forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm Stm rep
e
  Stms rep
es' <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope forall a b. (a -> b) -> a -> b
$ forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
e_es) forall a b. (a -> b) -> a -> b
$ forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms [Stm rep]
es
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms rep
e_es forall a. Semigroup a => a -> a -> a
<> Stms rep
es'

optimiseStm :: forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm :: forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form Body rep
body)) = do
  Body rep
body' <-
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm rep
form forall a. Semigroup a => a -> a -> a
<> forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge)) forall a b. (a -> b) -> a -> b
$
      forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body
  Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
opt_loop <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop
  (Stms rep
stms, Pat (LetDec rep)
pat', [(FParam rep, SubExp)]
merge', Body rep
body'') <- Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
opt_loop Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm (forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat' StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam rep, SubExp)]
merge' LoopForm rep
form Body rep
body'')
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = do
  OpC rep rep -> DoubleBufferM rep (OpC rep rep)
onOp <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep. Env rep -> OptimiseOp rep
envOptimiseOp
  forall rep. Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((OpC rep rep -> DoubleBufferM rep (OpC rep rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise OpC rep rep -> DoubleBufferM rep (OpC rep rep)
onOp) Exp rep
e
  where
    optimise :: (OpC rep rep -> DoubleBufferM rep (OpC rep rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise OpC rep rep -> DoubleBufferM rep (OpC rep rep)
onOp =
      forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope rep -> Body rep -> DoubleBufferM rep (Body rep)
mapOnBody = \Scope rep
_ Body rep
x ->
            forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
x :: DoubleBufferM rep (Body rep),
          mapOnOp :: OpC rep rep -> DoubleBufferM rep (OpC rep rep)
mapOnOp = OpC rep rep -> DoubleBufferM rep (OpC rep rep)
onOp
        }

optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
  forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall {rep} {inner :: * -> *}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem, BodyDec rep ~ (),
 BranchType rep ~ BranchTypeMem, ExpDec rep ~ (),
 OpC rep ~ MemOp inner, HasLetDecMem (LetDec rep),
 OpReturns (inner rep), RephraseOp inner, BuilderOps rep) =>
Env rep -> Env rep
inSegOp forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper SegOp SegLevel GPUMem
op
  where
    mapper :: SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper =
      forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda GPUMem -> DoubleBufferM GPUMem (Lambda GPUMem)
mapOnSegOpLambda = forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda,
          mapOnSegOpBody :: KernelBody GPUMem -> DoubleBufferM GPUMem (KernelBody GPUMem)
mapOnSegOpBody = forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop :: OptimiseLoop rep
envOptimiseLoop = forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoop}
optimiseGPUOp Op GPUMem
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure Op GPUMem
op

optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp (Inner (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op)) =
  forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall {rep} {inner :: * -> *}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem, BodyDec rep ~ (),
 BranchType rep ~ BranchTypeMem, ExpDec rep ~ (),
 OpC rep ~ MemOp inner, HasLetDecMem (LetDec rep),
 OpReturns (inner rep), RephraseOp inner, BuilderOps rep) =>
Env rep -> Env rep
inSegOp forall a b. (a -> b) -> a -> b
$
    forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper) Maybe (SegOp () MCMem)
par_op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper SegOp () MCMem
op)
  where
    mapper :: SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper =
      forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda MCMem -> DoubleBufferM MCMem (Lambda MCMem)
mapOnSegOpLambda = forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda,
          mapOnSegOpBody :: KernelBody MCMem -> DoubleBufferM MCMem (KernelBody MCMem)
mapOnSegOpBody = forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop :: OptimiseLoop rep
envOptimiseLoop = forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoop}
optimiseMCOp Op MCMem
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure Op MCMem
op

optimiseKernelBody ::
  ASTRep rep =>
  KernelBody rep ->
  DoubleBufferM rep (KernelBody rep)
optimiseKernelBody :: forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody KernelBody rep
kbody = do
  Stms rep
stms' <- forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelBody rep
kbody {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms'}

optimiseLambda ::
  ASTRep rep =>
  Lambda rep ->
  DoubleBufferM rep (Lambda rep)
optimiseLambda :: forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda Lambda rep
lam = do
  Body rep
body <- forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope forall a b. (a -> b) -> a -> b
$ forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body}

type Constraints rep inner =
  ( Mem rep inner,
    BuilderOps rep,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    LetDec rep ~ LetDecMem
  )

extractAllocOf :: Constraints rep inner => Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf :: forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound VName
needle Stms rep
stms = do
  (Stm rep
stm, Stms rep
stms') <- forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
  case Stm rep
stm of
    Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
size Space
_))
      | forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall a. Eq a => a -> a -> Bool
== VName
needle,
        SubExp -> Bool
invariant SubExp
size ->
          forall a. a -> Maybe a
Just (Stm rep
stm, Stms rep
stms')
    Stm rep
_ ->
      let bound' :: Names
bound' = [VName] -> Names
namesFromList (forall dec. Pat dec -> [VName]
patNames (forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm)) forall a. Semigroup a => a -> a -> a
<> Names
bound
       in forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound' VName
needle Stms rep
stms'
  where
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound

optimiseLoop :: Constraints rep inner => OptimiseLoop rep
optimiseLoop :: forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoop Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body = do
  (Stms rep
outer_stms_1, Pat LParamMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body') <-
    forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoopBySwitching Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body
  (Stms rep
outer_stms_2, Pat LParamMem
pat'', [(Param FParamMem, SubExp)]
merge'', Body rep
body'') <-
    forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms rep
outer_stms_1 forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoopByCopying Pat LParamMem
pat' [(Param FParamMem, SubExp)]
merge' Body rep
body'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
outer_stms_1 forall a. Semigroup a => a -> a -> a
<> Stms rep
outer_stms_2, Pat LParamMem
pat'', [(Param FParamMem, SubExp)]
merge'', Body rep
body'')

isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn VName
x (Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
y IxFun
_))) = VName
x forall a. Eq a => a -> a -> Bool
== VName
y
isArrayIn VName
_ Param FParamMem
_ = Bool
False

optimiseLoopBySwitching :: Constraints rep inner => OptimiseLoop rep
optimiseLoopBySwitching :: forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoopBySwitching (Pat [PatElem (LetDec rep)]
pes) [(FParam rep, SubExp)]
merge (Body BodyDec rep
_ Stms rep
body_stms Result
body_res) = do
  ((Pat LParamMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body'), Stms rep
outer_stms) <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
    ((Map VName VName
buffered, Stms rep
body_stms'), ([[PatElem LParamMem]]
pes', [[(Param FParamMem, SubExp)]]
merge', [Result]
body_res')) <-
      forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (Map VName VName, Stms rep)
-> (PatElem LParamMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
check (forall a. Monoid a => a
mempty, Stms rep
body_stms) (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
pes [(FParam rep, SubExp)]
merge Result
body_res)
    [(Param FParamMem, SubExp)]
merge'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *} {inner :: * -> *} {d}.
(LParamInfo (Rep m) ~ LParamMem, LetDec (Rep m) ~ LParamMem,
 FParamInfo (Rep m) ~ FParamMem, BranchType (Rep m) ~ BranchTypeMem,
 RetType (Rep m) ~ RetTypeMem, OpC (Rep m) ~ MemOp inner,
 MonadBuilder m, RephraseOp inner, HasLetDecMem (LetDec (Rep m)),
 OpReturns (inner (Rep m))) =>
Map VName VName
-> (Param (MemInfo d Uniqueness MemBind), SubExp)
-> m (Param (MemInfo d Uniqueness MemBind), SubExp)
maybeCopyInitial Map VName VName
buffered) forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [[(Param FParamMem, SubExp)]]
merge'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [[PatElem LParamMem]]
pes', [(Param FParamMem, SubExp)]
merge'', forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms rep
body_stms' forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [Result]
body_res')
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
outer_stms, Pat LParamMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body')
  where
    merge_bound :: Names
merge_bound = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge

    check :: (Map VName VName, Stms rep)
-> (PatElem LParamMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
check (Map VName VName
buffered, Stms rep
body_stms') (PatElem LParamMem
pe, (Param FParamMem
param, SubExp
arg), SubExpRes
res)
      | Mem Space
space <- forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
param,
        Var VName
arg_v <- SubExp
arg,
        -- XXX: what happens if there are multiple arrays in the same
        -- memory block?
        [Param FParamMem
arr_param] <- forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Param FParamMem -> Bool
isArrayIn (forall dec. Param dec -> VName
paramName Param FParamMem
param)) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge,
        MemArray PrimType
pt ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
_ IxFun
ixfun) <- forall dec. Param dec -> dec
paramDec Param FParamMem
arr_param,
        Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Names
merge_bound Names -> Names -> Bool
`namesIntersect` forall a. FreeIn a => a -> Names
freeIn (forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun),
        Var VName
res_v <- SubExpRes -> SubExp
resSubExp SubExpRes
res,
        Just (Stm rep
res_v_alloc, Stms rep
body_stms'') <- forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
merge_bound VName
res_v Stms rep
body_stms' = do
          SubExp
num_bytes <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_bytes" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a. Num a => PrimType -> a
primByteSize PrimType
pt forall a. a -> [a] -> [a]
: forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun)
          VName
arr_mem_in <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arg_v forall a. Semigroup a => a -> a -> a
<> String
"_in") forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
num_bytes Space
space
          PatElem LParamMem
pe_unused <-
            forall dec. VName -> dec -> PatElem dec
PatElem
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString (forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) forall a. Semigroup a => a -> a -> a
<> String
"_unused")
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)
          Param FParamMem
param_out <-
            forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString (forall dec. Param dec -> VName
paramName Param FParamMem
param) forall a. Semigroup a => a -> a -> a
<> String
"_out") (forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)
          forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm rep
res_v_alloc
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( ( forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. Param dec -> VName
paramName Param FParamMem
param) VName
arr_mem_in Map VName VName
buffered,
                forall a. Substitute a => Map VName VName -> a -> a
substituteNames (forall k a. k -> a -> Map k a
M.singleton VName
res_v (forall dec. Param dec -> VName
paramName Param FParamMem
param_out)) Stms rep
body_stms''
              ),
              ( [PatElem LParamMem
pe, PatElem LParamMem
pe_unused],
                [(Param FParamMem
param, VName -> SubExp
Var VName
arr_mem_in), (Param FParamMem
param_out, SubExpRes -> SubExp
resSubExp SubExpRes
res)],
                [ SubExpRes
res {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param FParamMem
param_out},
                  SubExp -> SubExpRes
subExpRes forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param FParamMem
param
                ]
              )
            )
      | Bool
otherwise =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( (Map VName VName
buffered, Stms rep
body_stms'),
              ([PatElem LParamMem
pe], [(Param FParamMem
param, SubExp
arg)], [SubExpRes
res])
            )

    maybeCopyInitial :: Map VName VName
-> (Param (MemInfo d Uniqueness MemBind), SubExp)
-> m (Param (MemInfo d Uniqueness MemBind), SubExp)
maybeCopyInitial Map VName VName
buffered (param :: Param (MemInfo d Uniqueness MemBind)
param@(Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase d
_ Uniqueness
_ (ArrayIn VName
mem IxFun
_))), Var VName
arg)
      | Just VName
mem' <- VName
mem forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName VName
buffered = do
          LParamMem
arg_info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
arg
          case LParamMem
arg_info of
            MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
_ IxFun
arg_ixfun) -> do
              VName
arg_copy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
arg forall a. Semigroup a => a -> a -> a
<> String
"_dbcopy")
              forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
arg_copy forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem' IxFun
arg_ixfun]) forall a b. (a -> b) -> a -> b
$
                forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                  VName -> BasicOp
Copy VName
arg
              -- We need to make this parameter unique to avoid invalid
              -- hoisting (see #1533), because we are invalidating the
              -- underlying memory.
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {d} {ret}.
MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
mkUnique Param (MemInfo d Uniqueness MemBind)
param, VName -> SubExp
Var VName
arg_copy)
            LParamMem
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {d} {ret}.
MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
mkUnique Param (MemInfo d Uniqueness MemBind)
param, VName -> SubExp
Var VName
arg)
    maybeCopyInitial Map VName VName
_ (Param (MemInfo d Uniqueness MemBind)
param, SubExp
arg) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo d Uniqueness MemBind)
param, SubExp
arg)

    mkUnique :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
mkUnique (MemArray PrimType
bt ShapeBase d
shape Uniqueness
_ ret
ret) = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape Uniqueness
Unique ret
ret
    mkUnique MemInfo d Uniqueness ret
x = MemInfo d Uniqueness ret
x

optimiseLoopByCopying :: Constraints rep inner => OptimiseLoop rep
optimiseLoopByCopying :: forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoopByCopying Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body = do
  -- We start out by figuring out which of the merge variables should
  -- be double-buffered.
  [DoubleBuffer]
buffered <-
    forall (m :: * -> *).
MonadFreshNames m =>
[(Param FParamMem, SubExpRes)] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams
      (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge) (forall rep. Body rep -> Result
bodyResult Body rep
body))
      (forall rep. Body rep -> Names
boundInBody Body rep
body)
  -- Then create the allocations of the buffers and copies of the
  -- initial values.
  ([(Param FParamMem, SubExp)]
merge', [Stm rep]
allocs) <- forall rep (inner :: * -> *).
Constraints rep inner =>
[(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms [(FParam rep, SubExp)]
merge [DoubleBuffer]
buffered
  -- Modify the loop body to copy buffered result arrays.
  let body' :: Body rep
body' = forall rep (inner :: * -> *).
Constraints rep inner =>
[FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
doubleBufferResult (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge) [DoubleBuffer]
buffered Body rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
allocs, Pat (LetDec rep)
pat, [(Param FParamMem, SubExp)]
merge', Body rep
body')

-- | The booleans indicate whether we should also play with the
-- initial merge values.
data DoubleBuffer
  = BufferAlloc VName (PrimExp VName) Space Bool
  | -- | First name is the memory block to copy to,
    -- second is the name of the array copy.
    BufferCopy VName IxFun VName Bool
  | NoBuffer
  deriving (Int -> DoubleBuffer -> ShowS
[DoubleBuffer] -> ShowS
DoubleBuffer -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DoubleBuffer] -> ShowS
$cshowList :: [DoubleBuffer] -> ShowS
show :: DoubleBuffer -> String
$cshow :: DoubleBuffer -> String
showsPrec :: Int -> DoubleBuffer -> ShowS
$cshowsPrec :: Int -> DoubleBuffer -> ShowS
Show)

doubleBufferMergeParams ::
  MonadFreshNames m =>
  [(Param FParamMem, SubExpRes)] ->
  Names ->
  m [DoubleBuffer]
doubleBufferMergeParams :: forall (m :: * -> *).
MonadFreshNames m =>
[(Param FParamMem, SubExpRes)] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams [(Param FParamMem, SubExpRes)]
ctx_and_res Names
bound_in_loop =
  forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param FParamMem, SubExpRes)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer [(Param FParamMem, SubExpRes)]
ctx_and_res) forall k a. Map k a
M.empty
  where
    params :: [Param FParamMem]
params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param FParamMem, SubExpRes)]
ctx_and_res
    loopVariant :: VName -> Bool
loopVariant VName
v =
      VName
v
        VName -> Names -> Bool
`nameIn` Names
bound_in_loop
        Bool -> Bool -> Bool
|| VName
v
        forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExpRes)]
ctx_and_res

    loopInvariantSize :: SubExp -> Maybe (SubExp, Bool)
loopInvariantSize (Constant PrimValue
v) =
      forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
v, Bool
True)
    loopInvariantSize (Var VName
v) =
      case forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExpRes)]
ctx_and_res of
        Just (Param FParamMem
_, SubExpRes Certs
_ (Constant PrimValue
val)) ->
          forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
val, Bool
False)
        Just (Param FParamMem
_, SubExpRes Certs
_ (Var VName
v'))
          | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName -> Bool
loopVariant VName
v' ->
              forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v', Bool
False)
        Just (Param FParamMem, SubExpRes)
_ ->
          forall a. Maybe a
Nothing
        Maybe (Param FParamMem, SubExpRes)
Nothing ->
          forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v, Bool
True)

    sizeForMem :: VName -> Maybe (PrimExp VName, Bool)
sizeForMem VName
mem = forall a. [a] -> Maybe a
maybeHead forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> dec
paramDec) [Param FParamMem]
params
      where
        arrayInMem :: FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
arraymem IxFun
ixfun))
          | forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun,
            Just ([SubExp]
dims, [Bool]
b) <-
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM SubExp -> Maybe (SubExp, Bool)
loopInvariantSize forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape,
            VName
mem forall a. Eq a => a -> a -> Bool
== VName
arraymem =
              forall a. a -> Maybe a
Just
                ( Type -> PrimExp VName
arraySizeInBytesExp forall a b. (a -> b) -> a -> b
$
                    forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness,
                  forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
b
                )
        arrayInMem FParamMem
_ = forall a. Maybe a
Nothing

    buffer :: (Param FParamMem, SubExpRes)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer (Param FParamMem
fparam, SubExpRes
res) = case forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam of
      Mem Space
space
        | Just (PrimExp VName
size, Bool
b) <- VName -> Maybe (PrimExp VName, Bool)
sizeForMem forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param FParamMem
fparam,
          Var VName
res_v <- SubExpRes -> SubExp
resSubExp SubExpRes
res,
          VName
res_v VName -> Names -> Bool
`nameIn` Names
bound_in_loop -> do
            -- Let us double buffer this!
            VName
bufname <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_mem"
            forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. Param dec -> VName
paramName Param FParamMem
fparam) (VName
bufname, Bool
b)
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Space -> Bool -> DoubleBuffer
BufferAlloc VName
bufname PrimExp VName
size Space
space Bool
b
      Array {}
        | MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) <- forall dec. Param dec -> dec
paramDec Param FParamMem
fparam -> do
            Maybe (VName, Bool)
buffered <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem
            case Maybe (VName, Bool)
buffered of
              Just (VName
bufname, Bool
b) -> do
                VName
copyname <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_array"
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> VName -> Bool -> DoubleBuffer
BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
b
              Maybe (VName, Bool)
Nothing ->
                forall (f :: * -> *) a. Applicative f => a -> f a
pure DoubleBuffer
NoBuffer
      Type
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure DoubleBuffer
NoBuffer

allocStms ::
  Constraints rep inner =>
  [(FParam rep, SubExp)] ->
  [DoubleBuffer] ->
  DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms :: forall rep (inner :: * -> *).
Constraints rep inner =>
[(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms [(FParam rep, SubExp)]
merge = forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {rep} {rep} {inner :: * -> *} {inner :: * -> *}
       {t :: (* -> *) -> * -> *} {m :: * -> *} {d} {ret}.
(RetType rep ~ RetTypeMem, ExpDec rep ~ (), LetDec rep ~ LParamMem,
 LetDec rep ~ LetDec rep, LParamInfo rep ~ LParamMem,
 LParamInfo rep ~ LParamInfo rep, BranchType rep ~ BranchTypeMem,
 OpC rep ~ MemOp inner, OpC rep ~ MemOp inner,
 FParamInfo rep ~ FParamMem, FParamInfo rep ~ FParamInfo rep,
 MonadWriter [Stm rep] (t m), Typed (MemInfo d Uniqueness ret),
 HasLetDecMem (LetDec rep), HasScope rep m, OpReturns (inner rep),
 MonadFreshNames m, MonadTrans t, BuilderOps rep, ASTRep rep,
 RephraseOp inner) =>
(Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation [(FParam rep, SubExp)]
merge
  where
    allocation :: (Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation m :: (Param (MemInfo d Uniqueness ret), SubExp)
m@(Param Attrs
attrs VName
pname MemInfo d Uniqueness ret
_, SubExp
_) (BufferAlloc VName
name PrimExp VName
size Space
space Bool
b) = do
      Stms rep
stms <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
          SubExp
size' <- forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"double_buffer_size" PrimExp VName
size
          forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size' Space
space
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
      if Bool
b
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pname forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, VName -> SubExp
Var VName
name)
        else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo d Uniqueness ret), SubExp)
m
    allocation (Param (MemInfo d Uniqueness ret)
f, Var VName
v) (BufferCopy VName
mem IxFun
_ VName
_ Bool
b) | Bool
b = do
      VName
v_copy <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v forall a. [a] -> [a] -> [a]
++ String
"_double_buffer_copy"
      (VName
_v_mem, IxFun
v_ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
      let bt :: PrimType
bt = forall shape u. TypeBase shape u -> PrimType
elemType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          shape :: ShapeBase SubExp
shape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          bound :: LParamMem
bound = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase SubExp
shape NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
v_ixfun
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy LParamMem
bound]) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v]
      -- It is important that we treat this as a consumption, to
      -- avoid the Copy from being hoisted out of any enclosing
      -- loops.  Since we re-use (=overwrite) memory in the loop,
      -- the copy is critical for initialisation.  See issue #816.
      let uniqueMemInfo :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemArray PrimType
pt ShapeBase d
pshape Uniqueness
_ ret
ret) =
            forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pshape Uniqueness
Unique ret
ret
          uniqueMemInfo MemInfo d Uniqueness ret
info = MemInfo d Uniqueness ret
info
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {d} {ret}.
MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo d Uniqueness ret)
f, VName -> SubExp
Var VName
v_copy)
    allocation (Param (MemInfo d Uniqueness ret)
f, SubExp
se) DoubleBuffer
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo d Uniqueness ret)
f, SubExp
se)

doubleBufferResult ::
  Constraints rep inner =>
  [FParam rep] ->
  [DoubleBuffer] ->
  Body rep ->
  Body rep
doubleBufferResult :: forall rep (inner :: * -> *).
Constraints rep inner =>
[FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
doubleBufferResult [FParam rep]
valparams [DoubleBuffer]
buffered (Body BodyDec rep
_ Stms rep
stms Result
res) =
  let (Result
ctx_res, Result
val_res) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
valparams) Result
res
      ([Maybe (Stm rep)]
copystms, Result
val_res') =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param FParamMem
-> DoubleBuffer -> SubExpRes -> (Maybe (Stm rep), SubExpRes)
buffer [FParam rep]
valparams [DoubleBuffer]
buffered Result
val_res
   in forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms rep
stms forall a. Semigroup a => a -> a -> a
<> forall rep. [Stm rep] -> Stms rep
stmsFromList (forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm rep)]
copystms)) forall a b. (a -> b) -> a -> b
$ Result
ctx_res forall a. [a] -> [a] -> [a]
++ Result
val_res'
  where
    buffer :: Param FParamMem
-> DoubleBuffer -> SubExpRes -> (Maybe (Stm rep), SubExpRes)
buffer Param FParamMem
_ (BufferAlloc VName
bufname PrimExp VName
_ Space
_ Bool
_) SubExpRes
se =
      (forall a. Maybe a
Nothing, SubExpRes
se {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var VName
bufname})
    buffer Param FParamMem
fparam (BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
_) (SubExpRes Certs
cs (Var VName
v)) =
      -- To construct the copy we will need to figure out its type
      -- based on the type of the function parameter.
      let t :: Type
t = Type -> Type
resultType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam
          summary :: LParamMem
summary = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
bufname IxFun
ixfun
          copystm :: Stm rep
copystm =
            forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
copyname LParamMem
summary]) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$
              forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                VName -> BasicOp
Copy VName
v
       in (forall a. a -> Maybe a
Just Stm rep
copystm, Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (VName -> SubExp
Var VName
copyname))
    buffer Param FParamMem
_ DoubleBuffer
_ SubExpRes
se =
      (forall a. Maybe a
Nothing, SubExpRes
se)

    parammap :: Map VName SubExp
parammap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [FParam rep]
valparams) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res

    resultType :: Type -> Type
resultType Type
t = Type
t forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substitute (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t)

    substitute :: SubExp -> SubExp
substitute (Var VName
v)
      | Just SubExp
replacement <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
replacement
    substitute SubExp
se =
      SubExp
se