{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.GPUMem
( GPUMem,
simplifyProg,
simplifyStms,
simpleGPUMem,
module Futhark.IR.Mem,
module Futhark.IR.GPU.Kernel,
)
where
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.GPU.Kernel
import Futhark.IR.GPU.Simplify (simplifyKernelOp)
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BinderOps (..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.TypeCheck as TC
data GPUMem
instance RepTypes GPUMem where
type LetDec GPUMem = LetDecMem
type FParamInfo GPUMem = FParamMem
type LParamInfo GPUMem = LParamMem
type RetType GPUMem = RetTypeMem
type BranchType GPUMem = BranchTypeMem
type Op GPUMem = MemOp (HostOp GPUMem ())
instance ASTRep GPUMem where
expTypesFromPattern :: forall (m :: * -> *).
(HasScope GPUMem m, Monad m) =>
Pattern GPUMem -> m [BranchType GPUMem]
expTypesFromPattern = [BranchTypeMem] -> m [BranchTypeMem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BranchTypeMem] -> m [BranchTypeMem])
-> (PatternT LetDecMem -> [BranchTypeMem])
-> PatternT LetDecMem
-> m [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BranchTypeMem) -> BranchTypeMem)
-> [(VName, BranchTypeMem)] -> [BranchTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BranchTypeMem) -> BranchTypeMem
forall a b. (a, b) -> b
snd ([(VName, BranchTypeMem)] -> [BranchTypeMem])
-> (PatternT LetDecMem -> [(VName, BranchTypeMem)])
-> PatternT LetDecMem
-> [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
-> [(VName, BranchTypeMem)]
forall a b. (a, b) -> b
snd (([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
-> [(VName, BranchTypeMem)])
-> (PatternT LetDecMem
-> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)]))
-> PatternT LetDecMem
-> [(VName, BranchTypeMem)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LetDecMem
-> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
bodyReturnsFromPattern
instance OpReturns GPUMem where
opReturns :: forall (m :: * -> *).
(Monad m, HasScope GPUMem m) =>
Op GPUMem -> m [ExpReturns]
opReturns (Alloc SubExp
_ Space
space) =
[ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
opReturns (Inner (SegOp SegOp SegLevel GPUMem
op)) = SegOp SegLevel GPUMem -> m [ExpReturns]
forall rep (m :: * -> *) lvl.
(Mem rep, Monad m, HasScope rep m) =>
SegOp lvl rep -> m [ExpReturns]
segOpReturns SegOp SegLevel GPUMem
op
opReturns Op GPUMem
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemOp (HostOp GPUMem ()) -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType Op GPUMem
MemOp (HostOp GPUMem ())
k
instance PrettyRep GPUMem
instance TC.CheckableOp GPUMem where
checkOp :: OpWithAliases (Op GPUMem) -> TypeM GPUMem ()
checkOp = Maybe SegLevel
-> MemOp (HostOp (Aliases GPUMem) ()) -> TypeM GPUMem ()
forall {rep} {op}.
(Checkable rep,
OpWithAliases (Op rep) ~ MemOp (HostOp (Aliases rep) op)) =>
Maybe SegLevel -> MemOp (HostOp (Aliases rep) op) -> TypeM rep ()
typeCheckMemoryOp Maybe SegLevel
forall a. Maybe a
Nothing
where
typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp (Aliases rep) op) -> TypeM rep ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
forall rep.
Checkable rep =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
TC.require [PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp (Aliases rep) op
op) =
(SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op -> TypeM rep ())
-> HostOp (Aliases rep) op
-> TypeM rep ()
forall rep op.
Checkable rep =>
(SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op -> TypeM rep ())
-> HostOp (Aliases rep) op
-> TypeM rep ()
typeCheckHostOp (Maybe SegLevel -> MemOp (HostOp (Aliases rep) op) -> TypeM rep ()
typeCheckMemoryOp (Maybe SegLevel -> MemOp (HostOp (Aliases rep) op) -> TypeM rep ())
-> (SegLevel -> Maybe SegLevel)
-> SegLevel
-> MemOp (HostOp (Aliases rep) op)
-> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Maybe SegLevel
forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (TypeM rep () -> op -> TypeM rep ()
forall a b. a -> b -> a
const (TypeM rep () -> op -> TypeM rep ())
-> TypeM rep () -> op -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) HostOp (Aliases rep) op
op
instance TC.Checkable GPUMem where
checkFParamDec :: VName -> FParamInfo GPUMem -> TypeM GPUMem ()
checkFParamDec = VName -> FParamInfo GPUMem -> TypeM GPUMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
checkLParamDec :: VName -> LParamInfo GPUMem -> TypeM GPUMem ()
checkLParamDec = VName -> LParamInfo GPUMem -> TypeM GPUMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
checkLetBoundDec :: VName -> LetDec GPUMem -> TypeM GPUMem ()
checkLetBoundDec = VName -> LetDec GPUMem -> TypeM GPUMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
checkRetType :: [RetType GPUMem] -> TypeM GPUMem ()
checkRetType = (RetTypeMem -> TypeM GPUMem ()) -> [RetTypeMem] -> TypeM GPUMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((RetTypeMem -> TypeM GPUMem ())
-> [RetTypeMem] -> TypeM GPUMem ())
-> (RetTypeMem -> TypeM GPUMem ())
-> [RetTypeMem]
-> TypeM GPUMem ()
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM GPUMem ()
forall rep u.
Checkable rep =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM rep ()
TC.checkExtType (TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM GPUMem ())
-> (RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness)
-> RetTypeMem
-> TypeM GPUMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf
primFParam :: VName -> PrimType -> TypeM GPUMem (FParam (Aliases GPUMem))
primFParam VName
name PrimType
t = Param FParamMem -> TypeM GPUMem (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param FParamMem -> TypeM GPUMem (Param FParamMem))
-> Param FParamMem -> TypeM GPUMem (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
matchPattern :: Pattern (Aliases GPUMem) -> Exp (Aliases GPUMem) -> TypeM GPUMem ()
matchPattern = Pattern (Aliases GPUMem) -> Exp (Aliases GPUMem) -> TypeM GPUMem ()
forall rep.
(Mem rep, Checkable rep) =>
Pattern (Aliases rep) -> Exp (Aliases rep) -> TypeM rep ()
matchPatternToExp
matchReturnType :: [RetType GPUMem] -> [SubExp] -> TypeM GPUMem ()
matchReturnType = [RetType GPUMem] -> [SubExp] -> TypeM GPUMem ()
forall rep.
(Mem rep, Checkable rep) =>
[RetTypeMem] -> [SubExp] -> TypeM rep ()
matchFunctionReturnType
matchBranchType :: [BranchType GPUMem] -> Body (Aliases GPUMem) -> TypeM GPUMem ()
matchBranchType = [BranchType GPUMem] -> Body (Aliases GPUMem) -> TypeM GPUMem ()
forall rep.
(Mem rep, Checkable rep) =>
[BranchTypeMem] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType
matchLoopResult :: [FParam (Aliases GPUMem)]
-> [FParam (Aliases GPUMem)] -> [SubExp] -> TypeM GPUMem ()
matchLoopResult = [FParam (Aliases GPUMem)]
-> [FParam (Aliases GPUMem)] -> [SubExp] -> TypeM GPUMem ()
forall rep.
(Mem rep, Checkable rep) =>
[FParam (Aliases rep)]
-> [FParam (Aliases rep)] -> [SubExp] -> TypeM rep ()
matchLoopResultMem
instance BinderOps GPUMem where
mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ GPUMem) =>
Pattern GPUMem -> Exp GPUMem -> m (ExpDec GPUMem)
mkExpDecB Pattern GPUMem
_ Exp GPUMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ GPUMem) =>
Stms GPUMem -> [SubExp] -> m (Body GPUMem)
mkBodyB Stms GPUMem
stms [SubExp]
res = Body GPUMem -> m (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPUMem -> m (Body GPUMem)) -> Body GPUMem -> m (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem -> Stms GPUMem -> [SubExp] -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> [SubExp] -> BodyT rep
Body () Stms GPUMem
stms [SubExp]
res
mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ GPUMem) =>
[VName] -> Exp GPUMem -> m (Stm GPUMem)
mkLetNamesB = ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *) inner.
(Op (Rep m) ~ MemOp inner, MonadBinder m, ExpDec (Rep m) ~ (),
Allocator (Rep m) (PatAllocM (Rep m))) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ()
instance BinderOps (Engine.Wise GPUMem) where
mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ Wise GPUMem) =>
Pattern (Wise GPUMem)
-> Exp (Wise GPUMem) -> m (ExpDec (Wise GPUMem))
mkExpDecB Pattern (Wise GPUMem)
pat Exp (Wise GPUMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise GPUMem)
-> ExpDec GPUMem -> Exp (Wise GPUMem) -> ExpDec (Wise GPUMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pattern (Wise rep)
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pattern (Wise GPUMem)
pat () Exp (Wise GPUMem)
e
mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ Wise GPUMem) =>
Stms (Wise GPUMem) -> [SubExp] -> m (Body (Wise GPUMem))
mkBodyB Stms (Wise GPUMem)
stms [SubExp]
res = Body (Wise GPUMem) -> m (Body (Wise GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise GPUMem) -> m (Body (Wise GPUMem)))
-> Body (Wise GPUMem) -> m (Body (Wise GPUMem))
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem
-> Stms (Wise GPUMem) -> [SubExp] -> Body (Wise GPUMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> [SubExp] -> Body (Wise rep)
Engine.mkWiseBody () Stms (Wise GPUMem)
stms [SubExp]
res
mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Rep m ~ Wise GPUMem) =>
[VName] -> Exp (Wise GPUMem) -> m (Stm (Wise GPUMem))
mkLetNamesB = [VName] -> Exp (Wise GPUMem) -> m (Stm (Wise GPUMem))
forall (m :: * -> *) inner rep.
(Op (Rep m) ~ MemOp inner, ExpDec rep ~ (), HasScope (Wise rep) m,
Allocator rep (PatAllocM rep), MonadBinder m,
CanBeWise (Op rep)) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB''
simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg = SimpleOps GPUMem -> Prog GPUMem -> PassM (Prog GPUMem)
forall rep inner.
(SimplifyMemory rep, Op rep ~ MemOp inner) =>
SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric SimpleOps GPUMem
simpleGPUMem
simplifyStms ::
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem ->
m
( Engine.SymbolTable (Engine.Wise GPUMem),
Stms GPUMem
)
simplifyStms :: forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (SymbolTable (Wise GPUMem), Stms GPUMem)
simplifyStms = SimpleOps GPUMem
-> Stms GPUMem -> m (SymbolTable (Wise GPUMem), Stms GPUMem)
forall rep (m :: * -> *) inner.
(HasScope rep m, MonadFreshNames m, SimplifyMemory rep,
Op rep ~ MemOp inner) =>
SimpleOps rep -> Stms rep -> m (SymbolTable (Wise rep), Stms rep)
simplifyStmsGeneric SimpleOps GPUMem
simpleGPUMem
simpleGPUMem :: Engine.SimpleOps GPUMem
simpleGPUMem :: SimpleOps GPUMem
simpleGPUMem =
(OpWithWisdom (HostOp GPUMem ()) -> UsageTable)
-> SimplifyOp GPUMem (HostOp GPUMem ()) -> SimpleOps GPUMem
forall rep inner.
(SimplifyMemory rep, Op rep ~ MemOp inner) =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp rep inner -> SimpleOps rep
simpleGeneric OpWithWisdom (HostOp GPUMem ()) -> UsageTable
HostOp (Wise GPUMem) () -> UsageTable
usage (SimplifyOp GPUMem (HostOp GPUMem ()) -> SimpleOps GPUMem)
-> SimplifyOp GPUMem (HostOp GPUMem ()) -> SimpleOps GPUMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp GPUMem ()
-> HostOp GPUMem ()
-> SimpleM
GPUMem (HostOp (Wise GPUMem) (OpWithWisdom ()), Stms (Wise GPUMem))
forall rep op.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep op
-> HostOp rep op
-> SimpleM
rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
simplifyKernelOp (SimplifyOp GPUMem ()
-> HostOp GPUMem ()
-> SimpleM
GPUMem
(HostOp (Wise GPUMem) (OpWithWisdom ()), Stms (Wise GPUMem)))
-> SimplifyOp GPUMem ()
-> HostOp GPUMem ()
-> SimpleM
GPUMem (HostOp (Wise GPUMem) (OpWithWisdom ()), Stms (Wise GPUMem))
forall a b. (a -> b) -> a -> b
$ SimpleM GPUMem ((), Stms (Wise GPUMem))
-> () -> SimpleM GPUMem ((), Stms (Wise GPUMem))
forall a b. a -> b -> a
const (SimpleM GPUMem ((), Stms (Wise GPUMem))
-> () -> SimpleM GPUMem ((), Stms (Wise GPUMem)))
-> SimpleM GPUMem ((), Stms (Wise GPUMem))
-> ()
-> SimpleM GPUMem ((), Stms (Wise GPUMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise GPUMem)) -> SimpleM GPUMem ((), Stms (Wise GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise GPUMem)
forall a. Monoid a => a
mempty)
where
usage :: HostOp (Wise GPUMem) () -> UsageTable
usage (SegOp (SegMap SegGroup {} SegSpace
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody (Wise GPUMem)
kbody)) = KernelBody (Wise GPUMem) -> UsageTable
localAllocs KernelBody (Wise GPUMem)
kbody
usage HostOp (Wise GPUMem) ()
_ = UsageTable
forall a. Monoid a => a
mempty
localAllocs :: KernelBody (Wise GPUMem) -> UsageTable
localAllocs = (Stm (Wise GPUMem) -> UsageTable)
-> Stms (Wise GPUMem) -> UsageTable
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc (Stms (Wise GPUMem) -> UsageTable)
-> (KernelBody (Wise GPUMem) -> Stms (Wise GPUMem))
-> KernelBody (Wise GPUMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Wise GPUMem) -> Stms (Wise GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms
stmLocalAlloc :: Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc = Exp (Wise GPUMem) -> UsageTable
forall {rep} {inner}.
(Op rep ~ MemOp inner) =>
ExpT rep -> UsageTable
expLocalAlloc (Exp (Wise GPUMem) -> UsageTable)
-> (Stm (Wise GPUMem) -> Exp (Wise GPUMem))
-> Stm (Wise GPUMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise GPUMem) -> Exp (Wise GPUMem)
forall rep. Stm rep -> Exp rep
stmExp
expLocalAlloc :: ExpT rep -> UsageTable
expLocalAlloc (Op (Alloc (Var VName
v) (Space String
"local"))) =
VName -> UsageTable
UT.sizeUsage VName
v
expLocalAlloc ExpT rep
_ =
UsageTable
forall a. Monoid a => a
mempty