{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}

-- | A simple representation with SOACs and nested parallelism.
module Futhark.IR.SOACS
  ( SOACS,
    usesAD,

    -- * Module re-exports
    module Futhark.IR.Prop,
    module Futhark.IR.Traversals,
    module Futhark.IR.Pretty,
    module Futhark.IR.Syntax,
    module Futhark.IR.SOACS.SOAC,
  )
where

import Futhark.Builder
import Futhark.Construct
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.SOACS.SOAC
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import qualified Futhark.IR.TypeCheck as TC

-- | The rep for the basic representation.
data SOACS

instance RepTypes SOACS where
  type Op SOACS = SOAC SOACS

instance ASTRep SOACS where
  expTypesFromPat :: Pat (LetDec SOACS) -> m [BranchType SOACS]
expTypesFromPat = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (Pat Type -> [ExtType]) -> Pat Type -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat Type -> [ExtType]
forall dec. Typed dec => Pat dec -> [ExtType]
expExtTypesFromPat

instance TC.CheckableOp SOACS where
  checkOp :: OpWithAliases (Op SOACS) -> TypeM SOACS ()
checkOp = OpWithAliases (Op SOACS) -> TypeM SOACS ()
forall rep. Checkable rep => SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC

instance TC.Checkable SOACS

instance Buildable SOACS where
  mkBody :: Stms SOACS -> Result -> Body SOACS
mkBody = BodyDec SOACS -> Stms SOACS -> Result -> Body SOACS
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body ()
  mkExpPat :: [Ident] -> Exp SOACS -> Pat (LetDec SOACS)
mkExpPat [Ident]
merge Exp SOACS
_ = [Ident] -> Pat Type
basicPat [Ident]
merge
  mkExpDec :: Pat (LetDec SOACS) -> Exp SOACS -> ExpDec SOACS
mkExpDec Pat (LetDec SOACS)
_ Exp SOACS
_ = ()
  mkLetNames :: [VName] -> Exp SOACS -> m (Stm SOACS)
mkLetNames = [VName] -> Exp SOACS -> m (Stm SOACS)
forall rep (m :: * -> *).
(ExpDec rep ~ (), LetDec rep ~ Type, MonadFreshNames m,
 TypedOp (Op rep), HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
simpleMkLetNames

instance BuilderOps SOACS

instance PrettyRep SOACS

usesAD :: Prog SOACS -> Bool
usesAD :: Prog SOACS -> Bool
usesAD Prog SOACS
prog = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
stmUsesAD (Prog SOACS -> Stms SOACS
forall rep. Prog rep -> Stms rep
progConsts Prog SOACS
prog) Bool -> Bool -> Bool
|| (FunDef SOACS -> Bool) -> [FunDef SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any FunDef SOACS -> Bool
funUsesAD (Prog SOACS -> [FunDef SOACS]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog)
  where
    funUsesAD :: FunDef SOACS -> Bool
funUsesAD = Body SOACS -> Bool
bodyUsesAD (Body SOACS -> Bool)
-> (FunDef SOACS -> Body SOACS) -> FunDef SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody
    bodyUsesAD :: Body SOACS -> Bool
bodyUsesAD = (Stm SOACS -> Bool) -> Stms SOACS -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Stm SOACS -> Bool
stmUsesAD (Stms SOACS -> Bool)
-> (Body SOACS -> Stms SOACS) -> Body SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms
    stmUsesAD :: Stm SOACS -> Bool
stmUsesAD = Exp SOACS -> Bool
expUsesAD (Exp SOACS -> Bool)
-> (Stm SOACS -> Exp SOACS) -> Stm SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Exp SOACS
forall rep. Stm rep -> Exp rep
stmExp
    lamUsesAD :: Lambda SOACS -> Bool
lamUsesAD = Body SOACS -> Bool
bodyUsesAD (Body SOACS -> Bool)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
    expUsesAD :: Exp SOACS -> Bool
expUsesAD (Op JVP {}) = Bool
True
    expUsesAD (Op VJP {}) = Bool
True
    expUsesAD (Op (Stream _ _ _ _ lam)) = Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
    expUsesAD (Op (Screma _ _ (ScremaForm scans reds lam))) =
      Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
        Bool -> Bool -> Bool
|| (Scan SOACS -> Bool) -> [Scan SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Lambda SOACS -> Bool
lamUsesAD (Lambda SOACS -> Bool)
-> (Scan SOACS -> Lambda SOACS) -> Scan SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan SOACS]
scans
        Bool -> Bool -> Bool
|| (Reduce SOACS -> Bool) -> [Reduce SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Lambda SOACS -> Bool
lamUsesAD (Lambda SOACS -> Bool)
-> (Reduce SOACS -> Lambda SOACS) -> Reduce SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce SOACS]
reds
    expUsesAD (Op (Hist _ _ ops lam)) =
      Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam Bool -> Bool -> Bool
|| (HistOp SOACS -> Bool) -> [HistOp SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Lambda SOACS -> Bool
lamUsesAD (Lambda SOACS -> Bool)
-> (HistOp SOACS -> Lambda SOACS) -> HistOp SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp SOACS -> Lambda SOACS
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp SOACS]
ops
    expUsesAD (Op (Scatter _ _ lam _)) =
      Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
    expUsesAD (If SubExp
_ Body SOACS
tbody Body SOACS
fbody IfDec (BranchType SOACS)
_) = Body SOACS -> Bool
bodyUsesAD Body SOACS
tbody Bool -> Bool -> Bool
|| Body SOACS -> Bool
bodyUsesAD Body SOACS
fbody
    expUsesAD (DoLoop [(FParam SOACS, SubExp)]
_ LoopForm SOACS
_ Body SOACS
body) = Body SOACS -> Bool
bodyUsesAD Body SOACS
body
    expUsesAD (WithAcc [WithAccInput SOACS]
_ Lambda SOACS
lam) = Lambda SOACS -> Bool
lamUsesAD Lambda SOACS
lam
    expUsesAD BasicOp {} = Bool
False
    expUsesAD Apply {} = Bool
False