{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.GPUMem
  ( GPUMem,

    -- * Simplification
    simplifyProg,
    simplifyStms,
    simpleGPUMem,

    -- * Module re-exports
    module Futhark.IR.Mem,
    module Futhark.IR.GPU.Op,
  )
where

import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.GPU.Op
import Futhark.IR.GPU.Simplify (simplifyKernelOp)
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import qualified Futhark.IR.TypeCheck as TC
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'')

data GPUMem

instance RepTypes GPUMem where
  type LetDec GPUMem = LetDecMem
  type FParamInfo GPUMem = FParamMem
  type LParamInfo GPUMem = LParamMem
  type RetType GPUMem = RetTypeMem
  type BranchType GPUMem = BranchTypeMem
  type Op GPUMem = MemOp (HostOp GPUMem ())

instance ASTRep GPUMem where
  expTypesFromPat :: Pat (LetDec GPUMem) -> m [BranchType GPUMem]
expTypesFromPat = [BodyReturns] -> m [BodyReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BodyReturns] -> m [BodyReturns])
-> (Pat (MemBound NoUniqueness) -> [BodyReturns])
-> Pat (MemBound NoUniqueness)
-> m [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BodyReturns) -> BodyReturns)
-> [(VName, BodyReturns)] -> [BodyReturns]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BodyReturns) -> BodyReturns
forall a b. (a, b) -> b
snd ([(VName, BodyReturns)] -> [BodyReturns])
-> (Pat (MemBound NoUniqueness) -> [(VName, BodyReturns)])
-> Pat (MemBound NoUniqueness)
-> [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (MemBound NoUniqueness) -> [(VName, BodyReturns)]
bodyReturnsFromPat

instance OpReturns (HostOp GPUMem ()) where
  opReturns :: HostOp GPUMem () -> m [ExpReturns]
opReturns (SegOp SegOp SegLevel GPUMem
op) = SegOp SegLevel GPUMem -> m [ExpReturns]
forall rep inner (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp SegLevel GPUMem
op
  opReturns HostOp GPUMem ()
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HostOp GPUMem () -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType HostOp GPUMem ()
k

instance OpReturns (HostOp (Engine.Wise GPUMem) ()) where
  opReturns :: HostOp (Wise GPUMem) () -> m [ExpReturns]
opReturns (SegOp SegOp SegLevel (Wise GPUMem)
op) = SegOp SegLevel (Wise GPUMem) -> m [ExpReturns]
forall rep inner (m :: * -> *) lvl somerep.
(Mem rep inner, Monad m, HasScope rep m) =>
SegOp lvl somerep -> m [ExpReturns]
segOpReturns SegOp SegLevel (Wise GPUMem)
op
  opReturns HostOp (Wise GPUMem) ()
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HostOp (Wise GPUMem) () -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType HostOp (Wise GPUMem) ()
k

instance PrettyRep GPUMem

instance TC.CheckableOp GPUMem where
  checkOp :: OpWithAliases (Op GPUMem) -> TypeM GPUMem ()
checkOp = Maybe SegLevel
-> MemOp (HostOp (Aliases GPUMem) ()) -> TypeM GPUMem ()
forall rep b.
(Checkable rep,
 OpWithAliases (Op rep) ~ MemOp (HostOp (Aliases rep) b)) =>
Maybe SegLevel -> MemOp (HostOp (Aliases rep) b) -> TypeM rep ()
typeCheckMemoryOp Maybe SegLevel
forall a. Maybe a
Nothing
    where
      typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp (Aliases rep) b) -> TypeM rep ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
        [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
      typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp (Aliases rep) b
op) =
        (SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> Maybe SegLevel
-> (b -> TypeM rep ())
-> HostOp (Aliases rep) b
-> TypeM rep ()
forall rep op.
Checkable rep =>
(SegLevel -> OpWithAliases (Op rep) -> TypeM rep ())
-> Maybe SegLevel
-> (op -> TypeM rep ())
-> HostOp (Aliases rep) op
-> TypeM rep ()
typeCheckHostOp (Maybe SegLevel -> MemOp (HostOp (Aliases rep) b) -> TypeM rep ()
typeCheckMemoryOp (Maybe SegLevel -> MemOp (HostOp (Aliases rep) b) -> TypeM rep ())
-> (SegLevel -> Maybe SegLevel)
-> SegLevel
-> MemOp (HostOp (Aliases rep) b)
-> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Maybe SegLevel
forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (TypeM rep () -> b -> TypeM rep ()
forall a b. a -> b -> a
const (TypeM rep () -> b -> TypeM rep ())
-> TypeM rep () -> b -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) HostOp (Aliases rep) b
op

instance TC.Checkable GPUMem where
  checkFParamDec :: VName -> FParamInfo GPUMem -> TypeM GPUMem ()
checkFParamDec = VName -> FParamInfo GPUMem -> TypeM GPUMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLParamDec :: VName -> LParamInfo GPUMem -> TypeM GPUMem ()
checkLParamDec = VName -> LParamInfo GPUMem -> TypeM GPUMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkLetBoundDec :: VName -> LetDec GPUMem -> TypeM GPUMem ()
checkLetBoundDec = VName -> LetDec GPUMem -> TypeM GPUMem ()
forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo
  checkRetType :: [RetType GPUMem] -> TypeM GPUMem ()
checkRetType = (RetTypeMem -> TypeM GPUMem ()) -> [RetTypeMem] -> TypeM GPUMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((RetTypeMem -> TypeM GPUMem ())
 -> [RetTypeMem] -> TypeM GPUMem ())
-> (RetTypeMem -> TypeM GPUMem ())
-> [RetTypeMem]
-> TypeM GPUMem ()
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape Uniqueness -> TypeM GPUMem ()
forall rep u. Checkable rep => TypeBase ExtShape u -> TypeM rep ()
TC.checkExtType (TypeBase ExtShape Uniqueness -> TypeM GPUMem ())
-> (RetTypeMem -> TypeBase ExtShape Uniqueness)
-> RetTypeMem
-> TypeM GPUMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf
  primFParam :: VName -> PrimType -> TypeM GPUMem (FParam (Aliases GPUMem))
primFParam VName
name PrimType
t = Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM GPUMem (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp Uniqueness MemBind)
 -> TypeM GPUMem (Param (MemInfo SubExp Uniqueness MemBind)))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM GPUMem (Param (MemInfo SubExp Uniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
name (PrimType -> MemInfo SubExp Uniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPat :: Pat (LetDec (Aliases GPUMem))
-> Exp (Aliases GPUMem) -> TypeM GPUMem ()
matchPat = Pat (LetDec (Aliases GPUMem))
-> Exp (Aliases GPUMem) -> TypeM GPUMem ()
forall rep inner.
(Mem rep inner, LetDec rep ~ MemBound NoUniqueness,
 Checkable rep) =>
Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep ()
matchPatToExp
  matchReturnType :: [RetType GPUMem] -> Result -> TypeM GPUMem ()
matchReturnType = [RetType GPUMem] -> Result -> TypeM GPUMem ()
forall rep inner.
(Mem rep inner, Checkable rep) =>
[RetTypeMem] -> Result -> TypeM rep ()
matchFunctionReturnType
  matchBranchType :: [BranchType GPUMem] -> Body (Aliases GPUMem) -> TypeM GPUMem ()
matchBranchType = [BranchType GPUMem] -> Body (Aliases GPUMem) -> TypeM GPUMem ()
forall rep inner.
(Mem rep inner, Checkable rep) =>
[BodyReturns] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType
  matchLoopResult :: [FParam (Aliases GPUMem)] -> Result -> TypeM GPUMem ()
matchLoopResult = [FParam (Aliases GPUMem)] -> Result -> TypeM GPUMem ()
forall rep inner.
(Mem rep inner, Checkable rep) =>
[FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem

instance BuilderOps GPUMem where
  mkExpDecB :: Pat (LetDec GPUMem) -> Exp GPUMem -> m (ExpDec GPUMem)
mkExpDecB Pat (LetDec GPUMem)
_ Exp GPUMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: Stms GPUMem -> Result -> m (Body GPUMem)
mkBodyB Stms GPUMem
stms Result
res = Body GPUMem -> m (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPUMem -> m (Body GPUMem)) -> Body GPUMem -> m (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPUMem
stms Result
res
  mkLetNamesB :: [VName] -> Exp GPUMem -> m (Stm GPUMem)
mkLetNamesB = ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *) inner.
(LetDec (Rep m) ~ MemBound NoUniqueness, Mem (Rep m) inner,
 MonadBuilder m, ExpDec (Rep m) ~ ()) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ()

instance BuilderOps (Engine.Wise GPUMem) where
  mkExpDecB :: Pat (LetDec (Wise GPUMem))
-> Exp (Wise GPUMem) -> m (ExpDec (Wise GPUMem))
mkExpDecB Pat (LetDec (Wise GPUMem))
pat Exp (Wise GPUMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Wise GPUMem))
-> ExpDec GPUMem -> Exp (Wise GPUMem) -> ExpDec (Wise GPUMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise GPUMem))
pat () Exp (Wise GPUMem)
e
  mkBodyB :: Stms (Wise GPUMem) -> Result -> m (Body (Wise GPUMem))
mkBodyB Stms (Wise GPUMem)
stms Result
res = Body (Wise GPUMem) -> m (Body (Wise GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise GPUMem) -> m (Body (Wise GPUMem)))
-> Body (Wise GPUMem) -> m (Body (Wise GPUMem))
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem
-> Stms (Wise GPUMem) -> Result -> Body (Wise GPUMem)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
Engine.mkWiseBody () Stms (Wise GPUMem)
stms Result
res
  mkLetNamesB :: [VName] -> Exp (Wise GPUMem) -> m (Stm (Wise GPUMem))
mkLetNamesB = [VName] -> Exp (Wise GPUMem) -> m (Stm (Wise GPUMem))
forall rep inner (m :: * -> *).
(BuilderOps rep, Mem rep inner, LetDec rep ~ MemBound NoUniqueness,
 OpReturns (OpWithWisdom inner), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m, CanBeWise inner) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB''

instance TraverseOpStms (Engine.Wise GPUMem) where
  traverseOpStms :: OpStmsTraverser m (Op (Wise GPUMem)) (Wise GPUMem)
traverseOpStms = OpStmsTraverser m (HostOp (Wise GPUMem) ()) (Wise GPUMem)
-> OpStmsTraverser
     m (MemOp (HostOp (Wise GPUMem) ())) (Wise GPUMem)
forall (m :: * -> *) inner rep.
Monad m =>
OpStmsTraverser m inner rep -> OpStmsTraverser m (MemOp inner) rep
traverseMemOpStms (OpStmsTraverser m () (Wise GPUMem)
-> OpStmsTraverser m (HostOp (Wise GPUMem) ()) (Wise GPUMem)
forall (m :: * -> *) op rep.
Monad m =>
OpStmsTraverser m op rep -> OpStmsTraverser m (HostOp rep op) rep
traverseHostOpStms ((() -> m ()) -> OpStmsTraverser m () (Wise GPUMem)
forall a b. a -> b -> a
const () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure))

simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem)
simplifyProg = SimpleOps GPUMem -> Prog GPUMem -> PassM (Prog GPUMem)
forall rep inner.
SimplifyMemory rep inner =>
SimpleOps rep -> Prog rep -> PassM (Prog rep)
simplifyProgGeneric SimpleOps GPUMem
simpleGPUMem

simplifyStms ::
  (HasScope GPUMem m, MonadFreshNames m) => Stms GPUMem -> m (Stms GPUMem)
simplifyStms :: Stms GPUMem -> m (Stms GPUMem)
simplifyStms = SimpleOps GPUMem -> Stms GPUMem -> m (Stms GPUMem)
forall rep (m :: * -> *) inner.
(HasScope rep m, MonadFreshNames m, SimplifyMemory rep inner) =>
SimpleOps rep -> Stms rep -> m (Stms rep)
simplifyStmsGeneric SimpleOps GPUMem
simpleGPUMem

simpleGPUMem :: Engine.SimpleOps GPUMem
simpleGPUMem :: SimpleOps GPUMem
simpleGPUMem =
  (OpWithWisdom (HostOp GPUMem ()) -> UsageTable)
-> SimplifyOp GPUMem (OpWithWisdom (HostOp GPUMem ()))
-> SimpleOps GPUMem
forall rep inner.
SimplifyMemory rep inner =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp rep (OpWithWisdom inner) -> SimpleOps rep
simpleGeneric OpWithWisdom (HostOp GPUMem ()) -> UsageTable
HostOp (Wise GPUMem) () -> UsageTable
usage (SimplifyOp GPUMem (OpWithWisdom (HostOp GPUMem ()))
 -> SimpleOps GPUMem)
-> SimplifyOp GPUMem (OpWithWisdom (HostOp GPUMem ()))
-> SimpleOps GPUMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp GPUMem ()
-> HostOp (Wise GPUMem) ()
-> SimpleM GPUMem (HostOp (Wise GPUMem) (), Stms (Wise GPUMem))
forall rep op.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep op
-> HostOp (Wise rep) op
-> SimpleM rep (HostOp (Wise rep) op, Stms (Wise rep))
simplifyKernelOp (SimplifyOp GPUMem ()
 -> HostOp (Wise GPUMem) ()
 -> SimpleM GPUMem (HostOp (Wise GPUMem) (), Stms (Wise GPUMem)))
-> SimplifyOp GPUMem ()
-> HostOp (Wise GPUMem) ()
-> SimpleM GPUMem (HostOp (Wise GPUMem) (), Stms (Wise GPUMem))
forall a b. (a -> b) -> a -> b
$ SimpleM GPUMem ((), Stms (Wise GPUMem)) -> SimplifyOp GPUMem ()
forall a b. a -> b -> a
const (SimpleM GPUMem ((), Stms (Wise GPUMem)) -> SimplifyOp GPUMem ())
-> SimpleM GPUMem ((), Stms (Wise GPUMem)) -> SimplifyOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise GPUMem)) -> SimpleM GPUMem ((), Stms (Wise GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise GPUMem)
forall a. Monoid a => a
mempty)
  where
    -- Slightly hackily and very inefficiently, we look at the inside
    -- of SegOps to figure out the sizes of local memory allocations,
    -- and add usages for those sizes.  This is necessary so the
    -- simplifier will hoist those sizes out as far as possible (most
    -- importantly, past the versioning If, but see also #1569).
    usage :: HostOp (Wise GPUMem) () -> UsageTable
usage (SegOp (SegMap SegLevel
_ SegSpace
_ [Type]
_ KernelBody (Wise GPUMem)
kbody)) = KernelBody (Wise GPUMem) -> UsageTable
localAllocs KernelBody (Wise GPUMem)
kbody
    usage HostOp (Wise GPUMem) ()
_ = UsageTable
forall a. Monoid a => a
mempty
    localAllocs :: KernelBody (Wise GPUMem) -> UsageTable
localAllocs = (Stm (Wise GPUMem) -> UsageTable)
-> Stms (Wise GPUMem) -> UsageTable
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc (Stms (Wise GPUMem) -> UsageTable)
-> (KernelBody (Wise GPUMem) -> Stms (Wise GPUMem))
-> KernelBody (Wise GPUMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Wise GPUMem) -> Stms (Wise GPUMem)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms
    stmLocalAlloc :: Stm (Wise GPUMem) -> UsageTable
stmLocalAlloc = Exp (Wise GPUMem) -> UsageTable
expLocalAlloc (Exp (Wise GPUMem) -> UsageTable)
-> (Stm (Wise GPUMem) -> Exp (Wise GPUMem))
-> Stm (Wise GPUMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise GPUMem) -> Exp (Wise GPUMem)
forall rep. Stm rep -> Exp rep
stmExp
    expLocalAlloc :: Exp (Wise GPUMem) -> UsageTable
expLocalAlloc (Op (Alloc (Var v) _)) =
      VName -> UsageTable
UT.sizeUsage VName
v
    expLocalAlloc (Op (Inner (SegOp (SegMap _ _ _ kbody)))) =
      KernelBody (Wise GPUMem) -> UsageTable
localAllocs KernelBody (Wise GPUMem)
kbody
    expLocalAlloc Exp (Wise GPUMem)
_ =
      UsageTable
forall a. Monoid a => a
mempty