{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-full-laziness #-}
{-# OPTIONS_HADDOCK not-home #-}
#ifdef AD_FFI
{-# LANGUAGE ForeignFunctionInterface #-}
#endif
module Numeric.AD.Internal.Reverse.Double
( ReverseDouble(..)
, Tape(..)
, reifyTape
, reifyTypeableTape
, partials
, partialArrayOf
, partialMapOf
, derivativeOf
, derivativeOf'
, bind
, unbind
, unbindMap
, unbindWith
, unbindMapWithDefault
, var
, varId
, primal
) where
#ifdef AD_FFI
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.C.Types
import qualified Foreign.Marshal.Array as MA
import qualified Foreign.Marshal.Alloc as MA
#else
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unsafe as Unsafe
import Data.IORef
import Unsafe.Coerce
#endif
import Data.Functor
import Control.Monad hiding (mapM)
import Control.Monad.Trans.State
import Data.Array
import Data.IntMap (IntMap, fromDistinctAscList, findWithDefault)
import Data.Number.Erf
import Data.Proxy
import Data.Reflection
import Data.Traversable (mapM)
import Data.Typeable
import Numeric.AD.Internal.Combinators
import Numeric.AD.Internal.Identity
import Numeric.AD.Jacobian
import Numeric.AD.Mode
import Prelude hiding (mapM)
import System.IO.Unsafe (unsafePerformIO)
#ifdef AD_FFI
newtype Tape = Tape { getTape :: ForeignPtr Tape }
foreign import ccall unsafe "tape_alloc" c_tape_alloc :: CInt -> CInt -> IO (Ptr Tape)
foreign import ccall unsafe "tape_push" c_tape_push :: Ptr Tape -> CInt -> CInt -> Double -> Double -> IO Int
foreign import ccall unsafe "tape_backPropagate" c_tape_backPropagate :: Ptr Tape -> CInt -> Ptr Double -> IO ()
foreign import ccall unsafe "tape_variables" c_tape_variables :: Ptr Tape -> IO CInt
foreign import ccall unsafe "&tape_free" c_ref_tape_free :: FinalizerPtr Tape
pushTape :: Reifies s Tape => p s -> Int -> Int -> Double -> Double -> IO Int
pushTape p i1 i2 d1 d2 = do
withForeignPtr (getTape (reflect p)) $ \tape ->
c_tape_push tape (fromIntegral i1) (fromIntegral i2) d1 d2
{-# INLINE pushTape #-}
partials :: forall s. (Reifies s Tape) => ReverseDouble s -> [Double]
partials Zero = []
partials (Lift _) = []
partials (ReverseDouble k _) = unsafePerformIO $
withForeignPtr (getTape (reflect (Proxy :: Proxy s))) $ \tape -> do
l <- fromIntegral <$> c_tape_variables tape
arr <- MA.mallocArray l
c_tape_backPropagate tape (fromIntegral k) arr
ps <- MA.peekArray l arr
MA.free arr
return ps
{-# INLINE partials #-}
newTape :: Int -> IO Tape
newTape vs = do
p <- c_tape_alloc (fromIntegral vs) (4 * 1024)
Tape <$> newForeignPtr c_ref_tape_free p
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape vs k = unsafePerformIO $ fmap (\t -> reify t k) (newTape vs)
{-# NOINLINE reifyTape #-}
reifyTypeableTape :: Int -> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape vs k = unsafePerformIO $ fmap (\t -> reifyTypeable t k) (newTape vs)
{-# NOINLINE reifyTypeableTape #-}
unarily :: forall s. Reifies s Tape => (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily f di i b = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i 0 di 0.0)) $! f b
{-# INLINE unarily #-}
binarily :: forall s. Reifies s Tape => (Double -> Double -> Double) -> Double -> Double -> Int -> Double -> Int -> Double -> ReverseDouble s
binarily f di dj i b j c = ReverseDouble (unsafePerformIO (pushTape (Proxy :: Proxy s) i j di dj)) $! f b c
{-# INLINE binarily #-}
#else
data Cells where
Nil :: Cells
Unary :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> !Cells -> Cells
Binary :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> {-# UNPACK #-} !Double -> !Cells -> Cells
dropCells :: Int -> Cells -> Cells
dropCells :: Int -> Cells -> Cells
dropCells Int
0 Cells
xs = Cells
xs
dropCells Int
_ Cells
Nil = Cells
Nil
dropCells Int
n (Unary Int
_ Double
_ Cells
xs) = (Int -> Cells -> Cells
dropCells forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
- Int
1) Cells
xs
dropCells Int
n (Binary Int
_ Int
_ Double
_ Double
_ Cells
xs) = (Int -> Cells -> Cells
dropCells forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
- Int
1) Cells
xs
data Head = Head {-# UNPACK #-} !Int !Cells
newtype Tape = Tape { Tape -> IORef Head
getTape :: IORef Head }
backPropagate :: Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate :: forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate Int
k Cells
Nil STArray s Int Double
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Int
k
backPropagate Int
k (Unary Int
i Double
g Cells
xs) STArray s Int Double
ss = do
Double
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
k
Double
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
i
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
i forall a b. (a -> b) -> a -> b
$! Double
db forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce Double
gforall a. Num a => a -> a -> a
*Double
da
(forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate forall a b. (a -> b) -> a -> b
$! Int
k forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int Double
ss
backPropagate Int
k (Binary Int
i Int
j Double
g Double
h Cells
xs) STArray s Int Double
ss = do
Double
da <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
k
Double
db <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
i
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
i forall a b. (a -> b) -> a -> b
$! Double
db forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce Double
gforall a. Num a => a -> a -> a
*Double
da
Double
dc <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Double
ss Int
j
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
j forall a b. (a -> b) -> a -> b
$! Double
dc forall a. Num a => a -> a -> a
+ forall a b. a -> b
unsafeCoerce Double
hforall a. Num a => a -> a -> a
*Double
da
(forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate forall a b. (a -> b) -> a -> b
$! Int
k forall a. Num a => a -> a -> a
- Int
1) Cells
xs STArray s Int Double
ss
partials :: forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials :: forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials ReverseDouble s
Zero = []
partials (Lift Double
_) = []
partials (ReverseDouble Int
k Double
_) = forall a b. (a -> b) -> [a] -> [b]
map (Array Int Double
sensitivities forall i e. Ix i => Array i e -> i -> e
!) [Int
0..Int
vs] where
Head Int
n Cells
t = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef (Tape -> IORef Head
getTape (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect (forall {k} (t :: k). Proxy t
Proxy :: Proxy s)))
tk :: Cells
tk = Int -> Cells -> Cells
dropCells (Int
n forall a. Num a => a -> a -> a
- Int
k) Cells
t
(Int
vs,Array Int Double
sensitivities) = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
STArray s Int Double
ss <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
k) Double
0
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Double
ss Int
k Double
1
Int
v <- forall s. Int -> Cells -> STArray s Int Double -> ST s Int
backPropagate Int
k Cells
tk STArray s Int Double
ss
Array Int Double
as <- forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
Unsafe.unsafeFreeze STArray s Int Double
ss
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
v, Array Int Double
as)
reifyTape :: Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape :: forall r. Int -> (forall s. Reifies s Tape => Proxy s -> r) -> r
reifyTape Int
vs forall s. Reifies s Tape => Proxy s -> r
k = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
IORef Head
h <- forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a r. a -> (forall s. Reifies s a => Proxy s -> r) -> r
reify (IORef Head -> Tape
Tape IORef Head
h) forall s. Reifies s Tape => Proxy s -> r
k)
{-# NOINLINE reifyTape #-}
reifyTypeableTape :: Int -> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape :: forall r.
Int
-> (forall s. (Reifies s Tape, Typeable s) => Proxy s -> r) -> r
reifyTypeableTape Int
vs forall s. (Reifies s Tape, Typeable s) => Proxy s -> r
k = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
IORef Head
h <- forall a. a -> IO (IORef a)
newIORef (Int -> Cells -> Head
Head Int
vs Cells
Nil)
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a r.
Typeable a =>
a -> (forall s. (Typeable s, Reifies s a) => Proxy s -> r) -> r
reifyTypeable (IORef Head -> Tape
Tape IORef Head
h) forall s. (Reifies s Tape, Typeable s) => Proxy s -> r
k)
{-# NOINLINE reifyTypeableTape #-}
un :: Int -> Double -> Head -> (Head, Int)
un :: Int -> Double -> Head -> (Head, Int)
un Int
i Double
di (Head Int
r Cells
t) = Head
h seq :: forall a b. a -> b -> b
`seq` Int
r' seq :: forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
r' :: Int
r' = Int
r forall a. Num a => a -> a -> a
+ Int
1
h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> Double -> Cells -> Cells
Unary Int
i Double
di Cells
t)
{-# INLINE un #-}
bin :: Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin :: Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin Int
i Int
j Double
di Double
dj (Head Int
r Cells
t) = Head
h seq :: forall a b. a -> b -> b
`seq` Int
r' seq :: forall a b. a -> b -> b
`seq` (Head
h, Int
r') where
r' :: Int
r' = Int
r forall a. Num a => a -> a -> a
+ Int
1
h :: Head
h = Int -> Cells -> Head
Head Int
r' (Int -> Int -> Double -> Double -> Cells -> Cells
Binary Int
i Int
j Double
di Double
dj Cells
t)
{-# INLINE bin #-}
modifyTape :: Reifies s Tape => p s -> (Head -> (Head, r)) -> IO r
modifyTape :: forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape p s
p = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef (Tape -> IORef Head
getTape (forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
reflect p s
p))
{-# INLINE modifyTape #-}
unarily :: forall s. Reifies s Tape => (Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily :: forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily Double -> Double
f Double
di Int
i Double
b = forall s. Int -> Double -> ReverseDouble s
ReverseDouble (forall a. IO a -> a
unsafePerformIO (forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (Int -> Double -> Head -> (Head, Int)
un Int
i Double
di))) forall a b. (a -> b) -> a -> b
$! Double -> Double
f Double
b
{-# INLINE unarily #-}
binarily :: forall s. Reifies s Tape => (Double -> Double -> Double) -> Double -> Double -> Int -> Double -> Int -> Double -> ReverseDouble s
binarily :: forall s.
Reifies s Tape =>
(Double -> Double -> Double)
-> Double
-> Double
-> Int
-> Double
-> Int
-> Double
-> ReverseDouble s
binarily Double -> Double -> Double
f Double
di Double
dj Int
i Double
b Int
j Double
c = forall s. Int -> Double -> ReverseDouble s
ReverseDouble (forall a. IO a -> a
unsafePerformIO (forall s (p :: * -> *) r.
Reifies s Tape =>
p s -> (Head -> (Head, r)) -> IO r
modifyTape (forall {k} (t :: k). Proxy t
Proxy :: Proxy s) (Int -> Int -> Double -> Double -> Head -> (Head, Int)
bin Int
i Int
j Double
di Double
dj))) forall a b. (a -> b) -> a -> b
$! Double -> Double -> Double
f Double
b Double
c
{-# INLINE binarily #-}
#endif
data ReverseDouble s where
Zero :: ReverseDouble s
Lift :: {-# UNPACK #-} !Double -> ReverseDouble s
ReverseDouble :: {-# UNPACK #-} !Int -> {-# UNPACK #-} !Double -> ReverseDouble s
deriving (Int -> ReverseDouble s -> ShowS
forall s. Int -> ReverseDouble s -> ShowS
forall s. [ReverseDouble s] -> ShowS
forall s. ReverseDouble s -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReverseDouble s] -> ShowS
$cshowList :: forall s. [ReverseDouble s] -> ShowS
show :: ReverseDouble s -> String
$cshow :: forall s. ReverseDouble s -> String
showsPrec :: Int -> ReverseDouble s -> ShowS
$cshowsPrec :: forall s. Int -> ReverseDouble s -> ShowS
Show, Typeable)
instance Reifies s Tape => Mode (ReverseDouble s) where
type Scalar (ReverseDouble s) = Double
isKnownZero :: ReverseDouble s -> Bool
isKnownZero ReverseDouble s
Zero = Bool
True
isKnownZero (Lift Double
0) = Bool
True
isKnownZero ReverseDouble s
_ = Bool
False
isKnownConstant :: ReverseDouble s -> Bool
isKnownConstant ReverseDouble{} = Bool
False
isKnownConstant ReverseDouble s
_ = Bool
True
auto :: Scalar (ReverseDouble s) -> ReverseDouble s
auto = forall s. Double -> ReverseDouble s
Lift
zero :: ReverseDouble s
zero = forall s. ReverseDouble s
Zero
Scalar (ReverseDouble s)
a *^ :: Scalar (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
*^ ReverseDouble s
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (Scalar (ReverseDouble s)
a forall a. Num a => a -> a -> a
*) (\D (ReverseDouble s)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (ReverseDouble s)
a) ReverseDouble s
b
ReverseDouble s
a ^* :: ReverseDouble s -> Scalar (ReverseDouble s) -> ReverseDouble s
^* Scalar (ReverseDouble s)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Num a => a -> a -> a
* Scalar (ReverseDouble s)
b) (\D (ReverseDouble s)
_ -> forall t. Mode t => Scalar t -> t
auto Scalar (ReverseDouble s)
b) ReverseDouble s
a
ReverseDouble s
a ^/ :: Fractional (Scalar (ReverseDouble s)) =>
ReverseDouble s -> Scalar (ReverseDouble s) -> ReverseDouble s
^/ Scalar (ReverseDouble s)
b = forall t.
Jacobian t =>
(Scalar t -> Scalar t) -> (D t -> D t) -> t -> t
lift1 (forall a. Fractional a => a -> a -> a
/ Scalar (ReverseDouble s)
b) (\D (ReverseDouble s)
_ -> forall t. Mode t => Scalar t -> t
auto (forall a. Fractional a => a -> a
recip Scalar (ReverseDouble s)
b)) ReverseDouble s
a
(<+>) :: Reifies s Tape => ReverseDouble s -> ReverseDouble s -> ReverseDouble s
<+> :: forall s.
Reifies s Tape =>
ReverseDouble s -> ReverseDouble s -> ReverseDouble s
(<+>) = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary forall a. Num a => a -> a -> a
(+) Id Double
1 Id Double
1
primal :: ReverseDouble s -> Double
primal :: forall s. ReverseDouble s -> Double
primal ReverseDouble s
Zero = Double
0
primal (Lift Double
a) = Double
a
primal (ReverseDouble Int
_ Double
a) = Double
a
instance Reifies s Tape => Jacobian (ReverseDouble s) where
type D (ReverseDouble s) = Id Double
unary :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s) -> ReverseDouble s -> ReverseDouble s
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ ReverseDouble s
Zero = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0)
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ (Lift Double
a) = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
a)
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadi) (ReverseDouble Int
i Double
b) = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
dadi Int
i Double
b
lift1 :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
lift1 Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s) -> D (ReverseDouble s)
df ReverseDouble s
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (D (ReverseDouble s) -> D (ReverseDouble s)
df (forall a. a -> Id a
Id Double
pb)) ReverseDouble s
b where
pb :: Double
pb = forall s. ReverseDouble s -> Double
primal ReverseDouble s
b
lift1_ :: (Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
-> D (ReverseDouble s) -> D (ReverseDouble s))
-> ReverseDouble s
-> ReverseDouble s
lift1_ Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s) -> D (ReverseDouble s) -> D (ReverseDouble s)
df ReverseDouble s
b = forall t. Jacobian t => (Scalar t -> Scalar t) -> D t -> t -> t
unary (forall a b. a -> b -> a
const Scalar (ReverseDouble s)
a) (D (ReverseDouble s) -> D (ReverseDouble s) -> D (ReverseDouble s)
df (forall a. a -> Id a
Id Scalar (ReverseDouble s)
a) (forall a. a -> Id a
Id Double
pb)) ReverseDouble s
b where
pb :: Double
pb = forall s. ReverseDouble s -> Double
primal ReverseDouble s
b
a :: Scalar (ReverseDouble s)
a = Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
pb
binary :: (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ D (ReverseDouble s)
_ ReverseDouble s
Zero ReverseDouble s
Zero = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0 Double
0)
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ D (ReverseDouble s)
_ ReverseDouble s
Zero (Lift Double
c) = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0 Double
c)
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ D (ReverseDouble s)
_ (Lift Double
b) ReverseDouble s
Zero = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
b Double
0)
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ D (ReverseDouble s)
_ (Lift Double
b) (Lift Double
c) = forall s. Double -> ReverseDouble s
Lift (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
b Double
c)
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ (Id Double
dadc) ReverseDouble s
Zero (ReverseDouble Int
i Double
c) = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
0) Double
dadc Int
i Double
c
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
_ (Id Double
dadc) (Lift Double
b) (ReverseDouble Int
i Double
c) = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
b) Double
dadc Int
i Double
c
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadb) D (ReverseDouble s)
_ (ReverseDouble Int
i Double
b) ReverseDouble s
Zero = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
`f` Double
0) Double
dadb Int
i Double
b
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadb) D (ReverseDouble s)
_ (ReverseDouble Int
i Double
b) (Lift Double
c) = forall s.
Reifies s Tape =>
(Double -> Double) -> Double -> Int -> Double -> ReverseDouble s
unarily (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
`f` Double
c) Double
dadb Int
i Double
b
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f (Id Double
dadb) (Id Double
dadc) (ReverseDouble Int
i Double
b) (ReverseDouble Int
j Double
c) = forall s.
Reifies s Tape =>
(Double -> Double -> Double)
-> Double
-> Double
-> Int
-> Double
-> Int
-> Double
-> ReverseDouble s
binarily Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
dadb Double
dadc Int
i Double
b Int
j Double
c
lift2 :: (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s)))
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
lift2 Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df ReverseDouble s
b ReverseDouble s
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
dadb D (ReverseDouble s)
dadc ReverseDouble s
b ReverseDouble s
c where
(D (ReverseDouble s)
dadb, D (ReverseDouble s)
dadc) = D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df (forall a. a -> Id a
Id (forall s. ReverseDouble s -> Double
primal ReverseDouble s
b)) (forall a. a -> Id a
Id (forall s. ReverseDouble s -> Double
primal ReverseDouble s
c))
lift2_ :: (Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s))
-> (D (ReverseDouble s)
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s)))
-> ReverseDouble s
-> ReverseDouble s
-> ReverseDouble s
lift2_ Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f D (ReverseDouble s)
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df ReverseDouble s
b ReverseDouble s
c = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t) -> D t -> D t -> t -> t -> t
binary (\Scalar (ReverseDouble s)
_ Scalar (ReverseDouble s)
_ -> Scalar (ReverseDouble s)
a) D (ReverseDouble s)
dadb D (ReverseDouble s)
dadc ReverseDouble s
b ReverseDouble s
c where
pb :: Double
pb = forall s. ReverseDouble s -> Double
primal ReverseDouble s
b
pc :: Double
pc = forall s. ReverseDouble s -> Double
primal ReverseDouble s
c
a :: Scalar (ReverseDouble s)
a = Scalar (ReverseDouble s)
-> Scalar (ReverseDouble s) -> Scalar (ReverseDouble s)
f Double
pb Double
pc
(D (ReverseDouble s)
dadb, D (ReverseDouble s)
dadc) = D (ReverseDouble s)
-> D (ReverseDouble s)
-> D (ReverseDouble s)
-> (D (ReverseDouble s), D (ReverseDouble s))
df (forall a. a -> Id a
Id Scalar (ReverseDouble s)
a) (forall a. a -> Id a
Id Double
pb) (forall a. a -> Id a
Id Double
pc)
mul :: Reifies s Tape => ReverseDouble s -> ReverseDouble s -> ReverseDouble s
mul :: forall s.
Reifies s Tape =>
ReverseDouble s -> ReverseDouble s -> ReverseDouble s
mul = forall t.
Jacobian t =>
(Scalar t -> Scalar t -> Scalar t)
-> (D t -> D t -> (D t, D t)) -> t -> t -> t
lift2 forall a. Num a => a -> a -> a
(*) (\D (ReverseDouble s)
x D (ReverseDouble s)
y -> (D (ReverseDouble s)
y, D (ReverseDouble s)
x))
#define BODY1(x) Reifies s Tape =>
#define BODY2(x,y) Reifies s Tape =>
#define HEAD (ReverseDouble s)
#define NO_Bounded
#include "instances.h"
derivativeOf :: Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf :: forall s. Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf Proxy s
_ = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE derivativeOf #-}
derivativeOf' :: Reifies s Tape => Proxy s -> ReverseDouble s -> (Double, Double)
derivativeOf' :: forall s.
Reifies s Tape =>
Proxy s -> ReverseDouble s -> (Double, Double)
derivativeOf' Proxy s
p ReverseDouble s
r = (forall s. ReverseDouble s -> Double
primal ReverseDouble s
r, forall s. Reifies s Tape => Proxy s -> ReverseDouble s -> Double
derivativeOf Proxy s
p ReverseDouble s
r)
{-# INLINE derivativeOf' #-}
partialArrayOf :: Reifies s Tape => Proxy s -> (Int, Int) -> ReverseDouble s -> Array Int Double
partialArrayOf :: forall s.
Reifies s Tape =>
Proxy s -> (Int, Int) -> ReverseDouble s -> Array Int Double
partialArrayOf Proxy s
_ (Int, Int)
vbounds = forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray forall a. Num a => a -> a -> a
(+) Double
0 (Int, Int)
vbounds forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE partialArrayOf #-}
partialMapOf :: Reifies s Tape => Proxy s -> ReverseDouble s-> IntMap Double
partialMapOf :: forall s.
Reifies s Tape =>
Proxy s -> ReverseDouble s -> IntMap Double
partialMapOf Proxy s
_ = forall a. [(Int, a)] -> IntMap a
fromDistinctAscList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. Reifies s Tape => ReverseDouble s -> [Double]
partials
{-# INLINE partialMapOf #-}
var :: Double -> Int -> ReverseDouble s
var :: forall s. Double -> Int -> ReverseDouble s
var Double
a Int
v = forall s. Int -> Double -> ReverseDouble s
ReverseDouble Int
v Double
a
varId :: ReverseDouble s -> Int
varId :: forall s. ReverseDouble s -> Int
varId (ReverseDouble Int
v Double
_) = Int
v
varId ReverseDouble s
_ = forall a. HasCallStack => String -> a
error String
"varId: not a Var"
bind :: Traversable f => f Double -> (f (ReverseDouble s), (Int,Int))
bind :: forall (f :: * -> *) s.
Traversable f =>
f Double -> (f (ReverseDouble s), (Int, Int))
bind f Double
xs = (f (ReverseDouble s)
r,(Int
0,Int
hi)) where
(f (ReverseDouble s)
r,Int
hi) = forall s a. State s a -> s -> (a, s)
runState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {s}.
Monad m =>
Double -> StateT Int m (ReverseDouble s)
freshVar f Double
xs) Int
0
freshVar :: Double -> StateT Int m (ReverseDouble s)
freshVar Double
a = forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state forall a b. (a -> b) -> a -> b
$ \Int
s -> let s' :: Int
s' = Int
s forall a. Num a => a -> a -> a
+ Int
1 in Int
s' seq :: forall a b. a -> b -> b
`seq` (forall s. Double -> Int -> ReverseDouble s
var Double
a Int
s, Int
s')
unbind :: Functor f => f (ReverseDouble s) -> Array Int Double -> f Double
unbind :: forall (f :: * -> *) s.
Functor f =>
f (ReverseDouble s) -> Array Int Double -> f Double
unbind f (ReverseDouble s)
xs Array Int Double
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Array Int Double
ys forall i e. Ix i => Array i e -> i -> e
! forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) f (ReverseDouble s)
xs
unbindWith :: Functor f => (Double -> b -> c) -> f (ReverseDouble s) -> Array Int b -> f c
unbindWith :: forall (f :: * -> *) b c s.
Functor f =>
(Double -> b -> c) -> f (ReverseDouble s) -> Array Int b -> f c
unbindWith Double -> b -> c
f f (ReverseDouble s)
xs Array Int b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Double -> b -> c
f (forall s. ReverseDouble s -> Double
primal ReverseDouble s
v) (Array Int b
ys forall i e. Ix i => Array i e -> i -> e
! forall s. ReverseDouble s -> Int
varId ReverseDouble s
v)) f (ReverseDouble s)
xs
unbindMap :: Functor f => f (ReverseDouble s) -> IntMap Double -> f Double
unbindMap :: forall (f :: * -> *) s.
Functor f =>
f (ReverseDouble s) -> IntMap Double -> f Double
unbindMap f (ReverseDouble s)
xs IntMap Double
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> forall a. a -> Int -> IntMap a -> a
findWithDefault Double
0 (forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) IntMap Double
ys) f (ReverseDouble s)
xs
unbindMapWithDefault :: Functor f => b -> (Double -> b -> c) -> f (ReverseDouble s) -> IntMap b -> f c
unbindMapWithDefault :: forall (f :: * -> *) b c s.
Functor f =>
b -> (Double -> b -> c) -> f (ReverseDouble s) -> IntMap b -> f c
unbindMapWithDefault b
z Double -> b -> c
f f (ReverseDouble s)
xs IntMap b
ys = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ReverseDouble s
v -> Double -> b -> c
f (forall s. ReverseDouble s -> Double
primal ReverseDouble s
v) forall a b. (a -> b) -> a -> b
$ forall a. a -> Int -> IntMap a -> a
findWithDefault b
z (forall s. ReverseDouble s -> Int
varId ReverseDouble s
v) IntMap b
ys) f (ReverseDouble s)
xs