{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.IR.KernelsMem
  ( KernelsMem,

    -- * Simplification
    simplifyProg,
    simplifyStms,
    simpleKernelsMem,

    -- * Module re-exports
    module Futhark.IR.Mem,
    module Futhark.IR.Kernels.Kernel,
  )
where

import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.Kernels.Kernel
import Futhark.IR.Kernels.Simplify (simplifyKernelOp)
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (BinderOps (..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.TypeCheck as TC

data KernelsMem

instance Decorations KernelsMem where
  type LetDec KernelsMem = LetDecMem
  type FParamInfo KernelsMem = FParamMem
  type LParamInfo KernelsMem = LParamMem
  type RetType KernelsMem = RetTypeMem
  type BranchType KernelsMem = BranchTypeMem
  type Op KernelsMem = MemOp (HostOp KernelsMem ())

instance ASTLore KernelsMem where
  expTypesFromPattern :: forall (m :: * -> *).
(HasScope KernelsMem m, Monad m) =>
Pattern KernelsMem -> m [BranchType KernelsMem]
expTypesFromPattern = [BranchTypeMem] -> m [BranchTypeMem]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BranchTypeMem] -> m [BranchTypeMem])
-> (PatternT LetDecMem -> [BranchTypeMem])
-> PatternT LetDecMem
-> m [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BranchTypeMem) -> BranchTypeMem)
-> [(VName, BranchTypeMem)] -> [BranchTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BranchTypeMem) -> BranchTypeMem
forall a b. (a, b) -> b
snd ([(VName, BranchTypeMem)] -> [BranchTypeMem])
-> (PatternT LetDecMem -> [(VName, BranchTypeMem)])
-> PatternT LetDecMem
-> [BranchTypeMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
-> [(VName, BranchTypeMem)]
forall a b. (a, b) -> b
snd (([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
 -> [(VName, BranchTypeMem)])
-> (PatternT LetDecMem
    -> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)]))
-> PatternT LetDecMem
-> [(VName, BranchTypeMem)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LetDecMem
-> ([(VName, BranchTypeMem)], [(VName, BranchTypeMem)])
bodyReturnsFromPattern

instance OpReturns KernelsMem where
  opReturns :: forall (m :: * -> *).
(Monad m, HasScope KernelsMem m) =>
Op KernelsMem -> m [ExpReturns]
opReturns (Alloc SubExp
_ Space
space) =
    [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner (SegOp SegOp SegLevel KernelsMem
op)) = SegOp SegLevel KernelsMem -> m [ExpReturns]
forall lore (m :: * -> *) lvl.
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore -> m [ExpReturns]
segOpReturns SegOp SegLevel KernelsMem
op
  opReturns Op KernelsMem
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemOp (HostOp KernelsMem ()) -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType Op KernelsMem
MemOp (HostOp KernelsMem ())
k

instance PrettyLore KernelsMem

instance TC.CheckableOp KernelsMem where
  checkOp :: OpWithAliases (Op KernelsMem) -> TypeM KernelsMem ()
checkOp = Maybe SegLevel
-> MemOp (HostOp (Aliases KernelsMem) ()) -> TypeM KernelsMem ()
forall {lore} {op}.
(Checkable lore,
 OpWithAliases (Op lore) ~ MemOp (HostOp (Aliases lore) op)) =>
Maybe SegLevel -> MemOp (HostOp (Aliases lore) op) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
forall a. Maybe a
Nothing
    where
      typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp (Aliases lore) op) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
        [TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM lore ()
forall lore.
Checkable lore =>
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> SubExp -> TypeM lore ()
TC.require [PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
      typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp (Aliases lore) op
op) =
        (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
forall lore op.
Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp (Maybe SegLevel -> MemOp (HostOp (Aliases lore) op) -> TypeM lore ()
typeCheckMemoryOp (Maybe SegLevel
 -> MemOp (HostOp (Aliases lore) op) -> TypeM lore ())
-> (SegLevel -> Maybe SegLevel)
-> SegLevel
-> MemOp (HostOp (Aliases lore) op)
-> TypeM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Maybe SegLevel
forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (TypeM lore () -> op -> TypeM lore ()
forall a b. a -> b -> a
const (TypeM lore () -> op -> TypeM lore ())
-> TypeM lore () -> op -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) HostOp (Aliases lore) op
op

instance TC.Checkable KernelsMem where
  checkFParamLore :: VName -> FParamInfo KernelsMem -> TypeM KernelsMem ()
checkFParamLore = VName -> FParamInfo KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLParamLore :: VName -> LParamInfo KernelsMem -> TypeM KernelsMem ()
checkLParamLore = VName -> LParamInfo KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLetBoundLore :: VName -> LetDec KernelsMem -> TypeM KernelsMem ()
checkLetBoundLore = VName -> LetDec KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkRetType :: [RetType KernelsMem] -> TypeM KernelsMem ()
checkRetType = (RetTypeMem -> TypeM KernelsMem ())
-> [RetTypeMem] -> TypeM KernelsMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((RetTypeMem -> TypeM KernelsMem ())
 -> [RetTypeMem] -> TypeM KernelsMem ())
-> (RetTypeMem -> TypeM KernelsMem ())
-> [RetTypeMem]
-> TypeM KernelsMem ()
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext SubExp)) Uniqueness -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
TypeBase (ShapeBase (Ext SubExp)) u -> TypeM lore ()
TC.checkExtType (TypeBase (ShapeBase (Ext SubExp)) Uniqueness
 -> TypeM KernelsMem ())
-> (RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness)
-> RetTypeMem
-> TypeM KernelsMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
forall t.
DeclExtTyped t =>
t -> TypeBase (ShapeBase (Ext SubExp)) Uniqueness
declExtTypeOf
  primFParam :: VName -> PrimType -> TypeM KernelsMem (FParam (Aliases KernelsMem))
primFParam VName
name PrimType
t = Param FParamMem -> TypeM KernelsMem (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param FParamMem -> TypeM KernelsMem (Param FParamMem))
-> Param FParamMem -> TypeM KernelsMem (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPattern :: Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
matchPattern = Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp
  matchReturnType :: [RetType KernelsMem] -> [SubExp] -> TypeM KernelsMem ()
matchReturnType = [RetType KernelsMem] -> [SubExp] -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
[RetTypeMem] -> [SubExp] -> TypeM lore ()
matchFunctionReturnType
  matchBranchType :: [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
matchBranchType = [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
[BranchTypeMem] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType
  matchLoopResult :: [FParam (Aliases KernelsMem)]
-> [FParam (Aliases KernelsMem)] -> [SubExp] -> TypeM KernelsMem ()
matchLoopResult = [FParam (Aliases KernelsMem)]
-> [FParam (Aliases KernelsMem)] -> [SubExp] -> TypeM KernelsMem ()
forall lore.
(Mem lore, Checkable lore) =>
[FParam (Aliases lore)]
-> [FParam (Aliases lore)] -> [SubExp] -> TypeM lore ()
matchLoopResultMem

instance BinderOps KernelsMem where
  mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ KernelsMem) =>
Pattern KernelsMem -> Exp KernelsMem -> m (ExpDec KernelsMem)
mkExpDecB Pattern KernelsMem
_ Exp KernelsMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ KernelsMem) =>
Stms KernelsMem -> [SubExp] -> m (Body KernelsMem)
mkBodyB Stms KernelsMem
stms [SubExp]
res = Body KernelsMem -> m (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body KernelsMem -> m (Body KernelsMem))
-> Body KernelsMem -> m (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem
-> Stms KernelsMem -> [SubExp] -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () Stms KernelsMem
stms [SubExp]
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ KernelsMem) =>
[VName] -> Exp KernelsMem -> m (Stm KernelsMem)
mkLetNamesB = ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpDec (Lore m) ~ (),
 Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()

instance BinderOps (Engine.Wise KernelsMem) where
  mkExpDecB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise KernelsMem) =>
Pattern (Wise KernelsMem)
-> Exp (Wise KernelsMem) -> m (ExpDec (Wise KernelsMem))
mkExpDecB Pattern (Wise KernelsMem)
pat Exp (Wise KernelsMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise KernelsMem)
-> ExpDec KernelsMem
-> Exp (Wise KernelsMem)
-> ExpDec (Wise KernelsMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
Engine.mkWiseExpDec Pattern (Wise KernelsMem)
pat () Exp (Wise KernelsMem)
e
  mkBodyB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise KernelsMem) =>
Stms (Wise KernelsMem) -> [SubExp] -> m (Body (Wise KernelsMem))
mkBodyB Stms (Wise KernelsMem)
stms [SubExp]
res = Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise KernelsMem) -> m (Body (Wise KernelsMem)))
-> Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem
-> Stms (Wise KernelsMem) -> [SubExp] -> Body (Wise KernelsMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> [SubExp] -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise KernelsMem)
stms [SubExp]
res
  mkLetNamesB :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ Wise KernelsMem) =>
[VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
mkLetNamesB = [VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpDec lore ~ (),
 HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
 MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''

simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg = SimpleOps KernelsMem -> Prog KernelsMem -> PassM (Prog KernelsMem)
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimpleOps lore -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric SimpleOps KernelsMem
simpleKernelsMem

simplifyStms ::
  (HasScope KernelsMem m, MonadFreshNames m) =>
  Stms KernelsMem ->
  m
    ( Engine.SymbolTable (Engine.Wise KernelsMem),
      Stms KernelsMem
    )
simplifyStms :: forall (m :: * -> *).
(HasScope KernelsMem m, MonadFreshNames m) =>
Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
simplifyStms = SimpleOps KernelsMem
-> Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall lore (m :: * -> *) inner.
(HasScope lore m, MonadFreshNames m, SimplifyMemory lore,
 Op lore ~ MemOp inner) =>
SimpleOps lore
-> Stms lore -> m (SymbolTable (Wise lore), Stms lore)
simplifyStmsGeneric SimpleOps KernelsMem
simpleKernelsMem

simpleKernelsMem :: Engine.SimpleOps KernelsMem
simpleKernelsMem :: SimpleOps KernelsMem
simpleKernelsMem =
  (OpWithWisdom (HostOp KernelsMem ()) -> UsageTable)
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
(OpWithWisdom inner -> UsageTable)
-> SimplifyOp lore inner -> SimpleOps lore
simpleGeneric OpWithWisdom (HostOp KernelsMem ()) -> UsageTable
HostOp (Wise KernelsMem) () -> UsageTable
usage (SimplifyOp KernelsMem (HostOp KernelsMem ())
 -> SimpleOps KernelsMem)
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
 -> HostOp KernelsMem ()
 -> SimpleM
      KernelsMem
      (HostOp (Wise KernelsMem) (OpWithWisdom ()),
       Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
 -> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)
  where
    -- Slightly hackily, we look at the inside of SegGroup operations
    -- to figure out the sizes of local memory allocations, and add
    -- usages for those sizes.  This is necessary so the simplifier
    -- will hoist those sizes out as far as possible (most
    -- importantly, past the versioning If).
    usage :: HostOp (Wise KernelsMem) () -> UsageTable
usage (SegOp (SegMap SegGroup {} SegSpace
_ [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody (Wise KernelsMem)
kbody)) = KernelBody (Wise KernelsMem) -> UsageTable
localAllocs KernelBody (Wise KernelsMem)
kbody
    usage HostOp (Wise KernelsMem) ()
_ = UsageTable
forall a. Monoid a => a
mempty
    localAllocs :: KernelBody (Wise KernelsMem) -> UsageTable
localAllocs = (Stm (Wise KernelsMem) -> UsageTable)
-> Stms (Wise KernelsMem) -> UsageTable
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm (Wise KernelsMem) -> UsageTable
stmLocalAlloc (Stms (Wise KernelsMem) -> UsageTable)
-> (KernelBody (Wise KernelsMem) -> Stms (Wise KernelsMem))
-> KernelBody (Wise KernelsMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody (Wise KernelsMem) -> Stms (Wise KernelsMem)
forall lore. KernelBody lore -> Stms lore
kernelBodyStms
    stmLocalAlloc :: Stm (Wise KernelsMem) -> UsageTable
stmLocalAlloc = Exp (Wise KernelsMem) -> UsageTable
forall {lore} {inner}.
(Op lore ~ MemOp inner) =>
ExpT lore -> UsageTable
expLocalAlloc (Exp (Wise KernelsMem) -> UsageTable)
-> (Stm (Wise KernelsMem) -> Exp (Wise KernelsMem))
-> Stm (Wise KernelsMem)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise KernelsMem) -> Exp (Wise KernelsMem)
forall lore. Stm lore -> Exp lore
stmExp
    expLocalAlloc :: ExpT lore -> UsageTable
expLocalAlloc (Op (Alloc (Var VName
v) (Space String
"local"))) =
      VName -> UsageTable
UT.sizeUsage VName
v
    expLocalAlloc ExpT lore
_ =
      UsageTable
forall a. Monoid a => a
mempty