{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | The simplification engine is only willing to hoist allocations
-- out of loops if the memory block resulting from the allocation is
-- dead at the end of the loop.  If it is not, we may cause data
-- hazards.
--
-- This module rewrites loops with memory block merge parameters such
-- that each memory block is copied at the end of the iteration, thus
-- ensuring that any allocation inside the loop is dead at the end of
-- the loop.  This is only possible for allocations whose size is
-- loop-invariant, although the initial size may differ from the size
-- produced by the loop result.
--
-- Additionally, inside parallel kernels we also copy the initial
-- value.  This has the effect of making the memory block returned by
-- the array non-existential, which is important for later memory
-- expansion to work.
module Futhark.Optimise.DoubleBuffer (doubleBufferKernels, doubleBufferMC) where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (find)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.KernelsMem as Kernels
import Futhark.IR.MCMem as MC
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (arraySizeInBytesExp)
import Futhark.Pass.ExplicitAllocations.Kernels ()
import Futhark.Util (maybeHead)

-- | The pass for GPU kernels.
doubleBufferKernels :: Pass KernelsMem KernelsMem
doubleBufferKernels :: Pass KernelsMem KernelsMem
doubleBufferKernels = OptimiseOp KernelsMem -> Pass KernelsMem KernelsMem
forall lore. Mem lore => OptimiseOp lore -> Pass lore lore
doubleBuffer OptimiseOp KernelsMem
optimiseKernelsOp

-- | The pass for multicore
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC = OptimiseOp MCMem -> Pass MCMem MCMem
forall lore. Mem lore => OptimiseOp lore -> Pass lore lore
doubleBuffer OptimiseOp MCMem
optimiseMCOp

-- | The double buffering pass definition.
doubleBuffer :: Mem lore => OptimiseOp lore -> Pass lore lore
doubleBuffer :: OptimiseOp lore -> Pass lore lore
doubleBuffer OptimiseOp lore
onOp =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
    { passName :: String
passName = String
"Double buffer",
      passDescription :: String
passDescription = String
"Perform double buffering for merge parameters of sequential loops.",
      passFunction :: Prog lore -> PassM (Prog lore)
passFunction = (Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope lore -> Stms lore -> PassM (Stms lore)
optimise
    }
  where
    optimise :: Scope lore -> Stms lore -> PassM (Stms lore)
optimise Scope lore
scope Stms lore
stms = (VNameSource -> (Stms lore, VNameSource)) -> PassM (Stms lore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms lore, VNameSource)) -> PassM (Stms lore))
-> (VNameSource -> (Stms lore, VNameSource)) -> PassM (Stms lore)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
      let m :: ReaderT (Env lore) (State VNameSource) (Stms lore)
m =
            DoubleBufferM lore (Stms lore)
-> ReaderT (Env lore) (State VNameSource) (Stms lore)
forall lore a.
DoubleBufferM lore a -> ReaderT (Env lore) (State VNameSource) a
runDoubleBufferM (DoubleBufferM lore (Stms lore)
 -> ReaderT (Env lore) (State VNameSource) (Stms lore))
-> DoubleBufferM lore (Stms lore)
-> ReaderT (Env lore) (State VNameSource) (Stms lore)
forall a b. (a -> b) -> a -> b
$
              Scope lore
-> DoubleBufferM lore (Stms lore) -> DoubleBufferM lore (Stms lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope lore
scope (DoubleBufferM lore (Stms lore) -> DoubleBufferM lore (Stms lore))
-> DoubleBufferM lore (Stms lore) -> DoubleBufferM lore (Stms lore)
forall a b. (a -> b) -> a -> b
$
                ([Stm lore] -> Stms lore)
-> DoubleBufferM lore [Stm lore] -> DoubleBufferM lore (Stms lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList (DoubleBufferM lore [Stm lore] -> DoubleBufferM lore (Stms lore))
-> DoubleBufferM lore [Stm lore] -> DoubleBufferM lore (Stms lore)
forall a b. (a -> b) -> a -> b
$ [Stm lore] -> DoubleBufferM lore [Stm lore]
forall lore.
ASTLore lore =>
[Stm lore] -> DoubleBufferM lore [Stm lore]
optimiseStms ([Stm lore] -> DoubleBufferM lore [Stm lore])
-> [Stm lore] -> DoubleBufferM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms
       in State VNameSource (Stms lore)
-> VNameSource -> (Stms lore, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore) (State VNameSource) (Stms lore)
-> Env lore -> State VNameSource (Stms lore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore) (State VNameSource) (Stms lore)
m Env lore
env) VNameSource
src

    env :: Env lore
env = Scope lore -> OptimiseLoop lore -> OptimiseOp lore -> Env lore
forall lore.
Scope lore -> OptimiseLoop lore -> OptimiseOp lore -> Env lore
Env Scope lore
forall a. Monoid a => a
mempty OptimiseLoop lore
forall (m :: * -> *) a b c d.
(Monad m, Monoid a) =>
b -> c -> d -> m (a, b, c, d)
doNotTouchLoop OptimiseOp lore
onOp
    doNotTouchLoop :: b -> c -> d -> m (a, b, c, d)
doNotTouchLoop b
ctx c
val d
body = (a, b, c, d) -> m (a, b, c, d)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
forall a. Monoid a => a
mempty, b
ctx, c
val, d
body)

type OptimiseLoop lore =
  [(FParam lore, SubExp)] ->
  [(FParam lore, SubExp)] ->
  Body lore ->
  DoubleBufferM
    lore
    ( [Stm lore],
      [(FParam lore, SubExp)],
      [(FParam lore, SubExp)],
      Body lore
    )

type OptimiseOp lore =
  Op lore -> DoubleBufferM lore (Op lore)

data Env lore = Env
  { Env lore -> Scope lore
envScope :: Scope lore,
    Env lore -> OptimiseLoop lore
envOptimiseLoop :: OptimiseLoop lore,
    Env lore -> OptimiseOp lore
envOptimiseOp :: OptimiseOp lore
  }

newtype DoubleBufferM lore a = DoubleBufferM
  { DoubleBufferM lore a -> ReaderT (Env lore) (State VNameSource) a
runDoubleBufferM :: ReaderT (Env lore) (State VNameSource) a
  }
  deriving (a -> DoubleBufferM lore b -> DoubleBufferM lore a
(a -> b) -> DoubleBufferM lore a -> DoubleBufferM lore b
(forall a b.
 (a -> b) -> DoubleBufferM lore a -> DoubleBufferM lore b)
-> (forall a b. a -> DoubleBufferM lore b -> DoubleBufferM lore a)
-> Functor (DoubleBufferM lore)
forall a b. a -> DoubleBufferM lore b -> DoubleBufferM lore a
forall a b.
(a -> b) -> DoubleBufferM lore a -> DoubleBufferM lore b
forall lore a b. a -> DoubleBufferM lore b -> DoubleBufferM lore a
forall lore a b.
(a -> b) -> DoubleBufferM lore a -> DoubleBufferM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DoubleBufferM lore b -> DoubleBufferM lore a
$c<$ :: forall lore a b. a -> DoubleBufferM lore b -> DoubleBufferM lore a
fmap :: (a -> b) -> DoubleBufferM lore a -> DoubleBufferM lore b
$cfmap :: forall lore a b.
(a -> b) -> DoubleBufferM lore a -> DoubleBufferM lore b
Functor, Functor (DoubleBufferM lore)
a -> DoubleBufferM lore a
Functor (DoubleBufferM lore)
-> (forall a. a -> DoubleBufferM lore a)
-> (forall a b.
    DoubleBufferM lore (a -> b)
    -> DoubleBufferM lore a -> DoubleBufferM lore b)
-> (forall a b c.
    (a -> b -> c)
    -> DoubleBufferM lore a
    -> DoubleBufferM lore b
    -> DoubleBufferM lore c)
-> (forall a b.
    DoubleBufferM lore a
    -> DoubleBufferM lore b -> DoubleBufferM lore b)
-> (forall a b.
    DoubleBufferM lore a
    -> DoubleBufferM lore b -> DoubleBufferM lore a)
-> Applicative (DoubleBufferM lore)
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore a
DoubleBufferM lore (a -> b)
-> DoubleBufferM lore a -> DoubleBufferM lore b
(a -> b -> c)
-> DoubleBufferM lore a
-> DoubleBufferM lore b
-> DoubleBufferM lore c
forall lore. Functor (DoubleBufferM lore)
forall a. a -> DoubleBufferM lore a
forall lore a. a -> DoubleBufferM lore a
forall a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore a
forall a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
forall a b.
DoubleBufferM lore (a -> b)
-> DoubleBufferM lore a -> DoubleBufferM lore b
forall lore a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore a
forall lore a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
forall lore a b.
DoubleBufferM lore (a -> b)
-> DoubleBufferM lore a -> DoubleBufferM lore b
forall a b c.
(a -> b -> c)
-> DoubleBufferM lore a
-> DoubleBufferM lore b
-> DoubleBufferM lore c
forall lore a b c.
(a -> b -> c)
-> DoubleBufferM lore a
-> DoubleBufferM lore b
-> DoubleBufferM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore a
$c<* :: forall lore a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore a
*> :: DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
$c*> :: forall lore a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
liftA2 :: (a -> b -> c)
-> DoubleBufferM lore a
-> DoubleBufferM lore b
-> DoubleBufferM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c)
-> DoubleBufferM lore a
-> DoubleBufferM lore b
-> DoubleBufferM lore c
<*> :: DoubleBufferM lore (a -> b)
-> DoubleBufferM lore a -> DoubleBufferM lore b
$c<*> :: forall lore a b.
DoubleBufferM lore (a -> b)
-> DoubleBufferM lore a -> DoubleBufferM lore b
pure :: a -> DoubleBufferM lore a
$cpure :: forall lore a. a -> DoubleBufferM lore a
$cp1Applicative :: forall lore. Functor (DoubleBufferM lore)
Applicative, Applicative (DoubleBufferM lore)
a -> DoubleBufferM lore a
Applicative (DoubleBufferM lore)
-> (forall a b.
    DoubleBufferM lore a
    -> (a -> DoubleBufferM lore b) -> DoubleBufferM lore b)
-> (forall a b.
    DoubleBufferM lore a
    -> DoubleBufferM lore b -> DoubleBufferM lore b)
-> (forall a. a -> DoubleBufferM lore a)
-> Monad (DoubleBufferM lore)
DoubleBufferM lore a
-> (a -> DoubleBufferM lore b) -> DoubleBufferM lore b
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
forall lore. Applicative (DoubleBufferM lore)
forall a. a -> DoubleBufferM lore a
forall lore a. a -> DoubleBufferM lore a
forall a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
forall a b.
DoubleBufferM lore a
-> (a -> DoubleBufferM lore b) -> DoubleBufferM lore b
forall lore a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
forall lore a b.
DoubleBufferM lore a
-> (a -> DoubleBufferM lore b) -> DoubleBufferM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DoubleBufferM lore a
$creturn :: forall lore a. a -> DoubleBufferM lore a
>> :: DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
$c>> :: forall lore a b.
DoubleBufferM lore a
-> DoubleBufferM lore b -> DoubleBufferM lore b
>>= :: DoubleBufferM lore a
-> (a -> DoubleBufferM lore b) -> DoubleBufferM lore b
$c>>= :: forall lore a b.
DoubleBufferM lore a
-> (a -> DoubleBufferM lore b) -> DoubleBufferM lore b
$cp1Monad :: forall lore. Applicative (DoubleBufferM lore)
Monad, MonadReader (Env lore), Monad (DoubleBufferM lore)
Applicative (DoubleBufferM lore)
DoubleBufferM lore VNameSource
Applicative (DoubleBufferM lore)
-> Monad (DoubleBufferM lore)
-> DoubleBufferM lore VNameSource
-> (VNameSource -> DoubleBufferM lore ())
-> MonadFreshNames (DoubleBufferM lore)
VNameSource -> DoubleBufferM lore ()
forall lore. Monad (DoubleBufferM lore)
forall lore. Applicative (DoubleBufferM lore)
forall lore. DoubleBufferM lore VNameSource
forall lore. VNameSource -> DoubleBufferM lore ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> DoubleBufferM lore ()
$cputNameSource :: forall lore. VNameSource -> DoubleBufferM lore ()
getNameSource :: DoubleBufferM lore VNameSource
$cgetNameSource :: forall lore. DoubleBufferM lore VNameSource
$cp2MonadFreshNames :: forall lore. Monad (DoubleBufferM lore)
$cp1MonadFreshNames :: forall lore. Applicative (DoubleBufferM lore)
MonadFreshNames)

instance ASTLore lore => HasScope lore (DoubleBufferM lore) where
  askScope :: DoubleBufferM lore (Scope lore)
askScope = (Env lore -> Scope lore) -> DoubleBufferM lore (Scope lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore -> Scope lore
forall lore. Env lore -> Scope lore
envScope

instance ASTLore lore => LocalScope lore (DoubleBufferM lore) where
  localScope :: Scope lore -> DoubleBufferM lore a -> DoubleBufferM lore a
localScope Scope lore
scope = (Env lore -> Env lore)
-> DoubleBufferM lore a -> DoubleBufferM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore -> Env lore)
 -> DoubleBufferM lore a -> DoubleBufferM lore a)
-> (Env lore -> Env lore)
-> DoubleBufferM lore a
-> DoubleBufferM lore a
forall a b. (a -> b) -> a -> b
$ \Env lore
env -> Env lore
env {envScope :: Scope lore
envScope = Env lore -> Scope lore
forall lore. Env lore -> Scope lore
envScope Env lore
env Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> Scope lore
scope}

optimiseBody :: ASTLore lore => Body lore -> DoubleBufferM lore (Body lore)
optimiseBody :: Body lore -> DoubleBufferM lore (Body lore)
optimiseBody Body lore
body = do
  [Stm lore]
bnds' <- [Stm lore] -> DoubleBufferM lore [Stm lore]
forall lore.
ASTLore lore =>
[Stm lore] -> DoubleBufferM lore [Stm lore]
optimiseStms ([Stm lore] -> DoubleBufferM lore [Stm lore])
-> [Stm lore] -> DoubleBufferM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
body
  Body lore -> DoubleBufferM lore (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> DoubleBufferM lore (Body lore))
-> Body lore -> DoubleBufferM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Body lore
body {bodyStms :: Stms lore
bodyStms = [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
bnds'}

optimiseStms :: ASTLore lore => [Stm lore] -> DoubleBufferM lore [Stm lore]
optimiseStms :: [Stm lore] -> DoubleBufferM lore [Stm lore]
optimiseStms [] = [Stm lore] -> DoubleBufferM lore [Stm lore]
forall (m :: * -> *) a. Monad m => a -> m a
return []
optimiseStms (Stm lore
e : [Stm lore]
es) = do
  [Stm lore]
e_es <- Stm lore -> DoubleBufferM lore [Stm lore]
forall lore.
ASTLore lore =>
Stm lore -> DoubleBufferM lore [Stm lore]
optimiseStm Stm lore
e
  [Stm lore]
es' <- Scope lore
-> DoubleBufferM lore [Stm lore] -> DoubleBufferM lore [Stm lore]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope lore -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope lore -> Scope lore) -> Scope lore -> Scope lore
forall a b. (a -> b) -> a -> b
$ [Stm lore] -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf [Stm lore]
e_es) (DoubleBufferM lore [Stm lore] -> DoubleBufferM lore [Stm lore])
-> DoubleBufferM lore [Stm lore] -> DoubleBufferM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ [Stm lore] -> DoubleBufferM lore [Stm lore]
forall lore.
ASTLore lore =>
[Stm lore] -> DoubleBufferM lore [Stm lore]
optimiseStms [Stm lore]
es
  [Stm lore] -> DoubleBufferM lore [Stm lore]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore] -> DoubleBufferM lore [Stm lore])
-> [Stm lore] -> DoubleBufferM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ [Stm lore]
e_es [Stm lore] -> [Stm lore] -> [Stm lore]
forall a. [a] -> [a] -> [a]
++ [Stm lore]
es'

optimiseStm :: forall lore. ASTLore lore => Stm lore -> DoubleBufferM lore [Stm lore]
optimiseStm :: Stm lore -> DoubleBufferM lore [Stm lore]
optimiseStm (Let Pattern lore
pat StmAux (ExpDec lore)
aux (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
body)) = do
  BodyT lore
body' <-
    Scope lore
-> DoubleBufferM lore (BodyT lore)
-> DoubleBufferM lore (BodyT lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form Scope lore -> Scope lore -> Scope lore
forall a. Semigroup a => a -> a -> a
<> [FParam lore] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val)) (DoubleBufferM lore (BodyT lore)
 -> DoubleBufferM lore (BodyT lore))
-> DoubleBufferM lore (BodyT lore)
-> DoubleBufferM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$
      BodyT lore -> DoubleBufferM lore (BodyT lore)
forall lore.
ASTLore lore =>
Body lore -> DoubleBufferM lore (Body lore)
optimiseBody BodyT lore
body
  [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> BodyT lore
-> DoubleBufferM
     lore
     ([Stm lore], [(FParam lore, SubExp)], [(FParam lore, SubExp)],
      BodyT lore)
opt_loop <- (Env lore
 -> [(FParam lore, SubExp)]
 -> [(FParam lore, SubExp)]
 -> BodyT lore
 -> DoubleBufferM
      lore
      ([Stm lore], [(FParam lore, SubExp)], [(FParam lore, SubExp)],
       BodyT lore))
-> DoubleBufferM
     lore
     ([(FParam lore, SubExp)]
      -> [(FParam lore, SubExp)]
      -> BodyT lore
      -> DoubleBufferM
           lore
           ([Stm lore], [(FParam lore, SubExp)], [(FParam lore, SubExp)],
            BodyT lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore
-> [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> BodyT lore
-> DoubleBufferM
     lore
     ([Stm lore], [(FParam lore, SubExp)], [(FParam lore, SubExp)],
      BodyT lore)
forall lore. Env lore -> OptimiseLoop lore
envOptimiseLoop
  ([Stm lore]
bnds, [(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val', BodyT lore
body'') <- [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> BodyT lore
-> DoubleBufferM
     lore
     ([Stm lore], [(FParam lore, SubExp)], [(FParam lore, SubExp)],
      BodyT lore)
opt_loop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val BodyT lore
body'
  [Stm lore] -> DoubleBufferM lore [Stm lore]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore] -> DoubleBufferM lore [Stm lore])
-> [Stm lore] -> DoubleBufferM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ [Stm lore]
bnds [Stm lore] -> [Stm lore] -> [Stm lore]
forall a. [a] -> [a] -> [a]
++ [Pattern lore -> StmAux (ExpDec lore) -> ExpT lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat StmAux (ExpDec lore)
aux (ExpT lore -> Stm lore) -> ExpT lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
body'']
optimiseStm (Let Pattern lore
pat StmAux (ExpDec lore)
aux ExpT lore
e) = do
  Op lore -> DoubleBufferM lore (Op lore)
onOp <- (Env lore -> Op lore -> DoubleBufferM lore (Op lore))
-> DoubleBufferM lore (Op lore -> DoubleBufferM lore (Op lore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore -> Op lore -> DoubleBufferM lore (Op lore)
forall lore. Env lore -> OptimiseOp lore
envOptimiseOp
  Stm lore -> [Stm lore]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm lore -> [Stm lore])
-> (ExpT lore -> Stm lore) -> ExpT lore -> [Stm lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern lore -> StmAux (ExpDec lore) -> ExpT lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern lore
pat StmAux (ExpDec lore)
aux (ExpT lore -> [Stm lore])
-> DoubleBufferM lore (ExpT lore) -> DoubleBufferM lore [Stm lore]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper lore lore (DoubleBufferM lore)
-> ExpT lore -> DoubleBufferM lore (ExpT lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ((Op lore -> DoubleBufferM lore (Op lore))
-> Mapper lore lore (DoubleBufferM lore)
optimise Op lore -> DoubleBufferM lore (Op lore)
onOp) ExpT lore
e
  where
    optimise :: (Op lore -> DoubleBufferM lore (Op lore))
-> Mapper lore lore (DoubleBufferM lore)
optimise Op lore -> DoubleBufferM lore (Op lore)
onOp =
      Mapper lore lore (DoubleBufferM lore)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope lore -> BodyT lore -> DoubleBufferM lore (BodyT lore)
mapOnBody = \Scope lore
_ BodyT lore
x ->
            BodyT lore -> DoubleBufferM lore (BodyT lore)
forall lore.
ASTLore lore =>
Body lore -> DoubleBufferM lore (Body lore)
optimiseBody BodyT lore
x :: DoubleBufferM lore (Body lore),
          mapOnOp :: Op lore -> DoubleBufferM lore (Op lore)
mapOnOp = Op lore -> DoubleBufferM lore (Op lore)
onOp
        }

optimiseKernelsOp :: OptimiseOp KernelsMem
optimiseKernelsOp :: OptimiseOp KernelsMem
optimiseKernelsOp (Inner (SegOp op)) =
  (Env KernelsMem -> Env KernelsMem)
-> DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ()))
-> DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env KernelsMem -> Env KernelsMem
forall lore inner.
(OpReturns lore, BinderOps lore, RetType lore ~ RetTypeMem,
 LetDec lore ~ LetDecMem, LParamInfo lore ~ LetDecMem,
 BranchType lore ~ BranchTypeMem, FParamInfo lore ~ FParamMem,
 ExpDec lore ~ (), Op lore ~ MemOp inner, BodyDec lore ~ ()) =>
Env lore -> Env lore
inSegOp (DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ()))
 -> DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ())))
-> DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ()))
-> DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ()))
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem
-> MemOp (HostOp KernelsMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> MemOp (HostOp KernelsMem ()))
-> DoubleBufferM KernelsMem (SegOp SegLevel KernelsMem)
-> DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  SegLevel KernelsMem KernelsMem (DoubleBufferM KernelsMem)
-> SegOp SegLevel KernelsMem
-> DoubleBufferM KernelsMem (SegOp SegLevel KernelsMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
  SegLevel KernelsMem KernelsMem (DoubleBufferM KernelsMem)
forall lvl.
SegOpMapper lvl KernelsMem KernelsMem (DoubleBufferM KernelsMem)
mapper SegOp SegLevel KernelsMem
op
  where
    mapper :: SegOpMapper lvl KernelsMem KernelsMem (DoubleBufferM KernelsMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM KernelsMem)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda KernelsMem -> DoubleBufferM KernelsMem (Lambda KernelsMem)
mapOnSegOpLambda = Lambda KernelsMem -> DoubleBufferM KernelsMem (Lambda KernelsMem)
forall lore.
ASTLore lore =>
Lambda lore -> DoubleBufferM lore (Lambda lore)
optimiseLambda,
          mapOnSegOpBody :: KernelBody KernelsMem
-> DoubleBufferM KernelsMem (KernelBody KernelsMem)
mapOnSegOpBody = KernelBody KernelsMem
-> DoubleBufferM KernelsMem (KernelBody KernelsMem)
forall lore.
ASTLore lore =>
KernelBody lore -> DoubleBufferM lore (KernelBody lore)
optimiseKernelBody
        }
    inSegOp :: Env lore -> Env lore
inSegOp Env lore
env = Env lore
env {envOptimiseLoop :: OptimiseLoop lore
envOptimiseLoop = OptimiseLoop lore
forall lore inner.
(Constraints lore, Op lore ~ MemOp inner, BinderOps lore) =>
OptimiseLoop lore
optimiseLoop}
optimiseKernelsOp Op KernelsMem
op = MemOp (HostOp KernelsMem ())
-> DoubleBufferM KernelsMem (MemOp (HostOp KernelsMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Op KernelsMem
MemOp (HostOp KernelsMem ())
op

optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp (Inner (ParOp par_op op)) =
  (Env MCMem -> Env MCMem)
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env MCMem -> Env MCMem
forall lore inner.
(OpReturns lore, BinderOps lore, RetType lore ~ RetTypeMem,
 LetDec lore ~ LetDecMem, LParamInfo lore ~ LetDecMem,
 BranchType lore ~ BranchTypeMem, FParamInfo lore ~ FParamMem,
 ExpDec lore ~ (), Op lore ~ MemOp inner, BodyDec lore ~ ()) =>
Env lore -> Env lore
inSegOp (DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
 -> DoubleBufferM MCMem (MemOp (MCOp MCMem ())))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall a b. (a -> b) -> a -> b
$
    MCOp MCMem () -> MemOp (MCOp MCMem ())
forall inner. inner -> MemOp inner
Inner
      (MCOp MCMem () -> MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MCOp MCMem ())
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp MCMem ()
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp MCMem ())
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
-> DoubleBufferM MCMem (SegOp () MCMem -> MCOp MCMem ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall lvl. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper) Maybe (SegOp () MCMem)
par_op DoubleBufferM MCMem (SegOp () MCMem -> MCOp MCMem ())
-> DoubleBufferM MCMem (SegOp () MCMem)
-> DoubleBufferM MCMem (MCOp MCMem ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall lvl. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper SegOp () MCMem
op)
  where
    mapper :: SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM MCMem)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda MCMem -> DoubleBufferM MCMem (Lambda MCMem)
mapOnSegOpLambda = Lambda MCMem -> DoubleBufferM MCMem (Lambda MCMem)
forall lore.
ASTLore lore =>
Lambda lore -> DoubleBufferM lore (Lambda lore)
optimiseLambda,
          mapOnSegOpBody :: KernelBody MCMem -> DoubleBufferM MCMem (KernelBody MCMem)
mapOnSegOpBody = KernelBody MCMem -> DoubleBufferM MCMem (KernelBody MCMem)
forall lore.
ASTLore lore =>
KernelBody lore -> DoubleBufferM lore (KernelBody lore)
optimiseKernelBody
        }
    inSegOp :: Env lore -> Env lore
inSegOp Env lore
env = Env lore
env {envOptimiseLoop :: OptimiseLoop lore
envOptimiseLoop = OptimiseLoop lore
forall lore inner.
(Constraints lore, Op lore ~ MemOp inner, BinderOps lore) =>
OptimiseLoop lore
optimiseLoop}
optimiseMCOp Op MCMem
op = MemOp (MCOp MCMem ())
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Op MCMem
MemOp (MCOp MCMem ())
op

optimiseKernelBody ::
  ASTLore lore =>
  KernelBody lore ->
  DoubleBufferM lore (KernelBody lore)
optimiseKernelBody :: KernelBody lore -> DoubleBufferM lore (KernelBody lore)
optimiseKernelBody KernelBody lore
kbody = do
  [Stm lore]
stms' <- [Stm lore] -> DoubleBufferM lore [Stm lore]
forall lore.
ASTLore lore =>
[Stm lore] -> DoubleBufferM lore [Stm lore]
optimiseStms ([Stm lore] -> DoubleBufferM lore [Stm lore])
-> [Stm lore] -> DoubleBufferM lore [Stm lore]
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ KernelBody lore -> Stms lore
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody lore
kbody
  KernelBody lore -> DoubleBufferM lore (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody lore -> DoubleBufferM lore (KernelBody lore))
-> KernelBody lore -> DoubleBufferM lore (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ KernelBody lore
kbody {kernelBodyStms :: Stms lore
kernelBodyStms = [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
stms'}

optimiseLambda ::
  ASTLore lore =>
  Lambda lore ->
  DoubleBufferM lore (Lambda lore)
optimiseLambda :: Lambda lore -> DoubleBufferM lore (Lambda lore)
optimiseLambda Lambda lore
lam = do
  Body lore
body <- Scope lore
-> DoubleBufferM lore (Body lore) -> DoubleBufferM lore (Body lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope lore -> Scope lore
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope lore -> Scope lore) -> Scope lore -> Scope lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda lore
lam) (DoubleBufferM lore (Body lore) -> DoubleBufferM lore (Body lore))
-> DoubleBufferM lore (Body lore) -> DoubleBufferM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Body lore -> DoubleBufferM lore (Body lore)
forall lore.
ASTLore lore =>
Body lore -> DoubleBufferM lore (Body lore)
optimiseBody (Body lore -> DoubleBufferM lore (Body lore))
-> Body lore -> DoubleBufferM lore (Body lore)
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  Lambda lore -> DoubleBufferM lore (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
lam {lambdaBody :: Body lore
lambdaBody = Body lore
body}

type Constraints lore =
  ( ASTLore lore,
    FParamInfo lore ~ FParamMem,
    LParamInfo lore ~ LParamMem,
    RetType lore ~ RetTypeMem,
    LetDec lore ~ LetDecMem,
    BranchType lore ~ BranchTypeMem,
    ExpDec lore ~ (),
    BodyDec lore ~ (),
    OpReturns lore
  )

optimiseLoop :: (Constraints lore, Op lore ~ MemOp inner, BinderOps lore) => OptimiseLoop lore
optimiseLoop :: OptimiseLoop lore
optimiseLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val Body lore
body = do
  -- We start out by figuring out which of the merge variables should
  -- be double-buffered.
  [DoubleBuffer]
buffered <-
    [(Param FParamMem, SubExp)]
-> [Param FParamMem] -> Names -> DoubleBufferM lore [DoubleBuffer]
forall (m :: * -> *).
MonadFreshNames m =>
[(Param FParamMem, SubExp)]
-> [Param FParamMem] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams
      ([Param FParamMem] -> [SubExp] -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
ctx) (Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body))
      (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge)
      (Body lore -> Names
forall lore. Body lore -> Names
boundInBody Body lore
body)
  -- Then create the allocations of the buffers and copies of the
  -- initial values.
  ([(Param FParamMem, SubExp)]
merge', [Stm lore]
allocs) <- [(FParam lore, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM lore ([(FParam lore, SubExp)], [Stm lore])
forall lore inner.
(Constraints lore, Op lore ~ MemOp inner, BinderOps lore) =>
[(FParam lore, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM lore ([(FParam lore, SubExp)], [Stm lore])
allocStms [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
merge [DoubleBuffer]
buffered
  -- Modify the loop body to copy buffered result arrays.
  let body' :: Body lore
body' = [FParam lore] -> [DoubleBuffer] -> Body lore -> Body lore
forall lore.
Constraints lore =>
[FParam lore] -> [DoubleBuffer] -> Body lore -> Body lore
doubleBufferResult (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge) [DoubleBuffer]
buffered Body lore
body
      ([(Param FParamMem, SubExp)]
ctx', [(Param FParamMem, SubExp)]
val') = Int
-> [(Param FParamMem, SubExp)]
-> ([(Param FParamMem, SubExp)], [(Param FParamMem, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param FParamMem, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
ctx) [(Param FParamMem, SubExp)]
merge'
  -- Modify the initial merge p
  ([Stm lore], [(Param FParamMem, SubExp)],
 [(Param FParamMem, SubExp)], Body lore)
-> DoubleBufferM
     lore
     ([Stm lore], [(Param FParamMem, SubExp)],
      [(Param FParamMem, SubExp)], Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm lore]
allocs, [(Param FParamMem, SubExp)]
ctx', [(Param FParamMem, SubExp)]
val', Body lore
body')
  where
    merge :: [(Param FParamMem, SubExp)]
merge = [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
val

-- | The booleans indicate whether we should also play with the
-- initial merge values.
data DoubleBuffer
  = BufferAlloc VName (PrimExp VName) Space Bool
  | -- | First name is the memory block to copy to,
    -- second is the name of the array copy.
    BufferCopy VName IxFun VName Bool
  | NoBuffer
  deriving (Int -> DoubleBuffer -> ShowS
[DoubleBuffer] -> ShowS
DoubleBuffer -> String
(Int -> DoubleBuffer -> ShowS)
-> (DoubleBuffer -> String)
-> ([DoubleBuffer] -> ShowS)
-> Show DoubleBuffer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DoubleBuffer] -> ShowS
$cshowList :: [DoubleBuffer] -> ShowS
show :: DoubleBuffer -> String
$cshow :: DoubleBuffer -> String
showsPrec :: Int -> DoubleBuffer -> ShowS
$cshowsPrec :: Int -> DoubleBuffer -> ShowS
Show)

doubleBufferMergeParams ::
  MonadFreshNames m =>
  [(Param FParamMem, SubExp)] ->
  [Param FParamMem] ->
  Names ->
  m [DoubleBuffer]
doubleBufferMergeParams :: [(Param FParamMem, SubExp)]
-> [Param FParamMem] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams [(Param FParamMem, SubExp)]
ctx_and_res [Param FParamMem]
val_params Names
bound_in_loop =
  StateT (Map VName (VName, Bool)) m [DoubleBuffer]
-> Map VName (VName, Bool) -> m [DoubleBuffer]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((Param FParamMem
 -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> [Param FParamMem]
-> StateT (Map VName (VName, Bool)) m [DoubleBuffer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param FParamMem -> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer [Param FParamMem]
val_params) Map VName (VName, Bool)
forall k a. Map k a
M.empty
  where
    loopVariant :: VName -> Bool
loopVariant VName
v =
      VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_loop
        Bool -> Bool -> Bool
|| VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
ctx_and_res

    loopInvariantSize :: SubExp -> Maybe (SubExp, Bool)
loopInvariantSize (Constant PrimValue
v) =
      (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
v, Bool
True)
    loopInvariantSize (Var VName
v) =
      case ((Param FParamMem, SubExp) -> Bool)
-> [(Param FParamMem, SubExp)] -> Maybe (Param FParamMem, SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
ctx_and_res of
        Just (Param FParamMem
_, Constant PrimValue
val) ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
val, Bool
False)
        Just (Param FParamMem
_, Var VName
v')
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
loopVariant VName
v' ->
            (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v', Bool
False)
        Just (Param FParamMem, SubExp)
_ ->
          Maybe (SubExp, Bool)
forall a. Maybe a
Nothing
        Maybe (Param FParamMem, SubExp)
Nothing ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v, Bool
True)

    sizeForMem :: VName -> Maybe (PrimExp VName, Bool)
sizeForMem VName
mem = [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a. [a] -> Maybe a
maybeHead ([(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool))
-> [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (PrimExp VName, Bool))
-> [Param FParamMem] -> [(PrimExp VName, Bool)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (FParamMem -> Maybe (PrimExp VName, Bool))
-> (Param FParamMem -> FParamMem)
-> Param FParamMem
-> Maybe (PrimExp VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [Param FParamMem]
val_params
      where
        arrayInMem :: FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
arraymem IxFun
ixfun))
          | IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun,
            Just ([SubExp]
dims, [Bool]
b) <-
              (SubExp -> Maybe (SubExp, Bool))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM SubExp -> Maybe (SubExp, Bool)
loopInvariantSize ([SubExp] -> Maybe ([SubExp], [Bool]))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape,
            VName
mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arraymem =
            (PrimExp VName, Bool) -> Maybe (PrimExp VName, Bool)
forall a. a -> Maybe a
Just
              ( Type -> PrimExp VName
arraySizeInBytesExp (Type -> PrimExp VName) -> Type -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                  PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness,
                [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
b
              )
        arrayInMem FParamMem
_ = Maybe (PrimExp VName, Bool)
forall a. Maybe a
Nothing

    buffer :: Param FParamMem -> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer Param FParamMem
fparam = case Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam of
      Mem Space
space
        | Just (PrimExp VName
size, Bool
b) <- VName -> Maybe (PrimExp VName, Bool)
sizeForMem (VName -> Maybe (PrimExp VName, Bool))
-> VName -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam -> do
          -- Let us double buffer this!
          VName
bufname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_mem"
          (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName (VName, Bool) -> Map VName (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m ())
-> (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall a b. (a -> b) -> a -> b
$ VName
-> (VName, Bool)
-> Map VName (VName, Bool)
-> Map VName (VName, Bool)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam) (VName
bufname, Bool
b)
          DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Space -> Bool -> DoubleBuffer
BufferAlloc VName
bufname PrimExp VName
size Space
space Bool
b
      Array {}
        | MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) <- Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
fparam -> do
          Maybe (VName, Bool)
buffered <- (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName (VName, Bool) -> Maybe (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool)))
-> (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VName, Bool) -> Maybe (VName, Bool)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem
          case Maybe (VName, Bool)
buffered of
            Just (VName
bufname, Bool
b) -> do
              VName
copyname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_array"
              DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> VName -> Bool -> DoubleBuffer
BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
b
            Maybe (VName, Bool)
Nothing ->
              DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer
      Type
_ -> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer

allocStms ::
  (Constraints lore, Op lore ~ MemOp inner, BinderOps lore) =>
  [(FParam lore, SubExp)] ->
  [DoubleBuffer] ->
  DoubleBufferM lore ([(FParam lore, SubExp)], [Stm lore])
allocStms :: [(FParam lore, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM lore ([(FParam lore, SubExp)], [Stm lore])
allocStms [(FParam lore, SubExp)]
merge = WriterT [Stm lore] (DoubleBufferM lore) [(Param FParamMem, SubExp)]
-> DoubleBufferM lore ([(Param FParamMem, SubExp)], [Stm lore])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   [Stm lore] (DoubleBufferM lore) [(Param FParamMem, SubExp)]
 -> DoubleBufferM lore ([(Param FParamMem, SubExp)], [Stm lore]))
-> ([DoubleBuffer]
    -> WriterT
         [Stm lore] (DoubleBufferM lore) [(Param FParamMem, SubExp)])
-> [DoubleBuffer]
-> DoubleBufferM lore ([(Param FParamMem, SubExp)], [Stm lore])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp)
 -> DoubleBuffer
 -> WriterT
      [Stm lore] (DoubleBufferM lore) (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp)]
-> [DoubleBuffer]
-> WriterT
     [Stm lore] (DoubleBufferM lore) [(Param FParamMem, SubExp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param FParamMem, SubExp)
-> DoubleBuffer
-> WriterT
     [Stm lore] (DoubleBufferM lore) (Param FParamMem, SubExp)
forall lore (m :: * -> *) (t :: (* -> *) -> * -> *) d ret lore
       inner.
(HasScope lore m, OpReturns lore, AllocOp (Op lore),
 MonadFreshNames m, MonadTrans t, Typed (MemInfo d Uniqueness ret),
 MonadWriter [Stm lore] (t m), ASTLore lore, BinderOps lore,
 RetType lore ~ RetTypeMem, ExpDec lore ~ (),
 BranchType lore ~ BranchTypeMem, Op lore ~ MemOp inner,
 LParamInfo lore ~ LetDecMem, LParamInfo lore ~ LParamInfo lore,
 FParamInfo lore ~ FParamMem, FParamInfo lore ~ FParamInfo lore,
 LetDec lore ~ LetDecMem, LetDec lore ~ LetDec lore,
 LetDec lore ~ LetDecMem) =>
(Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
merge
  where
    allocation :: (Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation m :: (Param (MemInfo d Uniqueness ret), SubExp)
m@(Param VName
pname MemInfo d Uniqueness ret
_, SubExp
_) (BufferAlloc VName
name PrimExp VName
size Space
space Bool
b) = do
      Stms lore
stms <- m (Stms lore) -> t m (Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms lore) -> t m (Stms lore))
-> m (Stms lore) -> t m (Stms lore)
forall a b. (a -> b) -> a -> b
$
        Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
          SubExp
size' <- String -> PrimExp VName -> BinderT lore (State VNameSource) SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"double_buffer_size" PrimExp VName
size
          [VName]
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
      [Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Stm lore] -> t m ()) -> [Stm lore] -> t m ()
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms
      if Bool
b
        then (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall dec. VName -> dec -> Param dec
Param VName
pname (MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret))
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d Uniqueness ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, VName -> SubExp
Var VName
name)
        else (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret), SubExp)
m
    allocation (Param (MemInfo d Uniqueness ret)
f, Var VName
v) (BufferCopy VName
mem IxFun
_ VName
_ Bool
b) | Bool
b = do
      VName
v_copy <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_double_buffer_copy"
      (VName
_v_mem, IxFun
v_ixfun) <- m (VName, IxFun) -> t m (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun) -> t m (VName, IxFun))
-> m (VName, IxFun) -> t m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun)
forall lore (m :: * -> *).
(Mem lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
      let bt :: PrimType
bt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          shape :: ShapeBase SubExp
shape = Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeBase SubExp) -> Type -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          bound :: LetDecMem
bound = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
v_ixfun
      [Stm lore] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
        [ Pattern lore -> StmAux (ExpDec lore) -> ExpT lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v_copy LetDecMem
bound]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT lore -> Stm lore) -> ExpT lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
        ]
      -- It is important that we treat this as a consumption, to
      -- avoid the Copy from being hoisted out of any enclosing
      -- loops.  Since we re-use (=overwrite) memory in the loop,
      -- the copy is critical for initialisation.  See issue #816.
      let uniqueMemInfo :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemArray PrimType
pt ShapeBase d
pshape Uniqueness
_ ret
ret) =
            PrimType
-> ShapeBase d -> Uniqueness -> ret -> MemInfo d Uniqueness ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pshape Uniqueness
Unique ret
ret
          uniqueMemInfo MemInfo d Uniqueness ret
info = MemInfo d Uniqueness ret
info
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
forall d ret. MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo d Uniqueness ret)
f, VName -> SubExp
Var VName
v_copy)
    allocation (Param (MemInfo d Uniqueness ret)
f, SubExp
se) DoubleBuffer
_ =
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret)
f, SubExp
se)

doubleBufferResult ::
  (Constraints lore) =>
  [FParam lore] ->
  [DoubleBuffer] ->
  Body lore ->
  Body lore
doubleBufferResult :: [FParam lore] -> [DoubleBuffer] -> Body lore -> Body lore
doubleBufferResult [FParam lore]
valparams [DoubleBuffer]
buffered (Body BodyDec lore
_ Stms lore
bnds [SubExp]
res) =
  let ([SubExp]
ctx_res, [SubExp]
val_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param FParamMem]
valparams) [SubExp]
res
      ([Maybe (Stm lore)]
copybnds, [SubExp]
val_res') =
        [(Maybe (Stm lore), SubExp)] -> ([Maybe (Stm lore)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Stm lore), SubExp)] -> ([Maybe (Stm lore)], [SubExp]))
-> [(Maybe (Stm lore), SubExp)] -> ([Maybe (Stm lore)], [SubExp])
forall a b. (a -> b) -> a -> b
$ (Param FParamMem
 -> DoubleBuffer -> SubExp -> (Maybe (Stm lore), SubExp))
-> [Param FParamMem]
-> [DoubleBuffer]
-> [SubExp]
-> [(Maybe (Stm lore), SubExp)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param FParamMem
-> DoubleBuffer -> SubExp -> (Maybe (Stm lore), SubExp)
buffer [FParam lore]
[Param FParamMem]
valparams [DoubleBuffer]
buffered [SubExp]
val_res
   in BodyDec lore -> Stms lore -> [SubExp] -> Body lore
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms lore
bnds Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Maybe (Stm lore)] -> [Stm lore]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm lore)]
copybnds)) ([SubExp] -> Body lore) -> [SubExp] -> Body lore
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_res [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_res'
  where
    buffer :: Param FParamMem
-> DoubleBuffer -> SubExp -> (Maybe (Stm lore), SubExp)
buffer Param FParamMem
_ (BufferAlloc VName
bufname PrimExp VName
_ Space
_ Bool
_) SubExp
_ =
      (Maybe (Stm lore)
forall a. Maybe a
Nothing, VName -> SubExp
Var VName
bufname)
    buffer Param FParamMem
fparam (BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
_) (Var VName
v) =
      -- To construct the copy we will need to figure out its type
      -- based on the type of the function parameter.
      let t :: Type
t = Type -> Type
resultType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam
          summary :: LetDecMem
summary = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
bufname IxFun
ixfun
          copybnd :: Stm lore
copybnd =
            Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
copyname LetDecMem
summary]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
       in (Stm lore -> Maybe (Stm lore)
forall a. a -> Maybe a
Just Stm lore
copybnd, VName -> SubExp
Var VName
copyname)
    buffer Param FParamMem
_ DoubleBuffer
_ SubExp
se =
      (Maybe (Stm lore)
forall a. Maybe a
Nothing, SubExp
se)

    parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [FParam lore]
[Param FParamMem]
valparams) [SubExp]
res

    resultType :: Type -> Type
resultType Type
t = Type
t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substitute (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t)

    substitute :: SubExp -> SubExp
substitute (Var VName
v)
      | Just SubExp
replacement <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
replacement
    substitute SubExp
se =
      SubExp
se