{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.MCMem
  ( MCMem,

    -- * Simplification
    simplifyProg,

    -- * Module re-exports
    module Futhark.IR.Mem,
    module Futhark.IR.SegOp,
    module Futhark.IR.MC.Op,
  )
where

import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.MC.Op
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.IR.SegOp
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'')

data MCMem

instance RepTypes MCMem where
  type LetDec MCMem = LetDecMem
  type FParamInfo MCMem = FParamMem
  type LParamInfo MCMem = LParamMem
  type RetType MCMem = RetTypeMem
  type BranchType MCMem = BranchTypeMem
  type OpC MCMem = MemOp (MCOp NoOp)

instance ASTRep MCMem where
  expTypesFromPat :: forall (m :: * -> *).
(HasScope MCMem m, Monad m) =>
Pat (LetDec MCMem) -> m [BranchType MCMem]
expTypesFromPat = [BranchTypeMem] -> m [BranchTypeMem]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchTypeMem] -> m [BranchTypeMem])
-> (Pat LetDecMem -> [BranchTypeMem])
-> Pat LetDecMem
-> m [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BranchTypeMem) -> BranchTypeMem)
-> [(VName, BranchTypeMem)] -> [BranchTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BranchTypeMem) -> BranchTypeMem
forall a b. (a, b) -> b
snd ([(VName, BranchTypeMem)] -> [BranchTypeMem])
-> (Pat LetDecMem -> [(VName, BranchTypeMem)])
-> Pat LetDecMem
-> [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat LetDecMem -> [(VName, BranchTypeMem)]
bodyReturnsFromPat

instance OpReturns (MCOp NoOp MCMem) where
  opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
MCOp NoOp MCMem -> m [ExpReturns]
opReturns (ParOp Maybe (SegOp () MCMem)
_ SegOp () MCMem
op) = SegOp () MCMem -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp () MCMem
op
  opReturns (OtherOp NoOp MCMem
NoOp) = [ExpReturns] -> m [ExpReturns]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

instance OpReturns (MCOp NoOp (Aliases MCMem)) where
  opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
MCOp NoOp (Aliases MCMem) -> m [ExpReturns]
opReturns (ParOp Maybe (SegOp () (Aliases MCMem))
_ SegOp () (Aliases MCMem)
op) = SegOp () (Aliases MCMem) -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp () (Aliases MCMem)
op
  opReturns (OtherOp NoOp (Aliases MCMem)
NoOp) = [ExpReturns] -> m [ExpReturns]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

instance OpReturns (MCOp NoOp (Engine.Wise MCMem)) where
  opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
MCOp NoOp (Wise MCMem) -> m [ExpReturns]
opReturns (ParOp Maybe (SegOp () (Wise MCMem))
_ SegOp () (Wise MCMem)
op) = SegOp () (Wise MCMem) -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp () (Wise MCMem)
op
  opReturns MCOp NoOp (Wise MCMem)
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MCOp NoOp (Wise MCMem) -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
forall t (m :: * -> *).
HasScope t m =>
MCOp NoOp (Wise MCMem) -> m [ExtType]
opType MCOp NoOp (Wise MCMem)
k

instance PrettyRep MCMem

instance TC.Checkable MCMem where
  checkOp :: Op (Aliases MCMem) -> TypeM MCMem ()
checkOp = Op (Aliases MCMem) -> TypeM MCMem ()
MemOp (MCOp NoOp) (Aliases MCMem) -> TypeM MCMem ()
forall {rep} {op :: * -> *}.
Checkable rep =>
MemOp (MCOp op) (Aliases rep) -> TypeM rep ()
typeCheckMemoryOp
    where
      typeCheckMemoryOp :: MemOp (MCOp op) (Aliases rep) -> TypeM rep ()
typeCheckMemoryOp (Alloc SubExp
size Space
_) =
        [TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
forall rep.
Checkable rep =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
TC.require [PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
      typeCheckMemoryOp (Inner MCOp op (Aliases rep)
op) =
        (op (Aliases rep) -> TypeM rep ())
-> MCOp op (Aliases rep) -> TypeM rep ()
forall rep (op :: * -> *).
Checkable rep =>
(op (Aliases rep) -> TypeM rep ())
-> MCOp op (Aliases rep) -> TypeM rep ()
typeCheckMCOp (TypeM rep () -> op (Aliases rep) -> TypeM rep ()
forall a b. a -> b -> a
const (TypeM rep () -> op (Aliases rep) -> TypeM rep ())
-> TypeM rep () -> op (Aliases rep) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) MCOp op (Aliases rep)
op
  checkFParamDec :: VName -> FParamInfo MCMem -> TypeM MCMem ()
checkFParamDec = VName -> FParamInfo MCMem -> TypeM MCMem ()
VName -> FParamMem -> TypeM MCMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLParamDec :: VName -> LParamInfo MCMem -> TypeM MCMem ()
checkLParamDec = VName -> LParamInfo MCMem -> TypeM MCMem ()
VName -> LetDecMem -> TypeM MCMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLetBoundDec :: VName -> LetDec MCMem -> TypeM MCMem ()
checkLetBoundDec = VName -> LetDec MCMem -> TypeM MCMem ()
VName -> LetDecMem -> TypeM MCMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkRetType :: [RetType MCMem] -> TypeM MCMem ()
checkRetType = (RetTypeMem -> TypeM MCMem ()) -> [RetTypeMem] -> TypeM MCMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM MCMem ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM rep ()
TC.checkExtType (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM MCMem ())
-> (RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness)
-> RetTypeMem
-> TypeM MCMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf)
  primFParam :: VName -> PrimType -> TypeM MCMem (FParam (Aliases MCMem))
primFParam VName
name PrimType
t = FParam (Aliases MCMem) -> TypeM MCMem (FParam (Aliases MCMem))
forall a. a -> TypeM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FParam (Aliases MCMem) -> TypeM MCMem (FParam (Aliases MCMem)))
-> FParam (Aliases MCMem) -> TypeM MCMem (FParam (Aliases MCMem))
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> FParamMem -> Param FParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
name (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPat :: Pat (LetDec (Aliases MCMem))
-> Exp (Aliases MCMem) -> TypeM MCMem ()
matchPat = Pat (LetDec (Aliases MCMem))
-> Exp (Aliases MCMem) -> TypeM MCMem ()
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, Checkable rep) =>
Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep ()
matchPatToExp
  matchReturnType :: [RetType MCMem] -> Result -> TypeM MCMem ()
matchReturnType = [RetType MCMem] -> Result -> TypeM MCMem ()
[RetTypeMem] -> Result -> TypeM MCMem ()
forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[RetTypeMem] -> Result -> TypeM rep ()
matchFunctionReturnType
  matchBranchType :: [BranchType MCMem] -> Body (Aliases MCMem) -> TypeM MCMem ()
matchBranchType = [BranchType MCMem] -> Body (Aliases MCMem) -> TypeM MCMem ()
[BranchTypeMem] -> Body (Aliases MCMem) -> TypeM MCMem ()
forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[BranchTypeMem] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType
  matchLoopResult :: [FParam (Aliases MCMem)] -> Result -> TypeM MCMem ()
matchLoopResult = [FParam (Aliases MCMem)] -> Result -> TypeM MCMem ()
forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem

instance BuilderOps MCMem where
  mkExpDecB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ MCMem) =>
Pat (LetDec MCMem) -> Exp MCMem -> m (ExpDec MCMem)
mkExpDecB Pat (LetDec MCMem)
_ Exp MCMem
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  mkBodyB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ MCMem) =>
Stms MCMem -> Result -> m (Body MCMem)
mkBodyB Stms MCMem
stms Result
res = Body MCMem -> m (Body MCMem)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MCMem -> m (Body MCMem)) -> Body MCMem -> m (Body MCMem)
forall a b. (a -> b) -> a -> b
$ BodyDec MCMem -> Stms MCMem -> Result -> Body MCMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms MCMem
stms Result
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ MCMem) =>
[VName] -> Exp MCMem -> m (Stm MCMem)
mkLetNamesB = Space
-> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *) (inner :: * -> *).
(LetDec (Rep m) ~ LetDecMem, Mem (Rep m) inner, MonadBuilder m,
 ExpDec (Rep m) ~ ()) =>
Space
-> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' Space
DefaultSpace ()

instance BuilderOps (Engine.Wise MCMem) where
  mkExpDecB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise MCMem) =>
Pat (LetDec (Wise MCMem))
-> Exp (Wise MCMem) -> m (ExpDec (Wise MCMem))
mkExpDecB Pat (LetDec (Wise MCMem))
pat Exp (Wise MCMem)
e = ExpDec (Wise MCMem) -> m (ExpDec (Wise MCMem))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpDec (Wise MCMem) -> m (ExpDec (Wise MCMem)))
-> ExpDec (Wise MCMem) -> m (ExpDec (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise MCMem))
-> ExpDec MCMem -> Exp (Wise MCMem) -> ExpDec (Wise MCMem)
forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise MCMem))
pat () Exp (Wise MCMem)
e
  mkBodyB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise MCMem) =>
Stms (Wise MCMem) -> Result -> m (Body (Wise MCMem))
mkBodyB Stms (Wise MCMem)
stms Result
res = Body (Wise MCMem) -> m (Body (Wise MCMem))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Wise MCMem) -> m (Body (Wise MCMem)))
-> Body (Wise MCMem) -> m (Body (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ BodyDec MCMem -> Stms (Wise MCMem) -> Result -> Body (Wise MCMem)
forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
Engine.mkWiseBody () Stms (Wise MCMem)
stms Result
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise MCMem) =>
[VName] -> Exp (Wise MCMem) -> m (Stm (Wise MCMem))
mkLetNamesB = Space -> [VName] -> Exp (Wise MCMem) -> m (Stm (Wise MCMem))
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem,
 OpReturns (inner (Wise rep)), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m,
 AliasedOp (inner (Wise rep)), RephraseOp (MemOp inner),
 CanBeWise inner) =>
Space -> [VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' Space
DefaultSpace

instance TraverseOpStms (Engine.Wise MCMem) where
  traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise MCMem)) (Wise MCMem)
traverseOpStms = OpStmsTraverser m (MCOp NoOp (Wise MCMem)) (Wise MCMem)
-> OpStmsTraverser m (MemOp (MCOp NoOp) (Wise MCMem)) (Wise MCMem)
forall (m :: * -> *) (inner :: * -> *) rep.
Monad m =>
OpStmsTraverser m (inner rep) rep
-> OpStmsTraverser m (MemOp inner rep) rep
traverseMemOpStms (OpStmsTraverser m (NoOp (Wise MCMem)) (Wise MCMem)
-> OpStmsTraverser m (MCOp NoOp (Wise MCMem)) (Wise MCMem)
forall (m :: * -> *) (op :: * -> *) rep.
Monad m =>
OpStmsTraverser m (op rep) rep
-> OpStmsTraverser m (MCOp op rep) rep
traverseMCOpStms ((NoOp (Wise MCMem) -> m (NoOp (Wise MCMem)))
-> OpStmsTraverser m (NoOp (Wise MCMem)) (Wise MCMem)
forall a b. a -> b -> a
const NoOp (Wise MCMem) -> m (NoOp (Wise MCMem))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure))

simplifyProg :: Prog MCMem -> PassM (Prog MCMem)
simplifyProg :: Prog MCMem -> PassM (Prog MCMem)
simplifyProg = RuleBook (Wise MCMem)
-> SimpleOps MCMem -> Prog MCMem -> PassM (Prog MCMem)
forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
RuleBook (Wise rep)
-> SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric RuleBook (Wise MCMem)
forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
RuleBook (Wise rep)
memRuleBook SimpleOps MCMem
simpleMCMem

simpleMCMem :: Engine.SimpleOps MCMem
simpleMCMem :: SimpleOps MCMem
simpleMCMem =
  (MCOp NoOp (Wise MCMem) -> UsageTable)
-> SimplifyOp MCMem (MCOp NoOp (Wise MCMem)) -> SimpleOps MCMem
forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
(inner (Wise rep) -> UsageTable)
-> SimplifyOp rep (inner (Wise rep)) -> SimpleOps rep
simpleGeneric (UsageTable -> MCOp NoOp (Wise MCMem) -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty) (SimplifyOp MCMem (MCOp NoOp (Wise MCMem)) -> SimpleOps MCMem)
-> SimplifyOp MCMem (MCOp NoOp (Wise MCMem)) -> SimpleOps MCMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp MCMem (NoOp (Wise MCMem))
-> SimplifyOp MCMem (MCOp NoOp (Wise MCMem))
forall rep (op :: * -> *).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep (op (Wise rep))
-> MCOp op (Wise rep)
-> SimpleM rep (MCOp op (Wise rep), Stms (Wise rep))
simplifyMCOp (SimplifyOp MCMem (NoOp (Wise MCMem))
 -> SimplifyOp MCMem (MCOp NoOp (Wise MCMem)))
-> SimplifyOp MCMem (NoOp (Wise MCMem))
-> SimplifyOp MCMem (MCOp NoOp (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ SimpleM MCMem (NoOp (Wise MCMem), Stms (Wise MCMem))
-> SimplifyOp MCMem (NoOp (Wise MCMem))
forall a b. a -> b -> a
const (SimpleM MCMem (NoOp (Wise MCMem), Stms (Wise MCMem))
 -> SimplifyOp MCMem (NoOp (Wise MCMem)))
-> SimpleM MCMem (NoOp (Wise MCMem), Stms (Wise MCMem))
-> SimplifyOp MCMem (NoOp (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ (NoOp (Wise MCMem), Stms (Wise MCMem))
-> SimpleM MCMem (NoOp (Wise MCMem), Stms (Wise MCMem))
forall a. a -> SimpleM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NoOp (Wise MCMem)
forall {k} (rep :: k). NoOp rep
NoOp, Stms (Wise MCMem)
forall a. Monoid a => a
mempty)