{-# LANGUAGE TypeFamilies #-}

-- | We require that entry points return arrays with zero offset in
-- row-major order.  "Futhark.Pass.ExplicitAllocations" is
-- conservative and inserts copies to ensure this is the case.  After
-- simplification, it may turn out that those copies are redundant.
-- This pass removes them.  It's a prettyString simple pass, as it only has
-- to look at the top level of entry points.
module Futhark.Optimise.EntryPointMem
  ( entryPointMemGPU,
    entryPointMemMC,
    entryPointMemSeq,
  )
where

import Data.List (find)
import Data.Map.Strict qualified as M
import Futhark.IR.GPUMem (GPUMem)
import Futhark.IR.MCMem (MCMem)
import Futhark.IR.Mem
import Futhark.IR.SeqMem (SeqMem)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Transform.Substitute

type Table rep = M.Map VName (Stm rep)

mkTable :: Stms rep -> Table rep
mkTable :: forall {k} (rep :: k). Stms rep -> Table rep
mkTable = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} {rep :: k}. Stm rep -> Map VName (Stm rep)
f
  where
    f :: Stm rep -> Map VName (Stm rep)
f Stm rep
stm = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm)) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Stm rep
stm

varInfo :: Mem rep inner => VName -> Table rep -> Maybe (LetDecMem, Exp rep)
varInfo :: forall {k} (rep :: k) inner.
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v Table rep
table = do
  Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Table rep
table
  PatElem VName
_ LetDec rep
info <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
  forall a. a -> Maybe a
Just (forall t. HasLetDecMem t => t -> LParamMem
letDecMem LetDec rep
info, Exp rep
e)

optimiseFun :: Mem rep inner => Table rep -> FunDef rep -> FunDef rep
optimiseFun :: forall {k} (rep :: k) inner.
Mem rep inner =>
Table rep -> FunDef rep -> FunDef rep
optimiseFun Table rep
consts_table FunDef rep
fd =
  FunDef rep
fd {funDefBody :: Body rep
funDefBody = Body rep -> Body rep
onBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef rep
fd}
  where
    table :: Table rep
table = Table rep
consts_table forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stms rep -> Table rep
mkTable (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms (forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef rep
fd))
    mkSubst :: SubExp -> Map VName VName
mkSubst (Var VName
v0)
      | Just (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem0 IxFun
ixfun0), BasicOp (Manifest [Int]
_ VName
v1)) <-
          forall {k} (rep :: k) inner.
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v0 Table rep
table,
        Just (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem1 IxFun
ixfun1), Exp rep
_) <-
          forall {k} (rep :: k) inner.
Mem rep inner =>
VName -> Table rep -> Maybe (LParamMem, Exp rep)
varInfo VName
v1 Table rep
table,
        IxFun
ixfun0 forall a. Eq a => a -> a -> Bool
== IxFun
ixfun1 =
          forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
mem0, VName
mem1), (VName
v0, VName
v1)]
    mkSubst SubExp
_ = forall a. Monoid a => a
mempty
    onBody :: Body rep -> Body rep
onBody (Body BodyDec rep
dec Stms rep
stms Result
res) =
      let substs :: Map VName VName
substs = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Map VName VName
mkSubst forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
       in forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Result
res

entryPointMem :: Mem rep inner => Pass rep rep
entryPointMem :: forall {k1} (rep :: k1) inner. Mem rep inner => Pass rep rep
entryPointMem =
  Pass
    { passName :: String
passName = String
"Entry point memory optimisation",
      passDescription :: String
passDescription = String
"Remove redundant copies of entry point results.",
      passFunction :: Prog rep -> PassM (Prog rep)
passFunction = forall {k1} {k2} (fromrep :: k1) (torep :: k2).
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} {rep :: k} {inner} {f :: * -> *}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
 Op rep ~ MemOp inner, Applicative f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns inner) =>
Stms rep -> FunDef rep -> f (FunDef rep)
onFun
    }
  where
    onFun :: Stms rep -> FunDef rep -> f (FunDef rep)
onFun Stms rep
consts FunDef rep
fd = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner.
Mem rep inner =>
Table rep -> FunDef rep -> FunDef rep
optimiseFun (forall {k} (rep :: k). Stms rep -> Table rep
mkTable Stms rep
consts) FunDef rep
fd

-- | The pass for GPU representation.
entryPointMemGPU :: Pass GPUMem GPUMem
entryPointMemGPU :: Pass GPUMem GPUMem
entryPointMemGPU = forall {k1} (rep :: k1) inner. Mem rep inner => Pass rep rep
entryPointMem

-- | The pass for MC representation.
entryPointMemMC :: Pass MCMem MCMem
entryPointMemMC :: Pass MCMem MCMem
entryPointMemMC = forall {k1} (rep :: k1) inner. Mem rep inner => Pass rep rep
entryPointMem

-- | The pass for Seq representation.
entryPointMemSeq :: Pass SeqMem SeqMem
entryPointMemSeq :: Pass SeqMem SeqMem
entryPointMemSeq = forall {k1} (rep :: k1) inner. Mem rep inner => Pass rep rep
entryPointMem