{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.SeqMem
  ( SeqMem,

    -- * Simplification
    simplifyProg,
    simpleSeqMem,

    -- * Module re-exports
    module Futhark.IR.Mem,
  )
where

import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.TypeCheck as TC

data SeqMem

instance RepTypes SeqMem where
  type LetDec SeqMem = LetDecMem
  type FParamInfo SeqMem = FParamMem
  type LParamInfo SeqMem = LParamMem
  type RetType SeqMem = RetTypeMem
  type BranchType SeqMem = BranchTypeMem
  type Op SeqMem = MemOp ()

instance ASTRep SeqMem where
  expTypesFromPat :: Pat SeqMem -> m [BranchType SeqMem]
expTypesFromPat = [BodyReturns] -> m [BodyReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BodyReturns] -> m [BodyReturns])
-> (PatT (MemBound NoUniqueness) -> [BodyReturns])
-> PatT (MemBound NoUniqueness)
-> m [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BodyReturns) -> BodyReturns)
-> [(VName, BodyReturns)] -> [BodyReturns]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BodyReturns) -> BodyReturns
forall a b. (a, b) -> b
snd ([(VName, BodyReturns)] -> [BodyReturns])
-> (PatT (MemBound NoUniqueness) -> [(VName, BodyReturns)])
-> PatT (MemBound NoUniqueness)
-> [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatT (MemBound NoUniqueness) -> [(VName, BodyReturns)]
bodyReturnsFromPat

instance PrettyRep SeqMem

instance TC.CheckableOp SeqMem where
  checkOp :: OpWithAliases (Op SeqMem) -> TypeM SeqMem ()
checkOp (Alloc size _) =
    [Type] -> SubExp -> TypeM SeqMem ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  checkOp (Inner ()) =
    () -> TypeM SeqMem ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance TC.Checkable SeqMem where
  checkFParamDec :: VName -> FParamInfo SeqMem -> TypeM SeqMem ()
checkFParamDec = VName -> FParamInfo SeqMem -> TypeM SeqMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLParamDec :: VName -> LParamInfo SeqMem -> TypeM SeqMem ()
checkLParamDec = VName -> LParamInfo SeqMem -> TypeM SeqMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLetBoundDec :: VName -> LetDec SeqMem -> TypeM SeqMem ()
checkLetBoundDec = VName -> LetDec SeqMem -> TypeM SeqMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkRetType :: [RetType SeqMem] -> TypeM SeqMem ()
checkRetType = (RetTypeMem -> TypeM SeqMem ()) -> [RetTypeMem] -> TypeM SeqMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TypeBase ExtShape Uniqueness -> TypeM SeqMem ()
forall rep u. Checkable rep => TypeBase ExtShape u -> TypeM rep ()
TC.checkExtType (TypeBase ExtShape Uniqueness -> TypeM SeqMem ())
-> (RetTypeMem -> TypeBase ExtShape Uniqueness)
-> RetTypeMem
-> TypeM SeqMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf)
  primFParam :: VName -> PrimType -> TypeM SeqMem (FParam (Aliases SeqMem))
primFParam VName
name PrimType
t = Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM SeqMem (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp Uniqueness MemBind)
 -> TypeM SeqMem (Param (MemInfo SubExp Uniqueness MemBind)))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM SeqMem (Param (MemInfo SubExp Uniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> MemInfo SubExp Uniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPat :: Pat (Aliases SeqMem) -> Exp (Aliases SeqMem) -> TypeM SeqMem ()
matchPat = Pat (Aliases SeqMem) -> Exp (Aliases SeqMem) -> TypeM SeqMem ()
forall rep inner.
(Mem rep inner, LetDec rep ~ MemBound NoUniqueness,
 Checkable rep) =>
Pat (Aliases rep) -> Exp (Aliases rep) -> TypeM rep ()
matchPatToExp
  matchReturnType :: [RetType SeqMem] -> Result -> TypeM SeqMem ()
matchReturnType = [RetType SeqMem] -> Result -> TypeM SeqMem ()
forall rep inner.
(Mem rep inner, Checkable rep) =>
[RetTypeMem] -> Result -> TypeM rep ()
matchFunctionReturnType
  matchBranchType :: [BranchType SeqMem] -> Body (Aliases SeqMem) -> TypeM SeqMem ()
matchBranchType = [BranchType SeqMem] -> Body (Aliases SeqMem) -> TypeM SeqMem ()
forall rep inner.
(Mem rep inner, Checkable rep) =>
[BodyReturns] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType
  matchLoopResult :: [FParam (Aliases SeqMem)] -> Result -> TypeM SeqMem ()
matchLoopResult = [FParam (Aliases SeqMem)] -> Result -> TypeM SeqMem ()
forall rep inner.
(Mem rep inner, Checkable rep) =>
[FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem

instance BuilderOps SeqMem where
  mkExpDecB :: Pat SeqMem -> Exp SeqMem -> m (ExpDec SeqMem)
mkExpDecB Pat SeqMem
_ Exp SeqMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: Stms SeqMem -> Result -> m (Body SeqMem)
mkBodyB Stms SeqMem
stms Result
res = Body SeqMem -> m (Body SeqMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body SeqMem -> m (Body SeqMem)) -> Body SeqMem -> m (Body SeqMem)
forall a b. (a -> b) -> a -> b
$ BodyDec SeqMem -> Stms SeqMem -> Result -> Body SeqMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () Stms SeqMem
stms Result
res
  mkLetNamesB :: [VName] -> Exp SeqMem -> m (Stm SeqMem)
mkLetNamesB = ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *) inner.
(LetDec (Rep m) ~ MemBound NoUniqueness, Mem (Rep m) inner,
 MonadBuilder m, ExpDec (Rep m) ~ ()) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ()

instance BuilderOps (Engine.Wise SeqMem) where
  mkExpDecB :: Pat (Wise SeqMem) -> Exp (Wise SeqMem) -> m (ExpDec (Wise SeqMem))
mkExpDecB Pat (Wise SeqMem)
pat Exp (Wise SeqMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pat (Wise SeqMem)
-> ExpDec SeqMem -> Exp (Wise SeqMem) -> ExpDec (Wise SeqMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (Wise rep) -> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (Wise SeqMem)
pat () Exp (Wise SeqMem)
e
  mkBodyB :: Stms (Wise SeqMem) -> Result -> m (Body (Wise SeqMem))
mkBodyB Stms (Wise SeqMem)
stms Result
res = Body (Wise SeqMem) -> m (Body (Wise SeqMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise SeqMem) -> m (Body (Wise SeqMem)))
-> Body (Wise SeqMem) -> m (Body (Wise SeqMem))
forall a b. (a -> b) -> a -> b
$ BodyDec SeqMem
-> Stms (Wise SeqMem) -> Result -> Body (Wise SeqMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
Engine.mkWiseBody () Stms (Wise SeqMem)
stms Result
res
  mkLetNamesB :: [VName] -> Exp (Wise SeqMem) -> m (Stm (Wise SeqMem))
mkLetNamesB = [VName] -> Exp (Wise SeqMem) -> m (Stm (Wise SeqMem))
forall rep inner (m :: * -> *).
(BuilderOps rep, Mem rep inner, LetDec rep ~ MemBound NoUniqueness,
 OpReturns (OpWithWisdom inner), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m, CanBeWise inner) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB''

simplifyProg :: Prog SeqMem -> PassM (Prog SeqMem)
simplifyProg :: Prog SeqMem -> PassM (Prog SeqMem)
simplifyProg = SimpleOps SeqMem -> Prog SeqMem -> PassM (Prog SeqMem)
forall rep inner.
SimplifyMemory rep inner =>
SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric SimpleOps SeqMem
simpleSeqMem

simpleSeqMem :: Engine.SimpleOps SeqMem
simpleSeqMem :: SimpleOps SeqMem
simpleSeqMem =
  (OpWithWisdom () -> UsageTable)
-> SimplifyOp SeqMem () -> SimpleOps SeqMem
forall rep inner.
SimplifyMemory rep inner =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp rep inner -> SimpleOps rep
simpleGeneric (UsageTable -> () -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty) (SimplifyOp SeqMem () -> SimpleOps SeqMem)
-> SimplifyOp SeqMem () -> SimpleOps SeqMem
forall a b. (a -> b) -> a -> b
$ SimpleM SeqMem ((), Stms (Wise SeqMem))
-> () -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. a -> b -> a
const (SimpleM SeqMem ((), Stms (Wise SeqMem))
 -> () -> SimpleM SeqMem ((), Stms (Wise SeqMem)))
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
-> ()
-> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise SeqMem)) -> SimpleM SeqMem ((), Stms (Wise SeqMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise SeqMem)
forall a. Monoid a => a
mempty)