{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}

#ifdef AD_FFI
{-# LANGUAGE ForeignFunctionInterface #-}
#endif

module Numeric.AD.Internal.Reverse.Double
  ( ReverseDouble(..)
  , Tape(..)
  , reifyTape
  , reifyTypeableTape
  , partials
  , partialArrayOf
  , partialMapOf
  , derivativeOf
  , derivativeOf'
  , bind
  , unbind
  , unbindMap
  , unbindWith
  , unbindMapWithDefault
  , var
  , varId
  , primal
  ) where

#ifdef AD_FFI
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.C.Types
import qualified Foreign.Marshal.Array as MA
import qualified Foreign.Marshal.Alloc as MA
#else
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unsafe as Unsafe
import Data.IORef
import Unsafe.Coerce
#endif

import Data.Functor
import Control.Monad hiding (mapM)
import Control.Monad.Trans.State
import Data.Array
import Data.IntMap (IntMap, fromDistinctAscList, findWithDefault)
import Data.Number.Erf
import Data.Proxy
import Data.Reflection
import Data.Traversable (mapM)
import Data.Typeable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
import Prelude hiding (mapM)
import System.IO.Unsafe (unsafePerformIO)

#ifdef AD_FFI

newtype Tape = Tape { getTape :: ForeignPtr Tape }

foreign import ccall unsafe "tape_alloc" c_tape_alloc :: CInt -> CInt -> IO (Ptr Tape)
foreign import ccall unsafe "tape_push" c_tape_push :: Ptr Tape -> CInt -> CInt -> Double -> Double -> IO Int
foreign import ccall unsafe "tape_backPropagate" c_tape_backPropagate :: Ptr Tape -> CInt -> Ptr Double -> IO ()
foreign import ccall unsafe "tape_variables" c_tape_variables :: Ptr Tape -> IO CInt
foreign import ccall unsafe "&tape_free" c_ref_tape_free :: FinalizerPtr Tape

pushTape :: Reifies s Tape => p s -> Int -> Int -> Double -> Double -> IO Int
pushTape p i1 i2 d1 d2 = do
  withForeignPtr (getTape (reflect p)) $ \tape -> 
    c_tape_push tape (fromIntegral i1) (fromIntegral i2) d1 d2
{-# INLINE pushTape #-}

-- | Extract the partials from the current chain for a given AD variable.
partials :: forall s. (Reifies s Tape) => ReverseDouble s -> [Double]
partials Zero        = []
partials (Lift _)    = []
partials (ReverseDouble k _) = unsafePerformIO $
  withForeignPtr (getTape (reflect (Proxy :: Proxy s))) $ \tape -> do
    l <- fromIntegral <$> c_tape_variables tape
    arr <- MA.mallocArray l
    c_tape_backPropagate tape (fromIntegral k) arr
    ps <- MA.peekArray l arr
    MA.free arr
    return ps
{-# INLINE partials #-}

newTape :: Int -> IO Tape
newTape vs = do
  p <- c_tape_alloc (fromIntegral vs) (4 * 1024)
  Tape <$> newForeignPtr c_ref_tape_free p

-- | Construct a tape that starts with @n@ variables.
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape vs k = unsafePerformIO $ fmap (\t -> reify t k) (newTape vs)
{-# NOINLINE reifyTape #-}

-- | Construct a tape that starts with @n@ variables.
reifyTypeableTape :: Int -> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape vs k = unsafePerformIO $ fmap (\t -> reifyTypeable t k) (newTape vs)
{-# NOINLINE reifyTypeableTape #-}

-- | This is used to create a new entry on the chain given a unary function, its derivative with respect to its input,
-- the variable ID of its input, and the value of its input. Used by 'unary' and 'binary' internally.
unarily :: forall s. Reifies s Tape => (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily f di i b = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i 0 di 0.0)) $! f b
{-# INLINE unarily #-}

-- | This is used to create a new entry on the chain given a binary function, its derivatives with respect to its inputs,
-- their variable IDs and values. Used by 'binary' internally.
binarily :: forall s. Reifies s Tape => (Double -> Double -> Double) -> Double -> Double -> Int -> Double -> Int -> Double -> ReverseDouble s
binarily f di dj i b j c = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i j di dj)) $! f b c
{-# INLINE binarily #-}

#else

data Cells where
  Nil    :: Cells
  Unary  :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> !Cells -> Cells
  Binary :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> {-# UNPACK #-} !Double -> !Cells -> Cells

dropCells :: Int -> Cells -> Cells
dropCells :: Int -> Cells -> Cells
dropCells Int
0 Cells
xs = Cells
xs
dropCells Int
_ Cells
Nil = Cells
Nil
dropCells Int
n (Unary Int
_ Double
_ Cells
xs)      = (Int -> Cells -> Cells
dropCells forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
- Int
1) Cells
xs
dropCells Int
n (Binary Int
_ Int
_ Double
_ Double
_ Cells
xs) = (Int -> Cells -> Cells
dropCells forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
- Int
1) Cells
xs

data Head = Head {-# UNPACK #-} !Int !Cells

newtype Tape = Tape { Tape -> IORef Head
getTape :: IORef Head }

-- | Used internally to push sensitivities down the chain.
backPropagate :: Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate :: forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate Int
k Cells
Nil STArray s Int Double
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Int
k
backPropagate Int
k (Unary Int
i Double
g Cells
xs) STArray s Int Double
ss = do
  Double
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
k
  Double
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
i
  forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
i forall a b. (a -> b) -> a -> b
$! Double
db forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce Double
gforall a. Num a => a -> a -> a
*Double
da
  (forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate forall a b. (a -> b) -> a -> b
$! Int
k forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int Double
ss
backPropagate Int
k (Binary Int
i Int
j Double
g Double
h Cells
xs) STArray s Int Double
ss = do
  Double
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
k
  Double
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
i
  forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
i forall a b. (a -> b) -> a -> b
$! Double
db forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce Double
gforall a. Num a => a -> a -> a
*Double
da
  Double
dc <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
j
  forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
j forall a b. (a -> b) -> a -> b
$! Double
dc forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce Double
hforall a. Num a => a -> a -> a
*Double
da
  (forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate forall a b. (a -> b) -> a -> b
$! Int
k forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int Double
ss

-- | Extract the partials from the current chain for a given AD variable.
partials :: forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials :: forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials ReverseDouble s
Zero        = []
partials (Lift Double
_)    = []
partials (ReverseDouble Int
k Double
_) = forall a b. (a -> b) -> [a] -> [b]
map (Array Int Double
sensitivities forall i e. Ix i => Array i e -> i -> e
!) [Int
0..Int
vs] where
  Head Int
n Cells
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef (Tape -> IORef Head
getTape (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy :: Proxy s)))
  tk :: Cells
tk = Int -> Cells -> Cells
dropCells (Int
n forall a. Num a => a -> a -> a
- Int
k) Cells
t
  (Int
vs,Array Int Double
sensitivities) = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
    STArray s Int Double
ss <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
k) Double
0
    forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
k Double
1
    Int
v <- forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate Int
k Cells
tk STArray s Int Double
ss
    Array Int Double
as <- forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
Unsafe.unsafeFreeze STArray s Int Double
ss
    forall (m :: * -> *) a. Monad m => a -> m a
return (Int
v, Array Int Double
as)

-- | Construct a tape that starts with @n@ variables.
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape :: forall r. Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape Int
vs forall s. Reifies s Tape => Proxy s -> r
k = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  IORef Head
h <- forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify (IORef Head -> Tape
Tape IORef Head
h) forall s. Reifies s Tape => Proxy s -> r
k)
{-# NOINLINE reifyTape #-}

-- | Construct a tape that starts with @n@ variables.
reifyTypeableTape :: Int -> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape :: forall r.
Int
-> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape Int
vs forall s. (Reifies s Tape, Typeable s) => Proxy s -> r
k = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  IORef Head
h <- forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall a r.
Typeable a =>
a -> (forall s. (Typeable s, Reifies s a) => Proxy s -> r) -> r
reifyTypeable (IORef Head -> Tape
Tape IORef Head
h) forall s. (Reifies s Tape, Typeable s) => Proxy s -> r
k)
{-# NOINLINE reifyTypeableTape #-}

un :: Int -> Double -> Head -> (Head, Int)
un :: Int -> Double -> Head -> (Head, Int)
un Int
i Double
di (Head Int
r Cells
t) = Head
h seq :: forall a b. a -> b -> b
`seq` Int
r' seq :: forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
  r' :: Int
r' = Int
r forall a. Num a => a -> a -> a
+ Int
1
  h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> Double -> Cells -> Cells
Unary Int
i Double
di Cells
t)
{-# INLINE un #-}

bin :: Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin :: Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin Int
i Int
j Double
di Double
dj (Head Int
r Cells
t) = Head
h seq :: forall a b. a -> b -> b
`seq` Int
r' seq :: forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
  r' :: Int
r' = Int
r forall a. Num a => a -> a -> a
+ Int
1
  h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> Int -> Double -> Double -> Cells -> Cells
Binary Int
i Int
j Double
di Double
dj Cells
t)
{-# INLINE bin #-}

modifyTape :: Reifies s Tape => p s -> (Head -> (Head, r)) -> IO r
modifyTape :: forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape p s
p = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef (Tape -> IORef Head
getTape (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect p s
p))
{-# INLINE modifyTape #-}

-- | This is used to create a new entry on the chain given a unary function, its derivative with respect to its input,
-- the variable ID of its input, and the value of its input. Used by 'unary' and 'binary' internally.
unarily :: forall s. Reifies s Tape => (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily :: forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily Double -> Double
f Double
di Int
i Double
b = forall s. Int -> Double -> ReverseDouble s
ReverseDouble (forall a. IO a -> a
unsafePerformIO (forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (Int -> Double -> Head -> (Head, Int)
un Int
i Double
di))) forall a b. (a -> b) -> a -> b
$! Double -> Double
f Double
b
{-# INLINE unarily #-}

-- | This is used to create a new entry on the chain given a binary function, its derivatives with respect to its inputs,
-- their variable IDs and values. Used by 'binary' internally.
binarily :: forall s. Reifies s Tape => (Double -> Double -> Double) -> Double -> Double -> Int -> Double -> Int -> Double -> ReverseDouble s
binarily :: forall s.
Reifies s Tape =>
(Double -> Double -> Double)
-> Double
-> Double
-> Int
-> Double
-> Int
-> Double
-> ReverseDouble s
binarily Double -> Double -> Double
f Double
di Double
dj Int
i Double
b Int
j Double
c = forall s. Int -> Double -> ReverseDouble s
ReverseDouble (forall a. IO a -> a
unsafePerformIO (forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin Int
i Int
j Double
di Double
dj))) forall a b. (a -> b) -> a -> b
$! Double -> Double -> Double
f Double
b Double
c
{-# INLINE binarily #-}

#endif

data ReverseDouble s where
  Zero :: ReverseDouble s
  Lift :: {-# UNPACK #-} !Double -> ReverseDouble s
  ReverseDouble :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> ReverseDouble s
  deriving (Int -> ReverseDouble s -> ShowS
forall s. Int -> ReverseDouble s -> ShowS
forall s. [ReverseDouble s] -> ShowS
forall s. ReverseDouble s -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReverseDouble s] -> ShowS
$cshowList :: forall s. [ReverseDouble s] -> ShowS
show :: ReverseDouble s -> String
$cshow :: forall s. ReverseDouble s -> String
showsPrec :: Int -> ReverseDouble s -> ShowS
$cshowsPrec :: forall s. Int -> ReverseDouble s -> ShowS
Show, Typeable)

instance Reifies s Tape => Mode (ReverseDouble s) where
  type Scalar (ReverseDouble s) = Double

  isKnownZero :: ReverseDouble s -> Bool
isKnownZero ReverseDouble s
Zero = Bool
True
  isKnownZero (Lift Double
0) = Bool
True
  isKnownZero ReverseDouble s
_    = Bool
False

  isKnownConstant :: ReverseDouble s -> Bool
isKnownConstant ReverseDouble{} = Bool
False
  isKnownConstant ReverseDouble s
_ = Bool
True

  auto :: Scalar (ReverseDouble s) -> ReverseDouble s
auto = forall s. Double -> ReverseDouble s
Lift
  zero :: ReverseDouble s
zero = forall s. ReverseDouble s
Zero
  Scalar (ReverseDouble s)
a *^ :: Scalar (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
*^ ReverseDouble s
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (ReverseDouble s)
a forall a. Num a => a -> a -> a
*) (\D (ReverseDouble s)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (ReverseDouble s)
a) ReverseDouble s
b
  ReverseDouble s
a ^* :: ReverseDouble s -> Scalar (ReverseDouble s) -> ReverseDouble s
^* Scalar (ReverseDouble s)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Num a => a -> a -> a
* Scalar (ReverseDouble s)
b) (\D (ReverseDouble s)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (ReverseDouble s)
b) ReverseDouble s
a
  ReverseDouble s
a ^/ :: Fractional (Scalar (ReverseDouble s)) =>
ReverseDouble s -> Scalar (ReverseDouble s) -> ReverseDouble s
^/ Scalar (ReverseDouble s)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Fractional a => a -> a -> a
/ Scalar (ReverseDouble s)
b) (\D (ReverseDouble s)
_ -> forall t. Mode t => Scalar t -> t
auto (forall a. Fractional a => a -> a
recip Scalar (ReverseDouble s)
b)) ReverseDouble s
a

(<+>) :: Reifies s Tape => ReverseDouble s -> ReverseDouble s -> ReverseDouble s
<+> :: forall s.
Reifies s Tape =>
ReverseDouble s -> ReverseDouble s -> ReverseDouble s
(<+>)  = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary forall a. Num a => a -> a -> a
(+) Id Double
1 Id Double
1

primal :: ReverseDouble s -> Double
primal :: forall s. ReverseDouble s -> Double
primal ReverseDouble s
Zero = Double
0
primal (Lift Double
a) = Double
a
primal (ReverseDouble Int
_ Double
a) = Double
a

instance Reifies s Tape => Jacobian (ReverseDouble s) where
  type D (ReverseDouble s) = Id Double

  unary :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_          ReverseDouble s
Zero    = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0)
  unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         (Lift Double
a) = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
a)
  unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadi) (ReverseDouble Int
i Double
b) = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
dadi Int
i Double
b

  lift1 :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
lift1 Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s) -> D (ReverseDouble s)
df ReverseDouble s
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (D (ReverseDouble s) -> D (ReverseDouble s)
df (forall a. a -> Id a
Id Double
pb)) ReverseDouble s
b where
    pb :: Double
pb = forall s. ReverseDouble s -> Double
primal ReverseDouble s
b

  lift1_ :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
    -> D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
lift1_ Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s) -> D (ReverseDouble s) -> D (ReverseDouble s)
df ReverseDouble s
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (forall a b. a -> b -> a
const Scalar (ReverseDouble s)
a) (D (ReverseDouble s) -> D (ReverseDouble s) -> D (ReverseDouble s)
df (forall a. a -> Id a
Id Scalar (ReverseDouble s)
a) (forall a. a -> Id a
Id Double
pb)) ReverseDouble s
b where
    pb :: Double
pb = forall s. ReverseDouble s -> Double
primal ReverseDouble s
b
    a :: Scalar (ReverseDouble s)
a = Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
pb

  binary :: (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         ReverseDouble s
Zero     ReverseDouble s
Zero     = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0 Double
0)
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         ReverseDouble s
Zero     (Lift Double
c) = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0 Double
c)
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         (Lift Double
b) ReverseDouble s
Zero     = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
b Double
0)
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         D (ReverseDouble s)
_         (Lift Double
b) (Lift Double
c) = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
b Double
c)

  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         (Id Double
dadc) ReverseDouble s
Zero        (ReverseDouble Int
i Double
c) = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0) Double
dadc Int
i Double
c
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_         (Id Double
dadc) (Lift Double
b)    (ReverseDouble Int
i Double
c) = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
b) Double
dadc Int
i Double
c
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadb) D (ReverseDouble s)
_         (ReverseDouble Int
i Double
b) ReverseDouble s
Zero        = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
`f` Double
0) Double
dadb Int
i Double
b
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadb) D (ReverseDouble s)
_         (ReverseDouble Int
i Double
b) (Lift Double
c)    = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
`f` Double
c) Double
dadb Int
i Double
b
  binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadb) (Id Double
dadc) (ReverseDouble Int
i Double
b) (ReverseDouble Int
j Double
c) = forall s.
Reifies s Tape =>
(Double -> Double -> Double)
-> Double
-> Double
-> Int
-> Double
-> Int
-> Double
-> ReverseDouble s
binarily Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
dadb Double
dadc Int
i Double
b Int
j Double
c

  lift2 :: (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
    -> D (ReverseDouble s)
    -> (D (ReverseDouble s), D (ReverseDouble s)))
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
lift2 Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df ReverseDouble s
b ReverseDouble s
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
dadb D (ReverseDouble s)
dadc ReverseDouble s
b ReverseDouble s
c where
    (D (ReverseDouble s)
dadb, D (ReverseDouble s)
dadc) = D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df (forall a. a -> Id a
Id (forall s. ReverseDouble s -> Double
primal ReverseDouble s
b)) (forall a. a -> Id a
Id (forall s. ReverseDouble s -> Double
primal ReverseDouble s
c))

  lift2_ :: (Scalar (ReverseDouble s)
 -> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
    -> D (ReverseDouble s)
    -> D (ReverseDouble s)
    -> (D (ReverseDouble s), D (ReverseDouble s)))
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
lift2_ Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df ReverseDouble s
b ReverseDouble s
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (ReverseDouble s)
_ Scalar (ReverseDouble s)
_ -> Scalar (ReverseDouble s)
a) D (ReverseDouble s)
dadb D (ReverseDouble s)
dadc ReverseDouble s
b ReverseDouble s
c where
    pb :: Double
pb = forall s. ReverseDouble s -> Double
primal ReverseDouble s
b
    pc :: Double
pc = forall s. ReverseDouble s -> Double
primal ReverseDouble s
c
    a :: Scalar (ReverseDouble s)
a = Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
pb Double
pc
    (D (ReverseDouble s)
dadb, D (ReverseDouble s)
dadc) = D (ReverseDouble s)
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df (forall a. a -> Id a
Id Scalar (ReverseDouble s)
a) (forall a. a -> Id a
Id Double
pb) (forall a. a -> Id a
Id Double
pc)

mul :: Reifies s Tape => ReverseDouble s -> ReverseDouble s -> ReverseDouble s
mul :: forall s.
Reifies s Tape =>
ReverseDouble s -> ReverseDouble s -> ReverseDouble s
mul = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 forall a. Num a => a -> a -> a
(*) (\D (ReverseDouble s)
x D (ReverseDouble s)
y -> (D (ReverseDouble s)
y, D (ReverseDouble s)
x))

#define BODY1(x) Reifies s Tape =>
#define BODY2(x,y) Reifies s Tape =>
#define HEAD (ReverseDouble s)
#define NO_Bounded
#include "instances.h"

-- | Helper that extracts the derivative of a chain when the chain was constructed with 1 variable.
derivativeOf :: Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf :: forall s. Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf Proxy s
_ = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE derivativeOf #-}

-- | Helper that extracts both the primal and derivative of a chain when the chain was constructed with 1 variable.
derivativeOf' :: Reifies s Tape => Proxy s -> ReverseDouble s -> (Double, Double)
derivativeOf' :: forall s.
Reifies s Tape =>
Proxy s -> ReverseDouble s -> (Double, Double)
derivativeOf' Proxy s
p ReverseDouble s
r = (forall s. ReverseDouble s -> Double
primal ReverseDouble s
r, forall s. Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf Proxy s
p ReverseDouble s
r)
{-# INLINE derivativeOf' #-}


-- | Return an 'Array' of 'partials' given bounds for the variable IDs.
partialArrayOf :: Reifies s Tape => Proxy s -> (Int, Int) -> ReverseDouble s -> Array Int Double
partialArrayOf :: forall s.
Reifies s Tape =>
Proxy s -> (Int, Int) -> ReverseDouble s -> Array Int Double
partialArrayOf Proxy s
_ (Int, Int)
vbounds = forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray forall a. Num a => a -> a -> a
(+) Double
0 (Int, Int)
vbounds forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE partialArrayOf #-}

-- | Return an 'IntMap' of sparse partials
partialMapOf :: Reifies s Tape => Proxy s -> ReverseDouble s-> IntMap Double
partialMapOf :: forall s.
Reifies s Tape =>
Proxy s -> ReverseDouble s -> IntMap Double
partialMapOf Proxy s
_ = forall a. [(Int, a)] -> IntMap a
fromDistinctAscList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE partialMapOf #-}

var :: Double -> Int -> ReverseDouble s
var :: forall s. Double -> Int -> ReverseDouble s
var Double
a Int
v = forall s. Int -> Double -> ReverseDouble s
ReverseDouble Int
v Double
a

varId :: ReverseDouble s -> Int
varId :: forall s. ReverseDouble s -> Int
varId (ReverseDouble Int
v Double
_) = Int
v
varId ReverseDouble s
_ = forall a. HasCallStack => String -> a
error String
"varId: not a Var"

bind :: Traversable f => f Double -> (f (ReverseDouble s), (Int,Int))
bind :: forall (f :: * -> *) s.
Traversable f =>
f Double -> (f (ReverseDouble s), (Int, Int))
bind f Double
xs = (f (ReverseDouble s)
r,(Int
0,Int
hi)) where
  (f (ReverseDouble s)
r,Int
hi) = forall s a. State s a -> s -> (a, s)
runState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {s}.
Monad m =>
Double -> StateT Int m (ReverseDouble s)
freshVar f Double
xs) Int
0
  freshVar :: Double -> StateT Int m (ReverseDouble s)
freshVar Double
a = forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s forall a. Num a => a -> a -> a
+ Int
1 in Int
s' seq :: forall a b. a -> b -> b
`seq` (forall s. Double -> Int -> ReverseDouble s
var Double
a Int
s, Int
s')

unbind :: Functor f => f (ReverseDouble s) -> Array Int Double -> f Double
unbind :: forall (f :: * -> *) s.
Functor f =>
f (ReverseDouble s) -> Array Int Double -> f Double
unbind f (ReverseDouble s)
xs Array Int Double
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Array Int Double
ys forall i e. Ix i => Array i e -> i -> e
! forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) f (ReverseDouble s)
xs

unbindWith :: Functor f => (Double -> b -> c) -> f (ReverseDouble s) -> Array Int b -> f c
unbindWith :: forall (f :: * -> *) b c s.
Functor f =>
(Double -> b -> c) -> f (ReverseDouble s) -> Array Int b -> f c
unbindWith Double -> b -> c
f f (ReverseDouble s)
xs Array Int b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Double -> b -> c
f (forall s. ReverseDouble s -> Double
primal ReverseDouble s
v) (Array Int b
ys forall i e. Ix i => Array i e -> i -> e
! forall s. ReverseDouble s -> Int
varId ReverseDouble s
v)) f (ReverseDouble s)
xs

unbindMap :: Functor f => f (ReverseDouble s) -> IntMap Double -> f Double
unbindMap :: forall (f :: * -> *) s.
Functor f =>
f (ReverseDouble s) -> IntMap Double -> f Double
unbindMap f (ReverseDouble s)
xs IntMap Double
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> forall a. a -> Int -> IntMap a -> a
findWithDefault Double
0 (forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) IntMap Double
ys) f (ReverseDouble s)
xs

unbindMapWithDefault :: Functor f => b -> (Double -> b -> c) -> f (ReverseDouble s) -> IntMap b -> f c
unbindMapWithDefault :: forall (f :: * -> *) b c s.
Functor f =>
b -> (Double -> b -> c) -> f (ReverseDouble s) -> IntMap b -> f c
unbindMapWithDefault b
z Double -> b -> c
f f (ReverseDouble s)
xs IntMap b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Double -> b -> c
f (forall s. ReverseDouble s -> Double
primal ReverseDouble s
v) forall a b. (a -> b) -> a -> b
$ forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) IntMap b
ys) f (ReverseDouble s)
xs