{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.GPUMem
  ( GPUMem,

    -- * Simplification
    simplifyProg,
    simplifyStms,
    simpleGPUMem,

    -- * Module re-exports
    module Futhark.IR.Mem,
    module Futhark.IR.GPU.Op,
  )
where

import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Aliases (Aliases)
import Futhark.IR.GPU.Op
import Futhark.IR.GPU.Simplify (simplifyKernelOp)
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.IR.TypeCheck qualified as TC
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'')

data GPUMem

instance RepTypes GPUMem where
  type LetDec GPUMem = LetDecMem
  type FParamInfo GPUMem = FParamMem
  type LParamInfo GPUMem = LParamMem
  type RetType GPUMem = RetTypeMem
  type BranchType GPUMem = BranchTypeMem
  type OpC GPUMem = MemOp (HostOp NoOp)

instance ASTRep GPUMem where
  expTypesFromPat :: forall (m :: * -> *).
(HasScope GPUMem m, Monad m) =>
Pat (LetDec GPUMem) -> m [BranchType GPUMem]
expTypesFromPat = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat LetDecMem -> [(VName, BranchTypeMem)]
bodyReturnsFromPat

instance OpReturns (HostOp NoOp GPUMem) where
  opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
HostOp NoOp GPUMem -> m [ExpReturns]
opReturns (SegOp SegOp SegLevel GPUMem
op) = forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp SegLevel GPUMem
op
  opReturns HostOp NoOp GPUMem
k = [ExtType] -> [ExpReturns]
extReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType HostOp NoOp GPUMem
k

instance OpReturns (HostOp NoOp (Engine.Wise GPUMem)) where
  opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
HostOp NoOp (Wise GPUMem) -> m [ExpReturns]
opReturns (SegOp SegOp SegLevel (Wise GPUMem)
op) = forall rep (inner :: * -> *) (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp SegLevel (Wise GPUMem)
op
  opReturns HostOp NoOp (Wise GPUMem)
k = [ExtType] -> [ExpReturns]
extReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType HostOp NoOp (Wise GPUMem)
k

instance PrettyRep GPUMem

instance TC.Checkable GPUMem where
  checkOp :: Op (Aliases GPUMem) -> TypeM GPUMem ()
checkOp = Maybe SegLevel
-> MemOp (HostOp NoOp) (Aliases GPUMem) -> TypeM GPUMem ()
typeCheckMemoryOp forall a. Maybe a
Nothing
    where
      -- GHC 9.2 goes into an infinite loop without the type annotation.
      typeCheckMemoryOp ::
        Maybe SegLevel ->
        MemOp (HostOp NoOp) (Aliases GPUMem) ->
        TC.TypeM GPUMem ()
      typeCheckMemoryOp :: Maybe SegLevel
-> MemOp (HostOp NoOp) (Aliases GPUMem) -> TypeM GPUMem ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
        forall rep.
Checkable rep =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
      typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp NoOp (Aliases GPUMem)
op) =
        forall rep (op :: * -> *).
Checkable rep =>
(SegLevel -> Op (Aliases rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op (Aliases rep) -> TypeM rep ())
-> HostOp op (Aliases rep)
-> TypeM rep ()
typeCheckHostOp (Maybe SegLevel
-> MemOp (HostOp NoOp) (Aliases GPUMem) -> TypeM GPUMem ()
typeCheckMemoryOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) HostOp NoOp (Aliases GPUMem)
op
  checkFParamDec :: VName -> FParamInfo GPUMem -> TypeM GPUMem ()
checkFParamDec = forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLParamDec :: VName -> LParamInfo GPUMem -> TypeM GPUMem ()
checkLParamDec = forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLetBoundDec :: VName -> LetDec GPUMem -> TypeM GPUMem ()
checkLetBoundDec = forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkRetType :: [RetType GPUMem] -> TypeM GPUMem ()
checkRetType = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a b. (a -> b) -> a -> b
$ forall rep u.
Checkable rep =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM rep ()
TC.checkExtType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf
  primFParam :: VName -> PrimType -> TypeM GPUMem (FParam (Aliases GPUMem))
primFParam VName
name PrimType
t = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
name (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPat :: Pat (LetDec (Aliases GPUMem))
-> Exp (Aliases GPUMem) -> TypeM GPUMem ()
matchPat = forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, Checkable rep) =>
Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep ()
matchPatToExp
  matchReturnType :: [RetType GPUMem] -> Result -> TypeM GPUMem ()
matchReturnType = forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[RetTypeMem] -> Result -> TypeM rep ()
matchFunctionReturnType
  matchBranchType :: [BranchType GPUMem] -> Body (Aliases GPUMem) -> TypeM GPUMem ()
matchBranchType = forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[BranchTypeMem] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType
  matchLoopResult :: [FParam (Aliases GPUMem)] -> Result -> TypeM GPUMem ()
matchLoopResult = forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem

instance BuilderOps GPUMem where
  mkExpDecB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPUMem) =>
Pat (LetDec GPUMem) -> Exp GPUMem -> m (ExpDec GPUMem)
mkExpDecB Pat (LetDec GPUMem)
_ Exp GPUMem
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  mkBodyB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPUMem) =>
Stms GPUMem -> Result -> m (Body GPUMem)
mkBodyB Stms GPUMem
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPUMem
stms Result
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ GPUMem) =>
[VName] -> Exp GPUMem -> m (Stm GPUMem)
mkLetNamesB = forall (m :: * -> *) (inner :: * -> *).
(LetDec (Rep m) ~ LetDecMem, Mem (Rep m) inner, MonadBuilder m,
 ExpDec (Rep m) ~ ()) =>
Space
-> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' (SpaceId -> Space
Space SpaceId
"device") ()

instance BuilderOps (Engine.Wise GPUMem) where
  mkExpDecB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise GPUMem) =>
Pat (LetDec (Wise GPUMem))
-> Exp (Wise GPUMem) -> m (ExpDec (Wise GPUMem))
mkExpDecB Pat (LetDec (Wise GPUMem))
pat Exp (Wise GPUMem)
e = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise GPUMem))
pat () Exp (Wise GPUMem)
e
  mkBodyB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise GPUMem) =>
Stms (Wise GPUMem) -> Result -> m (Body (Wise GPUMem))
mkBodyB Stms (Wise GPUMem)
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
Engine.mkWiseBody () Stms (Wise GPUMem)
stms Result
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ Wise GPUMem) =>
[VName] -> Exp (Wise GPUMem) -> m (Stm (Wise GPUMem))
mkLetNamesB = forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem,
 OpReturns (inner (Wise rep)), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m,
 AliasedOp (inner (Wise rep)), RephraseOp (MemOp inner),
 CanBeWise inner) =>
Space -> [VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' (SpaceId -> Space
Space SpaceId
"device")

instance TraverseOpStms (Engine.Wise GPUMem) where
  traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise GPUMem)) (Wise GPUMem)
traverseOpStms = forall (m :: * -> *) (inner :: * -> *) rep.
Monad m =>
OpStmsTraverser m (inner rep) rep
-> OpStmsTraverser m (MemOp inner rep) rep
traverseMemOpStms (forall (m :: * -> *) (op :: * -> *) rep.
Monad m =>
OpStmsTraverser m (op rep) rep
-> OpStmsTraverser m (HostOp op rep) rep
traverseHostOpStms (forall a b. a -> b -> a
const forall (f :: * -> *) a. Applicative f => a -> f a
pure))

simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg = forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric SimpleOps GPUMem
simpleGPUMem

simplifyStms ::
  (HasScope GPUMem m, MonadFreshNames m) => Stms GPUMem -> m (Stms GPUMem)
simplifyStms :: forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (Stms GPUMem)
simplifyStms = forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, MonadFreshNames m, SimplifyMemory rep inner) =>
SimpleOps rep -> Stms rep -> m (Stms rep)
simplifyStmsGeneric SimpleOps GPUMem
simpleGPUMem

simpleGPUMem :: Engine.SimpleOps GPUMem
simpleGPUMem :: SimpleOps GPUMem
simpleGPUMem =
  forall rep (inner :: * -> *).
SimplifyMemory rep inner =>
(inner (Wise rep) -> UsageTable)
-> SimplifyOp rep (inner (Wise rep)) -> SimpleOps rep
simpleGeneric HostOp NoOp (Wise GPUMem) -> UsageTable
usage forall a b. (a -> b) -> a -> b
$ forall rep (op :: * -> *).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep (op (Wise rep))
-> HostOp op (Wise rep)
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
simplifyKernelOp forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k). NoOp rep
NoOp, forall a. Monoid a => a
mempty)
  where
    -- Slightly hackily and very inefficiently, we look at the inside
    -- of SegOps to figure out the sizes of local memory allocations,
    -- and add usages for those sizes.  This is necessary so the
    -- simplifier will hoist those sizes out as far as possible (most
    -- importantly, past the versioning If, but see also #1569).
    usage :: HostOp NoOp (Wise GPUMem) -> UsageTable
usage (SegOp (SegMap SegLevel
_ SegSpace
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody (Wise GPUMem)
kbody)) = KernelBody (Wise GPUMem) -> UsageTable
localAllocs KernelBody (Wise GPUMem)
kbody
    usage HostOp NoOp (Wise GPUMem)
_ = forall a. Monoid a => a
mempty
    localAllocs :: KernelBody (Wise GPUMem) -> UsageTable
localAllocs = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. KernelBody rep -> Stms rep
kernelBodyStms
    stmLocalAlloc :: Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc = Exp (Wise GPUMem) -> UsageTable
expLocalAlloc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp
    expLocalAlloc :: Exp (Wise GPUMem) -> UsageTable
expLocalAlloc (Op (Alloc (Var VName
v) Space
_)) =
      VName -> UsageTable
UT.sizeUsage VName
v
    expLocalAlloc (Op (Inner (SegOp (SegMap SegLevel
_ SegSpace
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody (Wise GPUMem)
kbody)))) =
      KernelBody (Wise GPUMem) -> UsageTable
localAllocs KernelBody (Wise GPUMem)
kbody
    expLocalAlloc Exp (Wise GPUMem)
_ =
      forall a. Monoid a => a
mempty