{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | The simplification engine is only willing to hoist allocations
-- out of loops if the memory block resulting from the allocation is
-- dead at the end of the loop.  If it is not, we may cause data
-- hazards.
--
-- This pass tries to rewrite loops with memory parameters.
-- Specifically, it takes loops of this form:
--
-- @
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...) -- stores A_out
--   in {..., A_out_mem, ..., A_out, ...}
-- }
-- @
--
-- and turns them into
--
-- @
-- let A_in_mem = alloc(...)
-- let A_out_mem = alloc(...)
-- let A_in = copy A -- in A_in_mem
-- loop {..., A_in_mem, A_out_mem, ..., A=A_in, ...} ... do {
--   ...
--   in {..., A_out_mem, A_mem, ..., A_out, ...}
-- }
-- @
--
-- The result is essentially "pointer swapping" between the two memory
-- initial blocks @A_mem@ and @A_out_mem@.  The invariant is that the
-- array is always stored in the "first" memory block at the beginning
-- of the loop (and also in the final result).  We do need to add an
-- extra element to the pattern, however.  The initial copy of @A@
-- could be elided if @A@ is unique (thus @A_in_mem=A_mem@).  This is
-- because only then is it safe to use @A_mem@ to store loop results.
-- We don't currently do this.
--
-- Unfortunately, not all loops fit the pattern above.  In particular,
-- a nested loop that has been transformed as such does not!
-- Therefore we also have another double buffering strategy, that
-- turns
--
-- @
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...)
--   -- A in A_out_mem
--   in {..., A_out_mem, ..., A, ...}
-- }
-- @
--
-- into
--
-- @
-- let A_res_mem = alloc(...)
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...)
--   -- A in A_out_mem
--   let A' = copy A
--   -- A' in A_res_mem
--   in {..., A_res_mem, ..., A, ...}
-- }
-- @
--
-- The allocation of A_out_mem can then be hoisted out because it is
-- dead at the end of the loop.  This always works as long as
-- A_out_mem has a loop-invariant allocation size, but requires a copy
-- per iteration (and an initial one, elided above).
module Futhark.Optimise.DoubleBuffer (doubleBufferGPU, doubleBufferMC) where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor
import Data.List (find)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (arraySizeInBytesExp)
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Transform.Substitute
import Futhark.Util (mapAccumLM, maybeHead)

-- | The pass for GPU kernels.
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU = OptimiseOp GPUMem -> Pass GPUMem GPUMem
forall rep inner. Mem rep inner => OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp GPUMem
optimiseGPUOp

-- | The pass for multicore
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC = OptimiseOp MCMem -> Pass MCMem MCMem
forall rep inner. Mem rep inner => OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp MCMem
optimiseMCOp

-- | The double buffering pass definition.
doubleBuffer :: Mem rep inner => OptimiseOp rep -> Pass rep rep
doubleBuffer :: OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp rep
onOp =
  Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    { passName :: String
passName = String
"Double buffer",
      passDescription :: String
passDescription = String
"Perform double buffering for merge parameters of sequential loops.",
      passFunction :: Prog rep -> PassM (Prog rep)
passFunction = (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope rep -> Stms rep -> PassM (Stms rep)
optimise
    }
  where
    optimise :: Scope rep -> Stms rep -> PassM (Stms rep)
optimise Scope rep
scope Stms rep
stms = (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep))
-> (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
      let m :: ReaderT (Env rep) (State VNameSource) (Stms rep)
m =
            DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM (DoubleBufferM rep (Stms rep)
 -> ReaderT (Env rep) (State VNameSource) (Stms rep))
-> DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall a b. (a -> b) -> a -> b
$ Scope rep
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope (DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep))
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
       in State VNameSource (Stms rep)
-> VNameSource -> (Stms rep, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep) (State VNameSource) (Stms rep)
-> Env rep -> State VNameSource (Stms rep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep) (State VNameSource) (Stms rep)
m Env rep
env) VNameSource
src

    env :: Env rep
env = Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
forall rep.
Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
Env Scope rep
forall a. Monoid a => a
mempty OptimiseLoop rep
forall (f :: * -> *) a b c d.
(Applicative f, Monoid a) =>
b -> c -> d -> f (a, b, c, d)
doNotTouchLoop OptimiseOp rep
onOp
    doNotTouchLoop :: b -> c -> d -> f (a, b, c, d)
doNotTouchLoop b
pat c
merge d
body = (a, b, c, d) -> f (a, b, c, d)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
forall a. Monoid a => a
mempty, b
pat, c
merge, d
body)

type OptimiseLoop rep =
  Pat (LetDec rep) ->
  [(FParam rep, SubExp)] ->
  Body rep ->
  DoubleBufferM
    rep
    ( Stms rep,
      Pat (LetDec rep),
      [(FParam rep, SubExp)],
      Body rep
    )

type OptimiseOp rep =
  Op rep -> DoubleBufferM rep (Op rep)

data Env rep = Env
  { Env rep -> Scope rep
envScope :: Scope rep,
    Env rep -> OptimiseLoop rep
envOptimiseLoop :: OptimiseLoop rep,
    Env rep -> OptimiseOp rep
envOptimiseOp :: OptimiseOp rep
  }

newtype DoubleBufferM rep a = DoubleBufferM
  { DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM :: ReaderT (Env rep) (State VNameSource) a
  }
  deriving (a -> DoubleBufferM rep b -> DoubleBufferM rep a
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
(forall a b.
 (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Functor (DoubleBufferM rep)
forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DoubleBufferM rep b -> DoubleBufferM rep a
$c<$ :: forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
fmap :: (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
$cfmap :: forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
Functor, Functor (DoubleBufferM rep)
a -> DoubleBufferM rep a
Functor (DoubleBufferM rep)
-> (forall a. a -> DoubleBufferM rep a)
-> (forall a b.
    DoubleBufferM rep (a -> b)
    -> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b c.
    (a -> b -> c)
    -> DoubleBufferM rep a
    -> DoubleBufferM rep b
    -> DoubleBufferM rep c)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Applicative (DoubleBufferM rep)
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall rep. Functor (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
$c<* :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
*> :: DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c*> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
liftA2 :: (a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
<*> :: DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
$c<*> :: forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
pure :: a -> DoubleBufferM rep a
$cpure :: forall rep a. a -> DoubleBufferM rep a
$cp1Applicative :: forall rep. Functor (DoubleBufferM rep)
Applicative, Applicative (DoubleBufferM rep)
a -> DoubleBufferM rep a
Applicative (DoubleBufferM rep)
-> (forall a b.
    DoubleBufferM rep a
    -> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a. a -> DoubleBufferM rep a)
-> Monad (DoubleBufferM rep)
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep. Applicative (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DoubleBufferM rep a
$creturn :: forall rep a. a -> DoubleBufferM rep a
>> :: DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c>> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
>>= :: DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
$c>>= :: forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
$cp1Monad :: forall rep. Applicative (DoubleBufferM rep)
Monad, MonadReader (Env rep), Monad (DoubleBufferM rep)
Applicative (DoubleBufferM rep)
DoubleBufferM rep VNameSource
Applicative (DoubleBufferM rep)
-> Monad (DoubleBufferM rep)
-> DoubleBufferM rep VNameSource
-> (VNameSource -> DoubleBufferM rep ())
-> MonadFreshNames (DoubleBufferM rep)
VNameSource -> DoubleBufferM rep ()
forall rep. Monad (DoubleBufferM rep)
forall rep. Applicative (DoubleBufferM rep)
forall rep. DoubleBufferM rep VNameSource
forall rep. VNameSource -> DoubleBufferM rep ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> DoubleBufferM rep ()
$cputNameSource :: forall rep. VNameSource -> DoubleBufferM rep ()
getNameSource :: DoubleBufferM rep VNameSource
$cgetNameSource :: forall rep. DoubleBufferM rep VNameSource
$cp2MonadFreshNames :: forall rep. Monad (DoubleBufferM rep)
$cp1MonadFreshNames :: forall rep. Applicative (DoubleBufferM rep)
MonadFreshNames)

instance ASTRep rep => HasScope rep (DoubleBufferM rep) where
  askScope :: DoubleBufferM rep (Scope rep)
askScope = (Env rep -> Scope rep) -> DoubleBufferM rep (Scope rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> Scope rep
forall rep. Env rep -> Scope rep
envScope

instance ASTRep rep => LocalScope rep (DoubleBufferM rep) where
  localScope :: Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
localScope Scope rep
scope = (Env rep -> Env rep) -> DoubleBufferM rep a -> DoubleBufferM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep -> Env rep)
 -> DoubleBufferM rep a -> DoubleBufferM rep a)
-> (Env rep -> Env rep)
-> DoubleBufferM rep a
-> DoubleBufferM rep a
forall a b. (a -> b) -> a -> b
$ \Env rep
env -> Env rep
env {envScope :: Scope rep
envScope = Env rep -> Scope rep
forall rep. Env rep -> Scope rep
envScope Env rep
env Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> Scope rep
scope}

optimiseBody :: ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody :: Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body = do
  Stms rep
stms' <- [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
  Body rep -> DoubleBufferM rep (Body rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> DoubleBufferM rep (Body rep))
-> Body rep -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms'}

optimiseStms :: ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms :: [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms [] = Stms rep -> DoubleBufferM rep (Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms rep
forall a. Monoid a => a
mempty
optimiseStms (Stm rep
e : [Stm rep]
es) = do
  Stms rep
e_es <- Stm rep -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm Stm rep
e
  Stms rep
es' <- Scope rep
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ Stms rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
e_es) (DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep))
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms [Stm rep]
es
  Stms rep -> DoubleBufferM rep (Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> DoubleBufferM rep (Stms rep))
-> Stms rep -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep
e_es Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
es'

optimiseStm :: forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm :: Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form Body rep
body)) = do
  Body rep
body' <-
    Scope rep
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm rep
form Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> [FParam rep] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam rep, SubExp) -> FParam rep)
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge)) (DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep))
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$
      Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body
  Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
opt_loop <- (Env rep
 -> Pat (LetDec rep)
 -> [(FParam rep, SubExp)]
 -> Body rep
 -> DoubleBufferM
      rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep))
-> DoubleBufferM
     rep
     (Pat (LetDec rep)
      -> [(FParam rep, SubExp)]
      -> Body rep
      -> DoubleBufferM
           rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop
  (Stms rep
stms, Pat (LetDec rep)
pat', [(FParam rep, SubExp)]
merge', Body rep
body'') <- Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
opt_loop Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body'
  Stms rep -> DoubleBufferM rep (Stms rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> DoubleBufferM rep (Stms rep))
-> Stms rep -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep
stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat' StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam rep, SubExp)]
merge' LoopForm rep
form Body rep
body'')
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = do
  Op rep -> DoubleBufferM rep (Op rep)
onOp <- (Env rep -> Op rep -> DoubleBufferM rep (Op rep))
-> DoubleBufferM rep (Op rep -> DoubleBufferM rep (Op rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> Op rep -> DoubleBufferM rep (Op rep)
forall rep. Env rep -> OptimiseOp rep
envOptimiseOp
  Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep)
-> (Exp rep -> Stm rep) -> Exp rep -> Stms rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Exp rep -> Stms rep)
-> DoubleBufferM rep (Exp rep) -> DoubleBufferM rep (Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper rep rep (DoubleBufferM rep)
-> Exp rep -> DoubleBufferM rep (Exp rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((Op rep -> DoubleBufferM rep (Op rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise Op rep -> DoubleBufferM rep (Op rep)
onOp) Exp rep
e
  where
    optimise :: (Op rep -> DoubleBufferM rep (Op rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise Op rep -> DoubleBufferM rep (Op rep)
onOp =
      Mapper rep rep (DoubleBufferM rep)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope rep -> Body rep -> DoubleBufferM rep (Body rep)
mapOnBody = \Scope rep
_ Body rep
x ->
            Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
x :: DoubleBufferM rep (Body rep),
          mapOnOp :: Op rep -> DoubleBufferM rep (Op rep)
mapOnOp = Op rep -> DoubleBufferM rep (Op rep)
onOp
        }

optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp (Inner (SegOp op)) =
  (Env GPUMem -> Env GPUMem)
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env GPUMem -> Env GPUMem
forall rep inner.
(HasLetDecMem (LetDec rep), OpReturns inner, BuilderOps rep,
 FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem, BodyDec rep ~ (),
 BranchType rep ~ BranchTypeMem, ExpDec rep ~ (),
 Op rep ~ MemOp inner) =>
Env rep -> Env rep
inSegOp (DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
 -> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ())))
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem ()))
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
-> SegOp SegLevel GPUMem
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
forall lvl. SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper SegOp SegLevel GPUMem
op
  where
    mapper :: SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM GPUMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda GPUMem -> DoubleBufferM GPUMem (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> DoubleBufferM GPUMem (Lambda GPUMem)
forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda,
          mapOnSegOpBody :: KernelBody GPUMem -> DoubleBufferM GPUMem (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> DoubleBufferM GPUMem (KernelBody GPUMem)
forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop :: OptimiseLoop rep
envOptimiseLoop = OptimiseLoop rep
forall rep inner. Constraints rep inner => OptimiseLoop rep
optimiseLoop}
optimiseGPUOp Op GPUMem
op = MemOp (HostOp GPUMem ())
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op GPUMem
MemOp (HostOp GPUMem ())
op

optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp (Inner (ParOp par_op op)) =
  (Env MCMem -> Env MCMem)
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env MCMem -> Env MCMem
forall rep inner.
(HasLetDecMem (LetDec rep), OpReturns inner, BuilderOps rep,
 FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem, BodyDec rep ~ (),
 BranchType rep ~ BranchTypeMem, ExpDec rep ~ (),
 Op rep ~ MemOp inner) =>
Env rep -> Env rep
inSegOp (DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
 -> DoubleBufferM MCMem (MemOp (MCOp MCMem ())))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall a b. (a -> b) -> a -> b
$
    MCOp MCMem () -> MemOp (MCOp MCMem ())
forall inner. inner -> MemOp inner
Inner
      (MCOp MCMem () -> MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MCOp MCMem ())
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp MCMem ()
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp MCMem ())
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
-> DoubleBufferM MCMem (SegOp () MCMem -> MCOp MCMem ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall lvl. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper) Maybe (SegOp () MCMem)
par_op DoubleBufferM MCMem (SegOp () MCMem -> MCOp MCMem ())
-> DoubleBufferM MCMem (SegOp () MCMem)
-> DoubleBufferM MCMem (MCOp MCMem ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall lvl. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper SegOp () MCMem
op)
  where
    mapper :: SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM MCMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda MCMem -> DoubleBufferM MCMem (Lambda MCMem)
mapOnSegOpLambda = Lambda MCMem -> DoubleBufferM MCMem (Lambda MCMem)
forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda,
          mapOnSegOpBody :: KernelBody MCMem -> DoubleBufferM MCMem (KernelBody MCMem)
mapOnSegOpBody = KernelBody MCMem -> DoubleBufferM MCMem (KernelBody MCMem)
forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop :: OptimiseLoop rep
envOptimiseLoop = OptimiseLoop rep
forall rep inner. Constraints rep inner => OptimiseLoop rep
optimiseLoop}
optimiseMCOp Op MCMem
op = MemOp (MCOp MCMem ())
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op MCMem
MemOp (MCOp MCMem ())
op

optimiseKernelBody ::
  ASTRep rep =>
  KernelBody rep ->
  DoubleBufferM rep (KernelBody rep)
optimiseKernelBody :: KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody KernelBody rep
kbody = do
  Stms rep
stms' <- [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
  KernelBody rep -> DoubleBufferM rep (KernelBody rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> DoubleBufferM rep (KernelBody rep))
-> KernelBody rep -> DoubleBufferM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep
kbody {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms'}

optimiseLambda ::
  ASTRep rep =>
  Lambda rep ->
  DoubleBufferM rep (Lambda rep)
optimiseLambda :: Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda Lambda rep
lam = do
  Body rep
body <- Scope rep
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda rep
lam) (DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep))
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody (Body rep -> DoubleBufferM rep (Body rep))
-> Body rep -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  Lambda rep -> DoubleBufferM rep (Lambda rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body}

type Constraints rep inner =
  ( Mem rep inner,
    BuilderOps rep,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    LetDec rep ~ LetDecMem
  )

extractAllocOf :: Constraints rep inner => Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf :: Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound VName
needle Stms rep
stms = do
  (Stm rep
stm, Stms rep
stms') <- Stms rep -> Maybe (Stm rep, Stms rep)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
  case Stm rep
stm of
    Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (Alloc size _))
      | PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
PatElem LParamMem
pe VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
needle,
        SubExp -> Bool
invariant SubExp
size ->
          (Stm rep, Stms rep) -> Maybe (Stm rep, Stms rep)
forall a. a -> Maybe a
Just (Stm rep
stm, Stms rep
stms')
    Stm rep
_ ->
      let bound' :: Names
bound' = [VName] -> Names
namesFromList (Pat LParamMem -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm)) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
bound
       in (Stms rep -> Stms rep)
-> (Stm rep, Stms rep) -> (Stm rep, Stms rep)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<>) ((Stm rep, Stms rep) -> (Stm rep, Stms rep))
-> Maybe (Stm rep, Stms rep) -> Maybe (Stm rep, Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
forall rep inner.
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound' VName
needle Stms rep
stms'
  where
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound

optimiseLoop :: Constraints rep inner => OptimiseLoop rep
optimiseLoop :: OptimiseLoop rep
optimiseLoop Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body = do
  (Stms rep
outer_stms_1, Pat LParamMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body') <-
    OptimiseLoop rep
forall rep inner. Constraints rep inner => OptimiseLoop rep
optimiseLoopBySwitching Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body
  (Stms rep
outer_stms_2, Pat LParamMem
pat'', [(Param FParamMem, SubExp)]
merge'', Body rep
body'') <-
    Stms rep
-> DoubleBufferM
     rep
     (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms rep
outer_stms_1 (DoubleBufferM
   rep
   (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
 -> DoubleBufferM
      rep
      (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep))
-> DoubleBufferM
     rep
     (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
forall a b. (a -> b) -> a -> b
$ OptimiseLoop rep
forall rep inner. Constraints rep inner => OptimiseLoop rep
optimiseLoopByCopying Pat (LetDec rep)
Pat LParamMem
pat' [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge' Body rep
body'
  (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
outer_stms_1 Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
outer_stms_2, Pat LParamMem
pat'', [(Param FParamMem, SubExp)]
merge'', Body rep
body'')

isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn VName
x (Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
y IxFun
_))) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
isArrayIn VName
_ Param FParamMem
_ = Bool
False

optimiseLoopBySwitching :: Constraints rep inner => OptimiseLoop rep
optimiseLoopBySwitching :: OptimiseLoop rep
optimiseLoopBySwitching (Pat [PatElem (LetDec rep)]
pes) [(FParam rep, SubExp)]
merge (Body BodyDec rep
_ Stms rep
body_stms Result
body_res) = do
  ((Pat LParamMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body'), Stms rep
outer_stms) <- Builder rep (Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     ((Pat LParamMem, [(Param FParamMem, SubExp)], Body rep), Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
 -> DoubleBufferM
      rep
      ((Pat LParamMem, [(Param FParamMem, SubExp)], Body rep), Stms rep))
-> Builder
     rep (Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     ((Pat LParamMem, [(Param FParamMem, SubExp)], Body rep), Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    ((Map VName VName
buffered, Stms rep
body_stms'), ([[PatElem LParamMem]]
pes', [[(Param FParamMem, SubExp)]]
merge', [Result]
body_res')) <-
      ([([PatElem LParamMem], [(Param FParamMem, SubExp)], Result)]
 -> ([[PatElem LParamMem]], [[(Param FParamMem, SubExp)]],
     [Result]))
-> ((Map VName VName, Stms rep),
    [([PatElem LParamMem], [(Param FParamMem, SubExp)], Result)])
-> ((Map VName VName, Stms rep),
    ([[PatElem LParamMem]], [[(Param FParamMem, SubExp)]], [Result]))
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [([PatElem LParamMem], [(Param FParamMem, SubExp)], Result)]
-> ([[PatElem LParamMem]], [[(Param FParamMem, SubExp)]], [Result])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (((Map VName VName, Stms rep),
  [([PatElem LParamMem], [(Param FParamMem, SubExp)], Result)])
 -> ((Map VName VName, Stms rep),
     ([[PatElem LParamMem]], [[(Param FParamMem, SubExp)]], [Result])))
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      [([PatElem LParamMem], [(Param FParamMem, SubExp)], Result)])
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      ([[PatElem LParamMem]], [[(Param FParamMem, SubExp)]], [Result]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Map VName VName, Stms rep)
 -> (PatElem LParamMem, (Param FParamMem, SubExp), SubExpRes)
 -> BuilderT
      rep
      (State VNameSource)
      ((Map VName VName, Stms rep),
       ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result)))
-> (Map VName VName, Stms rep)
-> [(PatElem LParamMem, (Param FParamMem, SubExp), SubExpRes)]
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      [([PatElem LParamMem], [(Param FParamMem, SubExp)], Result)])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (Map VName VName, Stms rep)
-> (PatElem LParamMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
check (Map VName VName
forall a. Monoid a => a
mempty, Stms rep
body_stms) ([PatElem LParamMem]
-> [(Param FParamMem, SubExp)]
-> Result
-> [(PatElem LParamMem, (Param FParamMem, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
[PatElem LParamMem]
pes [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge Result
body_res)
    [(Param FParamMem, SubExp)]
merge'' <- ((Param FParamMem, SubExp)
 -> BuilderT rep (State VNameSource) (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp)]
-> BuilderT rep (State VNameSource) [(Param FParamMem, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Map VName VName
-> (Param FParamMem, SubExp)
-> BuilderT rep (State VNameSource) (Param FParamMem, SubExp)
forall (m :: * -> *) inner d.
(MonadBuilder m, OpReturns inner, HasLetDecMem (LetDec (Rep m)),
 Op (Rep m) ~ MemOp inner, BranchType (Rep m) ~ BranchTypeMem,
 RetType (Rep m) ~ RetTypeMem, LetDec (Rep m) ~ LParamMem,
 LParamInfo (Rep m) ~ LParamMem, FParamInfo (Rep m) ~ FParamMem) =>
Map VName VName
-> (Param (MemInfo d Uniqueness MemBind), SubExp)
-> m (Param (MemInfo d Uniqueness MemBind), SubExp)
maybeCopyInitial Map VName VName
buffered) ([(Param FParamMem, SubExp)]
 -> BuilderT rep (State VNameSource) [(Param FParamMem, SubExp)])
-> [(Param FParamMem, SubExp)]
-> BuilderT rep (State VNameSource) [(Param FParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ [[(Param FParamMem, SubExp)]] -> [(Param FParamMem, SubExp)]
forall a. Monoid a => [a] -> a
mconcat [[(Param FParamMem, SubExp)]]
merge'
    (Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> Builder
     rep (Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem LParamMem] -> Pat LParamMem)
-> [PatElem LParamMem] -> Pat LParamMem
forall a b. (a -> b) -> a -> b
$ [[PatElem LParamMem]] -> [PatElem LParamMem]
forall a. Monoid a => [a] -> a
mconcat [[PatElem LParamMem]]
pes', [(Param FParamMem, SubExp)]
merge'', BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms rep
body_stms' (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall a. Monoid a => [a] -> a
mconcat [Result]
body_res')
  (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
outer_stms, Pat LParamMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body')
  where
    merge_bound :: Names
merge_bound = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge

    check :: (Map VName VName, Stms rep)
-> (PatElem LParamMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
check (Map VName VName
buffered, Stms rep
body_stms') (PatElem LParamMem
pe, (Param FParamMem
param, SubExp
arg), SubExpRes
res)
      | Mem Space
space <- Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
param,
        Var VName
arg_v <- SubExp
arg,
        -- XXX: what happens if there are multiple arrays in the same
        -- memory block?
        [Param FParamMem
arr_param] <- (Param FParamMem -> Bool) -> [Param FParamMem] -> [Param FParamMem]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Param FParamMem -> Bool
isArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param)) ([Param FParamMem] -> [Param FParamMem])
-> [Param FParamMem] -> [Param FParamMem]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge,
        MemArray PrimType
pt ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
_ IxFun
ixfun) <- Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
arr_param,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
merge_bound Names -> Names -> Bool
`namesIntersect` Shape (TPrimExp Int64 VName) -> Names
forall a. FreeIn a => a -> Names
freeIn (IxFun -> Shape (TPrimExp Int64 VName)
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun),
        Var VName
res_v <- SubExpRes -> SubExp
resSubExp SubExpRes
res,
        Just (Stm rep
res_v_alloc, Stms rep
body_stms'') <- Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
forall rep inner.
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
merge_bound VName
res_v Stms rep
body_stms' = do
          SubExp
num_bytes <-
            String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_bytes" (Exp rep -> BuilderT rep (State VNameSource) SubExp)
-> BuilderT rep (State VNameSource) (Exp rep)
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
     rep
     (State VNameSource)
     (Exp (Rep (BuilderT rep (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (Shape (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt TPrimExp Int64 VName
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. a -> [a] -> [a]
: IxFun -> Shape (TPrimExp Int64 VName)
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun)
          VName
arr_mem_in <-
            String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arg_v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_in") (Exp (Rep (BuilderT rep (State VNameSource)))
 -> BuilderT rep (State VNameSource) VName)
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
num_bytes Space
space
          PatElem LParamMem
pe_unused <-
            VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem
              (VName -> LParamMem -> PatElem LParamMem)
-> BuilderT rep (State VNameSource) VName
-> BuilderT
     rep (State VNameSource) (LParamMem -> PatElem LParamMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BuilderT rep (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_unused")
              BuilderT rep (State VNameSource) (LParamMem -> PatElem LParamMem)
-> BuilderT rep (State VNameSource) LParamMem
-> BuilderT rep (State VNameSource) (PatElem LParamMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LParamMem -> BuilderT rep (State VNameSource) LParamMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)
          Param FParamMem
param_out <-
            String
-> FParamMem -> BuilderT rep (State VNameSource) (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_out") (Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)
          Stm (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm rep
Stm (Rep (BuilderT rep (State VNameSource)))
res_v_alloc
          ((Map VName VName, Stms rep),
 ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( ( VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param) VName
arr_mem_in Map VName VName
buffered,
                Map VName VName -> Stms rep -> Stms rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
res_v (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param_out)) Stms rep
body_stms''
              ),
              ( [PatElem LParamMem
pe, PatElem LParamMem
pe_unused],
                [(Param FParamMem
param, VName -> SubExp
Var VName
arr_mem_in), (Param FParamMem
param_out, SubExpRes -> SubExp
resSubExp SubExpRes
res)],
                [ SubExpRes
res {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param_out},
                  SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param
                ]
              )
            )
      | Bool
otherwise =
          ((Map VName VName, Stms rep),
 ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
-> BuilderT
     rep
     (State VNameSource)
     ((Map VName VName, Stms rep),
      ([PatElem LParamMem], [(Param FParamMem, SubExp)], Result))
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( (Map VName VName
buffered, Stms rep
body_stms'),
              ([PatElem LParamMem
pe], [(Param FParamMem
param, SubExp
arg)], [SubExpRes
res])
            )

    maybeCopyInitial :: Map VName VName
-> (Param (MemInfo d Uniqueness MemBind), SubExp)
-> m (Param (MemInfo d Uniqueness MemBind), SubExp)
maybeCopyInitial Map VName VName
buffered (param :: Param (MemInfo d Uniqueness MemBind)
param@(Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase d
_ Uniqueness
_ (ArrayIn VName
mem IxFun
_))), Var VName
arg)
      | Just VName
mem' <- VName
mem VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName VName
buffered = do
          LParamMem
arg_info <- VName -> m LParamMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
arg
          case LParamMem
arg_info of
            MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
_ IxFun
arg_ixfun) -> do
              VName
arg_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
arg String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_dbcopy")
              Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
arg_copy (LParamMem -> PatElem LParamMem) -> LParamMem -> PatElem LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem' IxFun
arg_ixfun]) (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arg
              -- We need to make this parameter unique to avoid invalid
              -- hoisting (see #1533), because we are invalidating the
              -- underlying memory.
              (Param (MemInfo d Uniqueness MemBind), SubExp)
-> m (Param (MemInfo d Uniqueness MemBind), SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((MemInfo d Uniqueness MemBind -> MemInfo d Uniqueness MemBind)
-> Param (MemInfo d Uniqueness MemBind)
-> Param (MemInfo d Uniqueness MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemInfo d Uniqueness MemBind -> MemInfo d Uniqueness MemBind
forall d ret. MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
mkUnique Param (MemInfo d Uniqueness MemBind)
param, VName -> SubExp
Var VName
arg_copy)
            LParamMem
_ -> (Param (MemInfo d Uniqueness MemBind), SubExp)
-> m (Param (MemInfo d Uniqueness MemBind), SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((MemInfo d Uniqueness MemBind -> MemInfo d Uniqueness MemBind)
-> Param (MemInfo d Uniqueness MemBind)
-> Param (MemInfo d Uniqueness MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemInfo d Uniqueness MemBind -> MemInfo d Uniqueness MemBind
forall d ret. MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
mkUnique Param (MemInfo d Uniqueness MemBind)
param, VName -> SubExp
Var VName
arg)
    maybeCopyInitial Map VName VName
_ (Param (MemInfo d Uniqueness MemBind)
param, SubExp
arg) = (Param (MemInfo d Uniqueness MemBind), SubExp)
-> m (Param (MemInfo d Uniqueness MemBind), SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo d Uniqueness MemBind)
param, SubExp
arg)

    mkUnique :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
mkUnique (MemArray PrimType
bt ShapeBase d
shape Uniqueness
_ ret
ret) = PrimType
-> ShapeBase d -> Uniqueness -> ret -> MemInfo d Uniqueness ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape Uniqueness
Unique ret
ret
    mkUnique MemInfo d Uniqueness ret
x = MemInfo d Uniqueness ret
x

optimiseLoopByCopying :: Constraints rep inner => OptimiseLoop rep
optimiseLoopByCopying :: OptimiseLoop rep
optimiseLoopByCopying Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body = do
  -- We start out by figuring out which of the merge variables should
  -- be double-buffered.
  [DoubleBuffer]
buffered <-
    [(Param FParamMem, SubExpRes)]
-> Names -> DoubleBufferM rep [DoubleBuffer]
forall (m :: * -> *).
MonadFreshNames m =>
[(Param FParamMem, SubExpRes)] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams
      ([Param FParamMem] -> Result -> [(Param FParamMem, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge) (Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
body))
      (Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body)
  -- Then create the allocations of the buffers and copies of the
  -- initial values.
  ([(Param FParamMem, SubExp)]
merge', [Stm rep]
allocs) <- [(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
forall rep inner.
Constraints rep inner =>
[(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms [(FParam rep, SubExp)]
merge [DoubleBuffer]
buffered
  -- Modify the loop body to copy buffered result arrays.
  let body' :: Body rep
body' = [FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
forall rep inner.
Constraints rep inner =>
[FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
doubleBufferResult (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge) [DoubleBuffer]
buffered Body rep
body
  (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     (Stms rep, Pat LParamMem, [(Param FParamMem, SubExp)], Body rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
allocs, Pat (LetDec rep)
Pat LParamMem
pat, [(Param FParamMem, SubExp)]
merge', Body rep
body')

-- | The booleans indicate whether we should also play with the
-- initial merge values.
data DoubleBuffer
  = BufferAlloc VName (PrimExp VName) Space Bool
  | -- | First name is the memory block to copy to,
    -- second is the name of the array copy.
    BufferCopy VName IxFun VName Bool
  | NoBuffer
  deriving (Int -> DoubleBuffer -> String -> String
[DoubleBuffer] -> String -> String
DoubleBuffer -> String
(Int -> DoubleBuffer -> String -> String)
-> (DoubleBuffer -> String)
-> ([DoubleBuffer] -> String -> String)
-> Show DoubleBuffer
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [DoubleBuffer] -> String -> String
$cshowList :: [DoubleBuffer] -> String -> String
show :: DoubleBuffer -> String
$cshow :: DoubleBuffer -> String
showsPrec :: Int -> DoubleBuffer -> String -> String
$cshowsPrec :: Int -> DoubleBuffer -> String -> String
Show)

doubleBufferMergeParams ::
  MonadFreshNames m =>
  [(Param FParamMem, SubExpRes)] ->
  Names ->
  m [DoubleBuffer]
doubleBufferMergeParams :: [(Param FParamMem, SubExpRes)] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams [(Param FParamMem, SubExpRes)]
ctx_and_res Names
bound_in_loop =
  StateT (Map VName (VName, Bool)) m [DoubleBuffer]
-> Map VName (VName, Bool) -> m [DoubleBuffer]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (((Param FParamMem, SubExpRes)
 -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> [(Param FParamMem, SubExpRes)]
-> StateT (Map VName (VName, Bool)) m [DoubleBuffer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param FParamMem, SubExpRes)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer [(Param FParamMem, SubExpRes)]
ctx_and_res) Map VName (VName, Bool)
forall k a. Map k a
M.empty
  where
    params :: [Param FParamMem]
params = ((Param FParamMem, SubExpRes) -> Param FParamMem)
-> [(Param FParamMem, SubExpRes)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExpRes) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExpRes)]
ctx_and_res
    loopVariant :: VName -> Bool
loopVariant VName
v =
      VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_loop
        Bool -> Bool -> Bool
|| VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((Param FParamMem, SubExpRes) -> VName)
-> [(Param FParamMem, SubExpRes)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExpRes) -> Param FParamMem)
-> (Param FParamMem, SubExpRes)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExpRes) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExpRes)]
ctx_and_res

    loopInvariantSize :: SubExp -> Maybe (SubExp, Bool)
loopInvariantSize (Constant PrimValue
v) =
      (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
v, Bool
True)
    loopInvariantSize (Var VName
v) =
      case ((Param FParamMem, SubExpRes) -> Bool)
-> [(Param FParamMem, SubExpRes)]
-> Maybe (Param FParamMem, SubExpRes)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param FParamMem, SubExpRes) -> VName)
-> (Param FParamMem, SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExpRes) -> Param FParamMem)
-> (Param FParamMem, SubExpRes)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExpRes) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExpRes)]
ctx_and_res of
        Just (Param FParamMem
_, SubExpRes Certs
_ (Constant PrimValue
val)) ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
val, Bool
False)
        Just (Param FParamMem
_, SubExpRes Certs
_ (Var VName
v'))
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
loopVariant VName
v' ->
              (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v', Bool
False)
        Just (Param FParamMem, SubExpRes)
_ ->
          Maybe (SubExp, Bool)
forall a. Maybe a
Nothing
        Maybe (Param FParamMem, SubExpRes)
Nothing ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v, Bool
True)

    sizeForMem :: VName -> Maybe (PrimExp VName, Bool)
sizeForMem VName
mem = [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a. [a] -> Maybe a
maybeHead ([(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool))
-> [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (PrimExp VName, Bool))
-> [Param FParamMem] -> [(PrimExp VName, Bool)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (FParamMem -> Maybe (PrimExp VName, Bool))
-> (Param FParamMem -> FParamMem)
-> Param FParamMem
-> Maybe (PrimExp VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [Param FParamMem]
params
      where
        arrayInMem :: FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
arraymem IxFun
ixfun))
          | IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun,
            Just ([SubExp]
dims, [Bool]
b) <-
              (SubExp -> Maybe (SubExp, Bool))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM SubExp -> Maybe (SubExp, Bool)
loopInvariantSize ([SubExp] -> Maybe ([SubExp], [Bool]))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape,
            VName
mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arraymem =
              (PrimExp VName, Bool) -> Maybe (PrimExp VName, Bool)
forall a. a -> Maybe a
Just
                ( Type -> PrimExp VName
arraySizeInBytesExp (Type -> PrimExp VName) -> Type -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                    PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness,
                  [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
b
                )
        arrayInMem FParamMem
_ = Maybe (PrimExp VName, Bool)
forall a. Maybe a
Nothing

    buffer :: (Param FParamMem, SubExpRes)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer (Param FParamMem
fparam, SubExpRes
res) = case Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam of
      Mem Space
space
        | Just (PrimExp VName
size, Bool
b) <- VName -> Maybe (PrimExp VName, Bool)
sizeForMem (VName -> Maybe (PrimExp VName, Bool))
-> VName -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam,
          Var VName
res_v <- SubExpRes -> SubExp
resSubExp SubExpRes
res,
          VName
res_v VName -> Names -> Bool
`nameIn` Names
bound_in_loop -> do
            -- Let us double buffer this!
            VName
bufname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_mem"
            (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName (VName, Bool) -> Map VName (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m ())
-> (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall a b. (a -> b) -> a -> b
$ VName
-> (VName, Bool)
-> Map VName (VName, Bool)
-> Map VName (VName, Bool)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam) (VName
bufname, Bool
b)
            DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Space -> Bool -> DoubleBuffer
BufferAlloc VName
bufname PrimExp VName
size Space
space Bool
b
      Array {}
        | MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) <- Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
fparam -> do
            Maybe (VName, Bool)
buffered <- (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName (VName, Bool) -> Maybe (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool)))
-> (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VName, Bool) -> Maybe (VName, Bool)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem
            case Maybe (VName, Bool)
buffered of
              Just (VName
bufname, Bool
b) -> do
                VName
copyname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_array"
                DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> VName -> Bool -> DoubleBuffer
BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
b
              Maybe (VName, Bool)
Nothing ->
                DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoubleBuffer
NoBuffer
      Type
_ -> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (f :: * -> *) a. Applicative f => a -> f a
pure DoubleBuffer
NoBuffer

allocStms ::
  Constraints rep inner =>
  [(FParam rep, SubExp)] ->
  [DoubleBuffer] ->
  DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms :: [(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms [(FParam rep, SubExp)]
merge = WriterT [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)]
-> DoubleBufferM rep ([(Param FParamMem, SubExp)], [Stm rep])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)]
 -> DoubleBufferM rep ([(Param FParamMem, SubExp)], [Stm rep]))
-> ([DoubleBuffer]
    -> WriterT
         [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)])
-> [DoubleBuffer]
-> DoubleBufferM rep ([(Param FParamMem, SubExp)], [Stm rep])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp)
 -> DoubleBuffer
 -> WriterT [Stm rep] (DoubleBufferM rep) (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp)]
-> [DoubleBuffer]
-> WriterT
     [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param FParamMem, SubExp)
-> DoubleBuffer
-> WriterT [Stm rep] (DoubleBufferM rep) (Param FParamMem, SubExp)
forall rep (t :: (* -> *) -> * -> *) (m :: * -> *) d ret rep inner
       inner.
(MonadWriter [Stm rep] (t m), Typed (MemInfo d Uniqueness ret),
 HasLetDecMem (LetDec rep), MonadTrans t, MonadFreshNames m,
 OpReturns inner, HasScope rep m, ASTRep rep, BuilderOps rep,
 FParamInfo rep ~ FParamMem, FParamInfo rep ~ FParamInfo rep,
 LParamInfo rep ~ LParamMem, LParamInfo rep ~ LParamInfo rep,
 LetDec rep ~ LetDec rep, BranchType rep ~ BranchTypeMem,
 ExpDec rep ~ (), RetType rep ~ RetTypeMem, LetDec rep ~ LParamMem,
 Op rep ~ MemOp inner, Op rep ~ MemOp inner) =>
(Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge
  where
    allocation :: (Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation m :: (Param (MemInfo d Uniqueness ret), SubExp)
m@(Param Attrs
attrs VName
pname MemInfo d Uniqueness ret
_, SubExp
_) (BufferAlloc VName
name PrimExp VName
size Space
space Bool
b) = do
      Stms rep
stms <- m (Stms rep) -> t m (Stms rep)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms rep) -> t m (Stms rep)) -> m (Stms rep) -> t m (Stms rep)
forall a b. (a -> b) -> a -> b
$
        Builder rep () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder rep () -> m (Stms rep)) -> Builder rep () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
          SubExp
size' <- String -> PrimExp VName -> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"double_buffer_size" PrimExp VName
size
          [VName]
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ())
-> Exp (Rep (BuilderT rep (State VNameSource))) -> Builder rep ()
forall a b. (a -> b) -> a -> b
$ Op rep -> Exp rep
forall rep. Op rep -> Exp rep
Op (Op rep -> Exp rep) -> Op rep -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
      [Stm rep] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Stm rep] -> t m ()) -> [Stm rep] -> t m ()
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
      if Bool
b
        then (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Attrs
-> VName
-> MemInfo d Uniqueness ret
-> Param (MemInfo d Uniqueness ret)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pname (MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret))
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d Uniqueness ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, VName -> SubExp
Var VName
name)
        else (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo d Uniqueness ret), SubExp)
m
    allocation (Param (MemInfo d Uniqueness ret)
f, Var VName
v) (BufferCopy VName
mem IxFun
_ VName
_ Bool
b) | Bool
b = do
      VName
v_copy <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_double_buffer_copy"
      (VName
_v_mem, IxFun
v_ixfun) <- m (VName, IxFun) -> t m (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun) -> t m (VName, IxFun))
-> m (VName, IxFun) -> t m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun)
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
      let bt :: PrimType
bt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          shape :: ShapeBase SubExp
shape = Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeBase SubExp) -> Type -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          bound :: LParamMem
bound = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
v_ixfun
      [Stm rep] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy LParamMem
bound]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v]
      -- It is important that we treat this as a consumption, to
      -- avoid the Copy from being hoisted out of any enclosing
      -- loops.  Since we re-use (=overwrite) memory in the loop,
      -- the copy is critical for initialisation.  See issue #816.
      let uniqueMemInfo :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemArray PrimType
pt ShapeBase d
pshape Uniqueness
_ ret
ret) =
            PrimType
-> ShapeBase d -> Uniqueness -> ret -> MemInfo d Uniqueness ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pshape Uniqueness
Unique ret
ret
          uniqueMemInfo MemInfo d Uniqueness ret
info = MemInfo d Uniqueness ret
info
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
forall d ret. MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo d Uniqueness ret)
f, VName -> SubExp
Var VName
v_copy)
    allocation (Param (MemInfo d Uniqueness ret)
f, SubExp
se) DoubleBuffer
_ =
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo d Uniqueness ret)
f, SubExp
se)

doubleBufferResult ::
  Constraints rep inner =>
  [FParam rep] ->
  [DoubleBuffer] ->
  Body rep ->
  Body rep
doubleBufferResult :: [FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
doubleBufferResult [FParam rep]
valparams [DoubleBuffer]
buffered (Body BodyDec rep
_ Stms rep
stms Result
res) =
  let (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param FParamMem]
valparams) Result
res
      ([Maybe (Stm rep)]
copystms, Result
val_res') =
        [(Maybe (Stm rep), SubExpRes)] -> ([Maybe (Stm rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Stm rep), SubExpRes)] -> ([Maybe (Stm rep)], Result))
-> [(Maybe (Stm rep), SubExpRes)] -> ([Maybe (Stm rep)], Result)
forall a b. (a -> b) -> a -> b
$ (Param FParamMem
 -> DoubleBuffer -> SubExpRes -> (Maybe (Stm rep), SubExpRes))
-> [Param FParamMem]
-> [DoubleBuffer]
-> Result
-> [(Maybe (Stm rep), SubExpRes)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param FParamMem
-> DoubleBuffer -> SubExpRes -> (Maybe (Stm rep), SubExpRes)
buffer [FParam rep]
[Param FParamMem]
valparams [DoubleBuffer]
buffered Result
val_res
   in BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms rep
stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Maybe (Stm rep)] -> [Stm rep]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm rep)]
copystms)) (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ Result
ctx_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val_res'
  where
    buffer :: Param FParamMem
-> DoubleBuffer -> SubExpRes -> (Maybe (Stm rep), SubExpRes)
buffer Param FParamMem
_ (BufferAlloc VName
bufname PrimExp VName
_ Space
_ Bool
_) SubExpRes
se =
      (Maybe (Stm rep)
forall a. Maybe a
Nothing, SubExpRes
se {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var VName
bufname})
    buffer Param FParamMem
fparam (BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
_) (SubExpRes Certs
cs (Var VName
v)) =
      -- To construct the copy we will need to figure out its type
      -- based on the type of the function parameter.
      let t :: Type
t = Type -> Type
resultType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam
          summary :: LParamMem
summary = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
bufname IxFun
ixfun
          copystm :: Stm rep
copystm =
            Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem LParamMem] -> Pat LParamMem
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> LParamMem -> PatElem LParamMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
copyname LParamMem
summary]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
       in (Stm rep -> Maybe (Stm rep)
forall a. a -> Maybe a
Just Stm rep
copystm, Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (VName -> SubExp
Var VName
copyname))
    buffer Param FParamMem
_ DoubleBuffer
_ SubExpRes
se =
      (Maybe (Stm rep)
forall a. Maybe a
Nothing, SubExpRes
se)

    parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [FParam rep]
[Param FParamMem]
valparams) ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res

    resultType :: Type -> Type
resultType Type
t = Type
t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substitute (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t)

    substitute :: SubExp -> SubExp
substitute (Var VName
v)
      | Just SubExp
replacement <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
replacement
    substitute SubExp
se =
      SubExp
se