{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Futhark.AD.Fwd (fwdJVP) where

import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import qualified Data.Kind
import Data.List (transpose)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.Map as M
import Futhark.AD.Derivatives
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.Construct
import Futhark.IR.SOACS

zeroTan :: Type -> ADM SubExp
zeroTan :: Type -> ADM SubExp
zeroTan (Prim PrimType
t) = SubExp -> ADM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
zeroTan Type
t = [Char] -> ADM SubExp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM SubExp) -> [Char] -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ [Char]
"zeroTan on non-primitive type: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Type
t

zeroExp :: Type -> Exp SOACS
zeroExp :: Type -> Exp SOACS
zeroExp (Prim PrimType
pt) =
  BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
  BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp Type
t = [Char] -> Exp SOACS
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp SOACS) -> [Char] -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [Char]
"zeroExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Show a => a -> [Char]
show Type
t

tanType :: TypeBase s u -> ADM (TypeBase s u)
tanType :: TypeBase s u -> ADM (TypeBase s u)
tanType (Acc VName
acc Shape
ispace [Type]
ts u
u) = do
  [Type]
ts_tan <- (Type -> ADM Type) -> [Type] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ADM Type
forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType [Type]
ts
  TypeBase s u -> ADM (TypeBase s u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase s u -> ADM (TypeBase s u))
-> TypeBase s u -> ADM (TypeBase s u)
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> u -> TypeBase s u
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace ([Type]
ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ts_tan) u
u
tanType TypeBase s u
t = TypeBase s u -> ADM (TypeBase s u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeBase s u
t

slocal' :: ADM a -> ADM a
slocal' :: ADM a -> ADM a
slocal' = (RState -> RState) -> ADM a -> ADM a
forall a. (RState -> RState) -> ADM a -> ADM a
slocal RState -> RState
forall a. a -> a
id

slocal :: (RState -> RState) -> ADM a -> ADM a
slocal :: (RState -> RState) -> ADM a -> ADM a
slocal RState -> RState
f ADM a
m = do
  RState
s <- ADM RState
forall s (m :: * -> *). MonadState s m => m s
get
  (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify RState -> RState
f
  a
a <- ADM a
m
  (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
s' -> RState
s' {stateTans :: Map VName VName
stateTans = RState -> Map VName VName
stateTans RState
s}
  a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

data RState = RState
  { RState -> Map VName VName
stateTans :: M.Map VName VName,
    RState -> VNameSource
stateNameSource :: VNameSource
  }

newtype ADM a = ADM (BuilderT SOACS (State RState) a)
  deriving
    ( a -> ADM b -> ADM a
(a -> b) -> ADM a -> ADM b
(forall a b. (a -> b) -> ADM a -> ADM b)
-> (forall a b. a -> ADM b -> ADM a) -> Functor ADM
forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ADM b -> ADM a
$c<$ :: forall a b. a -> ADM b -> ADM a
fmap :: (a -> b) -> ADM a -> ADM b
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
Functor,
      Functor ADM
a -> ADM a
Functor ADM
-> (forall a. a -> ADM a)
-> (forall a b. ADM (a -> b) -> ADM a -> ADM b)
-> (forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM a)
-> Applicative ADM
ADM a -> ADM b -> ADM b
ADM a -> ADM b -> ADM a
ADM (a -> b) -> ADM a -> ADM b
(a -> b -> c) -> ADM a -> ADM b -> ADM c
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ADM a -> ADM b -> ADM a
$c<* :: forall a b. ADM a -> ADM b -> ADM a
*> :: ADM a -> ADM b -> ADM b
$c*> :: forall a b. ADM a -> ADM b -> ADM b
liftA2 :: (a -> b -> c) -> ADM a -> ADM b -> ADM c
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
<*> :: ADM (a -> b) -> ADM a -> ADM b
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
pure :: a -> ADM a
$cpure :: forall a. a -> ADM a
$cp1Applicative :: Functor ADM
Applicative,
      Applicative ADM
a -> ADM a
Applicative ADM
-> (forall a b. ADM a -> (a -> ADM b) -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a. a -> ADM a)
-> Monad ADM
ADM a -> (a -> ADM b) -> ADM b
ADM a -> ADM b -> ADM b
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ADM a
$creturn :: forall a. a -> ADM a
>> :: ADM a -> ADM b -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>>= :: ADM a -> (a -> ADM b) -> ADM b
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$cp1Monad :: Applicative ADM
Monad,
      MonadState RState,
      Monad ADM
Applicative ADM
ADM VNameSource
Applicative ADM
-> Monad ADM
-> ADM VNameSource
-> (VNameSource -> ADM ())
-> MonadFreshNames ADM
VNameSource -> ADM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> ADM ()
$cputNameSource :: VNameSource -> ADM ()
getNameSource :: ADM VNameSource
$cgetNameSource :: ADM VNameSource
$cp2MonadFreshNames :: Monad ADM
$cp1MonadFreshNames :: Applicative ADM
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

instance MonadBuilder ADM where
  type Rep ADM = SOACS
  mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> BuilderT SOACS (State RState) () -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State RState))))
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
     SOACS (State RState) (ExpDec (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep (BuilderT SOACS (State RState))))
Pat (LetDec (Rep ADM))
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
  mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS))
-> BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT SOACS (State RState)))
-> Result
-> BuilderT
     SOACS (State RState) (Body (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT SOACS (State RState)))
Stms (Rep ADM)
bnds Result
res
  mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS))
-> BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
     SOACS (State RState) (Stm (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e

  addStms :: Stms (Rep ADM) -> ADM ()
addStms = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> (Stms SOACS -> BuilderT SOACS (State RState) ())
-> Stms SOACS
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> BuilderT SOACS (State RState) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = BuilderT SOACS (State RState) (a, Stms SOACS)
-> ADM (a, Stms SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (a, Stms SOACS)
 -> ADM (a, Stms SOACS))
-> BuilderT SOACS (State RState) (a, Stms SOACS)
-> ADM (a, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ BuilderT SOACS (State RState) a
-> BuilderT
     SOACS
     (State RState)
     (a, Stms (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m

instance MonadFreshNames (State RState) where
  getNameSource :: State RState VNameSource
getNameSource = (RState -> VNameSource) -> State RState VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = (RState -> RState) -> State RState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource :: VNameSource
stateNameSource = VNameSource
src})

runADM :: MonadFreshNames m => ADM a -> m a
runADM :: ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
  (VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
    (RState -> VNameSource) -> (a, RState) -> (a, VNameSource)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource ((a, RState) -> (a, VNameSource))
-> (a, RState) -> (a, VNameSource)
forall a b. (a -> b) -> a -> b
$
      State RState a -> RState -> (a, RState)
forall s a. State s a -> s -> (a, s)
runState
        ((a, Stms SOACS) -> a
forall a b. (a, b) -> a
fst ((a, Stms SOACS) -> a)
-> State RState (a, Stms SOACS) -> State RState a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS (State RState) a
-> Scope SOACS -> State RState (a, Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m Scope SOACS
forall a. Monoid a => a
mempty)
        (Map VName VName -> VNameSource -> RState
RState Map VName VName
forall a. Monoid a => a
mempty VNameSource
vn)

tanVName :: VName -> ADM VName
tanVName :: VName -> ADM VName
tanVName VName
v = [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_tan")

insertTan :: VName -> VName -> ADM ()
insertTan :: VName -> VName -> ADM ()
insertTan VName
v VName
v' =
  (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateTans :: Map VName VName
stateTans = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v' (RState -> Map VName VName
stateTans RState
env)}

class TanBuilder a where
  type Bundled a :: Data.Kind.Type
  type Bundled a = [a]
  newTan :: a -> ADM a
  bundleNew :: a -> ADM (Bundled a)

instance (Monoid (Bundled a), TanBuilder a) => TanBuilder [a] where
  type Bundled [a] = Bundled a
  newTan :: [a] -> ADM [a]
newTan = (a -> ADM a) -> [a] -> ADM [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> ADM a
forall a. TanBuilder a => a -> ADM a
newTan
  bundleNew :: [a] -> ADM (Bundled [a])
bundleNew = ([Bundled a] -> Bundled a) -> ADM [Bundled a] -> ADM (Bundled a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Bundled a] -> Bundled a
forall a. Monoid a => [a] -> a
mconcat (ADM [Bundled a] -> ADM (Bundled a))
-> ([a] -> ADM [Bundled a]) -> [a] -> ADM (Bundled a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> ADM (Bundled a)) -> [a] -> ADM [Bundled a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> ADM (Bundled a)
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew

instance TanBuilder (PatElem (TypeBase s u)) where
  newTan :: PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
newTan (PatElem VName
p TypeBase s u
t)
    | TypeBase s u -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase s u
t = do
        VName -> VName -> ADM ()
insertTan VName
p VName
p
        TypeBase s u
t' <- TypeBase s u -> ADM (TypeBase s u)
forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType TypeBase s u
t
        PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u)))
-> PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a b. (a -> b) -> a -> b
$ VName -> TypeBase s u -> PatElem (TypeBase s u)
forall dec. VName -> dec -> PatElem dec
PatElem VName
p TypeBase s u
t'
    | Bool
otherwise = do
        VName
p' <- VName -> ADM VName
tanVName VName
p
        VName -> VName -> ADM ()
insertTan VName
p VName
p'
        TypeBase s u
t' <- TypeBase s u -> ADM (TypeBase s u)
forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType TypeBase s u
t
        PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u)))
-> PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a b. (a -> b) -> a -> b
$ VName -> TypeBase s u -> PatElem (TypeBase s u)
forall dec. VName -> dec -> PatElem dec
PatElem VName
p' TypeBase s u
t'
  bundleNew :: PatElem (TypeBase s u) -> ADM (Bundled (PatElem (TypeBase s u)))
bundleNew pe :: PatElem (TypeBase s u)
pe@(PatElem VName
_ TypeBase s u
t) = do
    PatElem (TypeBase s u)
pe' <- PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan PatElem (TypeBase s u)
pe
    if TypeBase s u -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase s u
t
      then [PatElem (TypeBase s u)] -> ADM [PatElem (TypeBase s u)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PatElem (TypeBase s u)
pe']
      else [PatElem (TypeBase s u)] -> ADM [PatElem (TypeBase s u)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PatElem (TypeBase s u)
pe, PatElem (TypeBase s u)
pe']

instance TanBuilder (Pat (TypeBase s u)) where
  type Bundled (Pat (TypeBase s u)) = Pat (TypeBase s u)
  newTan :: Pat (TypeBase s u) -> ADM (Pat (TypeBase s u))
newTan (Pat [PatElem (TypeBase s u)]
pes) = [PatElem (TypeBase s u)] -> Pat (TypeBase s u)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase s u)] -> Pat (TypeBase s u))
-> ADM [PatElem (TypeBase s u)] -> ADM (Pat (TypeBase s u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PatElem (TypeBase s u)] -> ADM [PatElem (TypeBase s u)]
forall a. TanBuilder a => a -> ADM a
newTan [PatElem (TypeBase s u)]
pes
  bundleNew :: Pat (TypeBase s u) -> ADM (Bundled (Pat (TypeBase s u)))
bundleNew (Pat [PatElem (TypeBase s u)]
pes) = [PatElem (TypeBase s u)] -> Pat (TypeBase s u)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase s u)] -> Pat (TypeBase s u))
-> ADM [PatElem (TypeBase s u)] -> ADM (Pat (TypeBase s u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PatElem (TypeBase s u)] -> ADM (Bundled [PatElem (TypeBase s u)])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew [PatElem (TypeBase s u)]
pes

instance TanBuilder (Param (TypeBase s u)) where
  newTan :: Param (TypeBase s u) -> ADM (Param (TypeBase s u))
newTan (Param Attrs
_ VName
p TypeBase s u
t) = do
    PatElem VName
p' TypeBase s u
t' <- PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan (PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u)))
-> PatElem (TypeBase s u) -> ADM (PatElem (TypeBase s u))
forall a b. (a -> b) -> a -> b
$ VName -> TypeBase s u -> PatElem (TypeBase s u)
forall dec. VName -> dec -> PatElem dec
PatElem VName
p TypeBase s u
t
    Param (TypeBase s u) -> ADM (Param (TypeBase s u))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (TypeBase s u) -> ADM (Param (TypeBase s u)))
-> Param (TypeBase s u) -> ADM (Param (TypeBase s u))
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> TypeBase s u -> Param (TypeBase s u)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
p' TypeBase s u
t'
  bundleNew :: Param (TypeBase s u) -> ADM (Bundled (Param (TypeBase s u)))
bundleNew param :: Param (TypeBase s u)
param@(Param Attrs
_ VName
_ (Prim PrimType
Unit)) =
    [Param (TypeBase s u)] -> ADM [Param (TypeBase s u)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Param (TypeBase s u)
param]
  bundleNew param :: Param (TypeBase s u)
param@(Param Attrs
_ VName
_ TypeBase s u
t) = do
    Param (TypeBase s u)
param' <- Param (TypeBase s u) -> ADM (Param (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan Param (TypeBase s u)
param
    if TypeBase s u -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase s u
t
      then [Param (TypeBase s u)] -> ADM [Param (TypeBase s u)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Param (TypeBase s u)
param']
      else [Param (TypeBase s u)] -> ADM [Param (TypeBase s u)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Param (TypeBase s u)
param, Param (TypeBase s u)
param']

instance Tangent a => TanBuilder (Param (TypeBase s u), a) where
  newTan :: (Param (TypeBase s u), a) -> ADM (Param (TypeBase s u), a)
newTan (Param (TypeBase s u)
p, a
x) = (,) (Param (TypeBase s u) -> a -> (Param (TypeBase s u), a))
-> ADM (Param (TypeBase s u))
-> ADM (a -> (Param (TypeBase s u), a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (TypeBase s u) -> ADM (Param (TypeBase s u))
forall a. TanBuilder a => a -> ADM a
newTan Param (TypeBase s u)
p ADM (a -> (Param (TypeBase s u), a))
-> ADM a -> ADM (Param (TypeBase s u), a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> ADM a
forall a. Tangent a => a -> ADM a
tangent a
x
  bundleNew :: (Param (TypeBase s u), a)
-> ADM (Bundled (Param (TypeBase s u), a))
bundleNew (Param (TypeBase s u)
p, a
x) = do
    [Param (TypeBase s u)]
b <- Param (TypeBase s u) -> ADM (Bundled (Param (TypeBase s u)))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Param (TypeBase s u)
p
    a
x_tan <- a -> ADM a
forall a. Tangent a => a -> ADM a
tangent a
x
    [(Param (TypeBase s u), a)] -> ADM [(Param (TypeBase s u), a)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Param (TypeBase s u), a)] -> ADM [(Param (TypeBase s u), a)])
-> [(Param (TypeBase s u), a)] -> ADM [(Param (TypeBase s u), a)]
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase s u)] -> [a] -> [(Param (TypeBase s u), a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase s u)]
b [a
x, a
x_tan]

class Tangent a where
  type BundledTan a :: Data.Kind.Type
  type BundledTan a = [a]
  tangent :: a -> ADM a
  bundleTan :: a -> ADM (BundledTan a)

instance Tangent (TypeBase s u) where
  tangent :: TypeBase s u -> ADM (TypeBase s u)
tangent = TypeBase s u -> ADM (TypeBase s u)
forall s u. TypeBase s u -> ADM (TypeBase s u)
tanType
  bundleTan :: TypeBase s u -> ADM (BundledTan (TypeBase s u))
bundleTan TypeBase s u
t
    | TypeBase s u -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc TypeBase s u
t = do
        TypeBase s u
t' <- TypeBase s u -> ADM (TypeBase s u)
forall a. Tangent a => a -> ADM a
tangent TypeBase s u
t
        [TypeBase s u] -> ADM [TypeBase s u]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase s u
t']
    | Bool
otherwise = do
        TypeBase s u
t' <- TypeBase s u -> ADM (TypeBase s u)
forall a. Tangent a => a -> ADM a
tangent TypeBase s u
t
        [TypeBase s u] -> ADM [TypeBase s u]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase s u
t, TypeBase s u
t']

instance (Monoid (BundledTan a), Tangent a) => Tangent [a] where
  type BundledTan [a] = BundledTan a
  tangent :: [a] -> ADM [a]
tangent = (a -> ADM a) -> [a] -> ADM [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> ADM a
forall a. Tangent a => a -> ADM a
tangent
  bundleTan :: [a] -> ADM (BundledTan [a])
bundleTan = ([BundledTan a] -> BundledTan a
forall a. Monoid a => [a] -> a
mconcat ([BundledTan a] -> BundledTan a)
-> ADM [BundledTan a] -> ADM (BundledTan a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (ADM [BundledTan a] -> ADM (BundledTan a))
-> ([a] -> ADM [BundledTan a]) -> [a] -> ADM (BundledTan a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> ADM (BundledTan a)) -> [a] -> ADM [BundledTan a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> ADM (BundledTan a)
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan

instance Tangent VName where
  tangent :: VName -> ADM VName
tangent VName
v = do
    Maybe VName
maybeTan <- (RState -> Maybe VName) -> ADM (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe VName) -> ADM (Maybe VName))
-> (RState -> Maybe VName) -> ADM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName VName -> Maybe VName)
-> (RState -> Map VName VName) -> RState -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName VName
stateTans
    case Maybe VName
maybeTan of
      Just VName
v_tan -> VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_tan
      Maybe VName
Nothing -> do
        Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
        [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_implicit_tan") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
zeroExp Type
t
  bundleTan :: VName -> ADM (BundledTan VName)
bundleTan VName
v = do
    Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
    if Type -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc Type
t
      then [VName] -> ADM [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
v]
      else do
        VName
v_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
v
        [VName] -> ADM [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
v, VName
v_tan]

instance Tangent SubExp where
  tangent :: SubExp -> ADM SubExp
tangent (Constant PrimValue
c) = Type -> ADM SubExp
zeroTan (Type -> ADM SubExp) -> Type -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
c
  tangent (Var VName
v) = VName -> SubExp
Var (VName -> SubExp) -> ADM VName -> ADM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
v
  bundleTan :: SubExp -> ADM (BundledTan SubExp)
bundleTan c :: SubExp
c@Constant {} = do
    SubExp
c_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
c
    [SubExp] -> ADM [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp
c, SubExp
c_tan]
  bundleTan (Var VName
v) = (VName -> SubExp) -> [VName] -> [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var ([VName] -> [SubExp]) -> ADM [VName] -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM (BundledTan VName)
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan VName
v

instance Tangent SubExpRes where
  tangent :: SubExpRes -> ADM SubExpRes
tangent (SubExpRes Certs
cs SubExp
se) = Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes) -> ADM SubExp -> ADM SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
se
  bundleTan :: SubExpRes -> ADM (BundledTan SubExpRes)
bundleTan (SubExpRes Certs
cs SubExp
se) = (SubExp -> SubExpRes) -> [SubExp] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs) ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM (BundledTan SubExp)
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan SubExp
se

basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM ()
basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM ()
basicFwd Pat Type
pat StmAux ()
aux BasicOp
op = do
  Pat Type
pat_tan <- Pat Type -> ADM (Pat Type)
forall a. TanBuilder a => a -> ADM a
newTan Pat Type
pat
  case BasicOp
op of
    SubExp SubExp
se -> do
      SubExp
se_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
se
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se_tan
    Opaque OpaqueOp
opaqueop SubExp
se -> do
      SubExp
se_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
se
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ OpaqueOp -> SubExp -> BasicOp
Opaque OpaqueOp
opaqueop SubExp
se_tan
    ArrayLit [SubExp]
ses Type
t -> do
      [SubExp]
ses_tan <- [SubExp] -> ADM [SubExp]
forall a. Tangent a => a -> ADM a
tangent [SubExp]
ses
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp]
ses_tan Type
t
    UnOp UnOp
unop SubExp
x -> do
      let t :: PrimType
t = UnOp -> PrimType
unOpType UnOp
unop
          x_pe :: PrimExp VName
x_pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x
          dx :: PrimExp VName
dx = UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
unop PrimExp VName
x_pe
      PrimExp VName
x_tan <- PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> ADM SubExp -> ADM (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
      StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat_tan) (Exp SOACS -> ADM ())
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM ()) -> PrimExp VName -> ADM ()
forall a b. (a -> b) -> a -> b
$ PrimExp VName
x_tan PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
dx
    BinOp BinOp
bop SubExp
x SubExp
y -> do
      let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
bop
      PrimExp VName
x_tan <- PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> ADM SubExp -> ADM (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
      PrimExp VName
y_tan <- PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> ADM SubExp -> ADM (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
y
      let (PrimExp VName
wrt_x, PrimExp VName
wrt_y) =
            BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
bop (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x) (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
y)
      StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
        [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat_tan) (Exp SOACS -> ADM ())
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM ()) -> PrimExp VName -> ADM ()
forall a b. (a -> b) -> a -> b
$
          PrimExp VName
x_tan PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_x PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~+~ PrimExp VName
y_tan PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_y
    CmpOp {} ->
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
    ConvOp ConvOp
cop SubExp
x -> do
      SubExp
x_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp ConvOp
cop SubExp
x_tan
    Assert {} -> () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Index VName
arr Slice SubExp
slice -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_tan Slice SubExp
slice
    Update Safety
safety VName
arr Slice SubExp
slice SubExp
se -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      SubExp
se_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
se
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr_tan Slice SubExp
slice SubExp
se_tan
    Concat Int
d (VName
arr :| [VName]
arrs) SubExp
w -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      [VName]
arrs_tans <- [VName] -> ADM [VName]
forall a. Tangent a => a -> ADM a
tangent [VName]
arrs
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
d (VName
arr_tan VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
arrs_tans) SubExp
w
    Copy VName
arr -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arr_tan
    Manifest [Int]
ds VName
arr -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
ds VName
arr_tan
    Iota SubExp
n SubExp
_ SubExp
_ IntType
it -> do
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)
    Replicate Shape
n SubExp
x -> do
      SubExp
x_tan <- SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent SubExp
x
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
n SubExp
x_tan
    Scratch PrimType
t [SubExp]
shape ->
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t [SubExp]
shape
    Reshape ShapeChange SubExp
reshape VName
arr -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
reshape VName
arr_tan
    Rearrange [Int]
perm VName
arr -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
arr_tan
    Rotate [SubExp]
rots VName
arr -> do
      VName
arr_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
arr
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat_tan StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots VName
arr_tan
    BasicOp
_ -> [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM ()) -> [Char] -> ADM ()
forall a b. (a -> b) -> a -> b
$ [Char]
"basicFwd: Unsupported op " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ BasicOp -> [Char]
forall a. Pretty a => a -> [Char]
pretty BasicOp
op

fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
  [Param Type] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda ([Param Type] -> Body SOACS -> [Type] -> Lambda SOACS)
-> ADM [Param Type] -> ADM (Body SOACS -> [Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param Type] -> ADM (Bundled [Param Type])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew [Param Type]
[LParam SOACS]
params ADM (Body SOACS -> [Type] -> Lambda SOACS)
-> ADM (Body SOACS) -> ADM ([Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
body) ADM ([Type] -> Lambda SOACS) -> ADM [Type] -> ADM (Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> ADM (BundledTan [Type])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [Type]
ret

fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS)
fwdStreamLambda l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
  [Param Type] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda ([Param Type] -> Body SOACS -> [Type] -> Lambda SOACS)
-> ADM [Param Type] -> ADM (Body SOACS -> [Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take Int
1 [Param Type]
[LParam SOACS]
params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++) ([Param Type] -> [Param Type])
-> ADM [Param Type] -> ADM [Param Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param Type] -> ADM (Bundled [Param Type])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew (Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop Int
1 [Param Type]
[LParam SOACS]
params)) ADM (Body SOACS -> [Type] -> Lambda SOACS)
-> ADM (Body SOACS) -> ADM ([Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
body) ADM ([Type] -> Lambda SOACS) -> ADM [Type] -> ADM (Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> ADM (BundledTan [Type])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [Type]
ret

interleave :: [a] -> [a] -> [a]
interleave :: [a] -> [a] -> [a]
interleave [a]
xs [a]
ys = [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a]) -> [[a]] -> [a]
forall a b. (a -> b) -> a -> b
$ [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose [[a]
xs, [a]
ys]

zeroFromSubExp :: SubExp -> ADM VName
zeroFromSubExp :: SubExp -> ADM VName
zeroFromSubExp (Constant PrimValue
c) =
  [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
    BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
c
zeroFromSubExp (Var VName
v) = do
  Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
zeroExp Type
t

fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC Pat Type
pat StmAux ()
aux (Screma SubExp
size [VName]
xs (ScremaForm [Scan SOACS]
scs [Reduce SOACS]
reds Lambda SOACS
f)) = do
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
pat
  [VName]
xs' <- [VName] -> ADM (BundledTan [VName])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [VName]
xs
  [Scan SOACS]
scs' <- (Scan SOACS -> ADM (Scan SOACS))
-> [Scan SOACS] -> ADM [Scan SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Scan SOACS -> ADM (Scan SOACS)
fwdScan [Scan SOACS]
scs
  [Reduce SOACS]
reds' <- (Reduce SOACS -> ADM (Reduce SOACS))
-> [Reduce SOACS] -> ADM [Reduce SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Reduce SOACS -> ADM (Reduce SOACS)
fwdRed [Reduce SOACS]
reds
  Lambda SOACS
f' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda Lambda SOACS
f
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
size [VName]
xs' (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan SOACS]
scs' [Reduce SOACS]
reds' Lambda SOACS
f'
  where
    fwdScan :: Scan SOACS -> ADM (Scan SOACS)
    fwdScan :: Scan SOACS -> ADM (Scan SOACS)
fwdScan Scan SOACS
sc = do
      Lambda SOACS
op' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
sc
      [VName]
neutral_tans <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ADM VName
zeroFromSubExp ([SubExp] -> ADM [VName]) -> [SubExp] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
sc
      Scan SOACS -> ADM (Scan SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$
        Scan :: forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan
          { scanNeutral :: [SubExp]
scanNeutral = Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
sc [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
`interleave` (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
neutral_tans,
            scanLambda :: Lambda SOACS
scanLambda = Lambda SOACS
op'
          }
    fwdRed :: Reduce SOACS -> ADM (Reduce SOACS)
    fwdRed :: Reduce SOACS -> ADM (Reduce SOACS)
fwdRed Reduce SOACS
red = do
      Lambda SOACS
op' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red
      [VName]
neutral_tans <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ADM VName
zeroFromSubExp ([SubExp] -> ADM [VName]) -> [SubExp] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral Reduce SOACS
red
      Reduce SOACS -> ADM (Reduce SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Reduce SOACS -> ADM (Reduce SOACS))
-> Reduce SOACS -> ADM (Reduce SOACS)
forall a b. (a -> b) -> a -> b
$
        Reduce :: forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce
          { redComm :: Commutativity
redComm = Reduce SOACS -> Commutativity
forall rep. Reduce rep -> Commutativity
redComm Reduce SOACS
red,
            redLambda :: Lambda SOACS
redLambda = Lambda SOACS
op',
            redNeutral :: [SubExp]
redNeutral = Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral Reduce SOACS
red [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
`interleave` (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
neutral_tans
          }
fwdSOAC Pat Type
pat StmAux ()
aux (Stream SubExp
size [VName]
xs StreamForm SOACS
form [SubExp]
nes Lambda SOACS
lam) = do
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
pat
  Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdStreamLambda Lambda SOACS
lam
  [VName]
xs' <- [VName] -> ADM (BundledTan [VName])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [VName]
xs
  [SubExp]
nes_tan <- (SubExp -> ADM SubExp) -> [SubExp] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> SubExp) -> ADM VName -> ADM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (ADM VName -> ADM SubExp)
-> (SubExp -> ADM VName) -> SubExp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> ADM VName
zeroFromSubExp) [SubExp]
nes
  let nes' :: [SubExp]
nes' = [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
interleave [SubExp]
nes [SubExp]
nes_tan
  case StreamForm SOACS
form of
    StreamForm SOACS
Sequential ->
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName]
-> StreamForm SOACS
-> [SubExp]
-> Lambda SOACS
-> SOAC SOACS
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
xs' StreamForm SOACS
forall rep. StreamForm rep
Sequential [SubExp]
nes' Lambda SOACS
lam'
    Parallel StreamOrd
o Commutativity
comm Lambda SOACS
lam0 -> do
      Lambda SOACS
lam0' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda Lambda SOACS
lam0
      let form' :: StreamForm SOACS
form' = StreamOrd -> Commutativity -> Lambda SOACS -> StreamForm SOACS
forall rep.
StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
Parallel StreamOrd
o Commutativity
comm Lambda SOACS
lam0'
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName]
-> StreamForm SOACS
-> [SubExp]
-> Lambda SOACS
-> SOAC SOACS
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
xs' StreamForm SOACS
form' [SubExp]
nes' Lambda SOACS
lam'
fwdSOAC Pat Type
pat StmAux ()
aux (Hist SubExp
w [VName]
arrs [HistOp SOACS]
ops Lambda SOACS
bucket_fun) = do
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
pat
  [HistOp SOACS]
ops' <- (HistOp SOACS -> ADM (HistOp SOACS))
-> [HistOp SOACS] -> ADM [HistOp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp SOACS -> ADM (HistOp SOACS)
fwdHist [HistOp SOACS]
ops
  Lambda SOACS
bucket_fun' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdHistBucket Lambda SOACS
bucket_fun
  [VName]
arrs' <- [VName] -> ADM (BundledTan [VName])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [VName]
arrs
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w [VName]
arrs' [HistOp SOACS]
ops' Lambda SOACS
bucket_fun'
  where
    n_indices :: Int
n_indices = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (HistOp SOACS -> Int) -> [HistOp SOACS] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (Shape -> Int) -> (HistOp SOACS -> Shape) -> HistOp SOACS -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp SOACS -> Shape
forall rep. HistOp rep -> Shape
histShape) [HistOp SOACS]
ops
    fwdBodyHist :: Body SOACS -> ADM (Body (Rep ADM))
fwdBodyHist (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
      (Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
      let (Result
res_is, Result
res_vs) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n_indices Result
res
      (Result
res_is Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> Result) -> ADM Result -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> ADM (BundledTan Result)
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan Result
res_vs
    fwdHistBucket :: Lambda SOACS -> ADM (Lambda SOACS)
fwdHistBucket l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
      let ([Type]
r_is, [Type]
r_vs) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n_indices [Type]
ret
       in [Param Type] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
            ([Param Type] -> Body SOACS -> [Type] -> Lambda SOACS)
-> ADM [Param Type] -> ADM (Body SOACS -> [Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Param Type] -> ADM (Bundled [Param Type])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew [Param Type]
[LParam SOACS]
params
            ADM (Body SOACS -> [Type] -> Lambda SOACS)
-> ADM (Body SOACS) -> ADM ([Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (Body SOACS -> ADM (Body (Rep ADM))
fwdBodyHist Body SOACS
body)
            ADM ([Type] -> Lambda SOACS) -> ADM [Type] -> ADM (Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (([Type]
r_is [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++) ([Type] -> [Type]) -> ADM [Type] -> ADM [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> ADM (BundledTan [Type])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [Type]
r_vs)

    fwdHist :: HistOp SOACS -> ADM (HistOp SOACS)
    fwdHist :: HistOp SOACS -> ADM (HistOp SOACS)
fwdHist (HistOp Shape
shape SubExp
rf [VName]
dest [SubExp]
nes Lambda SOACS
op) = do
      [VName]
dest' <- [VName] -> ADM (BundledTan [VName])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [VName]
dest
      [SubExp]
nes_tan <- (SubExp -> ADM SubExp) -> [SubExp] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> SubExp) -> ADM VName -> ADM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (ADM VName -> ADM SubExp)
-> (SubExp -> ADM VName) -> SubExp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> ADM VName
zeroFromSubExp) [SubExp]
nes
      Lambda SOACS
op' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda Lambda SOACS
op
      HistOp SOACS -> ADM (HistOp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HistOp SOACS -> ADM (HistOp SOACS))
-> HistOp SOACS -> ADM (HistOp SOACS)
forall a b. (a -> b) -> a -> b
$
        HistOp :: forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp
          { histShape :: Shape
histShape = Shape
shape,
            histRaceFactor :: SubExp
histRaceFactor = SubExp
rf,
            histDest :: [VName]
histDest = [VName]
dest',
            histNeutral :: [SubExp]
histNeutral = [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
interleave [SubExp]
nes [SubExp]
nes_tan,
            histOp :: Lambda SOACS
histOp = Lambda SOACS
op'
          }
fwdSOAC (Pat [PatElem Type]
pes) StmAux ()
aux (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
as) = do
  [(Shape, Int, VName)]
as_tan <- ((Shape, Int, VName) -> ADM (Shape, Int, VName))
-> [(Shape, Int, VName)] -> ADM [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Shape
s, Int
n, VName
a) -> do VName
a_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
a; (Shape, Int, VName) -> ADM (Shape, Int, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
s, Int
n, VName
a_tan)) [(Shape, Int, VName)]
as
  [PatElem Type]
pes_tan <- [PatElem Type] -> ADM [PatElem Type]
forall a. TanBuilder a => a -> ADM a
newTan [PatElem Type]
pes
  [VName]
ivs' <- [VName] -> ADM (BundledTan [VName])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [VName]
ivs
  let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
      n_indices :: Int
n_indices = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
  Lambda SOACS
lam' <- Int -> Lambda SOACS -> ADM (Lambda SOACS)
fwdScatterLambda Int
n_indices Lambda SOACS
lam
  let s :: Stm SOACS
s = Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type]
pes [PatElem Type] -> [PatElem Type] -> [PatElem Type]
forall a. [a] -> [a] -> [a]
++ [PatElem Type]
pes_tan)) StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName] -> Lambda SOACS -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
ivs' Lambda SOACS
lam' ([(Shape, Int, VName)] -> SOAC SOACS)
-> [(Shape, Int, VName)] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [(Shape, Int, VName)]
as [(Shape, Int, VName)]
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_tan
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
s
  where
    fwdScatterLambda :: Int -> Lambda SOACS -> ADM (Lambda SOACS)
    fwdScatterLambda :: Int -> Lambda SOACS -> ADM (Lambda SOACS)
fwdScatterLambda Int
n_indices (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) = do
      [Param Type]
params' <- [Param Type] -> ADM (Bundled [Param Type])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew [Param Type]
[LParam SOACS]
params
      [Type]
ret_tan <- [Type] -> ADM [Type]
forall a. Tangent a => a -> ADM a
tangent ([Type] -> ADM [Type]) -> [Type] -> ADM [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
n_indices [Type]
ret
      Body SOACS
body' <- Int -> Body SOACS -> ADM (Body SOACS)
fwdBodyScatter Int
n_indices Body SOACS
body
      let indices :: [Type]
indices = [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [[Type]]
forall a. Int -> a -> [a]
replicate Int
2 ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
n_indices [Type]
ret
          ret' :: [Type]
ret' = [Type]
indices [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
n_indices [Type]
ret [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ret_tan
      Lambda SOACS -> ADM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
[LParam SOACS]
params' Body SOACS
body' [Type]
ret'
    fwdBodyScatter :: Int -> Body SOACS -> ADM (Body SOACS)
    fwdBodyScatter :: Int -> Body SOACS -> ADM (Body SOACS)
fwdBodyScatter Int
n_indices (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = do
      (Result
res_tan, Stms SOACS
stms') <- ADM Result -> ADM (Result, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM Result -> ADM (Result, Stms (Rep ADM)))
-> ADM Result -> ADM (Result, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        (Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
        Result -> ADM Result
forall a. Tangent a => a -> ADM a
tangent (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
n_indices Result
res
      let indices :: Result
indices = [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result) -> [Result] -> Result
forall a b. (a -> b) -> a -> b
$ Int -> Result -> [Result]
forall a. Int -> a -> [a]
replicate Int
2 (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
n_indices Result
res
          res' :: Result
res' = Result
indices Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
n_indices Result
res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
res_tan
      Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> Result -> Body SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms' Result
res'
fwdSOAC Pat Type
_ StmAux ()
_ JVP {} =
  [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error [Char]
"fwdSOAC: nested JVP not allowed."
fwdSOAC Pat Type
_ StmAux ()
_ VJP {} =
  [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error [Char]
"fwdSOAC: nested VJP not allowed."

fwdStm :: Stm SOACS -> ADM ()
fwdStm :: Stm SOACS -> ADM ()
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp (UpdateAcc VName
acc [SubExp]
i [SubExp]
x))) = do
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
Pat (LetDec SOACS)
pat
  [SubExp]
x' <- [SubExp] -> ADM (BundledTan [SubExp])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [SubExp]
x
  VName
acc_tan <- VName -> ADM VName
forall a. Tangent a => a -> ADM a
tangent VName
acc
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc_tan [SubExp]
i [SubExp]
x'
fwdStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
e)) = do
  -- XXX: this has to be too naive.
  Bool -> ADM () -> ADM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
forall shape u. TypeBase shape u -> Bool
isAcc ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
Pat (LetDec SOACS)
pat) (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
    Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
  Pat Type -> StmAux () -> BasicOp -> ADM ()
basicFwd Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux BasicOp
e
fwdStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just (PrimType
ret, [PrimType]
argts) <- Name
-> Map Name (PrimType, [PrimType]) -> Maybe (PrimType, [PrimType])
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f Map Name (PrimType, [PrimType])
builtInFunctions = do
      Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep ADM)
Stm SOACS
stm
      [PrimExp VName]
arg_tans <-
        (PrimType -> SubExp -> PrimExp VName)
-> [PrimType] -> [SubExp] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts ([SubExp] -> [PrimExp VName])
-> ADM [SubExp] -> ADM [PrimExp VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> ADM SubExp) -> [(SubExp, Diet)] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> ADM SubExp
forall a. Tangent a => a -> ADM a
tangent (SubExp -> ADM SubExp)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args
      Pat Type
pat_tan <- Pat Type -> ADM (Pat Type)
forall a. TanBuilder a => a -> ADM a
newTan Pat Type
Pat (LetDec SOACS)
pat
      let arg_pes :: [PrimExp VName]
arg_pes = (PrimType -> SubExp -> PrimExp VName)
-> [PrimType] -> [SubExp] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)
      case Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin Name
f [PrimExp VName]
arg_pes of
        Maybe [PrimExp VName]
Nothing ->
          [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM ()) -> [Char] -> ADM ()
forall a b. (a -> b) -> a -> b
$ [Char]
"No partial derivative defined for builtin function: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Pretty a => a -> [Char]
pretty Name
f
        Just [PrimExp VName]
derivs -> do
          let convertTo :: PrimType -> PrimExp VName -> PrimExp VName
convertTo PrimType
tt PrimExp VName
e
                | PrimType
e_t PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
tt = PrimExp VName
e
                | Bool
otherwise =
                    case (PrimType
tt, PrimType
e_t) of
                      (IntType IntType
tt', IntType IntType
ft) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
ft IntType
tt') PrimExp VName
e
                      (FloatType FloatType
tt', FloatType FloatType
ft) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> FloatType -> ConvOp
FPConv FloatType
ft FloatType
tt') PrimExp VName
e
                      (PrimType
Bool, FloatType FloatType
ft) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
FToB FloatType
ft) PrimExp VName
e
                      (FloatType FloatType
tt', PrimType
Bool) -> ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
BToF FloatType
tt') PrimExp VName
e
                      (PrimType, PrimType)
_ -> [Char] -> PrimExp VName
forall a. HasCallStack => [Char] -> a
error ([Char] -> PrimExp VName) -> [Char] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ [Char]
"fwdStm.convertTo: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Name, PrimType, PrimType) -> [Char]
forall a. Pretty a => a -> [Char]
pretty (Name
f, PrimType
tt, PrimType
e_t)
                where
                  e_t :: PrimType
e_t = PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp VName
e
          (VName -> Exp SOACS -> ADM ()) -> [VName] -> [Exp SOACS] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([VName] -> Exp SOACS -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ([VName] -> Exp SOACS -> ADM ())
-> (VName -> [VName]) -> VName -> Exp SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure) (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat_tan)
            ([Exp SOACS] -> ADM ()) -> ADM [Exp SOACS] -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (PrimExp VName -> ADM (Exp SOACS))
-> [PrimExp VName] -> ADM [Exp SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp ((PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) ((PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> PrimExp VName -> PrimExp VName
convertTo PrimType
ret) [PrimExp VName]
arg_tans) [PrimExp VName]
derivs)
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (If SubExp
cond Body SOACS
t Body SOACS
f (IfDec [BranchType SOACS]
ret IfSort
ifsort))) = do
  Body SOACS
t' <- ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
slocal' (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
t
  Body SOACS
f' <- ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
slocal' (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
f
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
Pat (LetDec SOACS)
pat
  [ExtType]
ret' <- [ExtType] -> ADM (BundledTan [ExtType])
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan [ExtType]
[BranchType SOACS]
ret
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body SOACS
-> Body SOACS
-> IfDec (BranchType SOACS)
-> Exp SOACS
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
cond Body SOACS
t' Body SOACS
f' (IfDec (BranchType SOACS) -> Exp SOACS)
-> IfDec (BranchType SOACS) -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ [ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [ExtType]
ret' IfSort
ifsort
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
val_pats loop :: LoopForm SOACS
loop@(WhileLoop VName
v) Body SOACS
body)) = do
  [(Param DeclType, SubExp)]
val_pats' <- [(Param DeclType, SubExp)]
-> ADM (Bundled [(Param DeclType, SubExp)])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
Pat (LetDec SOACS)
pat
  Body SOACS
body' <-
    Scope SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats) Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> LoopForm SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
loop) (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
      ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
slocal' (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
body
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats' (VName -> LoopForm SOACS
forall rep. VName -> LoopForm rep
WhileLoop VName
v) Body SOACS
body'
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
val_pats loop :: LoopForm SOACS
loop@(ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
loop_vars) Body SOACS
body)) = do
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
Pat (LetDec SOACS)
pat
  [(Param DeclType, SubExp)]
val_pats' <- [(Param DeclType, SubExp)]
-> ADM (Bundled [(Param DeclType, SubExp)])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats
  [(Param Type, VName)]
loop_vars' <- [(Param Type, VName)] -> ADM (Bundled [(Param Type, VName)])
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars
  Body SOACS
body' <-
    Scope SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats) Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> LoopForm SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
loop) (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
      ADM (Body SOACS) -> ADM (Body SOACS)
forall a. ADM a -> ADM a
slocal' (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ADM (Body SOACS)
fwdBody Body SOACS
body
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats' (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars') Body SOACS
body'
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) = do
  [WithAccInput SOACS]
inputs' <- [WithAccInput SOACS]
-> (WithAccInput SOACS -> ADM (WithAccInput SOACS))
-> ADM [WithAccInput SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [WithAccInput SOACS]
inputs ((WithAccInput SOACS -> ADM (WithAccInput SOACS))
 -> ADM [WithAccInput SOACS])
-> (WithAccInput SOACS -> ADM (WithAccInput SOACS))
-> ADM [WithAccInput SOACS]
forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
op) -> do
    [VName]
arrs_tan <- [VName] -> ADM [VName]
forall a. Tangent a => a -> ADM a
tangent [VName]
arrs
    Maybe (Lambda SOACS, [SubExp])
op' <- case Maybe (Lambda SOACS, [SubExp])
op of
      Maybe (Lambda SOACS, [SubExp])
Nothing -> Maybe (Lambda SOACS, [SubExp])
-> ADM (Maybe (Lambda SOACS, [SubExp]))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Lambda SOACS, [SubExp])
forall a. Maybe a
Nothing
      Just (Lambda SOACS
op_lam, [SubExp]
nes) -> do
        [SubExp]
nes_tan <- (SubExp -> ADM SubExp) -> [SubExp] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> SubExp) -> ADM VName -> ADM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var (ADM VName -> ADM SubExp)
-> (SubExp -> ADM VName) -> SubExp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> ADM VName
zeroFromSubExp) [SubExp]
nes
        Lambda SOACS
op_lam' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda Lambda SOACS
op_lam
        case Lambda SOACS
op_lam' of
          Lambda [LParam SOACS]
ps Body SOACS
body [Type]
ret -> do
            let op_lam'' :: Lambda SOACS
op_lam'' = [LParam SOACS] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda (Int -> [Param Type] -> [Param Type]
forall t a. (Eq t, Num t) => t -> [a] -> [a]
removeIndexTans (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape) [Param Type]
[LParam SOACS]
ps) Body SOACS
body [Type]
ret
            Maybe (Lambda SOACS, [SubExp])
-> ADM (Maybe (Lambda SOACS, [SubExp]))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Lambda SOACS, [SubExp])
 -> ADM (Maybe (Lambda SOACS, [SubExp])))
-> Maybe (Lambda SOACS, [SubExp])
-> ADM (Maybe (Lambda SOACS, [SubExp]))
forall a b. (a -> b) -> a -> b
$ (Lambda SOACS, [SubExp]) -> Maybe (Lambda SOACS, [SubExp])
forall a. a -> Maybe a
Just (Lambda SOACS
op_lam'', [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
interleave [SubExp]
nes [SubExp]
nes_tan)
    WithAccInput SOACS -> ADM (WithAccInput SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
shape, [VName]
arrs [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
arrs_tan, Maybe (Lambda SOACS, [SubExp])
op')
  Pat Type
pat' <- Pat Type -> ADM (Bundled (Pat Type))
forall a. TanBuilder a => a -> ADM (Bundled a)
bundleNew Pat Type
Pat (LetDec SOACS)
pat
  Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
fwdLambda Lambda SOACS
lam
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [WithAccInput SOACS] -> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs' Lambda SOACS
lam'
  where
    removeIndexTans :: t -> [a] -> [a]
removeIndexTans t
0 [a]
ps = [a]
ps
    removeIndexTans t
i (a
p : a
_ : [a]
ps) = a
p a -> [a] -> [a]
forall a. a -> [a] -> [a]
: t -> [a] -> [a]
removeIndexTans (t
i t -> t -> t
forall a. Num a => a -> a -> a
- t
1) [a]
ps
    removeIndexTans t
_ [a]
ps = [a]
ps
fwdStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) = Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux Op SOACS
SOAC SOACS
soac
fwdStm Stm SOACS
stm =
  [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM ()) -> [Char] -> ADM ()
forall a b. (a -> b) -> a -> b
$ [Char]
"unhandled forward mode AD for Stm: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Stm SOACS -> [Char]
forall a. Pretty a => a -> [Char]
pretty Stm SOACS
stm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\n" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Stm SOACS -> [Char]
forall a. Show a => a -> [Char]
show Stm SOACS
stm

fwdBody :: Body SOACS -> ADM (Body SOACS)
fwdBody :: Body SOACS -> ADM (Body SOACS)
fwdBody (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
  (Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
  Result -> ADM (BundledTan Result)
forall a. Tangent a => a -> ADM (BundledTan a)
bundleTan Result
res

fwdBodyTansLast :: Body SOACS -> ADM (Body SOACS)
fwdBodyTansLast :: Body SOACS -> ADM (Body SOACS)
fwdBodyTansLast (Body BodyDec SOACS
_ Stms SOACS
stms Result
res) = ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
  (Stm SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> ADM ()
fwdStm Stms SOACS
stms
  (Result
res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<>) (Result -> Result) -> ADM Result -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> ADM Result
forall a. Tangent a => a -> ADM a
tangent Result
res

fwdJVP :: MonadFreshNames m => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
fwdJVP :: Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
fwdJVP Scope SOACS
scope l :: Lambda SOACS
l@(Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
  ADM (Lambda SOACS) -> m (Lambda SOACS)
forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS)
-> m (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope SOACS
scope (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS)
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda SOACS
l (ADM (Lambda SOACS) -> m (Lambda SOACS))
-> ADM (Lambda SOACS) -> m (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
params_tan <- [Param Type] -> ADM [Param Type]
forall a. TanBuilder a => a -> ADM a
newTan [Param Type]
[LParam SOACS]
params
    Body SOACS
body_tan <- Body SOACS -> ADM (Body SOACS)
fwdBodyTansLast Body SOACS
body
    [Type]
ret_tan <- [Type] -> ADM [Type]
forall a. Tangent a => a -> ADM a
tangent [Type]
ret
    Lambda SOACS -> ADM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> Body SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda ([Param Type]
[LParam SOACS]
params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
params_tan) Body SOACS
body_tan ([Type]
ret [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
<> [Type]
ret_tan)