{-# LANGUAGE TypeFamilies #-}

-- | Apply all AD operators in the program, leaving AD-free code.
module Futhark.Pass.AD (applyAD, applyADInnermost) where

import Control.Monad
import Control.Monad.Reader
import Futhark.AD.Fwd (fwdJVP)
import Futhark.AD.Rev (revVJP)
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Pass

-- | Whether we apply only the innermost AD operators, or all of them.
-- The former is very useful for debugging, but probably not useful
-- for actual compilation.
data Mode = Innermost | All
  deriving (Mode -> Mode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Mode -> Mode -> Bool
$c/= :: Mode -> Mode -> Bool
== :: Mode -> Mode -> Bool
$c== :: Mode -> Mode -> Bool
Eq)

bindLambda ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  Pat Type ->
  StmAux (ExpDec SOACS) ->
  Lambda SOACS ->
  [SubExp] ->
  m ()
bindLambda :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat Type
pat StmAux (ExpDec SOACS)
aux (Lambda [LParam SOACS]
params Body SOACS
body [Type]
_) [SubExp]
args = do
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [LParam SOACS]
params [SubExp]
args) forall a b. (a -> b) -> a -> b
$ \(Param Type
param, SubExp
arg) ->
    forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
param] forall a b. (a -> b) -> a -> b
$
      forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ case (forall dec. Typed dec => Param dec -> Type
paramType Param Type
param, SubExp
arg) of
        (Array {}, Var VName
v) -> VName -> BasicOp
Copy VName
v
        (Type, SubExp)
_ -> SubExp -> BasicOp
SubExp SubExp
arg
  Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body SOACS
body
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat Type
pat) Result
res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
    forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (VJP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) = do
  Lambda SOACS
lam' <- Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam
  if Mode
mode forall a. Eq a => a -> a -> Bool
== Mode
All Bool -> Bool -> Bool
|| Lambda SOACS
lam forall a. Eq a => a -> a -> Bool
== Lambda SOACS
lam'
    then do
      Lambda SOACS
lam'' <- (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
scope) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope Lambda SOACS
lam'
      forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Lambda SOACS
lam'' forall a b. (a -> b) -> a -> b
$ [SubExp]
args forall a. [a] -> [a] -> [a]
++ [SubExp]
vec) Scope SOACS
scope
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (JVP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) = do
  Lambda SOACS
lam' <- Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam
  if Mode
mode forall a. Eq a => a -> a -> Bool
== Mode
All Bool -> Bool -> Bool
|| Lambda SOACS
lam forall a. Eq a => a -> a -> Bool
== Lambda SOACS
lam'
    then do
      Lambda SOACS
lam'' <- forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
fwdJVP Scope SOACS
scope Lambda SOACS
lam'
      forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Lambda SOACS
lam'' forall a b. (a -> b) -> a -> b
$ [SubExp]
args forall a. [a] -> [a] -> [a]
++ [SubExp]
vec) Scope SOACS
scope
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = forall rep. Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS PassM
mapper Exp SOACS
e
  where
    mapper :: Mapper SOACS SOACS PassM
mapper =
      forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope SOACS -> Body SOACS -> PassM (Body SOACS)
mapOnBody = \Scope SOACS
bscope -> Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode (Scope SOACS
bscope forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope),
          mapOnOp :: Op SOACS -> PassM (Op SOACS)
mapOnOp = forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS SOACS PassM
soac_mapper
        }
    soac_mapper :: SOACMapper SOACS SOACS PassM
soac_mapper = forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> PassM (Lambda SOACS)
mapOnSOACLambda = Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope}

onStms :: Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms :: Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
mode Scope SOACS
scope Stms SOACS
stms = forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm Mode
mode Scope SOACS
scope') (forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms)
  where
    scope' :: Scope SOACS
scope' = forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope

onBody :: Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody :: Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode Scope SOACS
scope Body SOACS
body = do
  Stms SOACS
stms <- Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
mode Scope SOACS
scope forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body SOACS
body {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms}

onLambda :: Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda :: Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam = do
  Body SOACS
body <- Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lambda SOACS
lam {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body}

onFun :: Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun :: Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
mode Stms SOACS
consts FunDef SOACS
fd = do
  Body SOACS
body <- Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode (forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts forall a. Semigroup a => a -> a -> a
<> forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef SOACS
fd) forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fd
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ FunDef SOACS
fd {funDefBody :: Body SOACS
funDefBody = Body SOACS
body}

applyAD :: Pass SOACS SOACS
applyAD :: Pass SOACS SOACS
applyAD =
  Pass
    { passName :: String
passName = String
"ad",
      passDescription :: String
passDescription = String
"Apply AD operators",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction =
        forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
          (Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
All forall a. Monoid a => a
mempty)
          (Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
All)
    }

applyADInnermost :: Pass SOACS SOACS
applyADInnermost :: Pass SOACS SOACS
applyADInnermost =
  Pass
    { passName :: String
passName = String
"ad innermost",
      passDescription :: String
passDescription = String
"Apply innermost AD operators",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction =
        forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
          (Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
Innermost forall a. Monoid a => a
mempty)
          (Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
Innermost)
    }