module Futhark.Optimise.ArrayLayout
  ( optimiseArrayLayoutGPU,
    optimiseArrayLayoutMC,
  )
where

import Control.Monad.State.Strict
import Futhark.Analysis.AccessPattern (Analyse, analyseDimAccesses)
import Futhark.Analysis.PrimExp.Table (primExpTable)
import Futhark.Builder
import Futhark.IR.GPU (GPU)
import Futhark.IR.MC (MC)
import Futhark.Optimise.ArrayLayout.Layout (layoutTableFromIndexTable)
import Futhark.Optimise.ArrayLayout.Transform (Transform, transformStms)
import Futhark.Pass

optimiseArrayLayout :: (Analyse rep, Transform rep, BuilderOps rep) => String -> Pass rep rep
optimiseArrayLayout :: forall rep.
(Analyse rep, Transform rep, BuilderOps rep) =>
String -> Pass rep rep
optimiseArrayLayout String
s =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    (String
"optimise array layout " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
s)
    String
"Transform array layout for locality optimisations."
    ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$ \Prog rep
prog -> do
      -- Analyse the program
      let index_table :: IndexTable rep
index_table = Prog rep -> IndexTable rep
forall rep. Analyse rep => Prog rep -> IndexTable rep
analyseDimAccesses Prog rep
prog
      -- Compute primExps for all variables
      let table :: PrimExpTable
table = Prog rep -> PrimExpTable
forall rep.
(PrimExpAnalysis rep, RepTypes rep) =>
Prog rep -> PrimExpTable
primExpTable Prog rep
prog
      -- Compute permutations to acheive coalescence for all arrays
      let permutation_table :: LayoutTable
permutation_table = PrimExpTable -> IndexTable rep -> LayoutTable
forall {k} (rep :: k).
Layout rep =>
PrimExpTable -> IndexTable rep -> LayoutTable
layoutTableFromIndexTable PrimExpTable
table IndexTable rep
index_table
      -- Insert permutations in the AST
      (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation (LayoutTable -> Scope rep -> Stms rep -> PassM (Stms rep)
forall {f :: * -> *} {rep}.
(MonadFreshNames f, Transform rep, BuilderOps rep) =>
LayoutTable -> Scope rep -> Seq (Stm rep) -> f (Seq (Stm rep))
onStms LayoutTable
permutation_table) Prog rep
prog
  where
    onStms :: LayoutTable -> Scope rep -> Seq (Stm rep) -> f (Seq (Stm rep))
onStms LayoutTable
layout_table Scope rep
scope Seq (Stm rep)
stms = do
      let m :: TransformM rep (Seq (Stm rep))
m = LayoutTable
-> ExpMap rep -> Seq (Stm rep) -> TransformM rep (Seq (Stm rep))
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
transformStms LayoutTable
layout_table ExpMap rep
forall a. Monoid a => a
mempty Seq (Stm rep)
stms
      ((Seq (Stm rep), Seq (Stm rep)) -> Seq (Stm rep))
-> f (Seq (Stm rep), Seq (Stm rep)) -> f (Seq (Stm rep))
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Seq (Stm rep), Seq (Stm rep)) -> Seq (Stm rep)
forall a b. (a, b) -> a
fst (f (Seq (Stm rep), Seq (Stm rep)) -> f (Seq (Stm rep)))
-> f (Seq (Stm rep), Seq (Stm rep)) -> f (Seq (Stm rep))
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((Seq (Stm rep), Seq (Stm rep)), VNameSource))
-> f (Seq (Stm rep), Seq (Stm rep))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Seq (Stm rep), Seq (Stm rep)), VNameSource))
 -> f (Seq (Stm rep), Seq (Stm rep)))
-> (VNameSource -> ((Seq (Stm rep), Seq (Stm rep)), VNameSource))
-> f (Seq (Stm rep), Seq (Stm rep))
forall a b. (a -> b) -> a -> b
$ State VNameSource (Seq (Stm rep), Seq (Stm rep))
-> VNameSource -> ((Seq (Stm rep), Seq (Stm rep)), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Seq (Stm rep), Seq (Stm rep))
 -> VNameSource -> ((Seq (Stm rep), Seq (Stm rep)), VNameSource))
-> State VNameSource (Seq (Stm rep), Seq (Stm rep))
-> VNameSource
-> ((Seq (Stm rep), Seq (Stm rep)), VNameSource)
forall a b. (a -> b) -> a -> b
$ TransformM rep (Seq (Stm rep))
-> Scope rep -> State VNameSource (Seq (Stm rep), Seq (Stm rep))
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT TransformM rep (Seq (Stm rep))
m Scope rep
scope

-- | The optimisation performed on the GPU representation.
optimiseArrayLayoutGPU :: Pass GPU GPU
optimiseArrayLayoutGPU :: Pass GPU GPU
optimiseArrayLayoutGPU = String -> Pass GPU GPU
forall rep.
(Analyse rep, Transform rep, BuilderOps rep) =>
String -> Pass rep rep
optimiseArrayLayout String
"gpu"

-- | The optimisation performed on the MC representation.
optimiseArrayLayoutMC :: Pass MC MC
optimiseArrayLayoutMC :: Pass MC MC
optimiseArrayLayoutMC = String -> Pass MC MC
forall rep.
(Analyse rep, Transform rep, BuilderOps rep) =>
String -> Pass rep rep
optimiseArrayLayout String
"mc"