{-# LANGUAGE TypeFamilies #-}

-- | A simple representation with SOACs and nested parallelism.
module Futhark.IR.SOACS
  ( SOACS,
    usesAD,

    -- * Module re-exports
    module Futhark.IR.Prop,
    module Futhark.IR.Traversals,
    module Futhark.IR.Pretty,
    module Futhark.IR.Syntax,
    module Futhark.IR.SOACS.SOAC,
  )
where

import Futhark.Builder
import Futhark.Construct
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.SOACS.SOAC
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import Futhark.IR.TypeCheck qualified as TC

-- | The rep for the basic representation.
data SOACS

instance RepTypes SOACS where
  type OpC SOACS = SOAC

instance ASTRep SOACS where
  expTypesFromPat :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
Pat (LetDec SOACS) -> m [BranchType SOACS]
expTypesFromPat = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Pat dec -> [ExtType]
expExtTypesFromPat

instance TC.Checkable SOACS where
  checkOp :: Op (Aliases SOACS) -> TypeM SOACS ()
checkOp = forall rep. Checkable rep => SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC

instance Buildable SOACS where
  mkBody :: Stms SOACS -> Result -> Body SOACS
mkBody = forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body ()
  mkExpPat :: [Ident] -> Exp SOACS -> Pat (LetDec SOACS)
mkExpPat [Ident]
merge Exp SOACS
_ = [Ident] -> Pat Type
basicPat [Ident]
merge
  mkExpDec :: Pat (LetDec SOACS) -> Exp SOACS -> ExpDec SOACS
mkExpDec Pat (LetDec SOACS)
_ Exp SOACS
_ = ()
  mkLetNames :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
[VName] -> Exp SOACS -> m (Stm SOACS)
mkLetNames = forall rep (m :: * -> *).
(ExpDec rep ~ (), LetDec rep ~ Type, MonadFreshNames m,
 TypedOp (Op rep), HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
simpleMkLetNames

instance BuilderOps SOACS

instance PrettyRep SOACS

usesAD :: Prog SOACS -> Bool
usesAD :: Prog SOACS -> Bool
usesAD Prog SOACS
prog = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
stmUsesAD (forall rep. Prog rep -> Stms rep
progConsts Prog SOACS
prog) Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any FunDef SOACS -> Bool
funUsesAD (forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog)
  where
    funUsesAD :: FunDef SOACS -> Bool
funUsesAD = Body SOACS -> Bool
bodyUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. FunDef rep -> Body rep
funDefBody
    bodyUsesAD :: Body SOACS -> Bool
bodyUsesAD = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
stmUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms
    stmUsesAD :: Stm SOACS -> Bool
stmUsesAD = Exp SOACS -> Bool
expUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp
    lamUsesAD :: Lambda SOACS -> Bool
lamUsesAD = Body SOACS -> Bool
bodyUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
    expUsesAD :: Exp SOACS -> Bool
expUsesAD (Op JVP {}) = Bool
True
    expUsesAD (Op VJP {}) = Bool
True
    expUsesAD (Op (Stream SubExp
_ [VName]
_ [SubExp]
_ Lambda SOACS
lam)) = Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
    expUsesAD (Op (Screma SubExp
_ [VName]
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
lam))) =
      Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
        Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Lambda SOACS -> Bool
lamUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Scan rep -> Lambda rep
scanLambda) [Scan SOACS]
scans
        Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Lambda SOACS -> Bool
lamUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce SOACS]
reds
    expUsesAD (Op (Hist SubExp
_ [VName]
_ [HistOp SOACS]
ops Lambda SOACS
lam)) =
      Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Lambda SOACS -> Bool
lamUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. HistOp rep -> Lambda rep
histOp) [HistOp SOACS]
ops
    expUsesAD (Op (Scatter SubExp
_ [VName]
_ Lambda SOACS
lam [(ShapeBase SubExp, Int, VName)]
_)) =
      Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
    expUsesAD (Match [SubExp]
_ [Case (Body SOACS)]
cases Body SOACS
def_case MatchDec (BranchType SOACS)
_) =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Body SOACS -> Bool
bodyUsesAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body SOACS)]
cases Bool -> Bool -> Bool
|| Body SOACS -> Bool
bodyUsesAD Body SOACS
def_case
    expUsesAD (DoLoop [(FParam SOACS, SubExp)]
_ LoopForm SOACS
_ Body SOACS
body) = Body SOACS -> Bool
bodyUsesAD Body SOACS
body
    expUsesAD (WithAcc [WithAccInput SOACS]
_ Lambda SOACS
lam) = Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
    expUsesAD BasicOp {} = Bool
False
    expUsesAD Apply {} = Bool
False