{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | The simplification engine is only willing to hoist allocations
-- out of loops if the memory block resulting from the allocation is
-- dead at the end of the loop.  If it is not, we may cause data
-- hazards.
--
-- This module rewrites loops with memory block merge parameters such
-- that each memory block is copied at the end of the iteration, thus
-- ensuring that any allocation inside the loop is dead at the end of
-- the loop.  This is only possible for allocations whose size is
-- loop-invariant, although the initial size may differ from the size
-- produced by the loop result.
--
-- Additionally, inside parallel kernels we also copy the initial
-- value.  This has the effect of making the memory block returned by
-- the array non-existential, which is important for later memory
-- expansion to work.
module Futhark.Optimise.DoubleBuffer
       ( doubleBuffer )
       where

import           Control.Monad.State
import           Control.Monad.Writer
import           Control.Monad.Reader
import qualified Data.Map.Strict as M
import           Data.Maybe
import           Data.List (find)

import           Futhark.Construct
import           Futhark.Representation.AST
import           Futhark.Pass.ExplicitAllocations (arraySizeInBytesExp)
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import           Futhark.Representation.ExplicitMemory
                 hiding (Prog, Body, Stm, Pattern, PatElem,
                         BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import           Futhark.Pass
import           Futhark.Util (maybeHead)

doubleBuffer :: Pass ExplicitMemory ExplicitMemory
doubleBuffer :: Pass ExplicitMemory ExplicitMemory
doubleBuffer =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass { passName :: String
passName = String
"Double buffer"
       , passDescription :: String
passDescription = String
"Perform double buffering for merge parameters of sequential loops."
       , passFunction :: Prog ExplicitMemory -> PassM (Prog ExplicitMemory)
passFunction = (Scope ExplicitMemory
 -> Stms ExplicitMemory -> PassM (Stms ExplicitMemory))
-> Prog ExplicitMemory -> PassM (Prog ExplicitMemory)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope ExplicitMemory
-> Stms ExplicitMemory -> PassM (Stms ExplicitMemory)
forall (m :: * -> *).
MonadFreshNames m =>
Scope ExplicitMemory
-> Stms ExplicitMemory -> m (Stms ExplicitMemory)
optimise
       }
  where optimise :: Scope ExplicitMemory
-> Stms ExplicitMemory -> m (Stms ExplicitMemory)
optimise Scope ExplicitMemory
scope Stms ExplicitMemory
stms = (VNameSource -> (Stms ExplicitMemory, VNameSource))
-> m (Stms ExplicitMemory)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms ExplicitMemory, VNameSource))
 -> m (Stms ExplicitMemory))
-> (VNameSource -> (Stms ExplicitMemory, VNameSource))
-> m (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
          let m :: ReaderT Env (State VNameSource) (Stms ExplicitMemory)
m = DoubleBufferM (Stms ExplicitMemory)
-> ReaderT Env (State VNameSource) (Stms ExplicitMemory)
forall a. DoubleBufferM a -> ReaderT Env (State VNameSource) a
runDoubleBufferM (DoubleBufferM (Stms ExplicitMemory)
 -> ReaderT Env (State VNameSource) (Stms ExplicitMemory))
-> DoubleBufferM (Stms ExplicitMemory)
-> ReaderT Env (State VNameSource) (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Scope ExplicitMemory
-> DoubleBufferM (Stms ExplicitMemory)
-> DoubleBufferM (Stms ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope (DoubleBufferM (Stms ExplicitMemory)
 -> DoubleBufferM (Stms ExplicitMemory))
-> DoubleBufferM (Stms ExplicitMemory)
-> DoubleBufferM (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
                  ([Stm ExplicitMemory] -> Stms ExplicitMemory)
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM (Stms ExplicitMemory)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList (DoubleBufferM [Stm ExplicitMemory]
 -> DoubleBufferM (Stms ExplicitMemory))
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms ExplicitMemory
stms
          in State VNameSource (Stms ExplicitMemory)
-> VNameSource -> (Stms ExplicitMemory, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT Env (State VNameSource) (Stms ExplicitMemory)
-> Env -> State VNameSource (Stms ExplicitMemory)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Env (State VNameSource) (Stms ExplicitMemory)
m Env
env) VNameSource
src

        env :: Env
env = Scope ExplicitMemory -> OptimiseLoop -> Env
Env Scope ExplicitMemory
forall a. Monoid a => a
mempty OptimiseLoop
forall (m :: * -> *) a b c d.
(Monad m, Monoid a) =>
b -> c -> d -> m (a, b, c, d)
doNotTouchLoop
        doNotTouchLoop :: b -> c -> d -> m (a, b, c, d)
doNotTouchLoop b
ctx c
val d
body = (a, b, c, d) -> m (a, b, c, d)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
forall a. Monoid a => a
mempty, b
ctx, c
val, d
body)

data Env = Env { Env -> Scope ExplicitMemory
envScope :: Scope ExplicitMemory
               , Env -> OptimiseLoop
envOptimiseLoop :: OptimiseLoop
               }

newtype DoubleBufferM a =
  DoubleBufferM { DoubleBufferM a -> ReaderT Env (State VNameSource) a
runDoubleBufferM :: ReaderT Env (State VNameSource) a }
  deriving (a -> DoubleBufferM b -> DoubleBufferM a
(a -> b) -> DoubleBufferM a -> DoubleBufferM b
(forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b)
-> (forall a b. a -> DoubleBufferM b -> DoubleBufferM a)
-> Functor DoubleBufferM
forall a b. a -> DoubleBufferM b -> DoubleBufferM a
forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DoubleBufferM b -> DoubleBufferM a
$c<$ :: forall a b. a -> DoubleBufferM b -> DoubleBufferM a
fmap :: (a -> b) -> DoubleBufferM a -> DoubleBufferM b
$cfmap :: forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b
Functor, Functor DoubleBufferM
a -> DoubleBufferM a
Functor DoubleBufferM
-> (forall a. a -> DoubleBufferM a)
-> (forall a b.
    DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b)
-> (forall a b c.
    (a -> b -> c)
    -> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c)
-> (forall a b.
    DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b)
-> (forall a b.
    DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a)
-> Applicative DoubleBufferM
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
forall a. a -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a b.
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
forall a b c.
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
$c<* :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
*> :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
$c*> :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
liftA2 :: (a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
$cliftA2 :: forall a b c.
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
<*> :: DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
$c<*> :: forall a b.
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
pure :: a -> DoubleBufferM a
$cpure :: forall a. a -> DoubleBufferM a
$cp1Applicative :: Functor DoubleBufferM
Applicative, Applicative DoubleBufferM
a -> DoubleBufferM a
Applicative DoubleBufferM
-> (forall a b.
    DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b)
-> (forall a b.
    DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b)
-> (forall a. a -> DoubleBufferM a)
-> Monad DoubleBufferM
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a. a -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a b.
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DoubleBufferM a
$creturn :: forall a. a -> DoubleBufferM a
>> :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
$c>> :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
>>= :: DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
$c>>= :: forall a b.
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
$cp1Monad :: Applicative DoubleBufferM
Monad, MonadReader Env, Monad DoubleBufferM
Applicative DoubleBufferM
DoubleBufferM VNameSource
Applicative DoubleBufferM
-> Monad DoubleBufferM
-> DoubleBufferM VNameSource
-> (VNameSource -> DoubleBufferM ())
-> MonadFreshNames DoubleBufferM
VNameSource -> DoubleBufferM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> DoubleBufferM ()
$cputNameSource :: VNameSource -> DoubleBufferM ()
getNameSource :: DoubleBufferM VNameSource
$cgetNameSource :: DoubleBufferM VNameSource
$cp2MonadFreshNames :: Monad DoubleBufferM
$cp1MonadFreshNames :: Applicative DoubleBufferM
MonadFreshNames)

instance HasScope ExplicitMemory DoubleBufferM where
  askScope :: DoubleBufferM (Scope ExplicitMemory)
askScope = (Env -> Scope ExplicitMemory)
-> DoubleBufferM (Scope ExplicitMemory)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Scope ExplicitMemory
envScope

instance LocalScope ExplicitMemory DoubleBufferM where
  localScope :: Scope ExplicitMemory -> DoubleBufferM a -> DoubleBufferM a
localScope Scope ExplicitMemory
scope = (Env -> Env) -> DoubleBufferM a -> DoubleBufferM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> DoubleBufferM a -> DoubleBufferM a)
-> (Env -> Env) -> DoubleBufferM a -> DoubleBufferM a
forall a b. (a -> b) -> a -> b
$ \Env
env -> Env
env { envScope :: Scope ExplicitMemory
envScope = Env -> Scope ExplicitMemory
envScope Env
env Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> Scope ExplicitMemory
scope }

optimiseBody :: Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody :: Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody Body ExplicitMemory
body = do
  [Stm ExplicitMemory]
bnds' <- [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms ExplicitMemory -> [Stm ExplicitMemory])
-> Stms ExplicitMemory -> [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Body ExplicitMemory -> Stms ExplicitMemory
forall lore. BodyT lore -> Stms lore
bodyStms Body ExplicitMemory
body
  Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory))
-> Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Body ExplicitMemory
body { bodyStms :: Stms ExplicitMemory
bodyStms = [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm ExplicitMemory]
bnds' }

optimiseStms :: [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms :: [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms [] = [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall (m :: * -> *) a. Monad m => a -> m a
return []
optimiseStms (Stm ExplicitMemory
e:[Stm ExplicitMemory]
es) = do
  [Stm ExplicitMemory]
e_es <- Stm ExplicitMemory -> DoubleBufferM [Stm ExplicitMemory]
optimiseStm Stm ExplicitMemory
e
  [Stm ExplicitMemory]
es' <- Scope ExplicitMemory
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM [Stm ExplicitMemory]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope ExplicitMemory -> Scope ExplicitMemory
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope ExplicitMemory -> Scope ExplicitMemory)
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory] -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf [Stm ExplicitMemory]
e_es) (DoubleBufferM [Stm ExplicitMemory]
 -> DoubleBufferM [Stm ExplicitMemory])
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms [Stm ExplicitMemory]
es
  [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory]
e_es [Stm ExplicitMemory]
-> [Stm ExplicitMemory] -> [Stm ExplicitMemory]
forall a. [a] -> [a] -> [a]
++ [Stm ExplicitMemory]
es'

optimiseStm :: Stm ExplicitMemory -> DoubleBufferM [Stm ExplicitMemory]
optimiseStm :: Stm ExplicitMemory -> DoubleBufferM [Stm ExplicitMemory]
optimiseStm (Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux (DoLoop [(FParam ExplicitMemory, SubExp)]
ctx [(FParam ExplicitMemory, SubExp)]
val LoopForm ExplicitMemory
form Body ExplicitMemory
body)) = do
  Body ExplicitMemory
body' <- Scope ExplicitMemory
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm ExplicitMemory
form Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)] -> Scope ExplicitMemory
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
 -> [Param (MemInfo SubExp Uniqueness MemBind)])
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val)) (DoubleBufferM (Body ExplicitMemory)
 -> DoubleBufferM (Body ExplicitMemory))
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
           Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody Body ExplicitMemory
body
  [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
     ([Stm ExplicitMemory],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      Body ExplicitMemory)
opt_loop <- (Env
 -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
 -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
 -> Body ExplicitMemory
 -> DoubleBufferM
      ([Stm ExplicitMemory],
       [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
       [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
       Body ExplicitMemory))
-> DoubleBufferM
     ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
      -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
      -> Body ExplicitMemory
      -> DoubleBufferM
           ([Stm ExplicitMemory],
            [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
            [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
            Body ExplicitMemory))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> OptimiseLoop
Env
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
     ([Stm ExplicitMemory],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      Body ExplicitMemory)
envOptimiseLoop
  ([Stm ExplicitMemory]
bnds, [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx', [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val', Body ExplicitMemory
body'') <- [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
     ([Stm ExplicitMemory],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      Body ExplicitMemory)
opt_loop [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val Body ExplicitMemory
body'
  [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory]
bnds [Stm ExplicitMemory]
-> [Stm ExplicitMemory] -> [Stm ExplicitMemory]
forall a. [a] -> [a] -> [a]
++ [Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> ExpT ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux (ExpT ExplicitMemory -> Stm ExplicitMemory)
-> ExpT ExplicitMemory -> Stm ExplicitMemory
forall a b. (a -> b) -> a -> b
$ [(FParam ExplicitMemory, SubExp)]
-> [(FParam ExplicitMemory, SubExp)]
-> LoopForm ExplicitMemory
-> Body ExplicitMemory
-> ExpT ExplicitMemory
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx' [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val' LoopForm ExplicitMemory
form Body ExplicitMemory
body'']
optimiseStm (Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux ExpT ExplicitMemory
e) =
  Stm ExplicitMemory -> [Stm ExplicitMemory]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm ExplicitMemory -> [Stm ExplicitMemory])
-> (ExpT ExplicitMemory -> Stm ExplicitMemory)
-> ExpT ExplicitMemory
-> [Stm ExplicitMemory]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> ExpT ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux (ExpT ExplicitMemory -> [Stm ExplicitMemory])
-> DoubleBufferM (ExpT ExplicitMemory)
-> DoubleBufferM [Stm ExplicitMemory]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper ExplicitMemory ExplicitMemory DoubleBufferM
-> ExpT ExplicitMemory -> DoubleBufferM (ExpT ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper ExplicitMemory ExplicitMemory DoubleBufferM
optimise ExpT ExplicitMemory
e
  where optimise :: Mapper ExplicitMemory ExplicitMemory DoubleBufferM
optimise = Mapper ExplicitMemory ExplicitMemory DoubleBufferM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope ExplicitMemory
-> Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
mapOnBody = \Scope ExplicitMemory
_ Body ExplicitMemory
x ->
                                      Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody Body ExplicitMemory
x :: DoubleBufferM (Body ExplicitMemory)
                                  , mapOnOp :: Op ExplicitMemory -> DoubleBufferM (Op ExplicitMemory)
mapOnOp = Op ExplicitMemory -> DoubleBufferM (Op ExplicitMemory)
optimiseOp
                                  }

optimiseOp :: Op ExplicitMemory
           -> DoubleBufferM (Op ExplicitMemory)
optimiseOp :: Op ExplicitMemory -> DoubleBufferM (Op ExplicitMemory)
optimiseOp (Inner (SegOp op)) =
  (Env -> Env)
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env -> Env
inSegOp (DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
 -> DoubleBufferM (MemOp (HostOp ExplicitMemory ())))
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall a b. (a -> b) -> a -> b
$ HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory
-> MemOp (HostOp ExplicitMemory ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> MemOp (HostOp ExplicitMemory ()))
-> DoubleBufferM (SegOp ExplicitMemory)
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper ExplicitMemory ExplicitMemory DoubleBufferM
-> SegOp ExplicitMemory -> DoubleBufferM (SegOp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper ExplicitMemory ExplicitMemory DoubleBufferM
mapper SegOp ExplicitMemory
op
  where mapper :: SegOpMapper ExplicitMemory ExplicitMemory DoubleBufferM
mapper = SegOpMapper Any Any DoubleBufferM
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper
                 { mapOnSegOpLambda :: Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
mapOnSegOpLambda = Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
optimiseLambda
                 , mapOnSegOpBody :: KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
mapOnSegOpBody = KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
optimiseKernelBody
                 }
        inSegOp :: Env -> Env
inSegOp Env
env = Env
env { envOptimiseLoop :: OptimiseLoop
envOptimiseLoop = OptimiseLoop
optimiseLoop }
optimiseOp Op ExplicitMemory
op = MemOp (HostOp ExplicitMemory ())
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Op ExplicitMemory
MemOp (HostOp ExplicitMemory ())
op

optimiseKernelBody :: KernelBody ExplicitMemory
                   -> DoubleBufferM (KernelBody ExplicitMemory)
optimiseKernelBody :: KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
optimiseKernelBody KernelBody ExplicitMemory
kbody = do
  [Stm ExplicitMemory]
stms' <- [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms ExplicitMemory -> [Stm ExplicitMemory])
-> Stms ExplicitMemory -> [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody
  KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody ExplicitMemory
 -> DoubleBufferM (KernelBody ExplicitMemory))
-> KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory
kbody { kernelBodyStms :: Stms ExplicitMemory
kernelBodyStms = [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm ExplicitMemory]
stms' }

optimiseLambda :: Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
optimiseLambda :: Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
optimiseLambda Lambda ExplicitMemory
lam = do
  Body ExplicitMemory
body <- Scope ExplicitMemory
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope ExplicitMemory -> Scope ExplicitMemory
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope ExplicitMemory -> Scope ExplicitMemory)
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda ExplicitMemory
lam) (DoubleBufferM (Body ExplicitMemory)
 -> DoubleBufferM (Body ExplicitMemory))
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody (Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory))
-> Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
  Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda ExplicitMemory
lam { lambdaBody :: Body ExplicitMemory
lambdaBody = Body ExplicitMemory
body }

type OptimiseLoop =
  [(FParam ExplicitMemory, SubExp)] -> [(FParam ExplicitMemory, SubExp)] -> Body ExplicitMemory
  -> DoubleBufferM ([Stm ExplicitMemory],
                    [(FParam ExplicitMemory, SubExp)],
                    [(FParam ExplicitMemory, SubExp)],
                    Body ExplicitMemory)

optimiseLoop :: OptimiseLoop
optimiseLoop :: OptimiseLoop
optimiseLoop [(FParam ExplicitMemory, SubExp)]
ctx [(FParam ExplicitMemory, SubExp)]
val Body ExplicitMemory
body = do
  -- We start out by figuring out which of the merge variables should
  -- be double-buffered.
  [DoubleBuffer]
buffered <- [(FParam ExplicitMemory, SubExp)]
-> [FParam ExplicitMemory] -> Names -> DoubleBufferM [DoubleBuffer]
forall (m :: * -> *).
MonadFreshNames m =>
[(FParam ExplicitMemory, SubExp)]
-> [FParam ExplicitMemory] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams
              ([Param (MemInfo SubExp Uniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx) (Body ExplicitMemory -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body ExplicitMemory
body)) (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge)
              (Body ExplicitMemory -> Names
forall lore. Body lore -> Names
boundInBody Body ExplicitMemory
body)
  -- Then create the allocations of the buffers and copies of the
  -- initial values.
  ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge', [Stm ExplicitMemory]
allocs) <- [(FParam ExplicitMemory, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM
     ([(FParam ExplicitMemory, SubExp)], [Stm ExplicitMemory])
allocStms [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge [DoubleBuffer]
buffered
  -- Modify the loop body to copy buffered result arrays.
  let body' :: Body ExplicitMemory
body' = [FParam ExplicitMemory]
-> [DoubleBuffer] -> Body ExplicitMemory -> Body ExplicitMemory
doubleBufferResult (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge) [DoubleBuffer]
buffered Body ExplicitMemory
body
      ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx', [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val') = Int
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
    [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx) [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge'
  -- Modify the initial merge p
  ([Stm ExplicitMemory],
 [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
 [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
 Body ExplicitMemory)
-> DoubleBufferM
     ([Stm ExplicitMemory],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      Body ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm ExplicitMemory]
allocs, [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx', [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val', Body ExplicitMemory
body')
  where merge :: [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge = [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val

-- | The booleans indicate whether we should also play with the
-- initial merge values.
data DoubleBuffer = BufferAlloc VName (PrimExp VName) Space Bool
                  | BufferCopy VName IxFun VName Bool
                    -- ^ First name is the memory block to copy to,
                    -- second is the name of the array copy.
                  | NoBuffer
                    deriving (Int -> DoubleBuffer -> ShowS
[DoubleBuffer] -> ShowS
DoubleBuffer -> String
(Int -> DoubleBuffer -> ShowS)
-> (DoubleBuffer -> String)
-> ([DoubleBuffer] -> ShowS)
-> Show DoubleBuffer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DoubleBuffer] -> ShowS
$cshowList :: [DoubleBuffer] -> ShowS
show :: DoubleBuffer -> String
$cshow :: DoubleBuffer -> String
showsPrec :: Int -> DoubleBuffer -> ShowS
$cshowsPrec :: Int -> DoubleBuffer -> ShowS
Show)

doubleBufferMergeParams :: MonadFreshNames m =>
                           [(FParam ExplicitMemory, SubExp)]
                        -> [FParam ExplicitMemory] -> Names
                        -> m [DoubleBuffer]
doubleBufferMergeParams :: [(FParam ExplicitMemory, SubExp)]
-> [FParam ExplicitMemory] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams [(FParam ExplicitMemory, SubExp)]
ctx_and_res [FParam ExplicitMemory]
val_params Names
bound_in_loop =
  StateT (Map VName (VName, Bool)) m [DoubleBuffer]
-> Map VName (VName, Bool) -> m [DoubleBuffer]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((Param (MemInfo SubExp Uniqueness MemBind)
 -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> StateT (Map VName (VName, Bool)) m [DoubleBuffer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
val_params) Map VName (VName, Bool)
forall k a. Map k a
M.empty
  where loopVariant :: VName -> Bool
loopVariant VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_loop Bool -> Bool -> Bool
||
                        VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> VName)
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName (Param (MemInfo SubExp Uniqueness MemBind) -> VName)
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
    -> Param (MemInfo SubExp Uniqueness MemBind))
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx_and_res

        loopInvariantSize :: SubExp -> Maybe (SubExp, Bool)
loopInvariantSize (Constant PrimValue
v) =
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
v, Bool
True)
        loopInvariantSize (Var VName
v) =
          case ((Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> Bool)
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Maybe (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool)
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> VName)
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName (Param (MemInfo SubExp Uniqueness MemBind) -> VName)
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
    -> Param (MemInfo SubExp Uniqueness MemBind))
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx_and_res of
            Just (Param (MemInfo SubExp Uniqueness MemBind)
_, Constant PrimValue
val) ->
              (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
val, Bool
False)
            Just (Param (MemInfo SubExp Uniqueness MemBind)
_, Var VName
v') | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
loopVariant VName
v' ->
              (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v', Bool
False)
            Just (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
_ ->
              Maybe (SubExp, Bool)
forall a. Maybe a
Nothing
            Maybe (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
Nothing ->
              (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v, Bool
True)

        sizeForMem :: VName -> Maybe (PrimExp VName, Bool)
sizeForMem VName
mem = [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a. [a] -> Maybe a
maybeHead ([(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool))
-> [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp Uniqueness MemBind)
 -> Maybe (PrimExp VName, Bool))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [(PrimExp VName, Bool)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemInfo SubExp Uniqueness MemBind -> Maybe (PrimExp VName, Bool)
arrayInMem (MemInfo SubExp Uniqueness MemBind -> Maybe (PrimExp VName, Bool))
-> (Param (MemInfo SubExp Uniqueness MemBind)
    -> MemInfo SubExp Uniqueness MemBind)
-> Param (MemInfo SubExp Uniqueness MemBind)
-> Maybe (PrimExp VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp Uniqueness MemBind)
-> MemInfo SubExp Uniqueness MemBind
forall attr. Param attr -> attr
paramAttr) [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
val_params
          where arrayInMem :: MemInfo SubExp Uniqueness MemBind -> Maybe (PrimExp VName, Bool)
arrayInMem (MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
arraymem IxFun
ixfun))
                  | IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun,
                    Just ([SubExp]
dims, [Bool]
b) <-
                      (SubExp -> Maybe (SubExp, Bool))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM SubExp -> Maybe (SubExp, Bool)
loopInvariantSize ([SubExp] -> Maybe ([SubExp], [Bool]))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape,
                    VName
mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arraymem =
                      (PrimExp VName, Bool) -> Maybe (PrimExp VName, Bool)
forall a. a -> Maybe a
Just (Type -> PrimExp VName
arraySizeInBytesExp (Type -> PrimExp VName) -> Type -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                             PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness,
                            [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
b)
                arrayInMem MemInfo SubExp Uniqueness MemBind
_ = Maybe (PrimExp VName, Bool)
forall a. Maybe a
Nothing

        buffer :: Param (MemInfo SubExp Uniqueness MemBind)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer Param (MemInfo SubExp Uniqueness MemBind)
fparam = case Param (MemInfo SubExp Uniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp Uniqueness MemBind)
fparam of
          Mem Space
space
            | Just (PrimExp VName
size, Bool
b) <- VName -> Maybe (PrimExp VName, Bool)
sizeForMem (VName -> Maybe (PrimExp VName, Bool))
-> VName -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp Uniqueness MemBind)
fparam -> do
                -- Let us double buffer this!
                VName
bufname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_mem"
                (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName (VName, Bool) -> Map VName (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m ())
-> (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall a b. (a -> b) -> a -> b
$ VName
-> (VName, Bool)
-> Map VName (VName, Bool)
-> Map VName (VName, Bool)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp Uniqueness MemBind)
fparam) (VName
bufname, Bool
b)
                DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Space -> Bool -> DoubleBuffer
BufferAlloc VName
bufname PrimExp VName
size Space
space Bool
b
          Array {}
            | MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) <- Param (MemInfo SubExp Uniqueness MemBind)
-> MemInfo SubExp Uniqueness MemBind
forall attr. Param attr -> attr
paramAttr Param (MemInfo SubExp Uniqueness MemBind)
fparam -> do
                Maybe (VName, Bool)
buffered <- (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName (VName, Bool) -> Maybe (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool)))
-> (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VName, Bool) -> Maybe (VName, Bool)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem
                case Maybe (VName, Bool)
buffered of
                  Just (VName
bufname, Bool
b) -> do
                    VName
copyname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_array"
                    DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> VName -> Bool -> DoubleBuffer
BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
b
                  Maybe (VName, Bool)
Nothing ->
                    DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer
          Type
_ -> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer

allocStms :: [(FParam ExplicitMemory, SubExp)] -> [DoubleBuffer]
          -> DoubleBufferM ([(FParam ExplicitMemory, SubExp)], [Stm ExplicitMemory])
allocStms :: [(FParam ExplicitMemory, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM
     ([(FParam ExplicitMemory, SubExp)], [Stm ExplicitMemory])
allocStms [(FParam ExplicitMemory, SubExp)]
merge = WriterT
  [Stm ExplicitMemory]
  DoubleBufferM
  [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> DoubleBufferM
     ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      [Stm ExplicitMemory])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   [Stm ExplicitMemory]
   DoubleBufferM
   [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
 -> DoubleBufferM
      ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
       [Stm ExplicitMemory]))
-> ([DoubleBuffer]
    -> WriterT
         [Stm ExplicitMemory]
         DoubleBufferM
         [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)])
-> [DoubleBuffer]
-> DoubleBufferM
     ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
      [Stm ExplicitMemory])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> DoubleBuffer
 -> WriterT
      [Stm ExplicitMemory]
      DoubleBufferM
      (Param (MemInfo SubExp Uniqueness MemBind), SubExp))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [DoubleBuffer]
-> WriterT
     [Stm ExplicitMemory]
     DoubleBufferM
     [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> DoubleBuffer
-> WriterT
     [Stm ExplicitMemory]
     DoubleBufferM
     (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
forall lore (m :: * -> *) (t :: (* -> *) -> * -> *) d ret lore
       inner.
(HasScope lore m, OpReturns lore, Checkable lore,
 MonadFreshNames m, MonadTrans t, Typed (MemInfo d Uniqueness ret),
 MonadWriter [Stm lore] (t m), BinderOps lore,
 FParamAttr lore ~ MemInfo SubExp Uniqueness MemBind,
 FParamAttr lore ~ FParamAttr lore, ExpAttr lore ~ (),
 LParamAttr lore ~ MemInfo SubExp NoUniqueness MemBind,
 LParamAttr lore ~ LParamAttr lore, Op lore ~ MemOp inner,
 BranchType lore ~ BodyReturns, RetType lore ~ FunReturns,
 LetAttr lore ~ MemInfo SubExp NoUniqueness MemBind,
 LetAttr lore ~ LetAttr lore,
 LetAttr lore ~ MemInfo SubExp NoUniqueness MemBind) =>
(Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
  where allocation :: (Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation m :: (Param (MemInfo d Uniqueness ret), SubExp)
m@(Param VName
pname MemInfo d Uniqueness ret
_, SubExp
_) (BufferAlloc VName
name PrimExp VName
size Space
space Bool
b) = do
          Stms lore
stms <- m (Stms lore) -> t m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> t m (Stms lore))
-> m (Stms lore) -> t m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
            SubExp
size' <- String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"double_buffer_size" (ExpT lore -> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) (ExpT lore)
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName
-> BinderT
     lore
     (State VNameSource)
     (Exp (Lore (BinderT lore (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
size
            [VName]
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
          [Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Stm lore] -> t m ()) -> [Stm lore] -> t m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms
          if Bool
b then (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall attr. VName -> attr -> Param attr
Param VName
pname (MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret))
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d Uniqueness ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, VName -> SubExp
Var VName
name)
               else (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret), SubExp)
m
        allocation (Param (MemInfo d Uniqueness ret)
f, Var VName
v) (BufferCopy VName
mem IxFun
_ VName
_ Bool
b) | Bool
b = do
          VName
v_copy <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_double_buffer_copy"
          (VName
_v_mem, IxFun
v_ixfun) <- m (VName, IxFun) -> t m (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun) -> t m (VName, IxFun))
-> m (VName, IxFun) -> t m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun)
forall lore (m :: * -> *).
(ExplicitMemorish lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
          let bt :: PrimType
bt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo d Uniqueness ret)
f
              shape :: ShapeBase SubExp
shape = Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeBase SubExp) -> Type -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo d Uniqueness ret)
f
              bound :: MemInfo SubExp NoUniqueness MemBind
bound = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
v_ixfun
          [Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Pattern lore -> StmAux (ExpAttr lore) -> ExpT lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
v_copy MemInfo SubExp NoUniqueness MemBind
bound]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (ExpT lore -> Stm lore) -> ExpT lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
                BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v]
          -- It is important that we treat this as a consumption, to
          -- avoid the Copy from being hoisted out of any enclosing
          -- loops.  Since we re-use (=overwrite) memory in the loop,
          -- the copy is critical for initialisation.  See issue #816.
          let uniqueMemInfo :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemArray PrimType
pt ShapeBase d
pshape Uniqueness
_ ret
ret) =
                PrimType
-> ShapeBase d -> Uniqueness -> ret -> MemInfo d Uniqueness ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pshape Uniqueness
Unique ret
ret
              uniqueMemInfo MemInfo d Uniqueness ret
info = MemInfo d Uniqueness ret
info
          (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
forall d ret. MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo d Uniqueness ret)
f, VName -> SubExp
Var VName
v_copy)
        allocation (Param (MemInfo d Uniqueness ret)
f, SubExp
se) DoubleBuffer
_ =
          (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret)
f, SubExp
se)

doubleBufferResult :: [FParam ExplicitMemory] -> [DoubleBuffer]
                   -> Body ExplicitMemory -> Body ExplicitMemory
doubleBufferResult :: [FParam ExplicitMemory]
-> [DoubleBuffer] -> Body ExplicitMemory -> Body ExplicitMemory
doubleBufferResult [FParam ExplicitMemory]
valparams [DoubleBuffer]
buffered (Body () Stms ExplicitMemory
bnds [SubExp]
res) =
  let ([SubExp]
ctx_res, [SubExp]
val_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param (MemInfo SubExp Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams) [SubExp]
res
      ([Maybe (Stm ExplicitMemory)]
copybnds,[SubExp]
val_res') =
        [(Maybe (Stm ExplicitMemory), SubExp)]
-> ([Maybe (Stm ExplicitMemory)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Stm ExplicitMemory), SubExp)]
 -> ([Maybe (Stm ExplicitMemory)], [SubExp]))
-> [(Maybe (Stm ExplicitMemory), SubExp)]
-> ([Maybe (Stm ExplicitMemory)], [SubExp])
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp Uniqueness MemBind)
 -> DoubleBuffer -> SubExp -> (Maybe (Stm ExplicitMemory), SubExp))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [DoubleBuffer]
-> [SubExp]
-> [(Maybe (Stm ExplicitMemory), SubExp)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param (MemInfo SubExp Uniqueness MemBind)
-> DoubleBuffer -> SubExp -> (Maybe (Stm ExplicitMemory), SubExp)
buffer [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams [DoubleBuffer]
buffered [SubExp]
val_res
  in BodyAttr ExplicitMemory
-> Stms ExplicitMemory -> [SubExp] -> Body ExplicitMemory
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms ExplicitMemory
bndsStms ExplicitMemory -> Stms ExplicitMemory -> Stms ExplicitMemory
forall a. Semigroup a => a -> a -> a
<>[Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Maybe (Stm ExplicitMemory)] -> [Stm ExplicitMemory]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm ExplicitMemory)]
copybnds)) ([SubExp] -> Body ExplicitMemory)
-> [SubExp] -> Body ExplicitMemory
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_res [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_res'
  where buffer :: Param (MemInfo SubExp Uniqueness MemBind)
-> DoubleBuffer -> SubExp -> (Maybe (Stm ExplicitMemory), SubExp)
buffer Param (MemInfo SubExp Uniqueness MemBind)
_ (BufferAlloc VName
bufname PrimExp VName
_ Space
_ Bool
_) SubExp
_ =
          (Maybe (Stm ExplicitMemory)
forall a. Maybe a
Nothing, VName -> SubExp
Var VName
bufname)

        buffer Param (MemInfo SubExp Uniqueness MemBind)
fparam (BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
_) (Var VName
v) =
          -- To construct the copy we will need to figure out its type
          -- based on the type of the function parameter.
          let t :: Type
t = Type -> Type
resultType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp Uniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp Uniqueness MemBind)
fparam
              summary :: MemInfo SubExp NoUniqueness MemBind
summary = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
bufname IxFun
ixfun
              copybnd :: Stm ExplicitMemory
copybnd = Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> ExpT ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
copyname MemInfo SubExp NoUniqueness MemBind
summary]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (ExpT ExplicitMemory -> Stm ExplicitMemory)
-> ExpT ExplicitMemory -> Stm ExplicitMemory
forall a b. (a -> b) -> a -> b
$
                        BasicOp ExplicitMemory -> ExpT ExplicitMemory
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp ExplicitMemory -> ExpT ExplicitMemory)
-> BasicOp ExplicitMemory -> ExpT ExplicitMemory
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp ExplicitMemory
forall lore. VName -> BasicOp lore
Copy VName
v
          in (Stm ExplicitMemory -> Maybe (Stm ExplicitMemory)
forall a. a -> Maybe a
Just Stm ExplicitMemory
copybnd, VName -> SubExp
Var VName
copyname)

        buffer Param (MemInfo SubExp Uniqueness MemBind)
_ DoubleBuffer
_ SubExp
se =
          (Maybe (Stm ExplicitMemory)
forall a. Maybe a
Nothing, SubExp
se)

        parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (MemInfo SubExp Uniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp Uniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams) [SubExp]
res

        resultType :: Type -> Type
resultType Type
t = Type
t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substitute (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t)

        substitute :: SubExp -> SubExp
substitute (Var VName
v)
          | Just SubExp
replacement <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
replacement
        substitute SubExp
se =
          SubExp
se