{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.IR.KernelsMem
  ( KernelsMem

  -- * Simplification
  , simplifyProg
  , simplifyStms
  , simpleKernelsMem

    -- * Module re-exports
  , module Futhark.IR.Mem
  , module Futhark.IR.Kernels.Kernel
  )
  where

import Futhark.Analysis.PrimExp.Convert
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.IR.Syntax
import Futhark.IR.Prop
import Futhark.IR.Traversals
import Futhark.IR.Pretty
import Futhark.IR.Kernels.Kernel
import Futhark.IR.Kernels.Simplify (simplifyKernelOp)
import qualified Futhark.TypeCheck as TC
import Futhark.IR.Mem
import Futhark.IR.Mem.Simplify
import Futhark.Pass.ExplicitAllocations (BinderOps(..), mkLetNamesB', mkLetNamesB'')
import qualified Futhark.Optimise.Simplify.Engine as Engine

data KernelsMem

instance Decorations KernelsMem where
  type LetDec     KernelsMem = LetDecMem
  type FParamInfo KernelsMem = FParamMem
  type LParamInfo KernelsMem = LParamMem
  type RetType    KernelsMem = RetTypeMem
  type BranchType KernelsMem = BranchTypeMem
  type Op         KernelsMem = MemOp (HostOp KernelsMem ())

instance ASTLore KernelsMem where
  expTypesFromPattern :: Pattern KernelsMem -> m [BranchType KernelsMem]
expTypesFromPattern = [BodyReturns] -> m [BodyReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BodyReturns] -> m [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [BodyReturns])
-> PatternT (MemBound NoUniqueness)
-> m [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, BodyReturns) -> BodyReturns)
-> [(VName, BodyReturns)] -> [BodyReturns]
forall a b. (a -> b) -> [a] -> [b]
map (VName, BodyReturns) -> BodyReturns
forall a b. (a, b) -> b
snd ([(VName, BodyReturns)] -> [BodyReturns])
-> (PatternT (MemBound NoUniqueness) -> [(VName, BodyReturns)])
-> PatternT (MemBound NoUniqueness)
-> [BodyReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(VName, BodyReturns)], [(VName, BodyReturns)])
-> [(VName, BodyReturns)]
forall a b. (a, b) -> b
snd (([(VName, BodyReturns)], [(VName, BodyReturns)])
 -> [(VName, BodyReturns)])
-> (PatternT (MemBound NoUniqueness)
    -> ([(VName, BodyReturns)], [(VName, BodyReturns)]))
-> PatternT (MemBound NoUniqueness)
-> [(VName, BodyReturns)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemBound NoUniqueness)
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern

instance OpReturns KernelsMem where
  opReturns :: Op KernelsMem -> m [ExpReturns]
opReturns (Alloc _ space) =
    [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner (SegOp op)) = SegOp SegLevel KernelsMem -> m [ExpReturns]
forall lore (m :: * -> *) lvl.
(Mem lore, Monad m, HasScope lore m) =>
SegOp lvl lore -> m [ExpReturns]
segOpReturns SegOp SegLevel KernelsMem
op
  opReturns Op KernelsMem
k = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemOp (HostOp KernelsMem ()) -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType Op KernelsMem
MemOp (HostOp KernelsMem ())
k

instance PrettyLore KernelsMem where

instance TC.CheckableOp KernelsMem where
  checkOp :: OpWithAliases (Op KernelsMem) -> TypeM KernelsMem ()
checkOp = Maybe SegLevel
-> MemOp (HostOp (Aliases KernelsMem) ()) -> TypeM KernelsMem ()
forall lore b.
(Checkable lore,
 OpWithAliases (Op lore) ~ MemOp (HostOp (Aliases lore) b)) =>
Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
forall a. Maybe a
Nothing
    where typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp Maybe SegLevel
_ (Alloc SubExp
size Space
_) =
            [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
          typeCheckMemoryOp Maybe SegLevel
lvl (Inner HostOp (Aliases lore) b
op) =
            (SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (b -> TypeM lore ())
-> HostOp (Aliases lore) b
-> TypeM lore ()
forall lore op.
Checkable lore =>
(SegLevel -> OpWithAliases (Op lore) -> TypeM lore ())
-> Maybe SegLevel
-> (op -> TypeM lore ())
-> HostOp (Aliases lore) op
-> TypeM lore ()
typeCheckHostOp (Maybe SegLevel -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ()
typeCheckMemoryOp (Maybe SegLevel
 -> MemOp (HostOp (Aliases lore) b) -> TypeM lore ())
-> (SegLevel -> Maybe SegLevel)
-> SegLevel
-> MemOp (HostOp (Aliases lore) b)
-> TypeM lore ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Maybe SegLevel
forall a. a -> Maybe a
Just) Maybe SegLevel
lvl (TypeM lore () -> b -> TypeM lore ()
forall a b. a -> b -> a
const (TypeM lore () -> b -> TypeM lore ())
-> TypeM lore () -> b -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) HostOp (Aliases lore) b
op

instance TC.Checkable KernelsMem where
  checkFParamLore :: VName -> FParamInfo KernelsMem -> TypeM KernelsMem ()
checkFParamLore = VName -> FParamInfo KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLParamLore :: VName -> LParamInfo KernelsMem -> TypeM KernelsMem ()
checkLParamLore = VName -> LParamInfo KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkLetBoundLore :: VName -> LetDec KernelsMem -> TypeM KernelsMem ()
checkLetBoundLore = VName -> LetDec KernelsMem -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
VName -> MemInfo SubExp u MemBind -> TypeM lore ()
checkMemInfo
  checkRetType :: [RetType KernelsMem] -> TypeM KernelsMem ()
checkRetType = (RetTypeMem -> TypeM KernelsMem ())
-> [RetTypeMem] -> TypeM KernelsMem ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((RetTypeMem -> TypeM KernelsMem ())
 -> [RetTypeMem] -> TypeM KernelsMem ())
-> (RetTypeMem -> TypeM KernelsMem ())
-> [RetTypeMem]
-> TypeM KernelsMem ()
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape Uniqueness -> TypeM KernelsMem ()
forall lore u.
Checkable lore =>
TypeBase ExtShape u -> TypeM lore ()
TC.checkExtType (TypeBase ExtShape Uniqueness -> TypeM KernelsMem ())
-> (RetTypeMem -> TypeBase ExtShape Uniqueness)
-> RetTypeMem
-> TypeM KernelsMem ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetTypeMem -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf
  primFParam :: VName -> PrimType -> TypeM KernelsMem (FParam (Aliases KernelsMem))
primFParam VName
name PrimType
t = Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp Uniqueness MemBind)
 -> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind)))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> TypeM KernelsMem (Param (MemInfo SubExp Uniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall dec. VName -> dec -> Param dec
Param VName
name (PrimType -> MemInfo SubExp Uniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
  matchPattern :: Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
matchPattern = Pattern (Aliases KernelsMem)
-> Exp (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
Mem lore =>
Pattern (Aliases lore) -> Exp (Aliases lore) -> TypeM lore ()
matchPatternToExp
  matchReturnType :: [RetType KernelsMem] -> Result -> TypeM KernelsMem ()
matchReturnType = [RetType KernelsMem] -> Result -> TypeM KernelsMem ()
forall lore. Mem lore => [RetTypeMem] -> Result -> TypeM lore ()
matchFunctionReturnType
  matchBranchType :: [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
matchBranchType = [BranchType KernelsMem]
-> Body (Aliases KernelsMem) -> TypeM KernelsMem ()
forall lore.
Mem lore =>
[BodyReturns] -> Body (Aliases lore) -> TypeM lore ()
matchBranchReturnType

instance BinderOps KernelsMem where
  mkExpDecB :: Pattern KernelsMem -> Exp KernelsMem -> m (ExpDec KernelsMem)
mkExpDecB Pattern KernelsMem
_ Exp KernelsMem
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: Stms KernelsMem -> Result -> m (Body KernelsMem)
mkBodyB Stms KernelsMem
stms Result
res = Body KernelsMem -> m (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body KernelsMem -> m (Body KernelsMem))
-> Body KernelsMem -> m (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem -> Stms KernelsMem -> Result -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms KernelsMem
stms Result
res
  mkLetNamesB :: [VName] -> Exp KernelsMem -> m (Stm KernelsMem)
mkLetNamesB = ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpDec (Lore m) ~ (),
 Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpDec (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()

instance BinderOps (Engine.Wise KernelsMem) where
  mkExpDecB :: Pattern (Wise KernelsMem)
-> Exp (Wise KernelsMem) -> m (ExpDec (Wise KernelsMem))
mkExpDecB Pattern (Wise KernelsMem)
pat Exp (Wise KernelsMem)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise KernelsMem)
-> ExpDec KernelsMem
-> Exp (Wise KernelsMem)
-> ExpDec (Wise KernelsMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpDec lore -> Exp (Wise lore) -> ExpDec (Wise lore)
Engine.mkWiseExpDec Pattern (Wise KernelsMem)
pat () Exp (Wise KernelsMem)
e
  mkBodyB :: Stms (Wise KernelsMem) -> Result -> m (Body (Wise KernelsMem))
mkBodyB Stms (Wise KernelsMem)
stms Result
res = Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise KernelsMem) -> m (Body (Wise KernelsMem)))
-> Body (Wise KernelsMem) -> m (Body (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem
-> Stms (Wise KernelsMem) -> Result -> Body (Wise KernelsMem)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
BodyDec lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise KernelsMem)
stms Result
res
  mkLetNamesB :: [VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
mkLetNamesB = [VName] -> Exp (Wise KernelsMem) -> m (Stm (Wise KernelsMem))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpDec lore ~ (),
 HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
 MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''

simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg :: Prog KernelsMem -> PassM (Prog KernelsMem)
simplifyProg =
  SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Prog KernelsMem -> PassM (Prog KernelsMem)
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimplifyOp lore inner -> Prog lore -> PassM (Prog lore)
simplifyProgGeneric (SimplifyOp KernelsMem (HostOp KernelsMem ())
 -> Prog KernelsMem -> PassM (Prog KernelsMem))
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Prog KernelsMem
-> PassM (Prog KernelsMem)
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
 -> HostOp KernelsMem ()
 -> SimpleM
      KernelsMem
      (HostOp (Wise KernelsMem) (OpWithWisdom ()),
       Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
 -> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)

simplifyStms :: (HasScope KernelsMem m, MonadFreshNames m) =>
                 Stms KernelsMem
             -> m (Engine.SymbolTable (Engine.Wise KernelsMem),
                   Stms KernelsMem)
simplifyStms :: Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
simplifyStms =
  SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall lore (m :: * -> *) inner.
(HasScope lore m, MonadFreshNames m, SimplifyMemory lore,
 Op lore ~ MemOp inner) =>
SimplifyOp lore inner
-> Stms lore -> m (SymbolTable (Wise lore), Stms lore)
simplifyStmsGeneric (SimplifyOp KernelsMem (HostOp KernelsMem ())
 -> Stms KernelsMem
 -> m (SymbolTable (Wise KernelsMem), Stms KernelsMem))
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
 -> HostOp KernelsMem ()
 -> SimpleM
      KernelsMem
      (HostOp (Wise KernelsMem) (OpWithWisdom ()),
       Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
 -> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)

simpleKernelsMem :: Engine.SimpleOps KernelsMem
simpleKernelsMem :: SimpleOps KernelsMem
simpleKernelsMem =
  SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall lore inner.
(SimplifyMemory lore, Op lore ~ MemOp inner) =>
SimplifyOp lore inner -> SimpleOps lore
simpleGeneric (SimplifyOp KernelsMem (HostOp KernelsMem ())
 -> SimpleOps KernelsMem)
-> SimplifyOp KernelsMem (HostOp KernelsMem ())
-> SimpleOps KernelsMem
forall a b. (a -> b) -> a -> b
$ SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall lore op.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp (SimplifyOp KernelsMem ()
 -> HostOp KernelsMem ()
 -> SimpleM
      KernelsMem
      (HostOp (Wise KernelsMem) (OpWithWisdom ()),
       Stms (Wise KernelsMem)))
-> SimplifyOp KernelsMem ()
-> HostOp KernelsMem ()
-> SimpleM
     KernelsMem
     (HostOp (Wise KernelsMem) (OpWithWisdom ()),
      Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. a -> b -> a
const (SimpleM KernelsMem ((), Stms (Wise KernelsMem))
 -> () -> SimpleM KernelsMem ((), Stms (Wise KernelsMem)))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
-> ()
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall a b. (a -> b) -> a -> b
$ ((), Stms (Wise KernelsMem))
-> SimpleM KernelsMem ((), Stms (Wise KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return ((), Stms (Wise KernelsMem)
forall a. Monoid a => a
mempty)