{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | The simplification engine is only willing to hoist allocations
-- out of loops if the memory block resulting from the allocation is
-- dead at the end of the loop.  If it is not, we may cause data
-- hazards.
--
-- This module rewrites loops with memory block merge parameters such
-- that each memory block is copied at the end of the iteration, thus
-- ensuring that any allocation inside the loop is dead at the end of
-- the loop.  This is only possible for allocations whose size is
-- loop-invariant, although the initial size may differ from the size
-- produced by the loop result.
--
-- Additionally, inside parallel kernels we also copy the initial
-- value.  This has the effect of making the memory block returned by
-- the array non-existential, which is important for later memory
-- expansion to work.
module Futhark.Optimise.DoubleBuffer (doubleBuffer) where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (find)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (arraySizeInBytesExp)
import Futhark.Pass.ExplicitAllocations.Kernels ()
import Futhark.Util (maybeHead)

-- | The double buffering pass definition.
doubleBuffer :: Pass KernelsMem KernelsMem
doubleBuffer :: Pass KernelsMem KernelsMem
doubleBuffer =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
    { passName :: String
passName = String
"Double buffer",
      passDescription :: String
passDescription = String
"Perform double buffering for merge parameters of sequential loops.",
      passFunction :: Prog KernelsMem -> PassM (Prog KernelsMem)
passFunction = (Scope KernelsMem -> Stms KernelsMem -> PassM (Stms KernelsMem))
-> Prog KernelsMem -> PassM (Prog KernelsMem)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope KernelsMem -> Stms KernelsMem -> PassM (Stms KernelsMem)
forall (m :: * -> *).
MonadFreshNames m =>
Scope KernelsMem -> Stms KernelsMem -> m (Stms KernelsMem)
optimise
    }
  where
    optimise :: Scope KernelsMem -> Stms KernelsMem -> m (Stms KernelsMem)
optimise Scope KernelsMem
scope Stms KernelsMem
stms = (VNameSource -> (Stms KernelsMem, VNameSource))
-> m (Stms KernelsMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms KernelsMem, VNameSource))
 -> m (Stms KernelsMem))
-> (VNameSource -> (Stms KernelsMem, VNameSource))
-> m (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
      let m :: ReaderT Env (State VNameSource) (Stms KernelsMem)
m =
            DoubleBufferM (Stms KernelsMem)
-> ReaderT Env (State VNameSource) (Stms KernelsMem)
forall a. DoubleBufferM a -> ReaderT Env (State VNameSource) a
runDoubleBufferM (DoubleBufferM (Stms KernelsMem)
 -> ReaderT Env (State VNameSource) (Stms KernelsMem))
-> DoubleBufferM (Stms KernelsMem)
-> ReaderT Env (State VNameSource) (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$
              Scope KernelsMem
-> DoubleBufferM (Stms KernelsMem)
-> DoubleBufferM (Stms KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope (DoubleBufferM (Stms KernelsMem)
 -> DoubleBufferM (Stms KernelsMem))
-> DoubleBufferM (Stms KernelsMem)
-> DoubleBufferM (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$
                ([Stm KernelsMem] -> Stms KernelsMem)
-> DoubleBufferM [Stm KernelsMem]
-> DoubleBufferM (Stms KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList (DoubleBufferM [Stm KernelsMem] -> DoubleBufferM (Stms KernelsMem))
-> DoubleBufferM [Stm KernelsMem]
-> DoubleBufferM (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
optimiseStms ([Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem])
-> [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms KernelsMem
stms
       in State VNameSource (Stms KernelsMem)
-> VNameSource -> (Stms KernelsMem, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT Env (State VNameSource) (Stms KernelsMem)
-> Env -> State VNameSource (Stms KernelsMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Env (State VNameSource) (Stms KernelsMem)
m Env
env) VNameSource
src

    env :: Env
env = Scope KernelsMem -> OptimiseLoop -> Env
Env Scope KernelsMem
forall a. Monoid a => a
mempty OptimiseLoop
forall (m :: * -> *) a b c d.
(Monad m, Monoid a) =>
b -> c -> d -> m (a, b, c, d)
doNotTouchLoop
    doNotTouchLoop :: b -> c -> d -> m (a, b, c, d)
doNotTouchLoop b
ctx c
val d
body = (a, b, c, d) -> m (a, b, c, d)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
forall a. Monoid a => a
mempty, b
ctx, c
val, d
body)

data Env = Env
  { Env -> Scope KernelsMem
envScope :: Scope KernelsMem,
    Env -> OptimiseLoop
envOptimiseLoop :: OptimiseLoop
  }

newtype DoubleBufferM a = DoubleBufferM {DoubleBufferM a -> ReaderT Env (State VNameSource) a
runDoubleBufferM :: ReaderT Env (State VNameSource) a}
  deriving (a -> DoubleBufferM b -> DoubleBufferM a
(a -> b) -> DoubleBufferM a -> DoubleBufferM b
(forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b)
-> (forall a b. a -> DoubleBufferM b -> DoubleBufferM a)
-> Functor DoubleBufferM
forall a b. a -> DoubleBufferM b -> DoubleBufferM a
forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DoubleBufferM b -> DoubleBufferM a
$c<$ :: forall a b. a -> DoubleBufferM b -> DoubleBufferM a
fmap :: (a -> b) -> DoubleBufferM a -> DoubleBufferM b
$cfmap :: forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b
Functor, Functor DoubleBufferM
a -> DoubleBufferM a
Functor DoubleBufferM
-> (forall a. a -> DoubleBufferM a)
-> (forall a b.
    DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b)
-> (forall a b c.
    (a -> b -> c)
    -> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c)
-> (forall a b.
    DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b)
-> (forall a b.
    DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a)
-> Applicative DoubleBufferM
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
forall a. a -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a b.
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
forall a b c.
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
$c<* :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
*> :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
$c*> :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
liftA2 :: (a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
$cliftA2 :: forall a b c.
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
<*> :: DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
$c<*> :: forall a b.
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
pure :: a -> DoubleBufferM a
$cpure :: forall a. a -> DoubleBufferM a
$cp1Applicative :: Functor DoubleBufferM
Applicative, Applicative DoubleBufferM
a -> DoubleBufferM a
Applicative DoubleBufferM
-> (forall a b.
    DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b)
-> (forall a b.
    DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b)
-> (forall a. a -> DoubleBufferM a)
-> Monad DoubleBufferM
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a. a -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a b.
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DoubleBufferM a
$creturn :: forall a. a -> DoubleBufferM a
>> :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
$c>> :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
>>= :: DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
$c>>= :: forall a b.
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
$cp1Monad :: Applicative DoubleBufferM
Monad, MonadReader Env, Monad DoubleBufferM
Applicative DoubleBufferM
DoubleBufferM VNameSource
Applicative DoubleBufferM
-> Monad DoubleBufferM
-> DoubleBufferM VNameSource
-> (VNameSource -> DoubleBufferM ())
-> MonadFreshNames DoubleBufferM
VNameSource -> DoubleBufferM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> DoubleBufferM ()
$cputNameSource :: VNameSource -> DoubleBufferM ()
getNameSource :: DoubleBufferM VNameSource
$cgetNameSource :: DoubleBufferM VNameSource
$cp2MonadFreshNames :: Monad DoubleBufferM
$cp1MonadFreshNames :: Applicative DoubleBufferM
MonadFreshNames)

instance HasScope KernelsMem DoubleBufferM where
  askScope :: DoubleBufferM (Scope KernelsMem)
askScope = (Env -> Scope KernelsMem) -> DoubleBufferM (Scope KernelsMem)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Scope KernelsMem
envScope

instance LocalScope KernelsMem DoubleBufferM where
  localScope :: Scope KernelsMem -> DoubleBufferM a -> DoubleBufferM a
localScope Scope KernelsMem
scope = (Env -> Env) -> DoubleBufferM a -> DoubleBufferM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> DoubleBufferM a -> DoubleBufferM a)
-> (Env -> Env) -> DoubleBufferM a -> DoubleBufferM a
forall a b. (a -> b) -> a -> b
$ \Env
env -> Env
env {envScope :: Scope KernelsMem
envScope = Env -> Scope KernelsMem
envScope Env
env Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> Scope KernelsMem
scope}

optimiseBody :: Body KernelsMem -> DoubleBufferM (Body KernelsMem)
optimiseBody :: Body KernelsMem -> DoubleBufferM (Body KernelsMem)
optimiseBody Body KernelsMem
body = do
  [Stm KernelsMem]
bnds' <- [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
optimiseStms ([Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem])
-> [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem -> [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms Body KernelsMem
body
  Body KernelsMem -> DoubleBufferM (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body KernelsMem -> DoubleBufferM (Body KernelsMem))
-> Body KernelsMem -> DoubleBufferM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Body KernelsMem
body {bodyStms :: Stms KernelsMem
bodyStms = [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm KernelsMem]
bnds'}

optimiseStms :: [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
optimiseStms :: [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
optimiseStms [] = [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall (m :: * -> *) a. Monad m => a -> m a
return []
optimiseStms (Stm KernelsMem
e : [Stm KernelsMem]
es) = do
  [Stm KernelsMem]
e_es <- Stm KernelsMem -> DoubleBufferM [Stm KernelsMem]
optimiseStm Stm KernelsMem
e
  [Stm KernelsMem]
es' <- Scope KernelsMem
-> DoubleBufferM [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope KernelsMem -> Scope KernelsMem
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope KernelsMem -> Scope KernelsMem)
-> Scope KernelsMem -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ [Stm KernelsMem] -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf [Stm KernelsMem]
e_es) (DoubleBufferM [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem])
-> DoubleBufferM [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
optimiseStms [Stm KernelsMem]
es
  [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem])
-> [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ [Stm KernelsMem]
e_es [Stm KernelsMem] -> [Stm KernelsMem] -> [Stm KernelsMem]
forall a. [a] -> [a] -> [a]
++ [Stm KernelsMem]
es'

optimiseStm :: Stm KernelsMem -> DoubleBufferM [Stm KernelsMem]
optimiseStm :: Stm KernelsMem -> DoubleBufferM [Stm KernelsMem]
optimiseStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux (DoLoop [(FParam KernelsMem, SubExp)]
ctx [(FParam KernelsMem, SubExp)]
val LoopForm KernelsMem
form Body KernelsMem
body)) = do
  Body KernelsMem
body' <-
    Scope KernelsMem
-> DoubleBufferM (Body KernelsMem)
-> DoubleBufferM (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm KernelsMem
form Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem] -> Scope KernelsMem
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst ([(Param FParamMem, SubExp)] -> [Param FParamMem])
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> a -> b
$ [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
val)) (DoubleBufferM (Body KernelsMem)
 -> DoubleBufferM (Body KernelsMem))
-> DoubleBufferM (Body KernelsMem)
-> DoubleBufferM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$
      Body KernelsMem -> DoubleBufferM (Body KernelsMem)
optimiseBody Body KernelsMem
body
  [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)]
-> Body KernelsMem
-> DoubleBufferM
     ([Stm KernelsMem], [(Param FParamMem, SubExp)],
      [(Param FParamMem, SubExp)], Body KernelsMem)
opt_loop <- (Env
 -> [(Param FParamMem, SubExp)]
 -> [(Param FParamMem, SubExp)]
 -> Body KernelsMem
 -> DoubleBufferM
      ([Stm KernelsMem], [(Param FParamMem, SubExp)],
       [(Param FParamMem, SubExp)], Body KernelsMem))
-> DoubleBufferM
     ([(Param FParamMem, SubExp)]
      -> [(Param FParamMem, SubExp)]
      -> Body KernelsMem
      -> DoubleBufferM
           ([Stm KernelsMem], [(Param FParamMem, SubExp)],
            [(Param FParamMem, SubExp)], Body KernelsMem))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> OptimiseLoop
Env
-> [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)]
-> Body KernelsMem
-> DoubleBufferM
     ([Stm KernelsMem], [(Param FParamMem, SubExp)],
      [(Param FParamMem, SubExp)], Body KernelsMem)
envOptimiseLoop
  ([Stm KernelsMem]
bnds, [(Param FParamMem, SubExp)]
ctx', [(Param FParamMem, SubExp)]
val', Body KernelsMem
body'') <- [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)]
-> Body KernelsMem
-> DoubleBufferM
     ([Stm KernelsMem], [(Param FParamMem, SubExp)],
      [(Param FParamMem, SubExp)], Body KernelsMem)
opt_loop [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
val Body KernelsMem
body'
  [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem])
-> [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ [Stm KernelsMem]
bnds [Stm KernelsMem] -> [Stm KernelsMem] -> [Stm KernelsMem]
forall a. [a] -> [a] -> [a]
++ [Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$ [(FParam KernelsMem, SubExp)]
-> [(FParam KernelsMem, SubExp)]
-> LoopForm KernelsMem
-> Body KernelsMem
-> ExpT KernelsMem
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx' [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
val' LoopForm KernelsMem
form Body KernelsMem
body'']
optimiseStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux ExpT KernelsMem
e) =
  Stm KernelsMem -> [Stm KernelsMem]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm KernelsMem -> [Stm KernelsMem])
-> (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem
-> [Stm KernelsMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux (ExpT KernelsMem -> [Stm KernelsMem])
-> DoubleBufferM (ExpT KernelsMem)
-> DoubleBufferM [Stm KernelsMem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper KernelsMem KernelsMem DoubleBufferM
-> ExpT KernelsMem -> DoubleBufferM (ExpT KernelsMem)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper KernelsMem KernelsMem DoubleBufferM
optimise ExpT KernelsMem
e
  where
    optimise :: Mapper KernelsMem KernelsMem DoubleBufferM
optimise =
      Mapper KernelsMem KernelsMem DoubleBufferM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope KernelsMem
-> Body KernelsMem -> DoubleBufferM (Body KernelsMem)
mapOnBody = \Scope KernelsMem
_ Body KernelsMem
x ->
            Body KernelsMem -> DoubleBufferM (Body KernelsMem)
optimiseBody Body KernelsMem
x :: DoubleBufferM (Body KernelsMem),
          mapOnOp :: Op KernelsMem -> DoubleBufferM (Op KernelsMem)
mapOnOp = Op KernelsMem -> DoubleBufferM (Op KernelsMem)
optimiseOp
        }

optimiseOp ::
  Op KernelsMem ->
  DoubleBufferM (Op KernelsMem)
optimiseOp :: Op KernelsMem -> DoubleBufferM (Op KernelsMem)
optimiseOp (Inner (SegOp op)) =
  (Env -> Env)
-> DoubleBufferM (MemOp (HostOp KernelsMem ()))
-> DoubleBufferM (MemOp (HostOp KernelsMem ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env -> Env
inSegOp (DoubleBufferM (MemOp (HostOp KernelsMem ()))
 -> DoubleBufferM (MemOp (HostOp KernelsMem ())))
-> DoubleBufferM (MemOp (HostOp KernelsMem ()))
-> DoubleBufferM (MemOp (HostOp KernelsMem ()))
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem
-> MemOp (HostOp KernelsMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> MemOp (HostOp KernelsMem ()))
-> DoubleBufferM (SegOp SegLevel KernelsMem)
-> DoubleBufferM (MemOp (HostOp KernelsMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel KernelsMem KernelsMem DoubleBufferM
-> SegOp SegLevel KernelsMem
-> DoubleBufferM (SegOp SegLevel KernelsMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper SegLevel KernelsMem KernelsMem DoubleBufferM
forall lvl. SegOpMapper lvl KernelsMem KernelsMem DoubleBufferM
mapper SegOp SegLevel KernelsMem
op
  where
    mapper :: SegOpMapper lvl KernelsMem KernelsMem DoubleBufferM
mapper =
      SegOpMapper lvl Any Any DoubleBufferM
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda KernelsMem -> DoubleBufferM (Lambda KernelsMem)
mapOnSegOpLambda = Lambda KernelsMem -> DoubleBufferM (Lambda KernelsMem)
optimiseLambda,
          mapOnSegOpBody :: KernelBody KernelsMem -> DoubleBufferM (KernelBody KernelsMem)
mapOnSegOpBody = KernelBody KernelsMem -> DoubleBufferM (KernelBody KernelsMem)
optimiseKernelBody
        }
    inSegOp :: Env -> Env
inSegOp Env
env = Env
env {envOptimiseLoop :: OptimiseLoop
envOptimiseLoop = OptimiseLoop
optimiseLoop}
optimiseOp Op KernelsMem
op = MemOp (HostOp KernelsMem ())
-> DoubleBufferM (MemOp (HostOp KernelsMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Op KernelsMem
MemOp (HostOp KernelsMem ())
op

optimiseKernelBody ::
  KernelBody KernelsMem ->
  DoubleBufferM (KernelBody KernelsMem)
optimiseKernelBody :: KernelBody KernelsMem -> DoubleBufferM (KernelBody KernelsMem)
optimiseKernelBody KernelBody KernelsMem
kbody = do
  [Stm KernelsMem]
stms' <- [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
optimiseStms ([Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem])
-> [Stm KernelsMem] -> DoubleBufferM [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem -> [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody
  KernelBody KernelsMem -> DoubleBufferM (KernelBody KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody KernelsMem -> DoubleBufferM (KernelBody KernelsMem))
-> KernelBody KernelsMem -> DoubleBufferM (KernelBody KernelsMem)
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem
kbody {kernelBodyStms :: Stms KernelsMem
kernelBodyStms = [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm KernelsMem]
stms'}

optimiseLambda :: Lambda KernelsMem -> DoubleBufferM (Lambda KernelsMem)
optimiseLambda :: Lambda KernelsMem -> DoubleBufferM (Lambda KernelsMem)
optimiseLambda Lambda KernelsMem
lam = do
  Body KernelsMem
body <- Scope KernelsMem
-> DoubleBufferM (Body KernelsMem)
-> DoubleBufferM (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope KernelsMem -> Scope KernelsMem
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope KernelsMem -> Scope KernelsMem)
-> Scope KernelsMem -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda KernelsMem
lam) (DoubleBufferM (Body KernelsMem)
 -> DoubleBufferM (Body KernelsMem))
-> DoubleBufferM (Body KernelsMem)
-> DoubleBufferM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Body KernelsMem -> DoubleBufferM (Body KernelsMem)
optimiseBody (Body KernelsMem -> DoubleBufferM (Body KernelsMem))
-> Body KernelsMem -> DoubleBufferM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
  Lambda KernelsMem -> DoubleBufferM (Lambda KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda KernelsMem
lam {lambdaBody :: Body KernelsMem
lambdaBody = Body KernelsMem
body}

type OptimiseLoop =
  [(FParam KernelsMem, SubExp)] ->
  [(FParam KernelsMem, SubExp)] ->
  Body KernelsMem ->
  DoubleBufferM
    ( [Stm KernelsMem],
      [(FParam KernelsMem, SubExp)],
      [(FParam KernelsMem, SubExp)],
      Body KernelsMem
    )

optimiseLoop :: OptimiseLoop
optimiseLoop :: OptimiseLoop
optimiseLoop [(FParam KernelsMem, SubExp)]
ctx [(FParam KernelsMem, SubExp)]
val Body KernelsMem
body = do
  -- We start out by figuring out which of the merge variables should
  -- be double-buffered.
  [DoubleBuffer]
buffered <-
    [(FParam KernelsMem, SubExp)]
-> [FParam KernelsMem] -> Names -> DoubleBufferM [DoubleBuffer]
forall (m :: * -> *).
MonadFreshNames m =>
[(FParam KernelsMem, SubExp)]
-> [FParam KernelsMem] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams
      ([Param FParamMem] -> [SubExp] -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx) (Body KernelsMem -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body KernelsMem
body))
      (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge)
      (Body KernelsMem -> Names
forall lore. Body lore -> Names
boundInBody Body KernelsMem
body)
  -- Then create the allocations of the buffers and copies of the
  -- initial values.
  ([(Param FParamMem, SubExp)]
merge', [Stm KernelsMem]
allocs) <- [(FParam KernelsMem, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM ([(FParam KernelsMem, SubExp)], [Stm KernelsMem])
allocStms [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
merge [DoubleBuffer]
buffered
  -- Modify the loop body to copy buffered result arrays.
  let body' :: Body KernelsMem
body' = [FParam KernelsMem]
-> [DoubleBuffer] -> Body KernelsMem -> Body KernelsMem
doubleBufferResult (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge) [DoubleBuffer]
buffered Body KernelsMem
body
      ([(Param FParamMem, SubExp)]
ctx', [(Param FParamMem, SubExp)]
val') = Int
-> [(Param FParamMem, SubExp)]
-> ([(Param FParamMem, SubExp)], [(Param FParamMem, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param FParamMem, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx) [(Param FParamMem, SubExp)]
merge'
  -- Modify the initial merge p
  ([Stm KernelsMem], [(Param FParamMem, SubExp)],
 [(Param FParamMem, SubExp)], Body KernelsMem)
-> DoubleBufferM
     ([Stm KernelsMem], [(Param FParamMem, SubExp)],
      [(Param FParamMem, SubExp)], Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm KernelsMem]
allocs, [(Param FParamMem, SubExp)]
ctx', [(Param FParamMem, SubExp)]
val', Body KernelsMem
body')
  where
    merge :: [(Param FParamMem, SubExp)]
merge = [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
val

-- | The booleans indicate whether we should also play with the
-- initial merge values.
data DoubleBuffer
  = BufferAlloc VName (PrimExp VName) Space Bool
  | -- | First name is the memory block to copy to,
    -- second is the name of the array copy.
    BufferCopy VName IxFun VName Bool
  | NoBuffer
  deriving (Int -> DoubleBuffer -> ShowS
[DoubleBuffer] -> ShowS
DoubleBuffer -> String
(Int -> DoubleBuffer -> ShowS)
-> (DoubleBuffer -> String)
-> ([DoubleBuffer] -> ShowS)
-> Show DoubleBuffer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DoubleBuffer] -> ShowS
$cshowList :: [DoubleBuffer] -> ShowS
show :: DoubleBuffer -> String
$cshow :: DoubleBuffer -> String
showsPrec :: Int -> DoubleBuffer -> ShowS
$cshowsPrec :: Int -> DoubleBuffer -> ShowS
Show)

doubleBufferMergeParams ::
  MonadFreshNames m =>
  [(FParam KernelsMem, SubExp)] ->
  [FParam KernelsMem] ->
  Names ->
  m [DoubleBuffer]
doubleBufferMergeParams :: [(FParam KernelsMem, SubExp)]
-> [FParam KernelsMem] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams [(FParam KernelsMem, SubExp)]
ctx_and_res [FParam KernelsMem]
val_params Names
bound_in_loop =
  StateT (Map VName (VName, Bool)) m [DoubleBuffer]
-> Map VName (VName, Bool) -> m [DoubleBuffer]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((Param FParamMem
 -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> [Param FParamMem]
-> StateT (Map VName (VName, Bool)) m [DoubleBuffer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param FParamMem -> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer [FParam KernelsMem]
[Param FParamMem]
val_params) Map VName (VName, Bool)
forall k a. Map k a
M.empty
  where
    loopVariant :: VName -> Bool
loopVariant VName
v =
      VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_loop
        Bool -> Bool -> Bool
|| VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx_and_res

    loopInvariantSize :: SubExp -> Maybe (SubExp, Bool)
loopInvariantSize (Constant PrimValue
v) =
      (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
v, Bool
True)
    loopInvariantSize (Var VName
v) =
      case ((Param FParamMem, SubExp) -> Bool)
-> [(Param FParamMem, SubExp)] -> Maybe (Param FParamMem, SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
ctx_and_res of
        Just (Param FParamMem
_, Constant PrimValue
val) ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
val, Bool
False)
        Just (Param FParamMem
_, Var VName
v')
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
loopVariant VName
v' ->
            (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v', Bool
False)
        Just (Param FParamMem, SubExp)
_ ->
          Maybe (SubExp, Bool)
forall a. Maybe a
Nothing
        Maybe (Param FParamMem, SubExp)
Nothing ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v, Bool
True)

    sizeForMem :: VName -> Maybe (PrimExp VName, Bool)
sizeForMem VName
mem = [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a. [a] -> Maybe a
maybeHead ([(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool))
-> [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (PrimExp VName, Bool))
-> [Param FParamMem] -> [(PrimExp VName, Bool)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (FParamMem -> Maybe (PrimExp VName, Bool))
-> (Param FParamMem -> FParamMem)
-> Param FParamMem
-> Maybe (PrimExp VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [FParam KernelsMem]
[Param FParamMem]
val_params
      where
        arrayInMem :: FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
arraymem IxFun
ixfun))
          | IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun,
            Just ([SubExp]
dims, [Bool]
b) <-
              (SubExp -> Maybe (SubExp, Bool))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM SubExp -> Maybe (SubExp, Bool)
loopInvariantSize ([SubExp] -> Maybe ([SubExp], [Bool]))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape,
            VName
mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arraymem =
            (PrimExp VName, Bool) -> Maybe (PrimExp VName, Bool)
forall a. a -> Maybe a
Just
              ( Type -> PrimExp VName
arraySizeInBytesExp (Type -> PrimExp VName) -> Type -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                  PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness,
                [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
b
              )
        arrayInMem FParamMem
_ = Maybe (PrimExp VName, Bool)
forall a. Maybe a
Nothing

    buffer :: Param FParamMem -> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer Param FParamMem
fparam = case Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam of
      Mem Space
space
        | Just (PrimExp VName
size, Bool
b) <- VName -> Maybe (PrimExp VName, Bool)
sizeForMem (VName -> Maybe (PrimExp VName, Bool))
-> VName -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam -> do
          -- Let us double buffer this!
          VName
bufname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_mem"
          (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName (VName, Bool) -> Map VName (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m ())
-> (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall a b. (a -> b) -> a -> b
$ VName
-> (VName, Bool)
-> Map VName (VName, Bool)
-> Map VName (VName, Bool)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam) (VName
bufname, Bool
b)
          DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Space -> Bool -> DoubleBuffer
BufferAlloc VName
bufname PrimExp VName
size Space
space Bool
b
      Array {}
        | MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) <- Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
fparam -> do
          Maybe (VName, Bool)
buffered <- (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName (VName, Bool) -> Maybe (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool)))
-> (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VName, Bool) -> Maybe (VName, Bool)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem
          case Maybe (VName, Bool)
buffered of
            Just (VName
bufname, Bool
b) -> do
              VName
copyname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_array"
              DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> VName -> Bool -> DoubleBuffer
BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
b
            Maybe (VName, Bool)
Nothing ->
              DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer
      Type
_ -> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer

allocStms ::
  [(FParam KernelsMem, SubExp)] ->
  [DoubleBuffer] ->
  DoubleBufferM ([(FParam KernelsMem, SubExp)], [Stm KernelsMem])
allocStms :: [(FParam KernelsMem, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM ([(FParam KernelsMem, SubExp)], [Stm KernelsMem])
allocStms [(FParam KernelsMem, SubExp)]
merge = WriterT [Stm KernelsMem] DoubleBufferM [(Param FParamMem, SubExp)]
-> DoubleBufferM ([(Param FParamMem, SubExp)], [Stm KernelsMem])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm KernelsMem] DoubleBufferM [(Param FParamMem, SubExp)]
 -> DoubleBufferM ([(Param FParamMem, SubExp)], [Stm KernelsMem]))
-> ([DoubleBuffer]
    -> WriterT
         [Stm KernelsMem] DoubleBufferM [(Param FParamMem, SubExp)])
-> [DoubleBuffer]
-> DoubleBufferM ([(Param FParamMem, SubExp)], [Stm KernelsMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp)
 -> DoubleBuffer
 -> WriterT
      [Stm KernelsMem] DoubleBufferM (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp)]
-> [DoubleBuffer]
-> WriterT
     [Stm KernelsMem] DoubleBufferM [(Param FParamMem, SubExp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param FParamMem, SubExp)
-> DoubleBuffer
-> WriterT [Stm KernelsMem] DoubleBufferM (Param FParamMem, SubExp)
forall lore (m :: * -> *) (t :: (* -> *) -> * -> *) d ret lore
       inner.
(HasScope lore m, OpReturns lore, AllocOp (Op lore),
 MonadFreshNames m, MonadTrans t, Typed (MemInfo d Uniqueness ret),
 MonadWriter [Stm lore] (t m), ASTLore lore, BinderOps lore,
 RetType lore ~ RetTypeMem, ExpDec lore ~ (),
 BranchType lore ~ BranchTypeMem, Op lore ~ MemOp inner,
 LParamInfo lore ~ LParamMem, LParamInfo lore ~ LParamInfo lore,
 FParamInfo lore ~ FParamMem, FParamInfo lore ~ FParamInfo lore,
 LetDec lore ~ LParamMem, LetDec lore ~ LetDec lore,
 LetDec lore ~ LParamMem) =>
(Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation [(FParam KernelsMem, SubExp)]
[(Param FParamMem, SubExp)]
merge
  where
    allocation :: (Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation m :: (Param (MemInfo d Uniqueness ret), SubExp)
m@(Param VName
pname MemInfo d Uniqueness ret
_, SubExp
_) (BufferAlloc VName
name PrimExp VName
size Space
space Bool
b) = do
      Stms lore
stms <- m (Stms lore) -> t m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> t m (Stms lore))
-> m (Stms lore) -> t m (Stms lore)
forall a b. (a -> b) -> a -> b
$
        Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
          SubExp
size' <- String -> PrimExp VName -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"double_buffer_size" PrimExp VName
size
          [VName]
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
      [Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Stm lore] -> t m ()) -> [Stm lore] -> t m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms
      if Bool
b
        then (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall dec. VName -> dec -> Param dec
Param VName
pname (MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret))
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d Uniqueness ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, VName -> SubExp
Var VName
name)
        else (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret), SubExp)
m
    allocation (Param (MemInfo d Uniqueness ret)
f, Var VName
v) (BufferCopy VName
mem IxFun
_ VName
_ Bool
b) | Bool
b = do
      VName
v_copy <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_double_buffer_copy"
      (VName
_v_mem, IxFun
v_ixfun) <- m (VName, IxFun) -> t m (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun) -> t m (VName, IxFun))
-> m (VName, IxFun) -> t m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
      let bt :: PrimType
bt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          shape :: ShapeBase SubExp
shape = Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeBase SubExp) -> Type -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          bound :: LParamMem
bound = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
v_ixfun
      [Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
        [ Pattern lore -> StmAux (ExpDec lore) -> ExpT lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT LParamMem] -> [PatElemT LParamMem] -> PatternT LParamMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v_copy LParamMem
bound]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT lore -> Stm lore) -> ExpT lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
        ]
      -- It is important that we treat this as a consumption, to
      -- avoid the Copy from being hoisted out of any enclosing
      -- loops.  Since we re-use (=overwrite) memory in the loop,
      -- the copy is critical for initialisation.  See issue #816.
      let uniqueMemInfo :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemArray PrimType
pt ShapeBase d
pshape Uniqueness
_ ret
ret) =
            PrimType
-> ShapeBase d -> Uniqueness -> ret -> MemInfo d Uniqueness ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pshape Uniqueness
Unique ret
ret
          uniqueMemInfo MemInfo d Uniqueness ret
info = MemInfo d Uniqueness ret
info
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
forall d ret. MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo d Uniqueness ret)
f, VName -> SubExp
Var VName
v_copy)
    allocation (Param (MemInfo d Uniqueness ret)
f, SubExp
se) DoubleBuffer
_ =
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret)
f, SubExp
se)

doubleBufferResult ::
  [FParam KernelsMem] ->
  [DoubleBuffer] ->
  Body KernelsMem ->
  Body KernelsMem
doubleBufferResult :: [FParam KernelsMem]
-> [DoubleBuffer] -> Body KernelsMem -> Body KernelsMem
doubleBufferResult [FParam KernelsMem]
valparams [DoubleBuffer]
buffered (Body () Stms KernelsMem
bnds [SubExp]
res) =
  let ([SubExp]
ctx_res, [SubExp]
val_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam KernelsMem]
[Param FParamMem]
valparams) [SubExp]
res
      ([Maybe (Stm KernelsMem)]
copybnds, [SubExp]
val_res') =
        [(Maybe (Stm KernelsMem), SubExp)]
-> ([Maybe (Stm KernelsMem)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Stm KernelsMem), SubExp)]
 -> ([Maybe (Stm KernelsMem)], [SubExp]))
-> [(Maybe (Stm KernelsMem), SubExp)]
-> ([Maybe (Stm KernelsMem)], [SubExp])
forall a b. (a -> b) -> a -> b
$ (Param FParamMem
 -> DoubleBuffer -> SubExp -> (Maybe (Stm KernelsMem), SubExp))
-> [Param FParamMem]
-> [DoubleBuffer]
-> [SubExp]
-> [(Maybe (Stm KernelsMem), SubExp)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param FParamMem
-> DoubleBuffer -> SubExp -> (Maybe (Stm KernelsMem), SubExp)
buffer [FParam KernelsMem]
[Param FParamMem]
valparams [DoubleBuffer]
buffered [SubExp]
val_res
   in BodyDec KernelsMem
-> Stms KernelsMem -> [SubExp] -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms KernelsMem
bnds Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Maybe (Stm KernelsMem)] -> [Stm KernelsMem]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm KernelsMem)]
copybnds)) ([SubExp] -> Body KernelsMem) -> [SubExp] -> Body KernelsMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_res [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_res'
  where
    buffer :: Param FParamMem
-> DoubleBuffer -> SubExp -> (Maybe (Stm KernelsMem), SubExp)
buffer Param FParamMem
_ (BufferAlloc VName
bufname PrimExp VName
_ Space
_ Bool
_) SubExp
_ =
      (Maybe (Stm KernelsMem)
forall a. Maybe a
Nothing, VName -> SubExp
Var VName
bufname)
    buffer Param FParamMem
fparam (BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
_) (Var VName
v) =
      -- To construct the copy we will need to figure out its type
      -- based on the type of the function parameter.
      let t :: Type
t = Type -> Type
resultType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam
          summary :: LParamMem
summary = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
bufname IxFun
ixfun
          copybnd :: Stm KernelsMem
copybnd =
            Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT LParamMem] -> [PatElemT LParamMem] -> PatternT LParamMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
copyname LParamMem
summary]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$
              BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
       in (Stm KernelsMem -> Maybe (Stm KernelsMem)
forall a. a -> Maybe a
Just Stm KernelsMem
copybnd, VName -> SubExp
Var VName
copyname)
    buffer Param FParamMem
_ DoubleBuffer
_ SubExp
se =
      (Maybe (Stm KernelsMem)
forall a. Maybe a
Nothing, SubExp
se)

    parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [FParam KernelsMem]
[Param FParamMem]
valparams) [SubExp]
res

    resultType :: Type -> Type
resultType Type
t = Type
t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substitute (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t)

    substitute :: SubExp -> SubExp
substitute (Var VName
v)
      | Just SubExp
replacement <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
replacement
    substitute SubExp
se =
      SubExp
se