{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.GPUMem
( GPUMem,
simplifyProg,
simplifyStms,
simpleGPUMem,
module Futhark.IR.Mem,
module Futhark.IR.GPU.Op,
)
where
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.GPU.Op
import Futhark.IR.GPU.Simplify (simplifyKernelOp)
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.IR.TypeCheck qualified as TC
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'')
data GPUMem
instance RepTypes GPUMem where
type LetDec GPUMem = LetDecMem
type FParamInfo GPUMem = FParamMem
type LParamInfo GPUMem = LParamMem
type RetType GPUMem = RetTypeMem
type BranchType GPUMem = BranchTypeMem
type Op GPUMem = MemOp (HostOp GPUMem ())
instance ASTRep GPUMem where
expTypesFromPat :: forall (m :: * -> *).
(HasScope GPUMem m, Monad m) =>
Pat (LetDec GPUMem) -> m [BranchType GPUMem]
expTypesFromPat = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat LetDecMem -> [(VName, BranchTypeMem)]
bodyReturnsFromPat
instance OpReturns (HostOp GPUMem ()) where
opReturns :: forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
HostOp GPUMem () -> m [ExpReturns]
opReturns (SegOp SegOp SegLevel GPUMem
op) = forall {k1} {k2} (rep :: k1) inner (m :: * -> *) lvl
(somerep :: k2).
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp SegLevel GPUMem
op
opReturns HostOp GPUMem ()
k = [ExtType] -> [ExpReturns]
extReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType HostOp GPUMem ()
k
instance OpReturns (HostOp (Engine.Wise GPUMem) ()) where
opReturns :: forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
HostOp (Wise GPUMem) () -> m [ExpReturns]
opReturns (SegOp SegOp SegLevel (Wise GPUMem)
op) = forall {k1} {k2} (rep :: k1) inner (m :: * -> *) lvl
(somerep :: k2).
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp SegLevel (Wise GPUMem)
op
opReturns HostOp (Wise GPUMem) ()
k = [ExtType] -> [ExpReturns]
extReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType HostOp (Wise GPUMem) ()
k
instance PrettyRep GPUMem
instance TC.CheckableOp GPUMem where
checkOp :: OpWithAliases (Op GPUMem) -> TypeM GPUMem ()
checkOp = forall {k} {rep :: k} {op}.
(OpWithAliases (Op rep) ~ MemOp (HostOp (Aliases rep) op),
Checkable rep) =>
Maybe SegLevel -> MemOp (HostOp (Aliases rep) op) -> TypeM rep ()
typeCheckMemoryOp forall a. Maybe a
Nothing
where
typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp (Aliases rep) op) -> TypeM rep ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
forall {k} (rep :: k).
Checkable rep =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp (Aliases rep) op
op) =
forall {k} (rep :: k) op.
Checkable rep =>
(SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op -> TypeM rep ())
-> HostOp (Aliases rep) op
-> TypeM rep ()
typeCheckHostOp (Maybe SegLevel -> MemOp (HostOp (Aliases rep) op) -> TypeM rep ()
typeCheckMemoryOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) HostOp (Aliases rep) op
op
instance TC.Checkable GPUMem where
checkFParamDec :: VName -> FParamInfo GPUMem -> TypeM GPUMem ()
checkFParamDec = forall {k} (rep :: k) u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
checkLParamDec :: VName -> LParamInfo GPUMem -> TypeM GPUMem ()
checkLParamDec = forall {k} (rep :: k) u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
checkLetBoundDec :: VName -> LetDec GPUMem -> TypeM GPUMem ()
checkLetBoundDec = forall {k} (rep :: k) u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
checkRetType :: [RetType GPUMem] -> TypeM GPUMem ()
checkRetType = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) u.
Checkable rep =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM rep ()
TC.checkExtType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf
primFParam :: VName -> PrimType -> TypeM GPUMem (FParam (Aliases GPUMem))
primFParam VName
name PrimType
t = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
name (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
matchPat :: Pat (LetDec (Aliases GPUMem))
-> Exp (Aliases GPUMem) -> TypeM GPUMem ()
matchPat = forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem, Checkable rep) =>
Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep ()
matchPatToExp
matchReturnType :: [RetType GPUMem] -> Result -> TypeM GPUMem ()
matchReturnType = forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[RetTypeMem] -> Result -> TypeM rep ()
matchFunctionReturnType
matchBranchType :: [BranchType GPUMem] -> Body (Aliases GPUMem) -> TypeM GPUMem ()
matchBranchType = forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[BranchTypeMem] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType
matchLoopResult :: [FParam (Aliases GPUMem)] -> Result -> TypeM GPUMem ()
matchLoopResult = forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem
instance BuilderOps GPUMem where
mkExpDecB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPUMem) =>
Pat (LetDec GPUMem) -> Exp GPUMem -> m (ExpDec GPUMem)
mkExpDecB Pat (LetDec GPUMem)
_ Exp GPUMem
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
mkBodyB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPUMem) =>
Stms GPUMem -> Result -> m (Body GPUMem)
mkBodyB Stms GPUMem
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPUMem
stms Result
res
mkLetNamesB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPUMem) =>
[VName] -> Exp GPUMem -> m (Stm GPUMem)
mkLetNamesB = forall (m :: * -> *) inner.
(LetDec (Rep m) ~ LetDecMem, Mem (Rep m) inner, MonadBuilder m,
ExpDec (Rep m) ~ ()) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ()
instance BuilderOps (Engine.Wise GPUMem) where
mkExpDecB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise GPUMem) =>
Pat (LetDec (Wise GPUMem))
-> Exp (Wise GPUMem) -> m (ExpDec (Wise GPUMem))
mkExpDecB Pat (LetDec (Wise GPUMem))
pat Exp (Wise GPUMem)
e = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise GPUMem))
pat () Exp (Wise GPUMem)
e
mkBodyB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise GPUMem) =>
Stms (Wise GPUMem) -> Result -> m (Body (Wise GPUMem))
mkBodyB Stms (Wise GPUMem)
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
Engine.mkWiseBody () Stms (Wise GPUMem)
stms Result
res
mkLetNamesB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise GPUMem) =>
[VName] -> Exp (Wise GPUMem) -> m (Stm (Wise GPUMem))
mkLetNamesB = forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem,
OpReturns (OpWithWisdom inner), ExpDec rep ~ (), Rep m ~ Wise rep,
HasScope (Wise rep) m, MonadBuilder m, CanBeWise inner) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB''
instance TraverseOpStms (Engine.Wise GPUMem) where
traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise GPUMem)) (Wise GPUMem)
traverseOpStms = forall {k} (m :: * -> *) inner (rep :: k).
Monad m =>
OpStmsTraverser m inner rep -> OpStmsTraverser m (MemOp inner) rep
traverseMemOpStms (forall {k} (m :: * -> *) op (rep :: k).
Monad m =>
OpStmsTraverser m op rep -> OpStmsTraverser m (HostOp rep op) rep
traverseHostOpStms (forall a b. a -> b -> a
const forall (f :: * -> *) a. Applicative f => a -> f a
pure))
simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg = forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric SimpleOps GPUMem
simpleGPUMem
simplifyStms ::
(HasScope GPUMem m, MonadFreshNames m) => Stms GPUMem -> m (Stms GPUMem)
simplifyStms :: forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (Stms GPUMem)
simplifyStms = forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, MonadFreshNames m, SimplifyMemory rep inner) =>
SimpleOps rep -> Stms rep -> m (Stms rep)
simplifyStmsGeneric SimpleOps GPUMem
simpleGPUMem
simpleGPUMem :: Engine.SimpleOps GPUMem
simpleGPUMem :: SimpleOps GPUMem
simpleGPUMem =
forall {k} (rep :: k) inner.
SimplifyMemory rep inner =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp rep (OpWithWisdom inner) -> SimpleOps rep
simpleGeneric HostOp (Wise GPUMem) () -> UsageTable
usage forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep op
-> HostOp (Wise rep) op
-> SimpleM rep (HostOp (Wise rep) op, Stms (Wise rep))
simplifyKernelOp forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ((), forall a. Monoid a => a
mempty)
where
usage :: HostOp (Wise GPUMem) () -> UsageTable
usage (SegOp (SegMap SegLevel
_ SegSpace
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody (Wise GPUMem)
kbody)) = KernelBody (Wise GPUMem) -> UsageTable
localAllocs KernelBody (Wise GPUMem)
kbody
usage HostOp (Wise GPUMem) ()
_ = forall a. Monoid a => a
mempty
localAllocs :: KernelBody (Wise GPUMem) -> UsageTable
localAllocs = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms
stmLocalAlloc :: Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc = Exp (Wise GPUMem) -> UsageTable
expLocalAlloc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp
expLocalAlloc :: Exp (Wise GPUMem) -> UsageTable
expLocalAlloc (Op (Alloc (Var VName
v) Space
_)) =
VName -> UsageTable
UT.sizeUsage VName
v
expLocalAlloc (Op (Inner (SegOp (SegMap SegLevel
_ SegSpace
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody (Wise GPUMem)
kbody)))) =
KernelBody (Wise GPUMem) -> UsageTable
localAllocs KernelBody (Wise GPUMem)
kbody
expLocalAlloc Exp (Wise GPUMem)
_ =
forall a. Monoid a => a
mempty