{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.MCMem
  ( MCMem,

    -- * Simplification
    simplifyProg,

    -- * Module re-exports
    module Futhark.IR.Mem,
    module Futhark.IR.SegOp,
    module Futhark.IR.MC.Op,
  )
where

import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.MC.Op
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.IR.SegOp
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BinderOps (..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.TypeCheck as TC

data MCMem

instance RepTypes MCMem where
  type LetDec MCMem = LetDecMem
  type FParamInfo MCMem = FParamMem
  type LParamInfo MCMem = LParamMem
  type RetType MCMem = RetTypeMem
  type BranchType MCMem = BranchTypeMem
  type Op MCMem = MemOp (MCOp MCMem ())

instance ASTRep MCMem where
  expTypesFromPattern :: forall (m :: * -> *).
(HasScope MCMem m, Monad m) =>
Pattern MCMem -> m [BranchType MCMem]
expTypesFromPattern = [BranchTypeMem] -> m [BranchTypeMem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BranchTypeMem] -> m [BranchTypeMem])
-> (PatternT LetDecMem -> [BranchTypeMem])
-> PatternT LetDecMem
-> m [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BranchTypeMem) -> BranchTypeMem)
-> [(VName, BranchTypeMem)] -> [BranchTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BranchTypeMem) -> BranchTypeMem
forall a b. (a, b) -> b
snd ([(VName, BranchTypeMem)] -> [BranchTypeMem])
-> (PatternT LetDecMem -> [(VName, BranchTypeMem)])
-> PatternT LetDecMem
-> [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
-> [(VName, BranchTypeMem)]
forall a b. (a, b) -> b
snd (([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
 -> [(VName, BranchTypeMem)])
-> (PatternT LetDecMem
    -> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)]))
-> PatternT LetDecMem
-> [(VName, BranchTypeMem)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LetDecMem
-> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
bodyReturnsFromPattern

instance OpReturns MCMem where
  opReturns :: forall (m :: * -> *).
(Monad m, HasScope MCMem m) =>
Op MCMem -> m [ExpReturns]
opReturns (Alloc SubExp
_ Space
space) = [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner (ParOp Maybe (SegOp () MCMem)
_ SegOp () MCMem
op)) = SegOp () MCMem -> m [ExpReturns]
forall rep (m :: * -> *) lvl.
(Mem rep, Monad m, HasScope rep m) =>
SegOp lvl rep -> m [ExpReturns]
segOpReturns SegOp () MCMem
op
  opReturns (Inner (OtherOp ())) = [ExpReturns] -> m [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

instance PrettyRep MCMem

instance TC.CheckableOp MCMem where
  checkOp :: OpWithAliases (Op MCMem) -> TypeM MCMem ()
checkOp = OpWithAliases (Op MCMem) -> TypeM MCMem ()
forall {rep}.
Checkable rep =>
MemOp (MCOp (Aliases rep) ()) -> TypeM rep ()
typeCheckMemoryOp
    where
      typeCheckMemoryOp :: MemOp (MCOp (Aliases rep) ()) -> TypeM rep ()
typeCheckMemoryOp (Alloc SubExp
size Space
_) =
        [TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
forall rep.
Checkable rep =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
TC.require [PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
      typeCheckMemoryOp (Inner MCOp (Aliases rep) ()
op) =
        (() -> TypeM rep ()) -> MCOp (Aliases rep) () -> TypeM rep ()
forall rep op.
Checkable rep =>
(op -> TypeM rep ()) -> MCOp (Aliases rep) op -> TypeM rep ()
typeCheckMCOp () -> TypeM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp (Aliases rep) ()
op

instance TC.Checkable MCMem where
  checkFParamDec :: VName -> FParamInfo MCMem -> TypeM MCMem ()
checkFParamDec = VName -> FParamInfo MCMem -> TypeM MCMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLParamDec :: VName -> LParamInfo MCMem -> TypeM MCMem ()
checkLParamDec = VName -> LParamInfo MCMem -> TypeM MCMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLetBoundDec :: VName -> LetDec MCMem -> TypeM MCMem ()
checkLetBoundDec = VName -> LetDec MCMem -> TypeM MCMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkRetType :: [RetType MCMem] -> TypeM MCMem ()
checkRetType = (RetTypeMem -> TypeM MCMem ()) -> [RetTypeMem] -> TypeM MCMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM MCMem ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM rep ()
TC.checkExtType (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM MCMem ())
-> (RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness)
-> RetTypeMem
-> TypeM MCMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf)
  primFParam :: VName -> PrimType -> TypeM MCMem (FParam (Aliases MCMem))
primFParam VName
name PrimType
t = Param FParamMem -> TypeM MCMem (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param FParamMem -> TypeM MCMem (Param FParamMem))
-> Param FParamMem -> TypeM MCMem (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPattern :: Pattern (Aliases MCMem) -> Exp (Aliases MCMem) -> TypeM MCMem ()
matchPattern = Pattern (Aliases MCMem) -> Exp (Aliases MCMem) -> TypeM MCMem ()
forall rep.
(Mem rep, Checkable rep) =>
Pattern (Aliases rep) -> Exp (Aliases rep) -> TypeM rep ()
matchPatternToExp
  matchReturnType :: [RetType MCMem] -> [SubExp] -> TypeM MCMem ()
matchReturnType = [RetType MCMem] -> [SubExp] -> TypeM MCMem ()
forall rep.
(Mem rep, Checkable rep) =>
[RetTypeMem] -> [SubExp] -> TypeM rep ()
matchFunctionReturnType
  matchBranchType :: [BranchType MCMem] -> Body (Aliases MCMem) -> TypeM MCMem ()
matchBranchType = [BranchType MCMem] -> Body (Aliases MCMem) -> TypeM MCMem ()
forall rep.
(Mem rep, Checkable rep) =>
[BranchTypeMem] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType
  matchLoopResult :: [FParam (Aliases MCMem)]
-> [FParam (Aliases MCMem)] -> [SubExp] -> TypeM MCMem ()
matchLoopResult = [FParam (Aliases MCMem)]
-> [FParam (Aliases MCMem)] -> [SubExp] -> TypeM MCMem ()
forall rep.
(Mem rep, Checkable rep) =>
[FParam (Aliases rep)]
-> [FParam (Aliases rep)] -> [SubExp] -> TypeM rep ()
matchLoopResultMem

instance BinderOps MCMem where
  mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ MCMem) =>
Pattern MCMem -> Exp MCMem -> m (ExpDec MCMem)
mkExpDecB Pattern MCMem
_ Exp MCMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ MCMem) =>
Stms MCMem -> [SubExp] -> m (Body MCMem)
mkBodyB Stms MCMem
stms [SubExp]
res = Body MCMem -> m (Body MCMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body MCMem -> m (Body MCMem)) -> Body MCMem -> m (Body MCMem)
forall a b. (a -> b) -> a -> b
$ BodyDec MCMem -> Stms MCMem -> [SubExp] -> Body MCMem
forall rep. BodyDec rep -> Stms rep -> [SubExp] -> BodyT rep
Body () Stms MCMem
stms [SubExp]
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ MCMem) =>
[VName] -> Exp MCMem -> m (Stm MCMem)
mkLetNamesB = ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *) inner.
(Op (Rep m) ~ MemOp inner, MonadBinder m, ExpDec (Rep m) ~ (),
 Allocator (Rep m) (PatAllocM (Rep m))) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ()

instance BinderOps (Engine.Wise MCMem) where
  mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ Wise MCMem) =>
Pattern (Wise MCMem) -> Exp (Wise MCMem) -> m (ExpDec (Wise MCMem))
mkExpDecB Pattern (Wise MCMem)
pat Exp (Wise MCMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise MCMem)
-> ExpDec MCMem -> Exp (Wise MCMem) -> ExpDec (Wise MCMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pattern (Wise rep)
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pattern (Wise MCMem)
pat () Exp (Wise MCMem)
e
  mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ Wise MCMem) =>
Stms (Wise MCMem) -> [SubExp] -> m (Body (Wise MCMem))
mkBodyB Stms (Wise MCMem)
stms [SubExp]
res = Body (Wise MCMem) -> m (Body (Wise MCMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise MCMem) -> m (Body (Wise MCMem)))
-> Body (Wise MCMem) -> m (Body (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ BodyDec MCMem -> Stms (Wise MCMem) -> [SubExp] -> Body (Wise MCMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> [SubExp] -> Body (Wise rep)
Engine.mkWiseBody () Stms (Wise MCMem)
stms [SubExp]
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ Wise MCMem) =>
[VName] -> Exp (Wise MCMem) -> m (Stm (Wise MCMem))
mkLetNamesB = [VName] -> Exp (Wise MCMem) -> m (Stm (Wise MCMem))
forall (m :: * -> *) inner rep.
(Op (Rep m) ~ MemOp inner, ExpDec rep ~ (), HasScope (Wise rep) m,
 Allocator rep (PatAllocM rep), MonadBinder m,
 CanBeWise (Op rep)) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB''

simplifyProg :: Prog MCMem -> PassM (Prog MCMem)
simplifyProg :: Prog MCMem -> PassM (Prog MCMem)
simplifyProg = SimpleOps MCMem -> Prog MCMem -> PassM (Prog MCMem)
forall rep inner.
(SimplifyMemory rep, Op rep ~ MemOp inner) =>
SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric SimpleOps MCMem
simpleMCMem

simpleMCMem :: Engine.SimpleOps MCMem
simpleMCMem :: SimpleOps MCMem
simpleMCMem =
  (OpWithWisdom (MCOp MCMem ()) -> UsageTable)
-> SimplifyOp MCMem (MCOp MCMem ()) -> SimpleOps MCMem
forall rep inner.
(SimplifyMemory rep, Op rep ~ MemOp inner) =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp rep inner -> SimpleOps rep
simpleGeneric (UsageTable -> MCOp (Wise MCMem) () -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty) (SimplifyOp MCMem (MCOp MCMem ()) -> SimpleOps MCMem)
-> SimplifyOp MCMem (MCOp MCMem ()) -> SimpleOps MCMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp MCMem ()
-> MCOp MCMem ()
-> SimpleM
     MCMem (MCOp (Wise MCMem) (OpWithWisdom ()), Stms (Wise MCMem))
forall rep op.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep op
-> MCOp rep op
-> SimpleM rep (MCOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
simplifyMCOp (SimplifyOp MCMem ()
 -> MCOp MCMem ()
 -> SimpleM
      MCMem (MCOp (Wise MCMem) (OpWithWisdom ()), Stms (Wise MCMem)))
-> SimplifyOp MCMem ()
-> MCOp MCMem ()
-> SimpleM
     MCMem (MCOp (Wise MCMem) (OpWithWisdom ()), Stms (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ SimpleM MCMem ((), Stms (Wise MCMem))
-> () -> SimpleM MCMem ((), Stms (Wise MCMem))
forall a b. a -> b -> a
const (SimpleM MCMem ((), Stms (Wise MCMem))
 -> () -> SimpleM MCMem ((), Stms (Wise MCMem)))
-> SimpleM MCMem ((), Stms (Wise MCMem))
-> ()
-> SimpleM MCMem ((), Stms (Wise MCMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise MCMem)) -> SimpleM MCMem ((), Stms (Wise MCMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise MCMem)
forall a. Monoid a => a
mempty)