{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}

-----------------------------------------------------------------------------
-- |
-- Copyright   :  (c) Edward Kmett 2012-2021
-- License     :  BSD3
-- Maintainer  :  ekmett@gmail.com
-- Stability   :  experimental
-- Portability :  GHC only
--
-- Reverse-Mode Automatic Differentiation using a single Wengert list (or \"tape\").
--
-- This version uses @Data.Reflection@ to find and update the tape.
--
-- This is asymptotically faster than using @Kahn@, which
-- is forced to reify and topologically sort the graph, but it requires
-- a fairly expensive rendezvous during construction when updated using
-- multiple threads.
--
-----------------------------------------------------------------------------

module Numeric.AD.Internal.Reverse
  ( Reverse(..)
  , Tape(..)
  , Head(..)
  , Cells(..)
  , reifyTape
  , reifyTypeableTape
  , partials
  , partialArrayOf
  , partialMapOf
  , derivativeOf
  , derivativeOf'
  , bind
  , unbind
  , unbindMap
  , unbindWith
  , unbindMapWithDefault
  , var
  , varId
  , primal
  ) where

import Data.Functor
import Control.Monad hiding (mapM)
import Control.Monad.ST
import Control.Monad.Trans.State
import Data.Array.ST
import Data.Array
import Data.Array.Unsafe as Unsafe
import Data.IORef
import Data.IntMap (IntMap, fromDistinctAscList, findWithDefault)
import Data.Number.Erf
import Data.Proxy
import Data.Reflection
import Data.Traversable (mapM)
import Data.Typeable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
import Prelude hiding (mapM)
import System.IO.Unsafe (unsafePerformIO)
import Unsafe.Coerce

-- evil untyped tape
data Cells where
  Nil    :: Cells
  Unary  :: {-# UNPACK #-} !Int -> a -> Cells -> Cells
  Binary :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Int -> a -> a -> Cells -> Cells

dropCells :: Int -> Cells -> Cells
dropCells :: Int -> Cells -> Cells
dropCells Int
0 Cells
xs = Cells
xs
dropCells Int
_ Cells
Nil = Cells
Nil
dropCells Int
n (Unary Int
_ a
_ Cells
xs)      = (Int -> Cells -> Cells
dropCells forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
- Int
1) Cells
xs
dropCells Int
n (Binary Int
_ Int
_ a
_ a
_ Cells
xs) = (Int -> Cells -> Cells
dropCells forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
- Int
1) Cells
xs

data Head = Head {-# UNPACK #-} !Int Cells

newtype Tape = Tape { Tape -> IORef Head
getTape :: IORef Head }

un :: Int -> a -> Head -> (Head, Int)
un :: forall a. Int -> a -> Head -> (Head, Int)
un Int
i a
di (Head Int
r Cells
t) = Head
h seq :: forall a b. a -> b -> b
`seq` Int
r' seq :: forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
  r' :: Int
r' = Int
r forall a. Num a => a -> a -> a
+ Int
1
  h :: Head
h = Int -> Cells -> Head
Head Int
r' (forall a. Int -> a -> Cells -> Cells
Unary Int
i a
di Cells
t)
{-# INLINE un #-}

bin :: Int -> Int -> a -> a -> Head -> (Head, Int)
bin :: forall a. Int -> Int -> a -> a -> Head -> (Head, Int)
bin Int
i Int
j a
di a
dj (Head Int
r Cells
t) = Head
h seq :: forall a b. a -> b -> b
`seq` Int
r' seq :: forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
  r' :: Int
r' = Int
r forall a. Num a => a -> a -> a
+ Int
1
  h :: Head
h = Int -> Cells -> Head
Head Int
r' (forall a. Int -> Int -> a -> a -> Cells -> Cells
Binary Int
i Int
j a
di a
dj Cells
t)
{-# INLINE bin #-}

modifyTape :: Reifies s Tape => p s -> (Head -> (Head, r)) -> IO r
modifyTape :: forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape p s
p = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef (Tape -> IORef Head
getTape (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect p s
p))
{-# INLINE modifyTape #-}

-- | This is used to create a new entry on the chain given a unary function, its derivative with respect to its input,
-- the variable ID of its input, and the value of its input. Used by 'unary' and 'binary' internally.
unarily :: forall s a. Reifies s Tape => (a -> a) -> a -> Int -> a -> Reverse s a
unarily :: forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily a -> a
f a
di Int
i a
b = forall a s. Int -> a -> Reverse s a
Reverse (forall a. IO a -> a
unsafePerformIO (forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (forall a. Int -> a -> Head -> (Head, Int)
un Int
i a
di))) forall a b. (a -> b) -> a -> b
$! a -> a
f a
b
{-# INLINE unarily #-}

-- | This is used to create a new entry on the chain given a binary function, its derivatives with respect to its inputs,
-- their variable IDs and values. Used by 'binary' internally.
binarily :: forall s a. Reifies s Tape => (a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
binarily :: forall s a.
Reifies s Tape =>
(a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
binarily a -> a -> a
f a
di a
dj Int
i a
b Int
j a
c = forall a s. Int -> a -> Reverse s a
Reverse (forall a. IO a -> a
unsafePerformIO (forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (forall a. Int -> Int -> a -> a -> Head -> (Head, Int)
bin Int
i Int
j a
di a
dj))) forall a b. (a -> b) -> a -> b
$! a -> a -> a
f a
b a
c
{-# INLINE binarily #-}

data Reverse s a where
  Zero :: Reverse s a
  Lift :: a -> Reverse s a
  Reverse :: {-# UNPACK #-} !Int -> a -> Reverse s a
  deriving (Int -> Reverse s a -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall s a. Show a => Int -> Reverse s a -> ShowS
forall s a. Show a => [Reverse s a] -> ShowS
forall s a. Show a => Reverse s a -> String
showList :: [Reverse s a] -> ShowS
$cshowList :: forall s a. Show a => [Reverse s a] -> ShowS
show :: Reverse s a -> String
$cshow :: forall s a. Show a => Reverse s a -> String
showsPrec :: Int -> Reverse s a -> ShowS
$cshowsPrec :: forall s a. Show a => Int -> Reverse s a -> ShowS
Show, Typeable)

instance (Reifies s Tape, Num a) => Mode (Reverse s a) where
  type Scalar (Reverse s a) = a

  isKnownZero :: Reverse s a -> Bool
isKnownZero Reverse s a
Zero = Bool
True
  isKnownZero Reverse s a
_    = Bool
False

  asKnownConstant :: Reverse s a -> Maybe (Scalar (Reverse s a))
asKnownConstant Reverse s a
Zero = forall a. a -> Maybe a
Just a
0
  asKnownConstant (Lift a
n) = forall a. a -> Maybe a
Just a
n
  asKnownConstant Reverse s a
_ = forall a. Maybe a
Nothing

  isKnownConstant :: Reverse s a -> Bool
isKnownConstant Reverse{} = Bool
False
  isKnownConstant Reverse s a
_ = Bool
True

  auto :: Scalar (Reverse s a) -> Reverse s a
auto = forall a s. a -> Reverse s a
Lift
  zero :: Reverse s a
zero = forall s a. Reverse s a
Zero
  Scalar (Reverse s a)
a *^ :: Scalar (Reverse s a) -> Reverse s a -> Reverse s a
*^ Reverse s a
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (Reverse s a)
a forall a. Num a => a -> a -> a
*) (\D (Reverse s a)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (Reverse s a)
a) Reverse s a
b
  Reverse s a
a ^* :: Reverse s a -> Scalar (Reverse s a) -> Reverse s a
^* Scalar (Reverse s a)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Num a => a -> a -> a
* Scalar (Reverse s a)
b) (\D (Reverse s a)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (Reverse s a)
b) Reverse s a
a
  Reverse s a
a ^/ :: Fractional (Scalar (Reverse s a)) =>
Reverse s a -> Scalar (Reverse s a) -> Reverse s a
^/ Scalar (Reverse s a)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Fractional a => a -> a -> a
/ Scalar (Reverse s a)
b) (\D (Reverse s a)
_ -> forall t. Mode t => Scalar t -> t
auto (forall a. Fractional a => a -> a
recip Scalar (Reverse s a)
b)) Reverse s a
a

(<+>) :: (Reifies s Tape, Num a) => Reverse s a -> Reverse s a -> Reverse s a
<+> :: forall s a.
(Reifies s Tape, Num a) =>
Reverse s a -> Reverse s a -> Reverse s a
(<+>)  = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary forall a. Num a => a -> a -> a
(+) Id a
1 Id a
1

primal :: Num a => Reverse s a -> a
primal :: forall a s. Num a => Reverse s a -> a
primal Reverse s a
Zero = a
0
primal (Lift a
a) = a
a
primal (Reverse Int
_ a
a) = a
a

instance (Reifies s Tape, Num a) => Jacobian (Reverse s a) where
  type D (Reverse s a) = Id a

  unary :: (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a) -> Reverse s a -> Reverse s a
unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_          Reverse s a
Zero    = forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
0)
  unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_         (Lift a
a) = forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
a)
  unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadi) (Reverse Int
i a
b) = forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
dadi Int
i a
b

  lift1 :: (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a) -> D (Reverse s a))
-> Reverse s a
-> Reverse s a
lift1 Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a) -> D (Reverse s a)
df Reverse s a
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (Reverse s a) -> Scalar (Reverse s a)
f (D (Reverse s a) -> D (Reverse s a)
df (forall a. a -> Id a
Id a
pb)) Reverse s a
b where
    pb :: a
pb = forall a s. Num a => Reverse s a -> a
primal Reverse s a
b

  lift1_ :: (Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a) -> D (Reverse s a) -> D (Reverse s a))
-> Reverse s a
-> Reverse s a
lift1_ Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a) -> D (Reverse s a) -> D (Reverse s a)
df Reverse s a
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (forall a b. a -> b -> a
const Scalar (Reverse s a)
a) (D (Reverse s a) -> D (Reverse s a) -> D (Reverse s a)
df (forall a. a -> Id a
Id Scalar (Reverse s a)
a) (forall a. a -> Id a
Id a
pb)) Reverse s a
b where
    pb :: a
pb = forall a s. Num a => Reverse s a -> a
primal Reverse s a
b
    a :: Scalar (Reverse s a)
a = Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
pb

  binary :: (Scalar (Reverse s a)
 -> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> D (Reverse s a)
-> D (Reverse s a)
-> Reverse s a
-> Reverse s a
-> Reverse s a
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_         D (Reverse s a)
_         Reverse s a
Zero     Reverse s a
Zero     = forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
0 a
0)
  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_         D (Reverse s a)
_         Reverse s a
Zero     (Lift a
c) = forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
0 a
c)
  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_         D (Reverse s a)
_         (Lift a
b) Reverse s a
Zero     = forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
b a
0)
  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_         D (Reverse s a)
_         (Lift a
b) (Lift a
c) = forall a s. a -> Reverse s a
Lift (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
b a
c)

  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_         (Id a
dadc) Reverse s a
Zero        (Reverse Int
i a
c) = forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
0) a
dadc Int
i a
c
  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
_         (Id a
dadc) (Lift a
b)    (Reverse Int
i a
c) = forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
b) a
dadc Int
i a
c
  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadb) D (Reverse s a)
_         (Reverse Int
i a
b) Reverse s a
Zero        = forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
`f` a
0) a
dadb Int
i a
b
  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadb) D (Reverse s a)
_         (Reverse Int
i a
b) (Lift a
c)    = forall s a.
Reifies s Tape =>
(a -> a) -> a -> Int -> a -> Reverse s a
unarily (Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
`f` a
c) a
dadb Int
i a
b
  binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f (Id a
dadb) (Id a
dadc) (Reverse Int
i a
b) (Reverse Int
j a
c) = forall s a.
Reifies s Tape =>
(a -> a -> a) -> a -> a -> Int -> a -> Int -> a -> Reverse s a
binarily Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
dadb a
dadc Int
i a
b Int
j a
c

  lift2 :: (Scalar (Reverse s a)
 -> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a)
    -> D (Reverse s a) -> (D (Reverse s a), D (Reverse s a)))
-> Reverse s a
-> Reverse s a
-> Reverse s a
lift2 Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
-> D (Reverse s a) -> (D (Reverse s a), D (Reverse s a))
df Reverse s a
b Reverse s a
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
dadb D (Reverse s a)
dadc Reverse s a
b Reverse s a
c where
    (D (Reverse s a)
dadb, D (Reverse s a)
dadc) = D (Reverse s a)
-> D (Reverse s a) -> (D (Reverse s a), D (Reverse s a))
df (forall a. a -> Id a
Id (forall a s. Num a => Reverse s a -> a
primal Reverse s a
b)) (forall a. a -> Id a
Id (forall a s. Num a => Reverse s a -> a
primal Reverse s a
c))

  lift2_ :: (Scalar (Reverse s a)
 -> Scalar (Reverse s a) -> Scalar (Reverse s a))
-> (D (Reverse s a)
    -> D (Reverse s a)
    -> D (Reverse s a)
    -> (D (Reverse s a), D (Reverse s a)))
-> Reverse s a
-> Reverse s a
-> Reverse s a
lift2_ Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f D (Reverse s a)
-> D (Reverse s a)
-> D (Reverse s a)
-> (D (Reverse s a), D (Reverse s a))
df Reverse s a
b Reverse s a
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (Reverse s a)
_ Scalar (Reverse s a)
_ -> Scalar (Reverse s a)
a) D (Reverse s a)
dadb D (Reverse s a)
dadc Reverse s a
b Reverse s a
c where
    pb :: a
pb = forall a s. Num a => Reverse s a -> a
primal Reverse s a
b
    pc :: a
pc = forall a s. Num a => Reverse s a -> a
primal Reverse s a
c
    a :: Scalar (Reverse s a)
a = Scalar (Reverse s a)
-> Scalar (Reverse s a) -> Scalar (Reverse s a)
f a
pb a
pc
    (D (Reverse s a)
dadb, D (Reverse s a)
dadc) = D (Reverse s a)
-> D (Reverse s a)
-> D (Reverse s a)
-> (D (Reverse s a), D (Reverse s a))
df (forall a. a -> Id a
Id Scalar (Reverse s a)
a) (forall a. a -> Id a
Id a
pb) (forall a. a -> Id a
Id a
pc)

mul :: (Reifies s Tape, Num a) => Reverse s a -> Reverse s a -> Reverse s a
mul :: forall s a.
(Reifies s Tape, Num a) =>
Reverse s a -> Reverse s a -> Reverse s a
mul = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 forall a. Num a => a -> a -> a
(*) (\D (Reverse s a)
x D (Reverse s a)
y -> (D (Reverse s a)
y, D (Reverse s a)
x))

#define BODY1(x) (Reifies s Tape,x) =>
#define BODY2(x,y) (Reifies s Tape,x,y) =>
#define HEAD (Reverse s a)
#include "instances.h"

-- | Helper that extracts the derivative of a chain when the chain was constructed with 1 variable.
derivativeOf :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> a
derivativeOf :: forall s a. (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> a
derivativeOf Proxy s
_ = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials
{-# INLINE derivativeOf #-}

-- | Helper that extracts both the primal and derivative of a chain when the chain was constructed with 1 variable.
derivativeOf' :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> (a, a)
derivativeOf' :: forall s a.
(Reifies s Tape, Num a) =>
Proxy s -> Reverse s a -> (a, a)
derivativeOf' Proxy s
p Reverse s a
r = (forall a s. Num a => Reverse s a -> a
primal Reverse s a
r, forall s a. (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> a
derivativeOf Proxy s
p Reverse s a
r)
{-# INLINE derivativeOf' #-}

-- | Used internally to push sensitivities down the chain.
backPropagate :: Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate :: forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate Int
k Cells
Nil STArray s Int a
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Int
k
backPropagate Int
k (Unary Int
i a
g Cells
xs) STArray s Int a
ss = do
  a
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
k
  a
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
  forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
i forall a b. (a -> b) -> a -> b
$! a
db forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce a
gforall a. Num a => a -> a -> a
*a
da
  (forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate forall a b. (a -> b) -> a -> b
$! Int
k forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int a
ss
backPropagate Int
k (Binary Int
i Int
j a
g a
h Cells
xs) STArray s Int a
ss = do
  a
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
k
  a
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
i
  forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
i forall a b. (a -> b) -> a -> b
$! a
db forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce a
gforall a. Num a => a -> a -> a
*a
da
  a
dc <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int a
ss Int
j
  forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
j forall a b. (a -> b) -> a -> b
$! a
dc forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce a
hforall a. Num a => a -> a -> a
*a
da
  (forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate forall a b. (a -> b) -> a -> b
$! Int
k forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int a
ss

-- | Extract the partials from the current chain for a given AD variable.
{-# SPECIALIZE partials :: Reifies s Tape => Reverse s Double -> [Double] #-}
partials :: forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials :: forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials Reverse s a
Zero        = []
partials (Lift a
_)    = []
partials (Reverse Int
k a
_) = forall a b. (a -> b) -> [a] -> [b]
map (Array Int a
sensitivities forall i e. Ix i => Array i e -> i -> e
!) [Int
0..Int
vs] where
  Head Int
n Cells
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef (Tape -> IORef Head
getTape (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy :: Proxy s)))
  tk :: Cells
tk = Int -> Cells -> Cells
dropCells (Int
n forall a. Num a => a -> a -> a
- Int
k) Cells
t
  (Int
vs,Array Int a
sensitivities) = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
    STArray s Int a
ss <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
k) a
0
    forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int a
ss Int
k a
1
    Int
v <- forall a s. Num a => Int -> Cells -> STArray s Int a -> ST s Int
backPropagate Int
k Cells
tk STArray s Int a
ss
    Array Int a
as <- forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
Unsafe.unsafeFreeze STArray s Int a
ss
    forall (m :: * -> *) a. Monad m => a -> m a
return (Int
v, Array Int a
as)

-- | Return an 'Array' of 'partials' given bounds for the variable IDs.
partialArrayOf :: (Reifies s Tape, Num a) => Proxy s -> (Int, Int) -> Reverse s a -> Array Int a
partialArrayOf :: forall s a.
(Reifies s Tape, Num a) =>
Proxy s -> (Int, Int) -> Reverse s a -> Array Int a
partialArrayOf Proxy s
_ (Int, Int)
vbounds = forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray forall a. Num a => a -> a -> a
(+) a
0 (Int, Int)
vbounds forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials
{-# INLINE partialArrayOf #-}

-- | Return an 'IntMap' of sparse partials
partialMapOf :: (Reifies s Tape, Num a) => Proxy s -> Reverse s a -> IntMap a
partialMapOf :: forall s a.
(Reifies s Tape, Num a) =>
Proxy s -> Reverse s a -> IntMap a
partialMapOf Proxy s
_ = forall a. [(Int, a)] -> IntMap a
fromDistinctAscList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s a. (Reifies s Tape, Num a) => Reverse s a -> [a]
partials
{-# INLINE partialMapOf #-}

-- | Construct a tape that starts with @n@ variables.
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape :: forall r. Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape Int
vs forall s. Reifies s Tape => Proxy s -> r
k = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  IORef Head
h <- forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify (IORef Head -> Tape
Tape IORef Head
h) forall s. Reifies s Tape => Proxy s -> r
k)
{-# NOINLINE reifyTape #-}

-- | Construct a tape that starts with @n@ variables.
reifyTypeableTape :: Int -> (forall s. (Typeable s, Reifies s Tape) => Proxy s -> r) -> r
reifyTypeableTape :: forall r.
Int
-> (forall s. (Typeable s, Reifies s Tape) => Proxy s -> r) -> r
reifyTypeableTape Int
vs forall s. (Typeable s, Reifies s Tape) => Proxy s -> r
k = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  IORef Head
h <- forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall a r.
Typeable a =>
a -> (forall s. (Typeable s, Reifies s a) => Proxy s -> r) -> r
reifyTypeable (IORef Head -> Tape
Tape IORef Head
h) forall s. (Typeable s, Reifies s Tape) => Proxy s -> r
k)
{-# NOINLINE reifyTypeableTape #-}

var :: a -> Int -> Reverse s a
var :: forall a s. a -> Int -> Reverse s a
var a
a Int
v = forall a s. Int -> a -> Reverse s a
Reverse Int
v a
a

varId :: Reverse s a -> Int
varId :: forall s a. Reverse s a -> Int
varId (Reverse Int
v a
_) = Int
v
varId Reverse s a
_ = forall a. HasCallStack => String -> a
error String
"varId: not a Var"

bind :: Traversable f => f a -> (f (Reverse s a), (Int,Int))
bind :: forall (f :: * -> *) a s.
Traversable f =>
f a -> (f (Reverse s a), (Int, Int))
bind f a
xs = (f (Reverse s a)
r,(Int
0,Int
hi)) where
  (f (Reverse s a)
r,Int
hi) = forall s a. State s a -> s -> (a, s)
runState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {a} {s}.
Monad m =>
a -> StateT Int m (Reverse s a)
freshVar f a
xs) Int
0
  freshVar :: a -> StateT Int m (Reverse s a)
freshVar a
a = forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s forall a. Num a => a -> a -> a
+ Int
1 in Int
s' seq :: forall a b. a -> b -> b
`seq` (forall a s. a -> Int -> Reverse s a
var a
a Int
s, Int
s')

unbind :: Functor f => f (Reverse s a) -> Array Int a -> f a
unbind :: forall (f :: * -> *) s a.
Functor f =>
f (Reverse s a) -> Array Int a -> f a
unbind f (Reverse s a)
xs Array Int a
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> Array Int a
ys forall i e. Ix i => Array i e -> i -> e
! forall s a. Reverse s a -> Int
varId Reverse s a
v) f (Reverse s a)
xs

unbindWith :: (Functor f, Num a) => (a -> b -> c) -> f (Reverse s a) -> Array Int b -> f c
unbindWith :: forall (f :: * -> *) a b c s.
(Functor f, Num a) =>
(a -> b -> c) -> f (Reverse s a) -> Array Int b -> f c
unbindWith a -> b -> c
f f (Reverse s a)
xs Array Int b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> a -> b -> c
f (forall a s. Num a => Reverse s a -> a
primal Reverse s a
v) (Array Int b
ys forall i e. Ix i => Array i e -> i -> e
! forall s a. Reverse s a -> Int
varId Reverse s a
v)) f (Reverse s a)
xs

unbindMap :: (Functor f, Num a) => f (Reverse s a) -> IntMap a -> f a
unbindMap :: forall (f :: * -> *) a s.
(Functor f, Num a) =>
f (Reverse s a) -> IntMap a -> f a
unbindMap f (Reverse s a)
xs IntMap a
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> forall a. a -> Int -> IntMap a -> a
findWithDefault a
0 (forall s a. Reverse s a -> Int
varId Reverse s a
v) IntMap a
ys) f (Reverse s a)
xs

unbindMapWithDefault :: (Functor f, Num a) => b -> (a -> b -> c) -> f (Reverse s a) -> IntMap b -> f c
unbindMapWithDefault :: forall (f :: * -> *) a b c s.
(Functor f, Num a) =>
b -> (a -> b -> c) -> f (Reverse s a) -> IntMap b -> f c
unbindMapWithDefault b
z a -> b -> c
f f (Reverse s a)
xs IntMap b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Reverse s a
v -> a -> b -> c
f (forall a s. Num a => Reverse s a -> a
primal Reverse s a
v) forall a b. (a -> b) -> a -> b
$ forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (forall s a. Reverse s a -> Int
varId Reverse s a
v) IntMap b
ys) f (Reverse s a)
xs