{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Futhark.Optimise.DoubleBuffer
( doubleBuffer )
where
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader
import qualified Data.Map.Strict as M
import Data.Maybe
import Data.List (find)
import Futhark.Construct
import Futhark.Representation.AST
import Futhark.Pass.ExplicitAllocations (arraySizeInBytesExp)
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Representation.ExplicitMemory
hiding (Prog, Body, Stm, Pattern, PatElem,
BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Pass
import Futhark.Util (maybeHead)
doubleBuffer :: Pass ExplicitMemory ExplicitMemory
doubleBuffer :: Pass ExplicitMemory ExplicitMemory
doubleBuffer =
Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass { passName :: String
passName = String
"Double buffer"
, passDescription :: String
passDescription = String
"Perform double buffering for merge parameters of sequential loops."
, passFunction :: Prog ExplicitMemory -> PassM (Prog ExplicitMemory)
passFunction = (Scope ExplicitMemory
-> Stms ExplicitMemory -> PassM (Stms ExplicitMemory))
-> Prog ExplicitMemory -> PassM (Prog ExplicitMemory)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope ExplicitMemory
-> Stms ExplicitMemory -> PassM (Stms ExplicitMemory)
forall (m :: * -> *).
MonadFreshNames m =>
Scope ExplicitMemory
-> Stms ExplicitMemory -> m (Stms ExplicitMemory)
optimise
}
where optimise :: Scope ExplicitMemory
-> Stms ExplicitMemory -> m (Stms ExplicitMemory)
optimise Scope ExplicitMemory
scope Stms ExplicitMemory
stms = (VNameSource -> (Stms ExplicitMemory, VNameSource))
-> m (Stms ExplicitMemory)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms ExplicitMemory, VNameSource))
-> m (Stms ExplicitMemory))
-> (VNameSource -> (Stms ExplicitMemory, VNameSource))
-> m (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let m :: ReaderT Env (State VNameSource) (Stms ExplicitMemory)
m = DoubleBufferM (Stms ExplicitMemory)
-> ReaderT Env (State VNameSource) (Stms ExplicitMemory)
forall a. DoubleBufferM a -> ReaderT Env (State VNameSource) a
runDoubleBufferM (DoubleBufferM (Stms ExplicitMemory)
-> ReaderT Env (State VNameSource) (Stms ExplicitMemory))
-> DoubleBufferM (Stms ExplicitMemory)
-> ReaderT Env (State VNameSource) (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Scope ExplicitMemory
-> DoubleBufferM (Stms ExplicitMemory)
-> DoubleBufferM (Stms ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope (DoubleBufferM (Stms ExplicitMemory)
-> DoubleBufferM (Stms ExplicitMemory))
-> DoubleBufferM (Stms ExplicitMemory)
-> DoubleBufferM (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
([Stm ExplicitMemory] -> Stms ExplicitMemory)
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM (Stms ExplicitMemory)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList (DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM (Stms ExplicitMemory))
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms ExplicitMemory
stms
in State VNameSource (Stms ExplicitMemory)
-> VNameSource -> (Stms ExplicitMemory, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT Env (State VNameSource) (Stms ExplicitMemory)
-> Env -> State VNameSource (Stms ExplicitMemory)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Env (State VNameSource) (Stms ExplicitMemory)
m Env
env) VNameSource
src
env :: Env
env = Scope ExplicitMemory -> OptimiseLoop -> Env
Env Scope ExplicitMemory
forall a. Monoid a => a
mempty OptimiseLoop
forall (m :: * -> *) a b c d.
(Monad m, Monoid a) =>
b -> c -> d -> m (a, b, c, d)
doNotTouchLoop
doNotTouchLoop :: b -> c -> d -> m (a, b, c, d)
doNotTouchLoop b
ctx c
val d
body = (a, b, c, d) -> m (a, b, c, d)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
forall a. Monoid a => a
mempty, b
ctx, c
val, d
body)
data Env = Env { Env -> Scope ExplicitMemory
envScope :: Scope ExplicitMemory
, Env -> OptimiseLoop
envOptimiseLoop :: OptimiseLoop
}
newtype DoubleBufferM a =
DoubleBufferM { DoubleBufferM a -> ReaderT Env (State VNameSource) a
runDoubleBufferM :: ReaderT Env (State VNameSource) a }
deriving (a -> DoubleBufferM b -> DoubleBufferM a
(a -> b) -> DoubleBufferM a -> DoubleBufferM b
(forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b)
-> (forall a b. a -> DoubleBufferM b -> DoubleBufferM a)
-> Functor DoubleBufferM
forall a b. a -> DoubleBufferM b -> DoubleBufferM a
forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DoubleBufferM b -> DoubleBufferM a
$c<$ :: forall a b. a -> DoubleBufferM b -> DoubleBufferM a
fmap :: (a -> b) -> DoubleBufferM a -> DoubleBufferM b
$cfmap :: forall a b. (a -> b) -> DoubleBufferM a -> DoubleBufferM b
Functor, Functor DoubleBufferM
a -> DoubleBufferM a
Functor DoubleBufferM
-> (forall a. a -> DoubleBufferM a)
-> (forall a b.
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b)
-> (forall a b c.
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c)
-> (forall a b.
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b)
-> (forall a b.
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a)
-> Applicative DoubleBufferM
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
forall a. a -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a b.
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
forall a b c.
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
$c<* :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM a
*> :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
$c*> :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
liftA2 :: (a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
$cliftA2 :: forall a b c.
(a -> b -> c)
-> DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM c
<*> :: DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
$c<*> :: forall a b.
DoubleBufferM (a -> b) -> DoubleBufferM a -> DoubleBufferM b
pure :: a -> DoubleBufferM a
$cpure :: forall a. a -> DoubleBufferM a
$cp1Applicative :: Functor DoubleBufferM
Applicative, Applicative DoubleBufferM
a -> DoubleBufferM a
Applicative DoubleBufferM
-> (forall a b.
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b)
-> (forall a b.
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b)
-> (forall a. a -> DoubleBufferM a)
-> Monad DoubleBufferM
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a. a -> DoubleBufferM a
forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
forall a b.
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DoubleBufferM a
$creturn :: forall a. a -> DoubleBufferM a
>> :: DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
$c>> :: forall a b. DoubleBufferM a -> DoubleBufferM b -> DoubleBufferM b
>>= :: DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
$c>>= :: forall a b.
DoubleBufferM a -> (a -> DoubleBufferM b) -> DoubleBufferM b
$cp1Monad :: Applicative DoubleBufferM
Monad, MonadReader Env, Monad DoubleBufferM
Applicative DoubleBufferM
DoubleBufferM VNameSource
Applicative DoubleBufferM
-> Monad DoubleBufferM
-> DoubleBufferM VNameSource
-> (VNameSource -> DoubleBufferM ())
-> MonadFreshNames DoubleBufferM
VNameSource -> DoubleBufferM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> DoubleBufferM ()
$cputNameSource :: VNameSource -> DoubleBufferM ()
getNameSource :: DoubleBufferM VNameSource
$cgetNameSource :: DoubleBufferM VNameSource
$cp2MonadFreshNames :: Monad DoubleBufferM
$cp1MonadFreshNames :: Applicative DoubleBufferM
MonadFreshNames)
instance HasScope ExplicitMemory DoubleBufferM where
askScope :: DoubleBufferM (Scope ExplicitMemory)
askScope = (Env -> Scope ExplicitMemory)
-> DoubleBufferM (Scope ExplicitMemory)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Scope ExplicitMemory
envScope
instance LocalScope ExplicitMemory DoubleBufferM where
localScope :: Scope ExplicitMemory -> DoubleBufferM a -> DoubleBufferM a
localScope Scope ExplicitMemory
scope = (Env -> Env) -> DoubleBufferM a -> DoubleBufferM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> DoubleBufferM a -> DoubleBufferM a)
-> (Env -> Env) -> DoubleBufferM a -> DoubleBufferM a
forall a b. (a -> b) -> a -> b
$ \Env
env -> Env
env { envScope :: Scope ExplicitMemory
envScope = Env -> Scope ExplicitMemory
envScope Env
env Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> Scope ExplicitMemory
scope }
optimiseBody :: Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody :: Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody Body ExplicitMemory
body = do
[Stm ExplicitMemory]
bnds' <- [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms ExplicitMemory -> [Stm ExplicitMemory])
-> Stms ExplicitMemory -> [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Body ExplicitMemory -> Stms ExplicitMemory
forall lore. BodyT lore -> Stms lore
bodyStms Body ExplicitMemory
body
Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory))
-> Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Body ExplicitMemory
body { bodyStms :: Stms ExplicitMemory
bodyStms = [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm ExplicitMemory]
bnds' }
optimiseStms :: [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms :: [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms [] = [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall (m :: * -> *) a. Monad m => a -> m a
return []
optimiseStms (Stm ExplicitMemory
e:[Stm ExplicitMemory]
es) = do
[Stm ExplicitMemory]
e_es <- Stm ExplicitMemory -> DoubleBufferM [Stm ExplicitMemory]
optimiseStm Stm ExplicitMemory
e
[Stm ExplicitMemory]
es' <- Scope ExplicitMemory
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM [Stm ExplicitMemory]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope ExplicitMemory -> Scope ExplicitMemory
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope ExplicitMemory -> Scope ExplicitMemory)
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory] -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf [Stm ExplicitMemory]
e_es) (DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM [Stm ExplicitMemory])
-> DoubleBufferM [Stm ExplicitMemory]
-> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms [Stm ExplicitMemory]
es
[Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory]
e_es [Stm ExplicitMemory]
-> [Stm ExplicitMemory] -> [Stm ExplicitMemory]
forall a. [a] -> [a] -> [a]
++ [Stm ExplicitMemory]
es'
optimiseStm :: Stm ExplicitMemory -> DoubleBufferM [Stm ExplicitMemory]
optimiseStm :: Stm ExplicitMemory -> DoubleBufferM [Stm ExplicitMemory]
optimiseStm (Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux (DoLoop [(FParam ExplicitMemory, SubExp)]
ctx [(FParam ExplicitMemory, SubExp)]
val LoopForm ExplicitMemory
form Body ExplicitMemory
body)) = do
Body ExplicitMemory
body' <- Scope ExplicitMemory
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm ExplicitMemory
form Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)] -> Scope ExplicitMemory
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)])
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val)) (DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory))
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody Body ExplicitMemory
body
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
([Stm ExplicitMemory],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
Body ExplicitMemory)
opt_loop <- (Env
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
([Stm ExplicitMemory],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
Body ExplicitMemory))
-> DoubleBufferM
([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
([Stm ExplicitMemory],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
Body ExplicitMemory))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> OptimiseLoop
Env
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
([Stm ExplicitMemory],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
Body ExplicitMemory)
envOptimiseLoop
([Stm ExplicitMemory]
bnds, [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx', [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val', Body ExplicitMemory
body'') <- [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Body ExplicitMemory
-> DoubleBufferM
([Stm ExplicitMemory],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
Body ExplicitMemory)
opt_loop [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val Body ExplicitMemory
body'
[Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ [Stm ExplicitMemory]
bnds [Stm ExplicitMemory]
-> [Stm ExplicitMemory] -> [Stm ExplicitMemory]
forall a. [a] -> [a] -> [a]
++ [Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> ExpT ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux (ExpT ExplicitMemory -> Stm ExplicitMemory)
-> ExpT ExplicitMemory -> Stm ExplicitMemory
forall a b. (a -> b) -> a -> b
$ [(FParam ExplicitMemory, SubExp)]
-> [(FParam ExplicitMemory, SubExp)]
-> LoopForm ExplicitMemory
-> Body ExplicitMemory
-> ExpT ExplicitMemory
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx' [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val' LoopForm ExplicitMemory
form Body ExplicitMemory
body'']
optimiseStm (Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux ExpT ExplicitMemory
e) =
Stm ExplicitMemory -> [Stm ExplicitMemory]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm ExplicitMemory -> [Stm ExplicitMemory])
-> (ExpT ExplicitMemory -> Stm ExplicitMemory)
-> ExpT ExplicitMemory
-> [Stm ExplicitMemory]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> ExpT ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux (ExpT ExplicitMemory -> [Stm ExplicitMemory])
-> DoubleBufferM (ExpT ExplicitMemory)
-> DoubleBufferM [Stm ExplicitMemory]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper ExplicitMemory ExplicitMemory DoubleBufferM
-> ExpT ExplicitMemory -> DoubleBufferM (ExpT ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper ExplicitMemory ExplicitMemory DoubleBufferM
optimise ExpT ExplicitMemory
e
where optimise :: Mapper ExplicitMemory ExplicitMemory DoubleBufferM
optimise = Mapper ExplicitMemory ExplicitMemory DoubleBufferM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope ExplicitMemory
-> Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
mapOnBody = \Scope ExplicitMemory
_ Body ExplicitMemory
x ->
Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody Body ExplicitMemory
x :: DoubleBufferM (Body ExplicitMemory)
, mapOnOp :: Op ExplicitMemory -> DoubleBufferM (Op ExplicitMemory)
mapOnOp = Op ExplicitMemory -> DoubleBufferM (Op ExplicitMemory)
optimiseOp
}
optimiseOp :: Op ExplicitMemory
-> DoubleBufferM (Op ExplicitMemory)
optimiseOp :: Op ExplicitMemory -> DoubleBufferM (Op ExplicitMemory)
optimiseOp (Inner (SegOp op)) =
(Env -> Env)
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env -> Env
inSegOp (DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ())))
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall a b. (a -> b) -> a -> b
$ HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory
-> MemOp (HostOp ExplicitMemory ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> MemOp (HostOp ExplicitMemory ()))
-> DoubleBufferM (SegOp ExplicitMemory)
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper ExplicitMemory ExplicitMemory DoubleBufferM
-> SegOp ExplicitMemory -> DoubleBufferM (SegOp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper ExplicitMemory ExplicitMemory DoubleBufferM
mapper SegOp ExplicitMemory
op
where mapper :: SegOpMapper ExplicitMemory ExplicitMemory DoubleBufferM
mapper = SegOpMapper Any Any DoubleBufferM
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
mapOnSegOpLambda = Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
optimiseLambda
, mapOnSegOpBody :: KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
mapOnSegOpBody = KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
optimiseKernelBody
}
inSegOp :: Env -> Env
inSegOp Env
env = Env
env { envOptimiseLoop :: OptimiseLoop
envOptimiseLoop = OptimiseLoop
optimiseLoop }
optimiseOp Op ExplicitMemory
op = MemOp (HostOp ExplicitMemory ())
-> DoubleBufferM (MemOp (HostOp ExplicitMemory ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Op ExplicitMemory
MemOp (HostOp ExplicitMemory ())
op
optimiseKernelBody :: KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
optimiseKernelBody :: KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
optimiseKernelBody KernelBody ExplicitMemory
kbody = do
[Stm ExplicitMemory]
stms' <- [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
optimiseStms ([Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory])
-> [Stm ExplicitMemory] -> DoubleBufferM [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms ExplicitMemory -> [Stm ExplicitMemory])
-> Stms ExplicitMemory -> [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody
KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory))
-> KernelBody ExplicitMemory
-> DoubleBufferM (KernelBody ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory
kbody { kernelBodyStms :: Stms ExplicitMemory
kernelBodyStms = [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm ExplicitMemory]
stms' }
optimiseLambda :: Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
optimiseLambda :: Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
optimiseLambda Lambda ExplicitMemory
lam = do
Body ExplicitMemory
body <- Scope ExplicitMemory
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope ExplicitMemory -> Scope ExplicitMemory
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope ExplicitMemory -> Scope ExplicitMemory)
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda ExplicitMemory
lam) (DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory))
-> DoubleBufferM (Body ExplicitMemory)
-> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
optimiseBody (Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory))
-> Body ExplicitMemory -> DoubleBufferM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
Lambda ExplicitMemory -> DoubleBufferM (Lambda ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda ExplicitMemory
lam { lambdaBody :: Body ExplicitMemory
lambdaBody = Body ExplicitMemory
body }
type OptimiseLoop =
[(FParam ExplicitMemory, SubExp)] -> [(FParam ExplicitMemory, SubExp)] -> Body ExplicitMemory
-> DoubleBufferM ([Stm ExplicitMemory],
[(FParam ExplicitMemory, SubExp)],
[(FParam ExplicitMemory, SubExp)],
Body ExplicitMemory)
optimiseLoop :: OptimiseLoop
optimiseLoop :: OptimiseLoop
optimiseLoop [(FParam ExplicitMemory, SubExp)]
ctx [(FParam ExplicitMemory, SubExp)]
val Body ExplicitMemory
body = do
[DoubleBuffer]
buffered <- [(FParam ExplicitMemory, SubExp)]
-> [FParam ExplicitMemory] -> Names -> DoubleBufferM [DoubleBuffer]
forall (m :: * -> *).
MonadFreshNames m =>
[(FParam ExplicitMemory, SubExp)]
-> [FParam ExplicitMemory] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams
([Param (MemInfo SubExp Uniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx) (Body ExplicitMemory -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body ExplicitMemory
body)) (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge)
(Body ExplicitMemory -> Names
forall lore. Body lore -> Names
boundInBody Body ExplicitMemory
body)
([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge', [Stm ExplicitMemory]
allocs) <- [(FParam ExplicitMemory, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM
([(FParam ExplicitMemory, SubExp)], [Stm ExplicitMemory])
allocStms [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge [DoubleBuffer]
buffered
let body' :: Body ExplicitMemory
body' = [FParam ExplicitMemory]
-> [DoubleBuffer] -> Body ExplicitMemory -> Body ExplicitMemory
doubleBufferResult (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge) [DoubleBuffer]
buffered Body ExplicitMemory
body
([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx', [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val') = Int
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx) [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge'
([Stm ExplicitMemory],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
Body ExplicitMemory)
-> DoubleBufferM
([Stm ExplicitMemory],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
Body ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm ExplicitMemory]
allocs, [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx', [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val', Body ExplicitMemory
body')
where merge :: [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge = [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val
data DoubleBuffer = BufferAlloc VName (PrimExp VName) Space Bool
| BufferCopy VName IxFun VName Bool
| NoBuffer
deriving (Int -> DoubleBuffer -> ShowS
[DoubleBuffer] -> ShowS
DoubleBuffer -> String
(Int -> DoubleBuffer -> ShowS)
-> (DoubleBuffer -> String)
-> ([DoubleBuffer] -> ShowS)
-> Show DoubleBuffer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DoubleBuffer] -> ShowS
$cshowList :: [DoubleBuffer] -> ShowS
show :: DoubleBuffer -> String
$cshow :: DoubleBuffer -> String
showsPrec :: Int -> DoubleBuffer -> ShowS
$cshowsPrec :: Int -> DoubleBuffer -> ShowS
Show)
doubleBufferMergeParams :: MonadFreshNames m =>
[(FParam ExplicitMemory, SubExp)]
-> [FParam ExplicitMemory] -> Names
-> m [DoubleBuffer]
doubleBufferMergeParams :: [(FParam ExplicitMemory, SubExp)]
-> [FParam ExplicitMemory] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams [(FParam ExplicitMemory, SubExp)]
ctx_and_res [FParam ExplicitMemory]
val_params Names
bound_in_loop =
StateT (Map VName (VName, Bool)) m [DoubleBuffer]
-> Map VName (VName, Bool) -> m [DoubleBuffer]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((Param (MemInfo SubExp Uniqueness MemBind)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> StateT (Map VName (VName, Bool)) m [DoubleBuffer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
val_params) Map VName (VName, Bool)
forall k a. Map k a
M.empty
where loopVariant :: VName -> Bool
loopVariant VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_loop Bool -> Bool -> Bool
||
VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> VName)
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName (Param (MemInfo SubExp Uniqueness MemBind) -> VName)
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind))
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx_and_res
loopInvariantSize :: SubExp -> Maybe (SubExp, Bool)
loopInvariantSize (Constant PrimValue
v) =
(SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
v, Bool
True)
loopInvariantSize (Var VName
v) =
case ((Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> Bool)
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> Maybe (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool)
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> VName)
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName (Param (MemInfo SubExp Uniqueness MemBind) -> VName)
-> ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind))
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx_and_res of
Just (Param (MemInfo SubExp Uniqueness MemBind)
_, Constant PrimValue
val) ->
(SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
val, Bool
False)
Just (Param (MemInfo SubExp Uniqueness MemBind)
_, Var VName
v') | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
loopVariant VName
v' ->
(SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v', Bool
False)
Just (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
_ ->
Maybe (SubExp, Bool)
forall a. Maybe a
Nothing
Maybe (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
Nothing ->
(SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v, Bool
True)
sizeForMem :: VName -> Maybe (PrimExp VName, Bool)
sizeForMem VName
mem = [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a. [a] -> Maybe a
maybeHead ([(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool))
-> [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp Uniqueness MemBind)
-> Maybe (PrimExp VName, Bool))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [(PrimExp VName, Bool)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (MemInfo SubExp Uniqueness MemBind -> Maybe (PrimExp VName, Bool)
arrayInMem (MemInfo SubExp Uniqueness MemBind -> Maybe (PrimExp VName, Bool))
-> (Param (MemInfo SubExp Uniqueness MemBind)
-> MemInfo SubExp Uniqueness MemBind)
-> Param (MemInfo SubExp Uniqueness MemBind)
-> Maybe (PrimExp VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp Uniqueness MemBind)
-> MemInfo SubExp Uniqueness MemBind
forall attr. Param attr -> attr
paramAttr) [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
val_params
where arrayInMem :: MemInfo SubExp Uniqueness MemBind -> Maybe (PrimExp VName, Bool)
arrayInMem (MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
arraymem IxFun
ixfun))
| IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun,
Just ([SubExp]
dims, [Bool]
b) <-
(SubExp -> Maybe (SubExp, Bool))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM SubExp -> Maybe (SubExp, Bool)
loopInvariantSize ([SubExp] -> Maybe ([SubExp], [Bool]))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape,
VName
mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arraymem =
(PrimExp VName, Bool) -> Maybe (PrimExp VName, Bool)
forall a. a -> Maybe a
Just (Type -> PrimExp VName
arraySizeInBytesExp (Type -> PrimExp VName) -> Type -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness,
[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
b)
arrayInMem MemInfo SubExp Uniqueness MemBind
_ = Maybe (PrimExp VName, Bool)
forall a. Maybe a
Nothing
buffer :: Param (MemInfo SubExp Uniqueness MemBind)
-> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer Param (MemInfo SubExp Uniqueness MemBind)
fparam = case Param (MemInfo SubExp Uniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp Uniqueness MemBind)
fparam of
Mem Space
space
| Just (PrimExp VName
size, Bool
b) <- VName -> Maybe (PrimExp VName, Bool)
sizeForMem (VName -> Maybe (PrimExp VName, Bool))
-> VName -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp Uniqueness MemBind)
fparam -> do
VName
bufname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_mem"
(Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ())
-> (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall a b. (a -> b) -> a -> b
$ VName
-> (VName, Bool)
-> Map VName (VName, Bool)
-> Map VName (VName, Bool)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp Uniqueness MemBind)
fparam) (VName
bufname, Bool
b)
DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Space -> Bool -> DoubleBuffer
BufferAlloc VName
bufname PrimExp VName
size Space
space Bool
b
Array {}
| MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) <- Param (MemInfo SubExp Uniqueness MemBind)
-> MemInfo SubExp Uniqueness MemBind
forall attr. Param attr -> attr
paramAttr Param (MemInfo SubExp Uniqueness MemBind)
fparam -> do
Maybe (VName, Bool)
buffered <- (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool)))
-> (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VName, Bool) -> Maybe (VName, Bool)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem
case Maybe (VName, Bool)
buffered of
Just (VName
bufname, Bool
b) -> do
VName
copyname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_array"
DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> VName -> Bool -> DoubleBuffer
BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
b
Maybe (VName, Bool)
Nothing ->
DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer
Type
_ -> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer
allocStms :: [(FParam ExplicitMemory, SubExp)] -> [DoubleBuffer]
-> DoubleBufferM ([(FParam ExplicitMemory, SubExp)], [Stm ExplicitMemory])
allocStms :: [(FParam ExplicitMemory, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM
([(FParam ExplicitMemory, SubExp)], [Stm ExplicitMemory])
allocStms [(FParam ExplicitMemory, SubExp)]
merge = WriterT
[Stm ExplicitMemory]
DoubleBufferM
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> DoubleBufferM
([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[Stm ExplicitMemory])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
[Stm ExplicitMemory]
DoubleBufferM
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> DoubleBufferM
([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[Stm ExplicitMemory]))
-> ([DoubleBuffer]
-> WriterT
[Stm ExplicitMemory]
DoubleBufferM
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)])
-> [DoubleBuffer]
-> DoubleBufferM
([(Param (MemInfo SubExp Uniqueness MemBind), SubExp)],
[Stm ExplicitMemory])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> DoubleBuffer
-> WriterT
[Stm ExplicitMemory]
DoubleBufferM
(Param (MemInfo SubExp Uniqueness MemBind), SubExp))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [DoubleBuffer]
-> WriterT
[Stm ExplicitMemory]
DoubleBufferM
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> DoubleBuffer
-> WriterT
[Stm ExplicitMemory]
DoubleBufferM
(Param (MemInfo SubExp Uniqueness MemBind), SubExp)
forall lore (m :: * -> *) (t :: (* -> *) -> * -> *) d ret lore
inner.
(HasScope lore m, OpReturns lore, Checkable lore,
MonadFreshNames m, MonadTrans t, Typed (MemInfo d Uniqueness ret),
MonadWriter [Stm lore] (t m), BinderOps lore,
FParamAttr lore ~ MemInfo SubExp Uniqueness MemBind,
FParamAttr lore ~ FParamAttr lore, ExpAttr lore ~ (),
LParamAttr lore ~ MemInfo SubExp NoUniqueness MemBind,
LParamAttr lore ~ LParamAttr lore, Op lore ~ MemOp inner,
BranchType lore ~ BodyReturns, RetType lore ~ FunReturns,
LetAttr lore ~ MemInfo SubExp NoUniqueness MemBind,
LetAttr lore ~ LetAttr lore,
LetAttr lore ~ MemInfo SubExp NoUniqueness MemBind) =>
(Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation [(FParam ExplicitMemory, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
where allocation :: (Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation m :: (Param (MemInfo d Uniqueness ret), SubExp)
m@(Param VName
pname MemInfo d Uniqueness ret
_, SubExp
_) (BufferAlloc VName
name PrimExp VName
size Space
space Bool
b) = do
Stms lore
stms <- m (Stms lore) -> t m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> t m (Stms lore))
-> m (Stms lore) -> t m (Stms lore)
forall a b. (a -> b) -> a -> b
$ Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
SubExp
size' <- String
-> Exp (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"double_buffer_size" (ExpT lore -> BinderT lore (State VNameSource) SubExp)
-> BinderT lore (State VNameSource) (ExpT lore)
-> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName
-> BinderT
lore
(State VNameSource)
(Exp (Lore (BinderT lore (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
size
[VName]
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
[Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Stm lore] -> t m ()) -> [Stm lore] -> t m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms
if Bool
b then (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall attr. VName -> attr -> Param attr
Param VName
pname (MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret))
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d Uniqueness ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, VName -> SubExp
Var VName
name)
else (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret), SubExp)
m
allocation (Param (MemInfo d Uniqueness ret)
f, Var VName
v) (BufferCopy VName
mem IxFun
_ VName
_ Bool
b) | Bool
b = do
VName
v_copy <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_double_buffer_copy"
(VName
_v_mem, IxFun
v_ixfun) <- m (VName, IxFun) -> t m (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun) -> t m (VName, IxFun))
-> m (VName, IxFun) -> t m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun)
forall lore (m :: * -> *).
(ExplicitMemorish lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
let bt :: PrimType
bt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo d Uniqueness ret)
f
shape :: ShapeBase SubExp
shape = Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeBase SubExp) -> Type -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo d Uniqueness ret)
f
bound :: MemInfo SubExp NoUniqueness MemBind
bound = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
v_ixfun
[Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Pattern lore -> StmAux (ExpAttr lore) -> ExpT lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
v_copy MemInfo SubExp NoUniqueness MemBind
bound]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (ExpT lore -> Stm lore) -> ExpT lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
v]
let uniqueMemInfo :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemArray PrimType
pt ShapeBase d
pshape Uniqueness
_ ret
ret) =
PrimType
-> ShapeBase d -> Uniqueness -> ret -> MemInfo d Uniqueness ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pshape Uniqueness
Unique ret
ret
uniqueMemInfo MemInfo d Uniqueness ret
info = MemInfo d Uniqueness ret
info
(Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
forall d ret. MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo d Uniqueness ret)
f, VName -> SubExp
Var VName
v_copy)
allocation (Param (MemInfo d Uniqueness ret)
f, SubExp
se) DoubleBuffer
_ =
(Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret)
f, SubExp
se)
doubleBufferResult :: [FParam ExplicitMemory] -> [DoubleBuffer]
-> Body ExplicitMemory -> Body ExplicitMemory
doubleBufferResult :: [FParam ExplicitMemory]
-> [DoubleBuffer] -> Body ExplicitMemory -> Body ExplicitMemory
doubleBufferResult [FParam ExplicitMemory]
valparams [DoubleBuffer]
buffered (Body () Stms ExplicitMemory
bnds [SubExp]
res) =
let ([SubExp]
ctx_res, [SubExp]
val_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param (MemInfo SubExp Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams) [SubExp]
res
([Maybe (Stm ExplicitMemory)]
copybnds,[SubExp]
val_res') =
[(Maybe (Stm ExplicitMemory), SubExp)]
-> ([Maybe (Stm ExplicitMemory)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Stm ExplicitMemory), SubExp)]
-> ([Maybe (Stm ExplicitMemory)], [SubExp]))
-> [(Maybe (Stm ExplicitMemory), SubExp)]
-> ([Maybe (Stm ExplicitMemory)], [SubExp])
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp Uniqueness MemBind)
-> DoubleBuffer -> SubExp -> (Maybe (Stm ExplicitMemory), SubExp))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [DoubleBuffer]
-> [SubExp]
-> [(Maybe (Stm ExplicitMemory), SubExp)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param (MemInfo SubExp Uniqueness MemBind)
-> DoubleBuffer -> SubExp -> (Maybe (Stm ExplicitMemory), SubExp)
buffer [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams [DoubleBuffer]
buffered [SubExp]
val_res
in BodyAttr ExplicitMemory
-> Stms ExplicitMemory -> [SubExp] -> Body ExplicitMemory
forall lore. BodyAttr lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms ExplicitMemory
bndsStms ExplicitMemory -> Stms ExplicitMemory -> Stms ExplicitMemory
forall a. Semigroup a => a -> a -> a
<>[Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Maybe (Stm ExplicitMemory)] -> [Stm ExplicitMemory]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm ExplicitMemory)]
copybnds)) ([SubExp] -> Body ExplicitMemory)
-> [SubExp] -> Body ExplicitMemory
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_res [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_res'
where buffer :: Param (MemInfo SubExp Uniqueness MemBind)
-> DoubleBuffer -> SubExp -> (Maybe (Stm ExplicitMemory), SubExp)
buffer Param (MemInfo SubExp Uniqueness MemBind)
_ (BufferAlloc VName
bufname PrimExp VName
_ Space
_ Bool
_) SubExp
_ =
(Maybe (Stm ExplicitMemory)
forall a. Maybe a
Nothing, VName -> SubExp
Var VName
bufname)
buffer Param (MemInfo SubExp Uniqueness MemBind)
fparam (BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
_) (Var VName
v) =
let t :: Type
t = Type -> Type
resultType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp Uniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp Uniqueness MemBind)
fparam
summary :: MemInfo SubExp NoUniqueness MemBind
summary = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
bufname IxFun
ixfun
copybnd :: Stm ExplicitMemory
copybnd = Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> ExpT ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
copyname MemInfo SubExp NoUniqueness MemBind
summary]) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (ExpT ExplicitMemory -> Stm ExplicitMemory)
-> ExpT ExplicitMemory -> Stm ExplicitMemory
forall a b. (a -> b) -> a -> b
$
BasicOp ExplicitMemory -> ExpT ExplicitMemory
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp ExplicitMemory -> ExpT ExplicitMemory)
-> BasicOp ExplicitMemory -> ExpT ExplicitMemory
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp ExplicitMemory
forall lore. VName -> BasicOp lore
Copy VName
v
in (Stm ExplicitMemory -> Maybe (Stm ExplicitMemory)
forall a. a -> Maybe a
Just Stm ExplicitMemory
copybnd, VName -> SubExp
Var VName
copyname)
buffer Param (MemInfo SubExp Uniqueness MemBind)
_ DoubleBuffer
_ SubExp
se =
(Maybe (Stm ExplicitMemory)
forall a. Maybe a
Nothing, SubExp
se)
parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (MemInfo SubExp Uniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp Uniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName [FParam ExplicitMemory]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams) [SubExp]
res
resultType :: Type -> Type
resultType Type
t = Type
t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substitute (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t)
substitute :: SubExp -> SubExp
substitute (Var VName
v)
| Just SubExp
replacement <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
replacement
substitute SubExp
se =
SubExp
se